Skip to content

Commit

Permalink
Compute torch.angle of STFT only when needed
Browse files Browse the repository at this point in the history
This seems to be an expensive operation, at least on some architectures
  • Loading branch information
simonschwaer committed Dec 19, 2023
1 parent bba3748 commit cf7b090
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions auraloss/freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def __init__(
self.mag_distance = mag_distance
self.device = device

self.phs_used = bool(self.w_phs)

self.spectralconv = SpectralConvergenceLoss()
self.logstft = STFTMagnitudeLoss(
log=True,
Expand Down Expand Up @@ -218,7 +220,13 @@ def stft(self, x):
return_complex=True,
)
x_mag = torch.clamp(torch.abs(x_stft), min=self.eps)
x_phs = torch.angle(x_stft)

# torch.angle is expensive, so it is only evaluated if the values are used in the loss
if self.phs_used:
x_phs = torch.angle(x_stft)
else:
x_phs = None

return x_mag, x_phs

def forward(self, input: torch.Tensor, target: torch.Tensor):
Expand All @@ -239,6 +247,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):

# compute the magnitude and phase spectra of input and target
self.window = self.window.to(input.device)

x_mag, x_phs = self.stft(input.view(-1, input.size(-1)))
y_mag, y_phs = self.stft(target.view(-1, target.size(-1)))

Expand All @@ -257,7 +266,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0
phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.w_phs else 0.0
phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0

# combine loss terms
loss = (
Expand Down

0 comments on commit cf7b090

Please sign in to comment.