Skip to content

Commit

Permalink
[bug] Quick fix for XUMX in torch 2.0 (#684)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidDiazGuerra authored Mar 29, 2024
1 parent 08ffc2a commit d10b407
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions asteroid/models/x_umx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@


class XUMX(BaseModel):
def __init__(self, *args, **kwargs):
raise RuntimeError(
"XUMX is broken in torch 2.0, use torch<2.0 with asteroid<0.7 to use it until it's fixed."
)


class BrokenXUMX(BaseModel):
r"""CrossNet-Open-Unmix (X-UMX) for Music Source Separation introduced in [1].
There are two notable contributions with no effect on inference:
a) Multi Domain Losses
Expand Down Expand Up @@ -352,8 +345,9 @@ def forward(self, x):
normalized=False,
onesided=True,
pad_mode="reflect",
return_complex=False,
return_complex=True,
)
stft_f = torch.view_as_real(stft_f)

# reshape back to channel dimension
stft_f = stft_f.contiguous().view(nb_samples, nb_channels, self.n_fft // 2 + 1, -1, 2)
Expand Down Expand Up @@ -405,6 +399,7 @@ def forward(self, spec, ang):
x_i = spec * torch.sin(ang)
x = torch.stack([x_r, x_i], dim=-1)
x = x.view(sources * bsize * channels, fbins, frames, 2)
x = torch.view_as_complex(x)
wav = torch.istft(
x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=self.center
)
Expand Down

0 comments on commit d10b407

Please sign in to comment.