@@ -1445,14 +1445,14 @@ def _get_sinc_resample_kernel(
1445
1445
# future work.
1446
1446
idx_dtype = dtype if dtype is not None else torch .float64
1447
1447
1448
- idx = torch .arange (- width , width + orig_freq , dtype = idx_dtype )[None , None ] / orig_freq
1448
+ idx = torch .arange (- width , width + orig_freq , dtype = idx_dtype , device = device )[None , None ] / orig_freq
1449
1449
1450
1450
t = torch .arange (0 , - new_freq , - 1 , dtype = dtype )[:, None , None ] / new_freq + idx
1451
1451
t *= base_freq
1452
1452
t = t .clamp_ (- lowpass_filter_width , lowpass_filter_width )
1453
1453
1454
- scale = base_freq / orig_freq
1455
-
1454
+ # we do not use built in torch windows here as we need to evaluate the window
1455
+ # at specific positions, not over a regular grid.
1456
1456
if resampling_method == "sinc_interpolation" :
1457
1457
window = torch .cos (t * math .pi / lowpass_filter_width / 2 ) ** 2
1458
1458
else :
@@ -1462,11 +1462,11 @@ def _get_sinc_resample_kernel(
1462
1462
beta_tensor = torch .tensor (float (beta ))
1463
1463
window = torch .i0 (beta_tensor * torch .sqrt (1 - (t / lowpass_filter_width ) ** 2 )) / torch .i0 (beta_tensor )
1464
1464
1465
- t *= torch .pi
1465
+ t *= math .pi
1466
1466
1467
+ scale = base_freq / orig_freq
1467
1468
kernels = torch .where (t == 0 , torch .tensor (1.0 ).to (t ), t .sin () / t )
1468
1469
kernels *= window * scale
1469
- kernels .to (device )
1470
1470
1471
1471
if dtype is None :
1472
1472
kernels = kernels .to (dtype = torch .float32 )
0 commit comments