Skip to content

Commit

Permalink
Remove redundant functions in transformerUNet. Fix some comments in f…
Browse files Browse the repository at this point in the history
…low.py and change the naming of variables.

Signed-off-by: Pin-Jui Ku <pku9@gatech.edu>
  • Loading branch information
Pin-Jui Ku authored and Kuray107 committed Aug 12, 2024
1 parent 47893a9 commit 4b39a1f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 103 deletions.
6 changes: 3 additions & 3 deletions nemo/collections/audio/models/enhancement.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,10 +772,10 @@ def _step(self, target_signal, input_signal, input_length=None):
keep_conditions = einops.rearrange((torch.rand(batch_size) < self.p_cond).float(), 'B -> B 1 1 1')
input_enc = input_enc * keep_conditions.to(input_enc.device)

start_state = torch.zeros_like(input_enc)
x_start = torch.zeros_like(input_enc)

time = self.flow.generate_time(batch_size=batch_size).to(device=input_enc.device)
sample = self.flow.sample(time=time, start_state=start_state, end_state=target_enc)
sample = self.flow.sample(time=time, x_start=x_start, x_end=target_enc)

# we want to get a vector field estimate given current state
# at training time, current state is sampled from the conditional path
Expand All @@ -786,7 +786,7 @@ def _step(self, target_signal, input_signal, input_length=None):
estimate, estimate_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=time)

conditional_vector_field = self.flow.vector_field(
time=time, start_state=start_state, end_state=target_enc, point=sample
time=time, x_start=x_start, x_end=target_enc, point=sample
)

return self.loss(estimate=estimate, target=conditional_vector_field, input_length=input_enc_len)
Expand Down
5 changes: 2 additions & 3 deletions nemo/collections/audio/modules/ssl_pretrain_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SSLPretrainWithMaskedPatch(NeuralModule):
Args:
patch_size (int): up to how many time steps does one patch consist of.
Defaults to 48.
Defaults to 10.
mask_fraction (float): how much fraction in each sample to be masked (number of patches is rounded up).
Range from 0.0 to 1.0. Defaults to 0.7.
"""
Expand Down Expand Up @@ -72,8 +72,7 @@ def forward(self, input_spec, length):
if min_len < self.patch_size * mask_patches:
mask_patches = min_len // self.patch_size

for idx in range(input_spec.shape[0]):
cur_len = length[idx]
for idx, cur_len in enumerate(length.tolist()):
patches = range(cur_len // self.patch_size)
masked_patches = random.sample(patches, mask_patches)
for mp in masked_patches:
Expand Down
106 changes: 37 additions & 69 deletions nemo/collections/audio/parts/submodules/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,25 @@ def __init__(self, time_min: float = 1e-8, time_max: float = 1.0):
self.time_max = time_max

@abstractmethod
def mean(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor:
def mean(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor:
"""
Return the mean of p_t(x | start_state, end_state) at time t
Return the mean of p_t(x | x_start, x_end) at time t
"""
pass

@abstractmethod
def std(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor:
def std(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor:
"""
Return the standard deviation of p_t(x | start_state, end_state) at time t
Return the standard deviation of p_t(x | x_start, x_end) at time t
"""
pass

@abstractmethod
def d_mean(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor:
"""
Return the time derivatives of mean of p_t(x | start_state, end_state) at time t
"""
pass

@abstractmethod
def d_std(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor:
def vector_field(
self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor, point: torch.Tensor
) -> torch.Tensor:
"""
Return the time derivatives of standard deviation of p_t(x | start_state, end_state) at time t
Compute the conditional vector field v_t( point | x_start, x_end)
"""
pass

Expand All @@ -84,70 +79,45 @@ def generate_time(self, batch_size: int) -> torch.Tensor:
"""
return torch.clamp(torch.rand((batch_size,)), self.time_min, self.time_max)

