Skip to content

Commit

Permalink
Optimize fftfreq logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 6, 2024
1 parent b484821 commit be5262e
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions brainunit/fft/_fft_change_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,20 +1094,20 @@ def fftfreq(
time_scale = _find_closet_scale(d.unit.scale)
try:
time_unit, freq_unit = _time_freq_map[time_scale]
except:
except KeyError:
time_unit = d.unit
freq_unit_scale = -d.unit.scale
freq_unit = Unit.create(get_or_create_dimension(s=-1),
name=f'10^{freq_unit_scale} hertz',
dispname=f'10^{freq_unit_scale} Hz',
scale=freq_unit_scale,)
try:
if sys.version_info >= (3, 10):
return Quantity(jnpfft.fftfreq(n, d.to_decimal(time_unit), dtype=dtype, device=device), unit=freq_unit)
except:
else:
return Quantity(jnpfft.fftfreq(n, d.to_decimal(time_unit), dtype=dtype), unit=freq_unit)
try:
if sys.version_info >= (3, 10):
return jnpfft.fftfreq(n, d, dtype=dtype, device=device)
except:
else:
return jnpfft.fftfreq(n, d, dtype=dtype)


Expand Down Expand Up @@ -1146,18 +1146,18 @@ def rfftfreq(
time_scale = _find_closet_scale(d.unit.scale)
try:
time_unit, freq_unit = _time_freq_map[time_scale]
except:
except KeyError:
time_unit = d.unit
freq_unit_scale = -d.unit.scale
freq_unit = Unit.create(get_or_create_dimension(s=-1),
name=f'10^{freq_unit_scale} hertz',
dispname=f'10^{freq_unit_scale} Hz',
scale=freq_unit_scale, )
try:
if sys.version_info >= (3, 10):
return Quantity(jnpfft.rfftfreq(n, d.to_decimal(time_unit), dtype=dtype, device=device), unit=freq_unit)
except:
else:
return Quantity(jnpfft.rfftfreq(n, d.to_decimal(time_unit), dtype=dtype), unit=freq_unit)
try:
if sys.version_info >= (3, 10):
return jnpfft.rfftfreq(n, d, dtype=dtype, device=device)
except:
else:
return jnpfft.rfftfreq(n, d, dtype=dtype)

0 comments on commit be5262e

Please sign in to comment.