diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index 58ed7e3cfcefc..9fcd501b4f0c0 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -772,10 +772,10 @@ def _step(self, target_signal, input_signal, input_length=None): keep_conditions = einops.rearrange((torch.rand(batch_size) < self.p_cond).float(), 'B -> B 1 1 1') input_enc = input_enc * keep_conditions.to(input_enc.device) - start_state = torch.zeros_like(input_enc) + x_start = torch.zeros_like(input_enc) time = self.flow.generate_time(batch_size=batch_size).to(device=input_enc.device) - sample = self.flow.sample(time=time, start_state=start_state, end_state=target_enc) + sample = self.flow.sample(time=time, x_start=x_start, x_end=target_enc) # we want to get a vector field estimate given current state # at training time, current state is sampled from the conditional path @@ -786,7 +786,7 @@ def _step(self, target_signal, input_signal, input_length=None): estimate, estimate_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=time) conditional_vector_field = self.flow.vector_field( - time=time, start_state=start_state, end_state=target_enc, point=sample + time=time, x_start=x_start, x_end=target_enc, point=sample ) return self.loss(estimate=estimate, target=conditional_vector_field, input_length=input_enc_len) diff --git a/nemo/collections/audio/modules/ssl_pretrain_masking.py b/nemo/collections/audio/modules/ssl_pretrain_masking.py index 10b407547a01e..1f72a78e5731a 100644 --- a/nemo/collections/audio/modules/ssl_pretrain_masking.py +++ b/nemo/collections/audio/modules/ssl_pretrain_masking.py @@ -30,7 +30,7 @@ class SSLPretrainWithMaskedPatch(NeuralModule): Args: patch_size (int): up to how many time steps does one patch consist of. - Defaults to 48. + Defaults to 10. mask_fraction (float): how much fraction in each sample to be masked (number of patches is rounded up). Range from 0.0 to 1.0. Defaults to 0.7. """ @@ -72,8 +72,7 @@ def forward(self, input_spec, length): if min_len < self.patch_size * mask_patches: mask_patches = min_len // self.patch_size - for idx in range(input_spec.shape[0]): - cur_len = length[idx] + for idx, cur_len in enumerate(length.tolist()): patches = range(cur_len // self.patch_size) masked_patches = random.sample(patches, mask_patches) for mp in masked_patches: diff --git a/nemo/collections/audio/parts/submodules/flow.py b/nemo/collections/audio/parts/submodules/flow.py index 272708159422d..331e527fb644e 100644 --- a/nemo/collections/audio/parts/submodules/flow.py +++ b/nemo/collections/audio/parts/submodules/flow.py @@ -40,30 +40,25 @@ def __init__(self, time_min: float = 1e-8, time_max: float = 1.0): self.time_max = time_max @abstractmethod - def mean(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor: + def mean(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor: """ - Return the mean of p_t(x | start_state, end_state) at time t + Return the mean of p_t(x | x_start, x_end) at time t """ pass @abstractmethod - def std(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor: + def std(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor: """ - Return the standard deviation of p_t(x | start_state, end_state) at time t + Return the standard deviation of p_t(x | x_start, x_end) at time t """ pass @abstractmethod - def d_mean(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor: - """ - Return the time derivatives of mean of p_t(x | start_state, end_state) at time t - """ - pass - - @abstractmethod - def d_std(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor: + def vector_field( + self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor, point: torch.Tensor + ) -> torch.Tensor: """ - Return the time derivatives of standard deviation of p_t(x | start_state, end_state) at time t + Compute the conditional vector field v_t( point | x_start, x_end) """ pass @@ -84,70 +79,45 @@ def generate_time(self, batch_size: int) -> torch.Tensor: """ return torch.clamp(torch.rand((batch_size,)), self.time_min, self.time_max) - def sample(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor: + def sample(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor: """ - Generate a sample from p_t(x | start_state, end_state) at time t + Generate a sample from p_t(x | x_start, x_end) at time t. + Note that this implementation assumes all path marginals are normally distributed. """ - time = self._broadcast_time(time, n_dim=start_state.ndim) + time = self._broadcast_time(time, n_dim=x_start.ndim) - mean = self.mean(time=time, start_state=start_state, end_state=end_state) - std = self.std(time=time, start_state=start_state, end_state=end_state) + mean = self.mean(time=time, x_start=x_start, x_end=x_end) + std = self.std(time=time, x_start=x_start, x_end=x_end) return mean + std * torch.randn_like(mean) - def vector_field( - self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor - ) -> torch.Tensor: - """ - Compute the conditional vector field v_t( point | start_state, end_state) - """ - # vector field conditioned on `start_state`, `end_state` - # and evaluated at `point` - # !!! general form, may cause numerical issues - time = self._broadcast_time(time, n_dim=start_state.ndim) - - mean = self.mean(time=time, start_state=start_state, end_state=end_state) - std = self.std(time=time, start_state=start_state, end_state=end_state) - d_mean = self.d_mean(time=time, start_state=start_state, end_state=end_state) - d_std = self.d_std(time=time, start_state=start_state, end_state=end_state) - return d_std * (point - mean) / std + d_mean - def flow( - self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor + self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor, point: torch.Tensor ) -> torch.Tensor: """ - Compute the conditional flow phi_t( point | start_state, end_state) + Compute the conditional flow phi_t( point | x_start, x_end). + This is an affine flow. """ - mean = self.mean(time=time, start_state=start_state, end_state=end_state) - std = self.std(time=time, start_state=start_state, end_state=end_state) - return mean + std * (point - start_state) - - def d_flow( - self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor - ) -> torch.Tensor: - """ - Compute the time derivatives of conditional flow - """ - d_mean = self.d_mean(time=time, start_state=start_state, end_state=end_state) - d_std = self.d_std(time=time, start_state=start_state, end_state=end_state) - return d_mean + d_std * (point - start_state) + mean = self.mean(time=time, x_start=x_start, x_end=x_end) + std = self.std(time=time, x_start=x_start, x_end=x_end) + return mean + std * (point - x_start) class OptimalTransportFlow(ConditionalFlow): """The OT-CFM model from [Lipman et at, 2023] Every conditional path the following holds: - p_0 = N(start_state, sigma_start) - p_1 = N(end_state, sigma_end), + p_0 = N(x_start, sigma_start) + p_1 = N(x_end, sigma_end), - mean(x, t) = (time_max - t) * start_state + t * end_state - (linear interpolation between start_state and end_state) + mean(x, t) = (time_max - t) * x_start + t * x_end + (linear interpolation between x_start and x_end) std(x, t) = (time_max - t) * sigma_start + t * sigma_end - Every conditional path is optimal transport map from p_0(start_state, end_state) to p_1(start_state, end_state) + Every conditional path is optimal transport map from p_0(x_start, x_end) to p_1(x_start, x_end) Marginal path is not guaranteed to be an optimal transport map from p_0 to p_1 - To get the OT-CFM model from [Lipman et at, 2023] just pass zeroes for start_state + To get the OT-CFM model from [Lipman et at, 2023] just pass zeroes for x_start To get the I-CFM model, set sigma_min=sigma_max To get the rectified flow model, set sigma_min=sigma_max=0 @@ -171,33 +141,33 @@ def __init__( logging.debug('\tsgima_start: %s', self.sigma_start) logging.debug('\tsigma_end: %s', self.sigma_end) - def mean(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor) -> torch.Tensor: - return (self.time_max - time) * start_state + time * end_state + def mean(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + return (self.time_max - time) * x_start + time * x_end - def std(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + def std(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor: return (self.time_max - time) * self.sigma_start + time * self.sigma_end - def d_mean(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor) -> torch.Tensor: - return end_state - start_state + def d_mean(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + return x_end - x_start - def d_std(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + def d_std(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor: return self.sigma_end - self.sigma_start def vector_field( self, *, - start_state: torch.Tensor, - end_state: torch.Tensor, + x_start: torch.Tensor, + x_end: torch.Tensor, time: torch.Tensor, point: torch.Tensor, eps: float = 1e-6, ) -> torch.Tensor: - time = self._broadcast_time(time, n_dim=start_state.ndim) + time = self._broadcast_time(time, n_dim=x_start.ndim) if self.sigma_start == self.sigma_end: - return end_state - start_state + return x_end - x_start - num = self.sigma_end * (point - start_state) - self.sigma_start * (point - end_state) + num = self.sigma_end * (point - x_start) - self.sigma_start * (point - x_end) denom = (1 - time) * self.sigma_start + time * self.sigma_end return num / (denom + eps) @@ -287,6 +257,4 @@ def forward( if state_length is not None: state = mask_sequence_tensor(state, state_length) - if state_length is not None: - state = mask_sequence_tensor(state, state_length) return state, state_length diff --git a/nemo/collections/audio/parts/submodules/transformerunet.py b/nemo/collections/audio/parts/submodules/transformerunet.py index 1992daea85a87..86754038b7b55 100644 --- a/nemo/collections/audio/parts/submodules/transformerunet.py +++ b/nemo/collections/audio/parts/submodules/transformerunet.py @@ -51,24 +51,12 @@ __all__ = ['TransformerUNet'] -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -def divisible_by(num, den): - return (num % den) == 0 - - class LearnedSinusoidalPosEmb(Module): """The sinusoidal Embedding to encode time conditional information""" def __init__(self, dim: int): super().__init__() - if not divisible_by(dim, 2): + if (dim % 2) != 0: raise ValueError(f"Input dimension {dim} is not divisible by 2!") half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) @@ -92,10 +80,11 @@ class ConvPositionEmbed(Module): def __init__(self, dim: int, kernel_size: int, groups: Optional[int] = None): super().__init__() - if divisible_by(kernel_size, 2): + if (kernel_size % 2) == 0: raise ValueError(f"Kernel size {kernel_size} is divisible by 2!") - groups = default(groups, dim) # full depthwise conv by default + if groups is None: + groups = dim self.dw_conv1d = nn.Sequential( nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), nn.GELU() @@ -110,7 +99,7 @@ def forward(self, x, mask=None): out: output tensor with the same shape (B, T, D) """ - if exists(mask): + if mask is not None: mask = mask[..., None] x = x.masked_fill(mask, 0.0) @@ -118,7 +107,7 @@ def forward(self, x, mask=None): x = self.dw_conv1d(x) out = einops.rearrange(x, 'b c n -> b n c') - if exists(mask): + if mask is not None: out = out.masked_fill(mask, 0.0) return out @@ -148,7 +137,8 @@ class AdaptiveRMSNorm(Module): def __init__(self, dim: int, cond_dim: Optional[int] = None): super().__init__() - cond_dim = default(cond_dim, dim) + if cond_dim is None: + cond_dim = dim self.scale = dim**0.5 self.to_gamma = nn.Linear(cond_dim, dim) @@ -166,7 +156,8 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor): normed = F.normalize(x, dim=-1) * self.scale gamma, beta = self.to_gamma(cond), self.to_beta(cond) - gamma, beta = map(lambda t: einops.rearrange(t, 'b d -> b 1 d'), (gamma, beta)) + gamma = einops.rearrange(gamma, 'B D -> B 1 D') + beta = einops.rearrange(beta, 'B D -> B 1 D') return normed * gamma + beta @@ -179,7 +170,7 @@ def forward(self, x: torch.Tensor): return F.gelu(gate) * x -def FeedForward(dim: int, mult: int = 4, dropout: float = 0.0): +def get_feedforward_layer(dim: int, mult: int = 4, dropout: float = 0.0): """ Return a Feed-Forward layer for the Transformer Layer. GeGLU activation is used in this FF layer @@ -228,7 +219,8 @@ def __init__( skip_connect_scale: The scale of the U-Net connection. """ super().__init__() - assert divisible_by(depth, 2) + if (depth % 2) != 0: + raise ValueError(f"Number of layers {depth} is not divisible by 2!") self.layers = nn.ModuleList([]) self.init_alibi(max_positions=max_positions, heads=heads) @@ -237,7 +229,10 @@ def __init__( else: rmsnorm_class = RMSNorm - self.skip_connect_scale = default(skip_connect_scale, 2**-0.5) + if skip_connect_scale is None: + self.skip_connect_scale = 2**-0.5 + else: + self.skip_connect_scale = skip_connect_scale for ind in range(depth): layer = ind + 1 @@ -255,7 +250,7 @@ def __init__( batch_first=True, ), rmsnorm_class(dim=dim), - FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout), + get_feedforward_layer(dim=dim, mult=ff_mult, dropout=ff_dropout), ] ) ) @@ -335,12 +330,12 @@ def forward(self, x, key_padding_mask: Optional[torch.Tensor] = None, adaptive_r alibi_bias = self.get_alibi_bias(batch_size=batch_size, seq_len=seq_len) rmsnorm_kwargs = dict() - if exists(adaptive_rmsnorm_cond): + if adaptive_rmsnorm_cond is not None: rmsnorm_kwargs = dict(cond=adaptive_rmsnorm_cond) for skip_combiner, attn_prenorm, attn, ff_prenorm, ff in self.layers: - if not exists(skip_combiner): + if skip_combiner is None: skip_connects.append(x) else: skip_connect = skip_connects.pop() * self.skip_connect_scale @@ -348,7 +343,7 @@ def forward(self, x, key_padding_mask: Optional[torch.Tensor] = None, adaptive_r x = skip_combiner(x) attn_input = attn_prenorm(x, **rmsnorm_kwargs) - if exists(key_padding_mask): + if key_padding_mask is not None: # Since Alibi_bias is a float-type attn_mask, the padding_mask need to be float-type. float_key_padding_mask = key_padding_mask.float() float_key_padding_mask = float_key_padding_mask.masked_fill(key_padding_mask, float('-inf')) @@ -410,7 +405,9 @@ def __init__( self.out_channels = out_channels dim_in = freq_dim * in_channels * 2 - time_hidden_dim = default(time_hidden_dim, dim * 4) + if time_hidden_dim is None: + time_hidden_dim = dim * 4 + self.proj_in = nn.Linear(dim_in, dim) self.sinu_pos_emb = nn.Sequential(LearnedSinusoidalPosEmb(dim), nn.Linear(dim, time_hidden_dim), nn.SiLU()) @@ -496,7 +493,7 @@ def forward(self, input, input_length=None, condition=None): key_padding_mask = self._get_key_padding_mask(input_length, max_length=T) x = self.conv_embed(x, mask=key_padding_mask) + x - if not exists(condition): + if condition is None: raise NotImplementedError time_emb = self.sinu_pos_emb(condition)