From 2ba3457bf6d26346d97e415f90a50fad56fa8fbf Mon Sep 17 00:00:00 2001 From: LiChenda Date: Thu, 10 Feb 2022 16:48:46 +0800 Subject: [PATCH 01/34] add skim model --- espnet2/enh/layers/dprnn.py | 1 + espnet2/enh/layers/skim.py | 132 ++++++++++++++++++++++++++++++++++++ espnet2/enh/layers/tcn.py | 44 +++++++++--- 3 files changed, 169 insertions(+), 8 deletions(-) create mode 100644 espnet2/enh/layers/skim.py diff --git a/espnet2/enh/layers/dprnn.py b/espnet2/enh/layers/dprnn.py index 827c754ac86..aae6040f74f 100644 --- a/espnet2/enh/layers/dprnn.py +++ b/espnet2/enh/layers/dprnn.py @@ -4,6 +4,7 @@ # # The code is based on: # https://github.com/yluo42/TAC/blob/master/utility/models.py +# Licensed under CC BY-NC-SA 3.0 US. # diff --git a/espnet2/enh/layers/skim.py b/espnet2/enh/layers/skim.py new file mode 100644 index 00000000000..0e231a80c7e --- /dev/null +++ b/espnet2/enh/layers/skim.py @@ -0,0 +1,132 @@ +# An implementation of SkiM model described in +# "SkiM: Skipping Memory LSTM for Low-Latency Real-Time Continuous Speech Separation" +# (https://arxiv.org/abs/2201.10800) +# + + +from turtle import forward, shape +import torch +import torch.nn as nn + +from espnet2.enh.layers.dprnn import SingleRNN, split_feature, merge_feature +from espnet2.enh.layers.tcn import chose_norm + + +class MemLSTM(nn.Module): + """ the Mem-LSTM of SkiM + + args: + hidden_size: int, dimension of the hidden state. + dropout: float, dropout ratio. Default is 0. + bidirectional: bool, whether the LSTM layers are bidirectional. Default is False. + mem_type: 'hc', 'h', 'c' or 'id'. + It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. + In 'id' mode, both the hidden and cell states will be identically returned. + norm_type: gLN, cLN. cLN is for causal implementation. + """ + + def __init__(self,hidden_size, dropout=0.0, bidirectional=False, mem_type='hc', norm_type='gLN'): + super().__init__() + self.hidden_size = hidden_size + self.bidirectional = bidirectional + self.input_size = (int(bidirectional) + 1) * hidden_size + self.mem_type = mem_type + + assert mem_type in ["hc", "h", 'c', 'id'], f"only support 'hc', 'h', 'c' and 'id', current type: {mem_type}" + + if mem_type in ["hc", 'h']: + self.h_net = SingleRNN('LSTM', input_size=self.input_size, hidden_size=self.hidden_size, dropout=dropout, bidirectional=bidirectional) + self.h_norm = chose_norm(norm_type=norm_type, channel_size=self.input_size, shape='BTD') + if mem_type in ["hc", 'c']: + self.c_net = SingleRNN('LSTM', input_size=self.input_size, hidden_size=self.hidden_size, dropout=dropout, bidirectional=bidirectional) + self.c_norm = chose_norm(norm_type=norm_type, channel_size=self.input_size, shape='BTD') + + def extra_repr(self) -> str: + return f"Mem_type: {self.mem_type}, bidirectional: {self.bidirectional}" + + def forward(self, hc, S): + # hc = (h, c), tuple of hidden and cell states from SegLSTM + # shape of h and c: (d, B*S, H) + # S: number of segments in SegLSTM + + if self.mem_type == 'id': + ret_val = hc + else: + h, c = hc + d, BS, H = h.shape + B = BS // S + h = h.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH + c = c.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH + if self.mem_type == 'hc': + h = h + self.h_norm(self.h_net(h)) + c = c + self.c_norm(self.c_net(c)) + elif self.mem_type == 'h': + h = h + self.h_norm(self.h_net(h)) + c = torch.zeros_like(c) + elif self.mem_type == 'c': + h = torch.zeros_like(h) + c = c + self.c_norm(self.c_net(c)) + + h = h.view(B * S, d, H).transpose(1, 0).contiguous() + c = c.view(B * S, d, H).transpose(1, 0).contiguous() + ret_val = (h, c) + + if not self.bidirectional: + # for causal setup + causal_ret_val = [] + for x in ret_val: + x_ = torch.zeros_like(x) + x_[:, 1:, :] = x[:, :-1,:] + causal_ret_val.append(x_) + ret_val = tuple(causal_ret_val) + + return ret_val + + + + +class SegLSTM(nn.Module): + + """ the Seg-LSTM of SkiM + + args: + input_size: int, dimension of the input feature. The input should have shape + (batch, seq_len, input_size). + hidden_size: int, dimension of the hidden state. + dropout: float, dropout ratio. Default is 0. + bidirectional: bool, whether the LSTM layers are bidirectional. Default is False. + """ + + def __init__(self, input_size, hidden_size, dropout=0.0, bidirectional=False): + super().__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.num_direction = int(bidirectional) + 1 + + self.lstm = nn.LSTM(input_size, hidden_size, 1,batch_first=True,bidirectional=bidirectional,) + self.dropout = nn.Dropout(p=dropout) + self.proj = nn.Linear(hidden_size * self.num_direction, input_size) + + def forward(self, input, hc): + # input shape: B, T, H + + B, T, H = input.shape + + if hc == None: + # In fist input SkiM block, h and c are not available + d = self.num_direction + h = torch.zeros(d, B, self.hidden_size).to(input.device) + c = torch.zeros(d, B, self.hidden_size).to(input.device) + else: + h, c = hc + + output, (h, c) = self.lstm(input, (h, c)) + output = self.dropout(output) + output = self.proj( + output.contiguous().view(-1, output.shape[2]) + ).view(output.shape) + + return output, (h, c) + + diff --git a/espnet2/enh/layers/tcn.py b/espnet2/enh/layers/tcn.py index 8fe8cd17036..b41b108731c 100644 --- a/espnet2/enh/layers/tcn.py +++ b/espnet2/enh/layers/tcn.py @@ -4,6 +4,7 @@ # # The code is based on: # https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py +# Licensed under MIT. # @@ -46,7 +47,7 @@ def __init__( for r in range(R): blocks = [] for x in range(X): - dilation = 2**x + dilation = 2 ** x padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 blocks += [ TemporalBlock( @@ -86,9 +87,9 @@ def forward(self, mixture_w): elif self.mask_nonlinear == "relu": est_mask = F.relu(score) elif self.mask_nonlinear == "sigmoid": - est_mask = torch.sigmoid(score) + est_mask = F.sigmoid(score) elif self.mask_nonlinear == "tanh": - est_mask = torch.tanh(score) + est_mask = F.tanh(score) else: raise ValueError("Unsupported mask non-linear function") return est_mask @@ -214,19 +215,21 @@ def check_nonlinear(nolinear_type): raise ValueError("Unsupported nonlinear type") -def chose_norm(norm_type, channel_size): +def chose_norm(norm_type, channel_size, shape='BDT'): """The input of normalization will be (M, C, K), where M is batch size. C is channel size and K is sequence length. """ if norm_type == "gLN": - return GlobalLayerNorm(channel_size) + return GlobalLayerNorm(channel_size, shape=shape) elif norm_type == "cLN": - return ChannelwiseLayerNorm(channel_size) + return ChannelwiseLayerNorm(channel_size, shape=shape) elif norm_type == "BN": # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics # along M and K, so this BN usage is right. return nn.BatchNorm1d(channel_size) + elif norm_type == "GN": + return nn.GroupNorm(1, channel_size, eps=1e-8) else: raise ValueError("Unsupported normalization type") @@ -234,11 +237,13 @@ def chose_norm(norm_type, channel_size): class ChannelwiseLayerNorm(nn.Module): """Channel-wise Layer Normalization (cLN).""" - def __init__(self, channel_size): + def __init__(self, channel_size, shape='BDT'): super().__init__() self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.reset_parameters() + assert shape in ['BDT', 'BTD'] + self.shape = shape def reset_parameters(self): self.gamma.data.fill_(1) @@ -253,20 +258,37 @@ def forward(self, y): Returns: cLN_y: [M, N, K] """ + dim = 3 + if y.dim() == 4: + dim = 4 + M, N, K, L = y.shape + y = y.view(M, N, K * L) + + if self.shape == 'BTD': + y = y.transpose(1, 2).contiguous() + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + + if self.shape == 'BTD': + cLN_y = cLN_y.transpose(1, 2).contiguous() + + if dim == 4: + cLN_y = cLN_y.view(M, N, K, L) return cLN_y class GlobalLayerNorm(nn.Module): """Global Layer Normalization (gLN).""" - def __init__(self, channel_size): + def __init__(self, channel_size, shape='BDT'): super().__init__() self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.reset_parameters() + assert shape in ['BDT', 'BTD'] + self.shape = shape def reset_parameters(self): self.gamma.data.fill_(1) @@ -281,9 +303,15 @@ def forward(self, y): Returns: gLN_y: [M, N, K] """ + if self.shape == 'BTD': + y = y.transpose(1, 2).contiguous() + mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] var = ( (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) ) gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + + if self.shape == 'BTD': + gLN_y = gLN_y.transpose(1, 2).contiguous() return gLN_y From 7d7686fec615bef79724eef1f11fd2c6dce56087 Mon Sep 17 00:00:00 2001 From: LiChenda Date: Thu, 10 Feb 2022 19:10:07 +0800 Subject: [PATCH 02/34] update skim.py --- espnet2/enh/layers/skim.py | 210 ++++++++++++++++++++++++++++++------- 1 file changed, 174 insertions(+), 36 deletions(-) diff --git a/espnet2/enh/layers/skim.py b/espnet2/enh/layers/skim.py index 0e231a80c7e..6f247524283 100644 --- a/espnet2/enh/layers/skim.py +++ b/espnet2/enh/layers/skim.py @@ -1,10 +1,11 @@ -# An implementation of SkiM model described in +# An implementation of SkiM model described in # "SkiM: Skipping Memory LSTM for Low-Latency Real-Time Continuous Speech Separation" # (https://arxiv.org/abs/2201.10800) -# +# from turtle import forward, shape +from black import main import torch import torch.nn as nn @@ -13,34 +14,62 @@ class MemLSTM(nn.Module): - """ the Mem-LSTM of SkiM + """the Mem-LSTM of SkiM args: hidden_size: int, dimension of the hidden state. dropout: float, dropout ratio. Default is 0. bidirectional: bool, whether the LSTM layers are bidirectional. Default is False. - mem_type: 'hc', 'h', 'c' or 'id'. + mem_type: 'hc', 'h', 'c' or 'id'. It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. - In 'id' mode, both the hidden and cell states will be identically returned. + In 'id' mode, both the hidden and cell states will be identically returned. norm_type: gLN, cLN. cLN is for causal implementation. """ - def __init__(self,hidden_size, dropout=0.0, bidirectional=False, mem_type='hc', norm_type='gLN'): + def __init__( + self, + hidden_size, + dropout=0.0, + bidirectional=False, + mem_type="hc", + norm_type="cLN", + ): super().__init__() self.hidden_size = hidden_size self.bidirectional = bidirectional self.input_size = (int(bidirectional) + 1) * hidden_size self.mem_type = mem_type - assert mem_type in ["hc", "h", 'c', 'id'], f"only support 'hc', 'h', 'c' and 'id', current type: {mem_type}" + assert mem_type in [ + "hc", + "h", + "c", + "id", + ], f"only support 'hc', 'h', 'c' and 'id', current type: {mem_type}" + + if mem_type in ["hc", "h"]: + self.h_net = SingleRNN( + "LSTM", + input_size=self.input_size, + hidden_size=self.hidden_size, + dropout=dropout, + bidirectional=bidirectional, + ) + self.h_norm = chose_norm( + norm_type=norm_type, channel_size=self.input_size, shape="BTD" + ) + if mem_type in ["hc", "c"]: + self.c_net = SingleRNN( + "LSTM", + input_size=self.input_size, + hidden_size=self.hidden_size, + dropout=dropout, + bidirectional=bidirectional, + ) + self.c_norm = chose_norm( + norm_type=norm_type, channel_size=self.input_size, shape="BTD" + ) - if mem_type in ["hc", 'h']: - self.h_net = SingleRNN('LSTM', input_size=self.input_size, hidden_size=self.hidden_size, dropout=dropout, bidirectional=bidirectional) - self.h_norm = chose_norm(norm_type=norm_type, channel_size=self.input_size, shape='BTD') - if mem_type in ["hc", 'c']: - self.c_net = SingleRNN('LSTM', input_size=self.input_size, hidden_size=self.hidden_size, dropout=dropout, bidirectional=bidirectional) - self.c_norm = chose_norm(norm_type=norm_type, channel_size=self.input_size, shape='BTD') - def extra_repr(self) -> str: return f"Mem_type: {self.mem_type}, bidirectional: {self.bidirectional}" @@ -48,25 +77,25 @@ def forward(self, hc, S): # hc = (h, c), tuple of hidden and cell states from SegLSTM # shape of h and c: (d, B*S, H) # S: number of segments in SegLSTM - - if self.mem_type == 'id': + + if self.mem_type == "id": ret_val = hc else: h, c = hc d, BS, H = h.shape B = BS // S - h = h.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH - c = c.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH - if self.mem_type == 'hc': + h = h.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH + c = c.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH + if self.mem_type == "hc": h = h + self.h_norm(self.h_net(h)) c = c + self.c_norm(self.c_net(c)) - elif self.mem_type == 'h': + elif self.mem_type == "h": h = h + self.h_norm(self.h_net(h)) c = torch.zeros_like(c) - elif self.mem_type == 'c': + elif self.mem_type == "c": h = torch.zeros_like(h) c = c + self.c_norm(self.c_net(c)) - + h = h.view(B * S, d, H).transpose(1, 0).contiguous() c = c.view(B * S, d, H).transpose(1, 0).contiguous() ret_val = (h, c) @@ -76,38 +105,48 @@ def forward(self, hc, S): causal_ret_val = [] for x in ret_val: x_ = torch.zeros_like(x) - x_[:, 1:, :] = x[:, :-1,:] + x_[:, 1:, :] = x[:, :-1, :] causal_ret_val.append(x_) ret_val = tuple(causal_ret_val) - + return ret_val - - class SegLSTM(nn.Module): - """ the Seg-LSTM of SkiM + """the Seg-LSTM of SkiM args: input_size: int, dimension of the input feature. The input should have shape - (batch, seq_len, input_size). + (batch, seq_len, input_size). hidden_size: int, dimension of the hidden state. dropout: float, dropout ratio. Default is 0. bidirectional: bool, whether the LSTM layers are bidirectional. Default is False. + norm_type: gLN, cLN. cLN is for causal implementation. """ - def __init__(self, input_size, hidden_size, dropout=0.0, bidirectional=False): + def __init__( + self, input_size, hidden_size, dropout=0.0, bidirectional=False, norm_type="cLN" + ): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_direction = int(bidirectional) + 1 - self.lstm = nn.LSTM(input_size, hidden_size, 1,batch_first=True,bidirectional=bidirectional,) + self.lstm = nn.LSTM( + input_size, + hidden_size, + 1, + batch_first=True, + bidirectional=bidirectional, + ) self.dropout = nn.Dropout(p=dropout) - self.proj = nn.Linear(hidden_size * self.num_direction, input_size) - + self.proj = nn.Linear(hidden_size * self.num_direction, input_size) + self.norm = chose_norm( + norm_type=norm_type, channel_size=hidden_size, shape="BTD" + ) + def forward(self, input, hc): # input shape: B, T, H @@ -121,12 +160,111 @@ def forward(self, input, hc): else: h, c = hc - output, (h, c) = self.lstm(input, (h, c)) + output, (h, c) = self.lstm(input, (h, c)) output = self.dropout(output) - output = self.proj( - output.contiguous().view(-1, output.shape[2]) - ).view(output.shape) + output = self.proj(output.contiguous().view(-1, output.shape[2])).view( + output.shape + ) + output = input + self.norm(output) return output, (h, c) +class SkiM(nn.Module): + def __init__( + self, + input_size, + hidden_size, + output_size, + dropout=0.0, + num_blocks=2, + segment_size=20, + bidirectional=True, + mem_type="hc", + norm_type="gLN", + seg_overlap=False, + ): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.segment_size = segment_size + self.dropout = dropout + self.num_blocks = num_blocks + self.mem_type = mem_type + self.norm_type = norm_type + self.seg_overlap = seg_overlap + + self.seg_lstms = nn.ModuleList([]) + for i in range(num_blocks): + self.seg_lstms.append( + SegLSTM( + input_size=input_size, + hidden_size=hidden_size, + dropout=dropout, + bidirectional=bidirectional, + norm_type=norm_type, + ) + ) + if self.mem_type is not None: + self.mem_lstms = nn.ModuleList([]) + for i in range(num_blocks - 1): + self.mem_lstms.append( + MemLSTM( + hidden_size, + dropout=dropout, + bidirectional=bidirectional, + mem_type=mem_type, + norm_type=norm_type, + ) + ) + self.output_fc = nn.Sequential(nn.PReLU(), nn.Conv1d(input_size, output_size, 1)) + + def forward(self, input): + # input shape: B, T (S*K), D + B, T, D = input.shape + + if self.seg_overlap: + input, rest = split_feature(input.transpose(1, 2), + segment_size=self.segment_size) # B, D, K, S + input = input.permute(0, 3, 2, 1).contiguous() # B, S, K, D + else: + input, rest = self._padfeature(input=input) + input = input.view(B, -1, self.segment_size, D) # B, S, K, D + B, S, K, D = input.shape + + assert K == self.segment_size + + output = input.view(B * S, K, D).contiguous() # BS, K, D + hc = None + for i in range(self.num_blocks): + output, hc = self.seg_lstms[i](output, hc) # BS, K, D + if self.mem_type and i < self.num_blocks - 1: + hc = self.mem_lstms[i](hc, S) + + if self.seg_overlap: + output = output.view(B, S, K, D).permute(0, 3, 2, 1) # B, D, K, S + output = merge_feature(output, rest) # B, D, T + output = self.output(output).transpose(1, 2) + + else: + output = output.view(B, S * K, D)[:, :T, :] # B, T, D + output = self.output(output.transpose(1, 2)).transpose(1, 2) + + return output + + def _padfeature(self, input): + B, T, D = input.shape + rest = self.segment_size - T % self.segment_size + + if rest > 0: + pad = torch.zeros(B, rest, D, device=input.device) + input = torch.cat([input, pad], dim=1) + return input, rest + + +if __name__ == "__main__": + + model = SkiM(256, 123, 345, dropout=0.1, num_blocks=3, segment_size=20,bidirectional=True, mem_type='hc', norm_type='gLN', seg_overlap=False) + input = torch.randn(2, 1002, 256) + print(model(input).shape) From d103494a29b76c73520f40a92b43e17a9fea16ed Mon Sep 17 00:00:00 2001 From: LiChenda Date: Fri, 11 Feb 2022 13:50:57 +0800 Subject: [PATCH 03/34] add skim separator --- .../conf/tuning/train_enh_skim_tasnet.yaml | 72 +++++++++++ espnet2/enh/layers/skim.py | 69 ++++++++--- espnet2/enh/separator/skim_separator.py | 117 ++++++++++++++++++ espnet2/tasks/enh.py | 2 + 4 files changed, 246 insertions(+), 14 deletions(-) create mode 100644 egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml create mode 100644 espnet2/enh/separator/skim_separator.py diff --git a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml new file mode 100644 index 00000000000..0937dd6024c --- /dev/null +++ b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml @@ -0,0 +1,72 @@ +optim: adam +init: xavier_uniform +max_epoch: 150 +batch_type: folded +batch_size: 1 # batch_size 16 can be trained on 4 RTX 2080ti +iterator_type: chunk +chunk_length: 32000 +num_workers: 4 +optim_conf: + lr: 1.0e-03 + eps: 1.0e-08 + weight_decay: 0 +patience: 4 +val_scheduler_criterion: +- valid +- loss +best_model_criterion: +- - valid + - si_snr + - max +- - valid + - loss + - min +keep_nbest_models: 1 +scheduler: reducelronplateau +scheduler_conf: + mode: min + factor: 0.7 + patience: 1 + +encoder: conv +encoder_conf: + channel: 64 + kernel_size: 2 + stride: 1 +decoder: conv +decoder_conf: + channel: 64 + kernel_size: 2 + stride: 1 +separator: skim +separator_conf: + casual: False + num_spk: 2 + layer: 6 + nonlinear: relu + unit: 128 + segment_size: 250 + dropout: 0.1 + mem_type: hc + seg_overlap: False + +# A list for criterions +# The overlall loss in the multi-task learning will be: +# loss = weight_1 * loss_1 + ... + weight_N * loss_N +# The default `weight` for each sub-loss is 1.0 +criterions: + # The first criterion + - name: si_snr + conf: + eps: 1.0e-7 + wrapper: pit + wrapper_conf: + weight: 1.0 + independent_perm: True + + + + + + + diff --git a/espnet2/enh/layers/skim.py b/espnet2/enh/layers/skim.py index 6f247524283..2e3d506cfbd 100644 --- a/espnet2/enh/layers/skim.py +++ b/espnet2/enh/layers/skim.py @@ -144,7 +144,7 @@ def __init__( self.dropout = nn.Dropout(p=dropout) self.proj = nn.Linear(hidden_size * self.num_direction, input_size) self.norm = chose_norm( - norm_type=norm_type, channel_size=hidden_size, shape="BTD" + norm_type=norm_type, channel_size=input_size, shape="BTD" ) def forward(self, input, hc): @@ -163,7 +163,7 @@ def forward(self, input, hc): output, (h, c) = self.lstm(input, (h, c)) output = self.dropout(output) output = self.proj(output.contiguous().view(-1, output.shape[2])).view( - output.shape + input.shape ) output = input + self.norm(output) @@ -171,6 +171,26 @@ def forward(self, input, hc): class SkiM(nn.Module): + """Skipping Memory Net + + args: + input_size: int, dimension of the input feature. + Input shape shoud be (batch, length, input_size) + hidden_size: int, dimension of the hidden state. + output_size: int, dimension of the output size. + dropout: float, dropout ratio. Default is 0. + num_blocks: number of basic SkiM blocks + segment_size: segmentation size for splitting long features + bidirectional: bool, whether the RNN layers are bidirectional. + mem_type: 'hc', 'h', 'c', 'id' or None. + It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. + In 'id' mode, both the hidden and cell states will be identically returned. + When mem_type is None, the MemLSTM will be removed. + norm_type: gLN, cLN. cLN is for causal implementation. + seg_overlap: Bool, whether the segmentation will reserve 50% overlap for adjacent segments. + Default is False. + """ + def __init__( self, input_size, @@ -195,6 +215,13 @@ def __init__( self.norm_type = norm_type self.seg_overlap = seg_overlap + assert mem_type in [ + "hc", + "h", + "c", + None, + ], f"only support 'hc', 'h', 'c', 'id', and None, current type: {mem_type}" + self.seg_lstms = nn.ModuleList([]) for i in range(num_blocks): self.seg_lstms.append( @@ -218,16 +245,19 @@ def __init__( norm_type=norm_type, ) ) - self.output_fc = nn.Sequential(nn.PReLU(), nn.Conv1d(input_size, output_size, 1)) + self.output_fc = nn.Sequential( + nn.PReLU(), nn.Conv1d(input_size, output_size, 1) + ) def forward(self, input): # input shape: B, T (S*K), D B, T, D = input.shape if self.seg_overlap: - input, rest = split_feature(input.transpose(1, 2), - segment_size=self.segment_size) # B, D, K, S - input = input.permute(0, 3, 2, 1).contiguous() # B, S, K, D + input, rest = split_feature( + input.transpose(1, 2), segment_size=self.segment_size + ) # B, D, K, S + input = input.permute(0, 3, 2, 1).contiguous() # B, S, K, D else: input, rest = self._padfeature(input=input) input = input.view(B, -1, self.segment_size, D) # B, S, K, D @@ -241,16 +271,16 @@ def forward(self, input): output, hc = self.seg_lstms[i](output, hc) # BS, K, D if self.mem_type and i < self.num_blocks - 1: hc = self.mem_lstms[i](hc, S) - + if self.seg_overlap: - output = output.view(B, S, K, D).permute(0, 3, 2, 1) # B, D, K, S - output = merge_feature(output, rest) # B, D, T - output = self.output(output).transpose(1, 2) + output = output.view(B, S, K, D).permute(0, 3, 2, 1) # B, D, K, S + output = merge_feature(output, rest) # B, D, T + output = self.output_fc(output).transpose(1, 2) else: output = output.view(B, S * K, D)[:, :T, :] # B, T, D - output = self.output(output.transpose(1, 2)).transpose(1, 2) - + output = self.output_fc(output.transpose(1, 2)).transpose(1, 2) + return output def _padfeature(self, input): @@ -265,6 +295,17 @@ def _padfeature(self, input): if __name__ == "__main__": - model = SkiM(256, 123, 345, dropout=0.1, num_blocks=3, segment_size=20,bidirectional=True, mem_type='hc', norm_type='gLN', seg_overlap=False) - input = torch.randn(2, 1002, 256) + model = SkiM( + 333, + 111, + 222, + dropout=0.1, + num_blocks=3, + segment_size=20, + bidirectional=False, + mem_type="hc", + norm_type="cLN", + seg_overlap=True, + ) + input = torch.randn(2, 1002, 333) print(model(input).shape) diff --git a/espnet2/enh/separator/skim_separator.py b/espnet2/enh/separator/skim_separator.py new file mode 100644 index 00000000000..bf291b2daec --- /dev/null +++ b/espnet2/enh/separator/skim_separator.py @@ -0,0 +1,117 @@ +from collections import OrderedDict +from typing import List +from typing import Tuple +from typing import Union + +import torch +from torch_complex.tensor import ComplexTensor + +from espnet2.enh.layers.skim import SkiM +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class SkiMSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + casual: bool = True, + num_spk: int = 2, + nonlinear: str = "relu", + layer: int = 3, + unit: int = 512, + segment_size: int = 20, + dropout: float = 0.0, + mem_type: str = 'hc', + seg_overlap: bool = False, + ): + """Skipping Memory (SkiM) Separator + + Args: + input_dim: input feature dimension + casual: bool, whether the system is casual. + num_spk: number of target speakers. + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + layer: int, number of SkiM blocks. Default is 3. + unit: int, dimension of the hidden state. + segment_size: segmentation size for splitting long features + dropout: float, dropout ratio. Default is 0. + mem_type: 'hc', 'h', 'c', 'id' or None. + It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. + In 'id' mode, both the hidden and cell states will be identically returned. + When mem_type is None, the MemLSTM will be removed. + seg_overlap: Bool, whether the segmentation will reserve 50% overlap for adjacent segments. + Default is False. + """ + super().__init__() + + self._num_spk = num_spk + + self.segment_size = segment_size + + self.skim = SkiM( + input_size=input_dim, + hidden_size=unit, + output_size=input_dim * num_spk, + dropout=dropout, + num_blocks=layer, + bidirectional=(not casual), + norm_type='cLN' if casual else 'gLN', + segment_size=segment_size, + seg_overlap=seg_overlap, + mem_type=mem_type, + ) + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.nonlinear = { + "sigmoid": torch.nn.Sigmoid(), + "relu": torch.nn.ReLU(), + "tanh": torch.nn.Tanh(), + }[nonlinear] + + def forward( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] + ilens (torch.Tensor): input lengths [Batch] + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, Frames, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Freq), + ] + """ + + # if complex spectrum, + if isinstance(input, ComplexTensor): + feature = abs(input) + else: + feature = input + + B, T, N = feature.shape + + processed = self.skim(feature) # B,T, N + + processed = processed.view(B, T, N, self.num_spk) + masks = self.nonlinear(processed).unbind(dim=3) + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + + return masked, ilens, others + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/tasks/enh.py b/espnet2/tasks/enh.py index 2a722cba554..7b1aab863d7 100644 --- a/espnet2/tasks/enh.py +++ b/espnet2/tasks/enh.py @@ -35,6 +35,7 @@ from espnet2.enh.separator.dprnn_separator import DPRNNSeparator from espnet2.enh.separator.neural_beamformer import NeuralBeamformer from espnet2.enh.separator.rnn_separator import RNNSeparator +from espnet2.enh.separator.skim_separator import SkiMSeparator from espnet2.enh.separator.tcn_separator import TCNSeparator from espnet2.enh.separator.transformer_separator import TransformerSeparator from espnet2.tasks.abs_task import AbsTask @@ -58,6 +59,7 @@ name="separator", classes=dict( rnn=RNNSeparator, + skim=SkiMSeparator, tcn=TCNSeparator, dprnn=DPRNNSeparator, transformer=TransformerSeparator, From cf535c2661c20f91810a7c86f528386a3a0bdaf4 Mon Sep 17 00:00:00 2001 From: LiChenda Date: Fri, 11 Feb 2022 14:24:01 +0800 Subject: [PATCH 04/34] add causal config for skim --- .../conf/tuning/train_enh_skim_tasnet.yaml | 4 +- .../tuning/train_enh_skim_tasnet_causal.yaml | 72 +++++++++++++++++++ espnet2/enh/separator/skim_separator.py | 42 +++++------ 3 files changed, 96 insertions(+), 22 deletions(-) create mode 100644 egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet_causal.yaml diff --git a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml index 0937dd6024c..5fdd978f95c 100644 --- a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml +++ b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml @@ -2,9 +2,9 @@ optim: adam init: xavier_uniform max_epoch: 150 batch_type: folded -batch_size: 1 # batch_size 16 can be trained on 4 RTX 2080ti +batch_size: 8 iterator_type: chunk -chunk_length: 32000 +chunk_length: 16000 num_workers: 4 optim_conf: lr: 1.0e-03 diff --git a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet_causal.yaml b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet_causal.yaml new file mode 100644 index 00000000000..5681dab1996 --- /dev/null +++ b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet_causal.yaml @@ -0,0 +1,72 @@ +optim: adam +init: xavier_uniform +max_epoch: 150 +batch_type: folded +batch_size: 8 +iterator_type: chunk +chunk_length: 16000 +num_workers: 4 +optim_conf: + lr: 1.0e-03 + eps: 1.0e-08 + weight_decay: 0 +patience: 4 +val_scheduler_criterion: +- valid +- loss +best_model_criterion: +- - valid + - si_snr + - max +- - valid + - loss + - min +keep_nbest_models: 1 +scheduler: reducelronplateau +scheduler_conf: + mode: min + factor: 0.7 + patience: 1 + +encoder: conv +encoder_conf: + channel: 64 + kernel_size: 2 + stride: 1 +decoder: conv +decoder_conf: + channel: 64 + kernel_size: 2 + stride: 1 +separator: skim +separator_conf: + casual: True + num_spk: 2 + layer: 6 + nonlinear: relu + unit: 128 + segment_size: 250 + dropout: 0.1 + mem_type: hc + seg_overlap: False + +# A list for criterions +# The overlall loss in the multi-task learning will be: +# loss = weight_1 * loss_1 + ... + weight_N * loss_N +# The default `weight` for each sub-loss is 1.0 +criterions: + # The first criterion + - name: si_snr + conf: + eps: 1.0e-7 + wrapper: pit + wrapper_conf: + weight: 1.0 + independent_perm: True + + + + + + + diff --git a/espnet2/enh/separator/skim_separator.py b/espnet2/enh/separator/skim_separator.py index bf291b2daec..0f76a305d8b 100644 --- a/espnet2/enh/separator/skim_separator.py +++ b/espnet2/enh/separator/skim_separator.py @@ -11,6 +11,26 @@ class SkiMSeparator(AbsSeparator): + """Skipping Memory (SkiM) Separator + + Args: + input_dim: input feature dimension + casual: bool, whether the system is casual. + num_spk: number of target speakers. + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + layer: int, number of SkiM blocks. Default is 3. + unit: int, dimension of the hidden state. + segment_size: segmentation size for splitting long features + dropout: float, dropout ratio. Default is 0. + mem_type: 'hc', 'h', 'c', 'id' or None. + It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. + In 'id' mode, both the hidden and cell states will be identically returned. + When mem_type is None, the MemLSTM will be removed. + seg_overlap: Bool, whether the segmentation will reserve 50% overlap for adjacent segments. + Default is False. + """ + def __init__( self, input_dim: int, @@ -21,28 +41,10 @@ def __init__( unit: int = 512, segment_size: int = 20, dropout: float = 0.0, - mem_type: str = 'hc', + mem_type: str = "hc", seg_overlap: bool = False, ): - """Skipping Memory (SkiM) Separator - Args: - input_dim: input feature dimension - casual: bool, whether the system is casual. - num_spk: number of target speakers. - nonlinear: the nonlinear function for mask estimation, - select from 'relu', 'tanh', 'sigmoid' - layer: int, number of SkiM blocks. Default is 3. - unit: int, dimension of the hidden state. - segment_size: segmentation size for splitting long features - dropout: float, dropout ratio. Default is 0. - mem_type: 'hc', 'h', 'c', 'id' or None. - It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. - In 'id' mode, both the hidden and cell states will be identically returned. - When mem_type is None, the MemLSTM will be removed. - seg_overlap: Bool, whether the segmentation will reserve 50% overlap for adjacent segments. - Default is False. - """ super().__init__() self._num_spk = num_spk @@ -56,7 +58,7 @@ def __init__( dropout=dropout, num_blocks=layer, bidirectional=(not casual), - norm_type='cLN' if casual else 'gLN', + norm_type="cLN" if casual else "gLN", segment_size=segment_size, seg_overlap=seg_overlap, mem_type=mem_type, From 016b12e3474397a43756a8f7a6b50256d820034e Mon Sep 17 00:00:00 2001 From: LiChenda Date: Fri, 11 Feb 2022 14:40:48 +0800 Subject: [PATCH 05/34] add unit test --- espnet2/enh/separator/skim_separator.py | 11 +- .../enh/separator/test_skim_separator.py | 142 ++++++++++++++++++ 2 files changed, 149 insertions(+), 4 deletions(-) create mode 100644 test/espnet2/enh/separator/test_skim_separator.py diff --git a/espnet2/enh/separator/skim_separator.py b/espnet2/enh/separator/skim_separator.py index 0f76a305d8b..aa279a919bd 100644 --- a/espnet2/enh/separator/skim_separator.py +++ b/espnet2/enh/separator/skim_separator.py @@ -15,7 +15,7 @@ class SkiMSeparator(AbsSeparator): Args: input_dim: input feature dimension - casual: bool, whether the system is casual. + causal: bool, whether the system is causal. num_spk: number of target speakers. nonlinear: the nonlinear function for mask estimation, select from 'relu', 'tanh', 'sigmoid' @@ -34,7 +34,7 @@ class SkiMSeparator(AbsSeparator): def __init__( self, input_dim: int, - casual: bool = True, + causal: bool = True, num_spk: int = 2, nonlinear: str = "relu", layer: int = 3, @@ -51,14 +51,17 @@ def __init__( self.segment_size = segment_size + if mem_type not in ("hc", "h", "c", "id", None): + raise ValueError("Not supporting mem_type={}".format(mem_type)) + self.skim = SkiM( input_size=input_dim, hidden_size=unit, output_size=input_dim * num_spk, dropout=dropout, num_blocks=layer, - bidirectional=(not casual), - norm_type="cLN" if casual else "gLN", + bidirectional=(not causal), + norm_type="cLN" if causal else "gLN", segment_size=segment_size, seg_overlap=seg_overlap, mem_type=mem_type, diff --git a/test/espnet2/enh/separator/test_skim_separator.py b/test/espnet2/enh/separator/test_skim_separator.py new file mode 100644 index 00000000000..c1417e3b28a --- /dev/null +++ b/test/espnet2/enh/separator/test_skim_separator.py @@ -0,0 +1,142 @@ +import pytest + +import torch +from torch import Tensor +from torch_complex import ComplexTensor + +from espnet2.enh.separator.skim_separator import SkiMSeparator + + +@pytest.mark.parametrize("input_dim", [5]) +@pytest.mark.parametrize("layer", [1, 3]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("unit", [8]) +@pytest.mark.parametrize("dropout", [0.0, 0.2]) +@pytest.mark.parametrize("num_spk", [1, 2]) +@pytest.mark.parametrize("nonlinear", ["relu", "sigmoid", "tanh"]) +@pytest.mark.parametrize("mem_type", ["hc", "c", "h", None]) +@pytest.mark.parametrize("segment_size", [2, 4]) +@pytest.mark.parametrize("seg_overlap", [False, True]) +def test_skim_separator_forward_backward_complex( + input_dim, + layer, + causal, + unit, + dropout, + num_spk, + nonlinear, + mem_type, + segment_size, + seg_overlap, +): + model = SkiMSeparator( + input_dim=input_dim, + causal=causal, + num_spk=num_spk, + nonlinear=nonlinear, + layer=layer, + unit=unit, + segment_size=segment_size, + dropout=dropout, + mem_type=mem_type, + seg_overlap=seg_overlap + ) + model.train() + + real = torch.rand(2, 10, input_dim) + imag = torch.rand(2, 10, input_dim) + x = ComplexTensor(real, imag) + x_lens = torch.tensor([10, 8], dtype=torch.long) + + masked, flens, others = model(x, ilens=x_lens) + + assert isinstance(masked[0], ComplexTensor) + assert len(masked) == num_spk + + masked[0].abs().mean().backward() + + +@pytest.mark.parametrize("input_dim", [5]) +@pytest.mark.parametrize("layer", [1, 3]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("unit", [8]) +@pytest.mark.parametrize("dropout", [0.0, 0.2]) +@pytest.mark.parametrize("num_spk", [1, 2]) +@pytest.mark.parametrize("nonlinear", ["relu", "sigmoid", "tanh"]) +@pytest.mark.parametrize("mem_type", ["hc", "c", "h", None]) +@pytest.mark.parametrize("segment_size", [2, 4]) +@pytest.mark.parametrize("seg_overlap", [False, True]) +def test_skim_separator_forward_backward_real( + input_dim, + layer, + causal, + unit, + dropout, + num_spk, + nonlinear, + mem_type, + segment_size, + seg_overlap, +): + model = SkiMSeparator( + input_dim=input_dim, + causal=causal, + num_spk=num_spk, + nonlinear=nonlinear, + layer=layer, + unit=unit, + segment_size=segment_size, + dropout=dropout, + mem_type=mem_type, + seg_overlap=seg_overlap + ) + model.train() + + x = torch.rand(2, 10, input_dim) + x_lens = torch.tensor([10, 8], dtype=torch.long) + + masked, flens, others = model(x, ilens=x_lens) + + assert isinstance(masked[0], Tensor) + assert len(masked) == num_spk + + masked[0].abs().mean().backward() + + +def test_skim_separator_invalid_type(): + with pytest.raises(ValueError): + SkiMSeparator( + input_dim=10, + layer=2, + unit=10, + dropout=0.1, + num_spk=2, + nonlinear="fff", + mem_type='aaa', + segment_size=2, + ) + + +def test_skim_separator_output(): + + x = torch.rand(2, 10, 10) + x_lens = torch.tensor([10, 8], dtype=torch.long) + + for num_spk in range(1, 3): + model = SkiMSeparator( + input_dim=10, + layer=2, + unit=10, + dropout=0.1, + num_spk=2, + nonlinear="relu", + segment_size=2, + ) + model.eval() + specs, _, others = model(x, x_lens) + assert isinstance(specs, list) + assert isinstance(others, dict) + assert x.shape == specs[0].shape + for n in range(num_spk): + assert "mask_spk{}".format(n + 1) in others + assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape From 8310b735c503132e3dbe02d30883158a696a0935 Mon Sep 17 00:00:00 2001 From: LiChenda Date: Fri, 11 Feb 2022 14:41:03 +0800 Subject: [PATCH 06/34] update config --- .../enh1/conf/tuning/train_enh_skim_tasnet_causal.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet_causal.yaml b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet_causal.yaml index 5681dab1996..cb59580d9d8 100644 --- a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet_causal.yaml +++ b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet_causal.yaml @@ -40,7 +40,7 @@ decoder_conf: stride: 1 separator: skim separator_conf: - casual: True + causal: True num_spk: 2 layer: 6 nonlinear: relu From f1ac9259da4159be127e7ea351c209fa6dcc5965 Mon Sep 17 00:00:00 2001 From: LiChenda Date: Fri, 11 Feb 2022 14:50:23 +0800 Subject: [PATCH 07/34] fix for testing --- espnet2/enh/layers/skim.py | 35 +++++++++++-------- espnet2/enh/layers/tcn.py | 28 +++++++-------- espnet2/enh/separator/skim_separator.py | 14 ++++---- .../enh/separator/test_skim_separator.py | 6 ++-- 4 files changed, 45 insertions(+), 38 deletions(-) diff --git a/espnet2/enh/layers/skim.py b/espnet2/enh/layers/skim.py index 2e3d506cfbd..46f14c13227 100644 --- a/espnet2/enh/layers/skim.py +++ b/espnet2/enh/layers/skim.py @@ -3,13 +3,12 @@ # (https://arxiv.org/abs/2201.10800) # - -from turtle import forward, shape -from black import main import torch import torch.nn as nn -from espnet2.enh.layers.dprnn import SingleRNN, split_feature, merge_feature +from espnet2.enh.layers.dprnn import merge_feature +from espnet2.enh.layers.dprnn import SingleRNN +from espnet2.enh.layers.dprnn import split_feature from espnet2.enh.layers.tcn import chose_norm @@ -19,10 +18,13 @@ class MemLSTM(nn.Module): args: hidden_size: int, dimension of the hidden state. dropout: float, dropout ratio. Default is 0. - bidirectional: bool, whether the LSTM layers are bidirectional. Default is False. + bidirectional: bool, whether the LSTM layers are bidirectional. + Default is False. mem_type: 'hc', 'h', 'c' or 'id'. - It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. - In 'id' mode, both the hidden and cell states will be identically returned. + It controls whether the hidden (or cell) state of + SegLSTM will be processed by MemLSTM. + In 'id' mode, both the hidden and cell states will + be identically returned. norm_type: gLN, cLN. cLN is for causal implementation. """ @@ -117,11 +119,12 @@ class SegLSTM(nn.Module): """the Seg-LSTM of SkiM args: - input_size: int, dimension of the input feature. The input should have shape - (batch, seq_len, input_size). + input_size: int, dimension of the input feature. + The input should have shape (batch, seq_len, input_size). hidden_size: int, dimension of the hidden state. dropout: float, dropout ratio. Default is 0. - bidirectional: bool, whether the LSTM layers are bidirectional. Default is False. + bidirectional: bool, whether the LSTM layers are bidirectional. + Default is False. norm_type: gLN, cLN. cLN is for causal implementation. """ @@ -152,7 +155,7 @@ def forward(self, input, hc): B, T, H = input.shape - if hc == None: + if hc is None: # In fist input SkiM block, h and c are not available d = self.num_direction h = torch.zeros(d, B, self.hidden_size).to(input.device) @@ -183,12 +186,14 @@ class SkiM(nn.Module): segment_size: segmentation size for splitting long features bidirectional: bool, whether the RNN layers are bidirectional. mem_type: 'hc', 'h', 'c', 'id' or None. - It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. - In 'id' mode, both the hidden and cell states will be identically returned. + It controls whether the hidden (or cell) state of SegLSTM + will be processed by MemLSTM. + In 'id' mode, both the hidden and cell states will + be identically returned. When mem_type is None, the MemLSTM will be removed. norm_type: gLN, cLN. cLN is for causal implementation. - seg_overlap: Bool, whether the segmentation will reserve 50% overlap for adjacent segments. - Default is False. + seg_overlap: Bool, whether the segmentation will reserve 50% + overlap for adjacent segments.Default is False. """ def __init__( diff --git a/espnet2/enh/layers/tcn.py b/espnet2/enh/layers/tcn.py index b41b108731c..acc2ba6e309 100644 --- a/espnet2/enh/layers/tcn.py +++ b/espnet2/enh/layers/tcn.py @@ -47,7 +47,7 @@ def __init__( for r in range(R): blocks = [] for x in range(X): - dilation = 2 ** x + dilation = 2**x padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 blocks += [ TemporalBlock( @@ -215,7 +215,7 @@ def check_nonlinear(nolinear_type): raise ValueError("Unsupported nonlinear type") -def chose_norm(norm_type, channel_size, shape='BDT'): +def chose_norm(norm_type, channel_size, shape="BDT"): """The input of normalization will be (M, C, K), where M is batch size. C is channel size and K is sequence length. @@ -237,12 +237,12 @@ def chose_norm(norm_type, channel_size, shape='BDT'): class ChannelwiseLayerNorm(nn.Module): """Channel-wise Layer Normalization (cLN).""" - def __init__(self, channel_size, shape='BDT'): + def __init__(self, channel_size, shape="BDT"): super().__init__() self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.reset_parameters() - assert shape in ['BDT', 'BTD'] + assert shape in ["BDT", "BTD"] self.shape = shape def reset_parameters(self): @@ -264,17 +264,17 @@ def forward(self, y): M, N, K, L = y.shape y = y.view(M, N, K * L) - if self.shape == 'BTD': + if self.shape == "BTD": y = y.transpose(1, 2).contiguous() mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta - if self.shape == 'BTD': + if self.shape == "BTD": cLN_y = cLN_y.transpose(1, 2).contiguous() - if dim == 4: + if dim == 4: cLN_y = cLN_y.view(M, N, K, L) return cLN_y @@ -282,12 +282,12 @@ def forward(self, y): class GlobalLayerNorm(nn.Module): """Global Layer Normalization (gLN).""" - def __init__(self, channel_size, shape='BDT'): + def __init__(self, channel_size, shape="BDT"): super().__init__() self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.reset_parameters() - assert shape in ['BDT', 'BTD'] + assert shape in ["BDT", "BTD"] self.shape = shape def reset_parameters(self): @@ -303,15 +303,15 @@ def forward(self, y): Returns: gLN_y: [M, N, K] """ - if self.shape == 'BTD': - y = y.transpose(1, 2).contiguous() + if self.shape == "BTD": + y = y.transpose(1, 2).contiguous() mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] var = ( (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) ) gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta - - if self.shape == 'BTD': - gLN_y = gLN_y.transpose(1, 2).contiguous() + + if self.shape == "BTD": + gLN_y = gLN_y.transpose(1, 2).contiguous() return gLN_y diff --git a/espnet2/enh/separator/skim_separator.py b/espnet2/enh/separator/skim_separator.py index aa279a919bd..9df8e517a96 100644 --- a/espnet2/enh/separator/skim_separator.py +++ b/espnet2/enh/separator/skim_separator.py @@ -18,17 +18,19 @@ class SkiMSeparator(AbsSeparator): causal: bool, whether the system is causal. num_spk: number of target speakers. nonlinear: the nonlinear function for mask estimation, - select from 'relu', 'tanh', 'sigmoid' + select from 'relu', 'tanh', 'sigmoid' layer: int, number of SkiM blocks. Default is 3. unit: int, dimension of the hidden state. segment_size: segmentation size for splitting long features dropout: float, dropout ratio. Default is 0. mem_type: 'hc', 'h', 'c', 'id' or None. - It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. - In 'id' mode, both the hidden and cell states will be identically returned. - When mem_type is None, the MemLSTM will be removed. - seg_overlap: Bool, whether the segmentation will reserve 50% overlap for adjacent segments. - Default is False. + It controls whether the hidden (or cell) state of + SegLSTM will be processed by MemLSTM. + In 'id' mode, both the hidden and cell states + will be identically returned. + When mem_type is None, the MemLSTM will be removed. + seg_overlap: Bool, whether the segmentation will reserve 50% + overlap for adjacent segments. Default is False. """ def __init__( diff --git a/test/espnet2/enh/separator/test_skim_separator.py b/test/espnet2/enh/separator/test_skim_separator.py index c1417e3b28a..7cb8f4bc253 100644 --- a/test/espnet2/enh/separator/test_skim_separator.py +++ b/test/espnet2/enh/separator/test_skim_separator.py @@ -39,7 +39,7 @@ def test_skim_separator_forward_backward_complex( segment_size=segment_size, dropout=dropout, mem_type=mem_type, - seg_overlap=seg_overlap + seg_overlap=seg_overlap, ) model.train() @@ -88,7 +88,7 @@ def test_skim_separator_forward_backward_real( segment_size=segment_size, dropout=dropout, mem_type=mem_type, - seg_overlap=seg_overlap + seg_overlap=seg_overlap, ) model.train() @@ -112,7 +112,7 @@ def test_skim_separator_invalid_type(): dropout=0.1, num_spk=2, nonlinear="fff", - mem_type='aaa', + mem_type="aaa", segment_size=2, ) From ac3c10cfe4faf82c0bb30f8b32d9e8692363e0a9 Mon Sep 17 00:00:00 2001 From: LiChenda Date: Fri, 11 Feb 2022 16:22:52 +0800 Subject: [PATCH 08/34] fixing an assertion missing --- egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml | 2 +- espnet2/enh/layers/skim.py | 1 + test/espnet2/enh/separator/test_skim_separator.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml index 5fdd978f95c..2eced345061 100644 --- a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml +++ b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_skim_tasnet.yaml @@ -40,7 +40,7 @@ decoder_conf: stride: 1 separator: skim separator_conf: - casual: False + causal: False num_spk: 2 layer: 6 nonlinear: relu diff --git a/espnet2/enh/layers/skim.py b/espnet2/enh/layers/skim.py index 46f14c13227..3560d13b56e 100644 --- a/espnet2/enh/layers/skim.py +++ b/espnet2/enh/layers/skim.py @@ -224,6 +224,7 @@ def __init__( "hc", "h", "c", + "id", None, ], f"only support 'hc', 'h', 'c', 'id', and None, current type: {mem_type}" diff --git a/test/espnet2/enh/separator/test_skim_separator.py b/test/espnet2/enh/separator/test_skim_separator.py index 7cb8f4bc253..e1594cd5620 100644 --- a/test/espnet2/enh/separator/test_skim_separator.py +++ b/test/espnet2/enh/separator/test_skim_separator.py @@ -63,7 +63,7 @@ def test_skim_separator_forward_backward_complex( @pytest.mark.parametrize("dropout", [0.0, 0.2]) @pytest.mark.parametrize("num_spk", [1, 2]) @pytest.mark.parametrize("nonlinear", ["relu", "sigmoid", "tanh"]) -@pytest.mark.parametrize("mem_type", ["hc", "c", "h", None]) +@pytest.mark.parametrize("mem_type", ["hc", "c", "h", "id", None]) @pytest.mark.parametrize("segment_size", [2, 4]) @pytest.mark.parametrize("seg_overlap", [False, True]) def test_skim_separator_forward_backward_real( From 2554f265c818d608c378a8febac42fc1ffa10ca3 Mon Sep 17 00:00:00 2001 From: Siddharth Dalmia Date: Mon, 14 Feb 2022 18:04:07 -0500 Subject: [PATCH 09/34] initial MT files from ST for tracking --- egs2/TEMPLATE/mt1/cmd.sh | 110 ++ egs2/TEMPLATE/mt1/conf/fbank.conf | 2 + egs2/TEMPLATE/mt1/conf/pbs.conf | 11 + egs2/TEMPLATE/mt1/conf/pitch.conf | 1 + egs2/TEMPLATE/mt1/conf/queue.conf | 12 + egs2/TEMPLATE/mt1/conf/slurm.conf | 14 + egs2/TEMPLATE/mt1/db.sh | 1 + egs2/TEMPLATE/mt1/local/path.sh | 0 egs2/TEMPLATE/mt1/mt.sh | 1703 +++++++++++++++++++++++++++++ egs2/TEMPLATE/mt1/path.sh | 22 + egs2/TEMPLATE/mt1/pyscripts | 1 + egs2/TEMPLATE/mt1/scripts | 1 + egs2/TEMPLATE/mt1/setup.sh | 58 + egs2/TEMPLATE/mt1/steps | 1 + egs2/TEMPLATE/mt1/utils | 1 + egs2/iwslt14/mt1/cmd.sh | 110 ++ egs2/iwslt14/mt1/conf/fbank.conf | 2 + egs2/iwslt14/mt1/conf/pbs.conf | 11 + egs2/iwslt14/mt1/conf/pitch.conf | 1 + egs2/iwslt14/mt1/conf/queue.conf | 12 + egs2/iwslt14/mt1/conf/slurm.conf | 14 + egs2/iwslt14/mt1/db.sh | 1 + egs2/iwslt14/mt1/local/data.sh | 34 + egs2/iwslt14/mt1/local/path.sh | 0 egs2/iwslt14/mt1/mt.sh | 1 + egs2/iwslt14/mt1/path.sh | 1 + egs2/iwslt14/mt1/pyscripts | 1 + egs2/iwslt14/mt1/run.sh | 55 + egs2/iwslt14/mt1/scripts | 1 + egs2/iwslt14/mt1/steps | 1 + egs2/iwslt14/mt1/utils | 1 + espnet2/mt/__init__.py | 0 espnet2/mt/espnet_model.py | 452 ++++++++ 33 files changed, 2636 insertions(+) create mode 100644 egs2/TEMPLATE/mt1/cmd.sh create mode 100644 egs2/TEMPLATE/mt1/conf/fbank.conf create mode 100644 egs2/TEMPLATE/mt1/conf/pbs.conf create mode 100644 egs2/TEMPLATE/mt1/conf/pitch.conf create mode 100644 egs2/TEMPLATE/mt1/conf/queue.conf create mode 100644 egs2/TEMPLATE/mt1/conf/slurm.conf create mode 120000 egs2/TEMPLATE/mt1/db.sh create mode 100644 egs2/TEMPLATE/mt1/local/path.sh create mode 100755 egs2/TEMPLATE/mt1/mt.sh create mode 100755 egs2/TEMPLATE/mt1/path.sh create mode 120000 egs2/TEMPLATE/mt1/pyscripts create mode 120000 egs2/TEMPLATE/mt1/scripts create mode 100755 egs2/TEMPLATE/mt1/setup.sh create mode 120000 egs2/TEMPLATE/mt1/steps create mode 120000 egs2/TEMPLATE/mt1/utils create mode 100644 egs2/iwslt14/mt1/cmd.sh create mode 100644 egs2/iwslt14/mt1/conf/fbank.conf create mode 100644 egs2/iwslt14/mt1/conf/pbs.conf create mode 100644 egs2/iwslt14/mt1/conf/pitch.conf create mode 100644 egs2/iwslt14/mt1/conf/queue.conf create mode 100644 egs2/iwslt14/mt1/conf/slurm.conf create mode 120000 egs2/iwslt14/mt1/db.sh create mode 100644 egs2/iwslt14/mt1/local/data.sh create mode 100644 egs2/iwslt14/mt1/local/path.sh create mode 120000 egs2/iwslt14/mt1/mt.sh create mode 120000 egs2/iwslt14/mt1/path.sh create mode 120000 egs2/iwslt14/mt1/pyscripts create mode 100755 egs2/iwslt14/mt1/run.sh create mode 120000 egs2/iwslt14/mt1/scripts create mode 120000 egs2/iwslt14/mt1/steps create mode 120000 egs2/iwslt14/mt1/utils create mode 100644 espnet2/mt/__init__.py create mode 100644 espnet2/mt/espnet_model.py diff --git a/egs2/TEMPLATE/mt1/cmd.sh b/egs2/TEMPLATE/mt1/cmd.sh new file mode 100644 index 00000000000..2aae6919fef --- /dev/null +++ b/egs2/TEMPLATE/mt1/cmd.sh @@ -0,0 +1,110 @@ +# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ====== +# Usage: .pl [options] JOB=1: +# e.g. +# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB +# +# Options: +# --time