Skip to content

Commit

Permalink
Add fft keep unit test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 6, 2024
1 parent b494d8c commit 3b9a84c
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 41 deletions.
2 changes: 1 addition & 1 deletion brainunit/fft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
_fft_keep_unit_all)

del (_fft_change_unit_all,
_fft_keep_unit_all,)
_fft_keep_unit_all,)
68 changes: 41 additions & 27 deletions brainunit/fft/_fft_change_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
from typing import Callable, Union, Sequence

import jax
from jaxlib import xla_client
import jax.numpy as jnp
from jax.numpy import fft as jnpfft
from jaxlib import xla_client

from .._base import Quantity, maybe_decimal, UNITLESS, Unit
from .. import _unit_common as u
from .._unit_common import second, hertz
from .._base import Quantity, Unit
from .._misc import set_module_as
from ..math._fun_change_unit import _fun_change_unit_unary, _fun_change_unit_binary
from .._unit_common import second
from ..math._fun_change_unit import _fun_change_unit_unary

__all__ = [
# return original unit * time unit
Expand All @@ -40,6 +40,7 @@
'fftfreq', 'rfftfreq',
]


def unit_change(
unit_change_fun: Callable
):
Expand All @@ -49,12 +50,14 @@ def actual_decorator(func):

return actual_decorator


Shape = Sequence[int]


# return original unit * time unit
# --------------------------------

def _calculate_fftn_dimension(input_dim: int, axes:Sequence[int] | None = None) -> int:
def _calculate_fftn_dimension(input_dim: int, axes: Sequence[int] | None = None) -> int:
if axes is None:
return input_dim
return len(axes)
Expand Down Expand Up @@ -129,6 +132,7 @@ def fft(
lambda u: u * second,
a, n=n, axis=axis, norm=norm)


@unit_change(lambda u: u * second)
def rfft(
a: Union[Quantity, jax.typing.ArrayLike],
Expand Down Expand Up @@ -195,6 +199,7 @@ def rfft(
lambda u: u * second,
a, n=n, axis=axis, norm=norm)


# return original unit / time unit (inverse)
# ------------------------------------------

Expand Down Expand Up @@ -257,8 +262,9 @@ def ifft(
[ 0.67-0.58j -0.5 -1.44j 0.17-2.02j 1.83-0.29j]]
"""
return _fun_change_unit_unary(jnpfft.ifft,
lambda u: u / second,
a, n=n, axis=axis, norm=norm)
lambda u: u / second,
a, n=n, axis=axis, norm=norm)


@unit_change(lambda u: u / second)
def irfft(
Expand Down Expand Up @@ -325,6 +331,7 @@ def irfft(
lambda u: u / second,
a, n=n, axis=axis, norm=norm)


# return original unit * (time unit ^ n)
# --------------------------------------

Expand Down Expand Up @@ -409,8 +416,9 @@ def fft2(
Array(True, dtype=bool)
"""
return _fun_change_unit_unary(jnpfft.fft2,
lambda u: u * (second ** 2),
a, s=s, axes=axes, norm=norm)
lambda u: u * (second ** 2),
a, s=s, axes=axes, norm=norm)


