Skip to content

Commit 715c4b0

Browse files
author
Sean Kim
committed
modifications addressing comments, will benchmark and possibly change after merge
1 parent 2b48d4c commit 715c4b0

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchaudio/functional/functional.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,14 +1445,14 @@ def _get_sinc_resample_kernel(
14451445
# future work.
14461446
idx_dtype = dtype if dtype is not None else torch.float64
14471447

1448-
idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype)[None, None] / orig_freq
1448+
idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[None, None] / orig_freq
14491449

14501450
t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
14511451
t *= base_freq
14521452
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
14531453

1454-
scale = base_freq / orig_freq
1455-
1454+
# we do not use built in torch windows here as we need to evaluate the window
1455+
# at specific positions, not over a regular grid.
14561456
if resampling_method == "sinc_interpolation":
14571457
window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
14581458
else:
@@ -1462,11 +1462,11 @@ def _get_sinc_resample_kernel(
14621462
beta_tensor = torch.tensor(float(beta))
14631463
window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
14641464

1465-
t *= torch.pi
1465+
t *= math.pi
14661466

1467+
scale = base_freq / orig_freq
14671468
kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
14681469
kernels *= window * scale
1469-
kernels.to(device)
14701470

14711471
if dtype is None:
14721472
kernels = kernels.to(dtype=torch.float32)

0 commit comments

Comments
 (0)