@@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs):
946946 b_coeff_flipped = b_coeffs .flip (1 ).contiguous ()
947947 padded_waveform = F .pad (waveform , (n_order - 1 , 0 ))
948948 output = F .conv1d (padded_waveform , b_coeff_flipped .unsqueeze (1 ), groups = n_channel )
949- ctx .save_for_backward (waveform , b_coeffs , output )
949+ if not torch .jit .is_scripting ():
950+ ctx .save_for_backward (waveform , b_coeffs , output )
950951 return output
951952
952953 @staticmethod
@@ -956,32 +957,41 @@ def backward(ctx, dy):
956957 n_channel = x .size (1 )
957958 n_order = b_coeffs .size (1 )
958959 db = F .conv1d (
959- F .pad (x , (n_order - 1 , 0 )).view (1 , n_batch * n_channel , - 1 ),
960- dy .view (n_batch * n_channel , 1 , - 1 ),
961- groups = n_batch * n_channel
962- ).view (
963- n_batch , n_channel , - 1
964- ).sum (0 ).flip (1 ) if b_coeffs .requires_grad else None
960+ F .pad (x , (n_order - 1 , 0 )).view (1 , n_batch * n_channel , - 1 ),
961+ dy .view (n_batch * n_channel , 1 , - 1 ),
962+ groups = n_batch * n_channel
963+ ).view (
964+ n_batch , n_channel , - 1
965+ ).sum (0 ).flip (1 ) if b_coeffs .requires_grad else None
965966 dx = F .conv1d (
966- F .pad (dy , (0 , n_order - 1 )),
967- b_coeffs .unsqueeze (1 ),
968- groups = n_channel
969- ) if x .requires_grad else None
967+ F .pad (dy , (0 , n_order - 1 )),
968+ b_coeffs .unsqueeze (1 ),
969+ groups = n_channel
970+ ) if x .requires_grad else None
970971 return (dx , db )
971972
973+ @staticmethod
974+ def ts_apply (waveform , b_coeffs ):
975+ if torch .jit .is_scripting ():
976+ return DifferentiableFIR .forward (torch .empty (0 ), waveform , b_coeffs )
977+ else :
978+ return DifferentiableFIR .apply (waveform , b_coeffs )
979+
980+
972981class DifferentiableIIR (torch .autograd .Function ):
973982 @staticmethod
974983 def forward (ctx , waveform , a_coeffs_normalized ):
975984 n_batch , n_channel , n_sample = waveform .shape
976985 n_order = a_coeffs_normalized .size (1 )
977986 n_sample_padded = n_sample + n_order - 1
978987
979- a_coeff_flipped = a_coeffs_normalized .flip (1 ).contiguous ();
988+ a_coeff_flipped = a_coeffs_normalized .flip (1 ).contiguous ()
980989 padded_output_waveform = torch .zeros (n_batch , n_channel , n_sample_padded ,
981- device = waveform .device , dtype = waveform .dtype )
990+ device = waveform .device , dtype = waveform .dtype )
982991 _lfilter_core_loop (waveform , a_coeff_flipped , padded_output_waveform )
983- output = padded_output_waveform [:,:,n_order - 1 :]
984- ctx .save_for_backward (waveform , a_coeffs_normalized , output )
992+ output = padded_output_waveform [:, :, n_order - 1 :]
993+ if not torch .jit .is_scripting ():
994+ ctx .save_for_backward (waveform , a_coeffs_normalized , output )
985995 return output
986996
987997 @staticmethod
@@ -992,15 +1002,23 @@ def backward(ctx, dy):
9921002 tmp = DifferentiableIIR .apply (dy .flip (2 ).contiguous (), a_coeffs_normalized ).flip (2 )
9931003 dx = tmp if x .requires_grad else None
9941004 da = - (tmp .transpose (0 , 1 ).reshape (n_channel , 1 , - 1 ) @
995- F .pad (y , (n_order - 1 , 0 )).unfold (2 , n_order , 1 ).transpose (0 ,1 )
996- .reshape (n_channel , - 1 , n_order )
997- ).squeeze (1 ).flip (1 ) if a_coeffs_normalized .requires_grad else None
1005+ F .pad (y , (n_order - 1 , 0 )).unfold (2 , n_order , 1 ).transpose (0 , 1 )
1006+ .reshape (n_channel , - 1 , n_order )
1007+ ).squeeze (1 ).flip (1 ) if a_coeffs_normalized .requires_grad else None
9981008 return (dx , da )
9991009
1010+ @staticmethod
1011+ def ts_apply (waveform , a_coeffs_normalized ):
1012+ if torch .jit .is_scripting ():
1013+ return DifferentiableIIR .forward (torch .empty (0 ), waveform , a_coeffs_normalized )
1014+ else :
1015+ return DifferentiableIIR .apply (waveform , a_coeffs_normalized )
1016+
1017+
10001018def _lfilter (waveform , a_coeffs , b_coeffs ):
1001- n_order = b_coeffs . size ( 1 )
1002- filtered_waveform = DifferentiableFIR . apply ( waveform , b_coeffs / a_coeffs [:, 0 :1 ])
1003- return DifferentiableIIR . apply ( filtered_waveform , a_coeffs / a_coeffs [:, 0 : 1 ])
1019+ filtered_waveform = DifferentiableFIR . ts_apply ( waveform , b_coeffs / a_coeffs [:, 0 : 1 ] )
1020+ return DifferentiableIIR . ts_apply ( filtered_waveform , a_coeffs / a_coeffs [:, 0 :1 ])
1021+
10041022
10051023def lfilter (waveform : Tensor , a_coeffs : Tensor , b_coeffs : Tensor , clamp : bool = True , batching : bool = True ) -> Tensor :
10061024 r"""Perform an IIR filter by evaluating difference equation, using differentiable implementation
@@ -1071,6 +1089,7 @@ def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool =
10711089
10721090 return output
10731091
1092+
10741093def lowpass_biquad (waveform : Tensor , sample_rate : int , cutoff_freq : float , Q : float = 0.707 ) -> Tensor :
10751094 r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
10761095
0 commit comments