@unit_change(lambda u: u * (second ** 2))
def rfft2(
Expand Down Expand Up @@ -493,8 +501,9 @@ def rfft2(
[ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64)
"""
return _fun_change_unit_unary(jnpfft.rfft2,
lambda u: u * (second ** 2),
a, s=s, axes=axes, norm=norm)
lambda u: u * (second ** 2),
a, s=s, axes=axes, norm=norm)


@set_module_as('brainunit.fft')
def fftn(
Expand Down Expand Up @@ -576,8 +585,9 @@ def fftn(
# TODO: may cause computation overhead?
fftn._unit_change_fun = _unit_change_fun
return _fun_change_unit_unary(jnpfft.fftn,
_unit_change_fun,
a, s=s, axes=axes, norm=norm)
_unit_change_fun,
a, s=s, axes=axes, norm=norm)


@set_module_as('brainunit.fft')
def rfftn(
Expand Down Expand Up @@ -676,9 +686,8 @@ def rfftn(
# TODO: may cause computation overhead?
rfftn._unit_change_fun = _unit_change_fun
return _fun_change_unit_unary(jnpfft.rfftn,
_unit_change_fun,
a, s=s, axes=axes, norm=norm)

_unit_change_fun,
a, s=s, axes=axes, norm=norm)


# return original unit / (time unit ^ n) (inverse)
Expand Down Expand Up @@ -758,8 +767,9 @@ def ifft2(
[-0.33+0.58j, -0.33+0.58j]]], dtype=complex64)
"""
return _fun_change_unit_unary(jnpfft.ifft2,
lambda u: u / (second ** 2),
a, s=s, axes=axes, norm=norm)
lambda u: u / (second ** 2),
a, s=s, axes=axes, norm=norm)


@unit_change(lambda u: u / (second ** 2))
def irfft2(
Expand Down Expand Up @@ -838,8 +848,8 @@ def irfft2(
[ 0. , 0. , 0. ]]], dtype=float32)
"""
return _fun_change_unit_unary(jnpfft.irfft2,
lambda u: u / (second ** 2),
a, s=s, axes=axes, norm=norm)
lambda u: u / (second ** 2),
a, s=s, axes=axes, norm=norm)


@set_module_as('brainunit.fft')
Expand Down Expand Up @@ -916,8 +926,9 @@ def ifftn(
# TODO: may cause computation overhead?
ifftn._unit_change_fun = _unit_change_fun
return _fun_change_unit_unary(jnpfft.ifftn,
_unit_change_fun,
a, s=s, axes=axes, norm=norm)
_unit_change_fun,
a, s=s, axes=axes, norm=norm)


@set_module_as('brainunit.fft')
def irfftn(
Expand Down Expand Up @@ -1002,8 +1013,9 @@ def irfftn(
# TODO: may cause computation overhead?
irfftn._unit_change_fun = _unit_change_fun
return _fun_change_unit_unary(jnpfft.irfftn,
_unit_change_fun,
a, s=s, axes=axes, norm=norm)
_unit_change_fun,
a, s=s, axes=axes, norm=norm)


# return frequency unit
# ---------------------
Expand Down Expand Up @@ -1032,6 +1044,7 @@ def irfftn(
u.Ysecond: u.yhertz,
}


@set_module_as('brainunit.fft')
def fftfreq(
n: int,
Expand Down Expand Up @@ -1070,10 +1083,12 @@ def fftfreq(
try:
return Quantity(jnpfft.fftfreq(n, d.mantissa, dtype=dtype, device=device), unit=_time_freq_map[d.unit])
except:
raise TypeError(f"Cannot convert {d.unit} to common frequency unit, please specify the target frequency unit"
f"by passing the `target_freq_unit` argument.")
raise TypeError(
f"Cannot convert {d.unit} to common frequency unit, please specify the target frequency unit"
f"by passing the `target_freq_unit` argument.")
return jnpfft.fftfreq(n, d, dtype=dtype, device=device)


@set_module_as('brainunit.fft')
def rfftfreq(
n: int,
Expand Down Expand Up @@ -1117,4 +1132,3 @@ def rfftfreq(
f"Cannot convert {d.unit} to common frequency unit, please specify the target frequency unit"
f"by passing the `target_freq_unit` argument.")
return jnpfft.rfftfreq(n, d, dtype=dtype, device=device)

5 changes: 1 addition & 4 deletions brainunit/fft/_fft_change_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@
# ==============================================================================


import itertools

import jax.numpy as jnp
import jax.numpy.fft as jnpfft
import numpy as np
import pytest
from absl.testing import parameterized

import brainunit as u
import brainunit.fft as ufft
from brainunit import meter, second, volt
from brainunit import meter, second
from brainunit._base import assert_quantity, Unit, get_or_create_dimension

fft_change_1d = [
Expand Down
14 changes: 5 additions & 9 deletions brainunit/fft/_fft_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@
# ==============================================================================
from __future__ import annotations

import builtins
from typing import Union, Sequence, Callable
from typing import Union, Sequence

import jax
import numpy as np
import jax.numpy as jnp
from jax.numpy import fft as jnpfft
from jax._src.typing import Shape

from .._base import Quantity, maybe_decimal
from .._base import Quantity
from .._misc import set_module_as
from ..math._fun_keep_unit import _fun_keep_unit_unary, _fun_keep_unit_binary
from ..math._fun_keep_unit import _fun_keep_unit_unary

__all__ = [
# keep unit
Expand All @@ -35,7 +31,6 @@
# keep unit
# ---------

jnpfft.fftshift

@set_module_as('brainunit.fft')
def fftshift(
Expand Down Expand Up @@ -78,6 +73,7 @@ def fftshift(
"""
return _fun_keep_unit_unary(jnpfft.fftshift, x, axes=axes)


@set_module_as('brainunit.fft')
def ifftshift(
x: Union[Quantity, jax.typing.ArrayLike],
Expand Down Expand Up @@ -118,4 +114,4 @@ def ifftshift(
>>> brainunit.fft.ifftshift(shifted_freq)
Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32)
"""
return _fun_keep_unit_unary(jnpfft.ifftshift, x, axes=axes)
return _fun_keep_unit_unary(jnpfft.ifftshift, x, axes=axes)
70 changes: 70 additions & 0 deletions brainunit/fft/_fft_keep_unit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import jax.numpy as jnp
import pytest
from absl.testing import parameterized

import brainunit as bu
import brainunit.math as bm
from brainunit import second, meter, ms
from brainunit._base import assert_quantity

import jax.numpy as jnp
import jax.numpy.fft as jnpfft
import pytest
from absl.testing import parameterized

import brainunit as u
import brainunit.fft as ufft
from brainunit import meter, second
from brainunit._base import assert_quantity, Unit, get_or_create_dimension

fft_keep_unit = [
'fftshift', 'ifftshift',
]

class TestFftKeepUnit(parameterized.TestCase):
def __init__(self, *args, **kwargs):
super(TestFftKeepUnit, self).__init__(*args, **kwargs)

print()

@parameterized.product(
value_axes=[
([[1, 2, 3], [4, 5, 6]], (0, 1)),
([[1, 2, 3], [4, 5, 6]], (1, 0)),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], (0, 1)),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], (1, 0)),
],
unit=[meter, second],
)
def test_fft_keep_unit(self, value_axes, unit):
value = value_axes[0]
axes = value_axes[1]
ufft_fun_list = [getattr(ufft, fun) for fun in fft_keep_unit]
jnpfft_fun_list = [getattr(jnpfft, fun) for fun in fft_keep_unit]

for ufft_fun, jnpfft_fun in zip(ufft_fun_list, jnpfft_fun_list):
print(f'fun: {ufft_fun.__name__}')

result = ufft_fun(jnp.array(value), axes=axes)
expected = ufft_fun(jnp.array(value), axes=axes)
assert_quantity(result, expected)

q = value * unit
result = ufft_fun(q, axes=axes)
expected = ufft_fun(jnp.array(value), axes=axes)
assert_quantity(result, expected, unit=unit)
15 changes: 15 additions & 0 deletions brainunit/math/_fun_keep_unit_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import jax.numpy as jnp
import pytest
from absl.testing import parameterized
Expand Down

0 comments on commit 3b9a84c

Please sign in to comment.