Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FFT2 and FFT2 inverse #2845

Merged
merged 11 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -3393,6 +3393,48 @@ NDArray stft(
boolean normalize,
boolean returnComplex);

/**
* Computes the two-dimensional Discrete Fourier Transform.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @param axes Axes over which to compute the 2D-FFT.
* @return The truncated or zero-padded input, transformed along the axes.
*/
NDArray fft2(long[] sizes, long[] axes);

/**
* Computes the two-dimensional Discrete Fourier Transform along the last 2 axes.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @return The truncated or zero-padded input, transformed along the last two axes
*/
default NDArray fft2(long[] sizes) {
return fft2(sizes, new long[] {-2, -1});
}

/**
* Computes the two-dimensional inverse Discrete Fourier Transform.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @param axes Axes over which to compute the 2D-Inverse-FFT.
* @return The truncated or zero-padded input, transformed along the axes.
*/
NDArray ifft2(long[] sizes, long[] axes);

/**
* Computes the two-dimensional inverse Discrete Fourier Transform along the last 2 axes.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @return The truncated or zero-padded input, transformed along the axes.
*/
default NDArray ifft2(long[] sizes) {
return ifft2(sizes, new long[] {-2, -1});
}

/**
* Reshapes this {@code NDArray} to the given {@link Shape}.
*
Expand Down
12 changes: 12 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,18 @@ public NDArray stft(
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray fft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray ifft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray reshape(Shape shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,18 @@ public NDArray stft(
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray fft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray ifft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray reshape(Shape shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,18 @@ public NDArray stft(
this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex);
}

/** {@inheritDoc} */
@Override
public NDArray fft2(long[] sizes, long[] axes) {
return JniUtils.fft2(this, sizes, axes);
}

/** {@inheritDoc} */
@Override
public NDArray ifft2(long[] sizes, long[] axes) {
return JniUtils.ifft2(this, sizes, axes);
}

/** {@inheritDoc} */
@Override
public PtNDArray reshape(Shape shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,18 @@ public static PtNDArray stft(
return new PtNDArray(ndArray.getManager(), handle);
}

public static PtNDArray fft2(PtNDArray ndArray, long[] sizes, long[] axes) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchFft2(ndArray.getHandle(), sizes, axes));
}

public static PtNDArray ifft2(PtNDArray ndArray, long[] sizes, long[] axes) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchIfft2(ndArray.getHandle(), sizes, axes));
}

public static PtNDArray real(PtNDArray ndArray) {
long handle = PyTorchLibrary.LIB.torchViewAsReal(ndArray.getHandle());
if (handle == -1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ native long torchStft(
boolean normalize,
boolean returnComplex);

native long torchFft2(long handle, long[] sizes, long[] axes);

native long torchIfft2(long handle, long[] sizes, long[] axes);

native long torchViewAsReal(long handle);

native long torchViewAsComplex(long handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,28 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft2(
JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const std::vector<int64_t> sizes = djl::utils::jni::GetVecFromJLongArray(env, js);
const std::vector<int64_t> axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes);
const auto* result_ptr = new torch::Tensor(torch::fft_fft2(*tensor_ptr, sizes, axes));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIfft2(
JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const std::vector<int64_t> sizes = djl::utils::jni::GetVecFromJLongArray(env, js);
const std::vector<int64_t> axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes);
const auto* result_ptr = new torch::Tensor(torch::fft_ifft2(*tensor_ptr, sizes, axes));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchStft(JNIEnv* env, jobject jthis, jlong jhandle,
jlong jn_fft, jlong jhop_length, jlong jwindow, jboolean jcenter, jboolean jnormalize, jboolean jreturn_complex) {
#ifdef V1_11_X
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,18 @@ public NDArray stft(
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray fft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray ifft2(long[] sizes, long[] axes) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray reshape(Shape shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1087,4 +1087,58 @@ public void testStft() {
Assertions.assertAlmostEquals(result.real().flatten(), expected);
}
}

@Test
public void testFft2() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray array =
manager.create(
new float[][] {
{1f, 6.6f, 4.315f, 2.0f},
{16.9f, 6.697f, 2.399f, 67.9f},
{0f, 5f, 67.09f, 9.87f}
});
NDArray result = array.fft2(new long[] {3, 4}, new long[] {0, 1});
result = result.real().flatten(1, 2); // flatten complex numbers
NDArray expected =
manager.create(
new float[][] {
{189.771f, 0f, -55.904f, 61.473f, -6.363f, 0f, -55.904f, -61.473f},
{
-74.013f,
-10.3369f,
71.7653f,
-108.2964f,
-1.746f,
93.1133f,
-25.8063f,
-33.0234f
},
{
-74.013f, 10.3369f, -25.8063f, 33.0234f, -1.746f, -93.1133f,
71.7653f, 108.2964f
}
});
Assertions.assertAlmostEquals(result, expected);
}
}

@Test
public void testIfft2() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray array =
manager.create(
new float[][] {
{1f, 6.6f, 4.315f, 2.0f},
{16.9f, 6.697f, 2.399f, 67.9f},
{0f, 5f, 67.09f, 9.87f}
});
long[] sizes = {3, 4};
long[] axes = {0, 1};
NDArray fft2 = array.fft2(sizes, axes);
NDArray actual = fft2.ifft2(sizes, axes).real();
NDArray expected = array.toType(DataType.COMPLEX64, true).real();
Assertions.assertAlmostEquals(expected, actual);
}
}
}
Loading