def sample(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor:
def sample(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor:
"""
Generate a sample from p_t(x | start_state, end_state) at time t
Generate a sample from p_t(x | x_start, x_end) at time t.
Note that this implementation assumes all path marginals are normally distributed.
"""
time = self._broadcast_time(time, n_dim=start_state.ndim)
time = self._broadcast_time(time, n_dim=x_start.ndim)

mean = self.mean(time=time, start_state=start_state, end_state=end_state)
std = self.std(time=time, start_state=start_state, end_state=end_state)
mean = self.mean(time=time, x_start=x_start, x_end=x_end)
std = self.std(time=time, x_start=x_start, x_end=x_end)
return mean + std * torch.randn_like(mean)

def vector_field(
self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor
) -> torch.Tensor:
"""
Compute the conditional vector field v_t( point | start_state, end_state)
"""
# vector field conditioned on `start_state`, `end_state`
# and evaluated at `point`
# !!! general form, may cause numerical issues
time = self._broadcast_time(time, n_dim=start_state.ndim)

mean = self.mean(time=time, start_state=start_state, end_state=end_state)
std = self.std(time=time, start_state=start_state, end_state=end_state)
d_mean = self.d_mean(time=time, start_state=start_state, end_state=end_state)
d_std = self.d_std(time=time, start_state=start_state, end_state=end_state)
return d_std * (point - mean) / std + d_mean

def flow(
self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor
self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor, point: torch.Tensor
) -> torch.Tensor:
"""
Compute the conditional flow phi_t( point | start_state, end_state)
Compute the conditional flow phi_t( point | x_start, x_end).
This is an affine flow.
"""
mean = self.mean(time=time, start_state=start_state, end_state=end_state)
std = self.std(time=time, start_state=start_state, end_state=end_state)
return mean + std * (point - start_state)

def d_flow(
self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor
) -> torch.Tensor:
"""
Compute the time derivatives of conditional flow
"""
d_mean = self.d_mean(time=time, start_state=start_state, end_state=end_state)
d_std = self.d_std(time=time, start_state=start_state, end_state=end_state)
return d_mean + d_std * (point - start_state)
mean = self.mean(time=time, x_start=x_start, x_end=x_end)
std = self.std(time=time, x_start=x_start, x_end=x_end)
return mean + std * (point - x_start)


class OptimalTransportFlow(ConditionalFlow):
"""The OT-CFM model from [Lipman et at, 2023]
Every conditional path the following holds:
p_0 = N(start_state, sigma_start)
p_1 = N(end_state, sigma_end),
p_0 = N(x_start, sigma_start)
p_1 = N(x_end, sigma_end),
mean(x, t) = (time_max - t) * start_state + t * end_state
(linear interpolation between start_state and end_state)
mean(x, t) = (time_max - t) * x_start + t * x_end
(linear interpolation between x_start and x_end)
std(x, t) = (time_max - t) * sigma_start + t * sigma_end
Every conditional path is optimal transport map from p_0(start_state, end_state) to p_1(start_state, end_state)
Every conditional path is optimal transport map from p_0(x_start, x_end) to p_1(x_start, x_end)
Marginal path is not guaranteed to be an optimal transport map from p_0 to p_1
To get the OT-CFM model from [Lipman et at, 2023] just pass zeroes for start_state
To get the OT-CFM model from [Lipman et at, 2023] just pass zeroes for x_start
To get the I-CFM model, set sigma_min=sigma_max
To get the rectified flow model, set sigma_min=sigma_max=0
Expand All @@ -171,33 +141,33 @@ def __init__(
logging.debug('\tsgima_start: %s', self.sigma_start)
logging.debug('\tsigma_end: %s', self.sigma_end)

def mean(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
return (self.time_max - time) * start_state + time * end_state
def mean(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
return (self.time_max - time) * x_start + time * x_end

def std(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
def std(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
return (self.time_max - time) * self.sigma_start + time * self.sigma_end

def d_mean(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
return end_state - start_state
def d_mean(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
return x_end - x_start

def d_std(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
def d_std(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
return self.sigma_end - self.sigma_start

def vector_field(
self,
*,
start_state: torch.Tensor,
end_state: torch.Tensor,
x_start: torch.Tensor,
x_end: torch.Tensor,
time: torch.Tensor,
point: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
time = self._broadcast_time(time, n_dim=start_state.ndim)
time = self._broadcast_time(time, n_dim=x_start.ndim)

if self.sigma_start == self.sigma_end:
return end_state - start_state
return x_end - x_start

num = self.sigma_end * (point - start_state) - self.sigma_start * (point - end_state)
num = self.sigma_end * (point - x_start) - self.sigma_start * (point - x_end)
denom = (1 - time) * self.sigma_start + time * self.sigma_end
return num / (denom + eps)

Expand Down Expand Up @@ -287,6 +257,4 @@ def forward(
if state_length is not None:
state = mask_sequence_tensor(state, state_length)

if state_length is not None:
state = mask_sequence_tensor(state, state_length)
return state, state_length
53 changes: 25 additions & 28 deletions nemo/collections/audio/parts/submodules/transformerunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,12 @@
__all__ = ['TransformerUNet']


def exists(val):
return val is not None


def default(val, d):
return val if exists(val) else d


def divisible_by(num, den):
return (num % den) == 0


class LearnedSinusoidalPosEmb(Module):
"""The sinusoidal Embedding to encode time conditional information"""

def __init__(self, dim: int):
super().__init__()
if not divisible_by(dim, 2):
if (dim % 2) != 0:
raise ValueError(f"Input dimension {dim} is not divisible by 2!")
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
Expand All @@ -92,10 +80,11 @@ class ConvPositionEmbed(Module):

def __init__(self, dim: int, kernel_size: int, groups: Optional[int] = None):
super().__init__()
if divisible_by(kernel_size, 2):
if (kernel_size % 2) == 0:
raise ValueError(f"Kernel size {kernel_size} is divisible by 2!")

groups = default(groups, dim) # full depthwise conv by default
if groups is None:
groups = dim

self.dw_conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), nn.GELU()
Expand All @@ -110,15 +99,15 @@ def forward(self, x, mask=None):
out: output tensor with the same shape (B, T, D)
"""

if exists(mask):
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(mask, 0.0)

x = einops.rearrange(x, 'b n c -> b c n')
x = self.dw_conv1d(x)
out = einops.rearrange(x, 'b c n -> b n c')

if exists(mask):
if mask is not None:
out = out.masked_fill(mask, 0.0)

return out
Expand Down Expand Up @@ -148,7 +137,8 @@ class AdaptiveRMSNorm(Module):

def __init__(self, dim: int, cond_dim: Optional[int] = None):
super().__init__()
cond_dim = default(cond_dim, dim)
if cond_dim is None:
cond_dim = dim
self.scale = dim**0.5

self.to_gamma = nn.Linear(cond_dim, dim)
Expand All @@ -166,7 +156,8 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor):
normed = F.normalize(x, dim=-1) * self.scale

gamma, beta = self.to_gamma(cond), self.to_beta(cond)
gamma, beta = map(lambda t: einops.rearrange(t, 'b d -> b 1 d'), (gamma, beta))
gamma = einops.rearrange(gamma, 'B D -> B 1 D')
beta = einops.rearrange(beta, 'B D -> B 1 D')

return normed * gamma + beta

Expand All @@ -179,7 +170,7 @@ def forward(self, x: torch.Tensor):
return F.gelu(gate) * x


def FeedForward(dim: int, mult: int = 4, dropout: float = 0.0):
def get_feedforward_layer(dim: int, mult: int = 4, dropout: float = 0.0):
"""
Return a Feed-Forward layer for the Transformer Layer.
GeGLU activation is used in this FF layer
Expand Down Expand Up @@ -228,7 +219,8 @@ def __init__(
skip_connect_scale: The scale of the U-Net connection.
"""
super().__init__()
assert divisible_by(depth, 2)
if (depth % 2) != 0:
raise ValueError(f"Number of layers {depth} is not divisible by 2!")
self.layers = nn.ModuleList([])
self.init_alibi(max_positions=max_positions, heads=heads)

Expand All @@ -237,7 +229,10 @@ def __init__(
else:
rmsnorm_class = RMSNorm

self.skip_connect_scale = default(skip_connect_scale, 2**-0.5)
if skip_connect_scale is None:
self.skip_connect_scale = 2**-0.5
else:
self.skip_connect_scale = skip_connect_scale

for ind in range(depth):
layer = ind + 1
Expand All @@ -255,7 +250,7 @@ def __init__(
batch_first=True,
),
rmsnorm_class(dim=dim),
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout),
get_feedforward_layer(dim=dim, mult=ff_mult, dropout=ff_dropout),
]
)
)
Expand Down Expand Up @@ -335,20 +330,20 @@ def forward(self, x, key_padding_mask: Optional[torch.Tensor] = None, adaptive_r
alibi_bias = self.get_alibi_bias(batch_size=batch_size, seq_len=seq_len)

rmsnorm_kwargs = dict()
if exists(adaptive_rmsnorm_cond):
if adaptive_rmsnorm_cond is not None:
rmsnorm_kwargs = dict(cond=adaptive_rmsnorm_cond)

for skip_combiner, attn_prenorm, attn, ff_prenorm, ff in self.layers:

if not exists(skip_combiner):
if skip_combiner is None:
skip_connects.append(x)
else:
skip_connect = skip_connects.pop() * self.skip_connect_scale
x = torch.cat((x, skip_connect), dim=-1)
x = skip_combiner(x)

attn_input = attn_prenorm(x, **rmsnorm_kwargs)
if exists(key_padding_mask):
if key_padding_mask is not None:
# Since Alibi_bias is a float-type attn_mask, the padding_mask need to be float-type.
float_key_padding_mask = key_padding_mask.float()
float_key_padding_mask = float_key_padding_mask.masked_fill(key_padding_mask, float('-inf'))
Expand Down Expand Up @@ -410,7 +405,9 @@ def __init__(
self.out_channels = out_channels
dim_in = freq_dim * in_channels * 2

time_hidden_dim = default(time_hidden_dim, dim * 4)
if time_hidden_dim is None:
time_hidden_dim = dim * 4

self.proj_in = nn.Linear(dim_in, dim)

self.sinu_pos_emb = nn.Sequential(LearnedSinusoidalPosEmb(dim), nn.Linear(dim, time_hidden_dim), nn.SiLU())
Expand Down Expand Up @@ -496,7 +493,7 @@ def forward(self, input, input_length=None, condition=None):
key_padding_mask = self._get_key_padding_mask(input_length, max_length=T)
x = self.conv_embed(x, mask=key_padding_mask) + x

if not exists(condition):
if condition is None:
raise NotImplementedError

time_emb = self.sinu_pos_emb(condition)
Expand Down

0 comments on commit 4b39a1f

Please sign in to comment.