diff --git a/dnn/torch/osce/models/bwe_net.py b/dnn/torch/osce/models/bwe_net.py index 1cc4e3171..99362e1df 100644 --- a/dnn/torch/osce/models/bwe_net.py +++ b/dnn/torch/osce/models/bwe_net.py @@ -75,7 +75,9 @@ def __init__(self, kernel_size32=15, kernel_size48=15, conv_gain_limits_db=[-12, 12], - activation="AdaShape" + activation="AdaShape", + avg_pool_k32 = 8, + avg_pool_k48=12 ): super().__init__() @@ -99,8 +101,8 @@ def __init__(self, # non-linear transforms if activation == "AdaShape": - self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=8) - self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=12) + self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k32) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k48) self.act1 = self.tdshape1 self.act2 = self.tdshape2 elif activation == "ReLU":