Skip to content

Commit

Permalink
feat(train): add optional accumulate_grad_batches config param (#306)
Browse files Browse the repository at this point in the history
* feat(train): add accumulate_grad_batches hparam

Add an `accumulate_grad_batches` param to the `train` part
of the config to allow for gradient accumulation.
This updates the gradients once every `accumulate_grad_batches`
batches, with a default value of 1 to not break any existing configs.

* fix(train): normalize loss when using gradient accumulation
  • Loading branch information
guranon authored Apr 13, 2023
1 parent 0f6794a commit 1172b23
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,13 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
}
)

accumulate_grad_batches = self.hparams.train.get("accumulate_grad_batches", 1)
should_update = (batch_idx + 1) % accumulate_grad_batches == 0
# optimizer
optim_g.zero_grad()
self.manual_backward(loss_gen_all)
optim_g.step()
self.manual_backward(loss_gen_all / accumulate_grad_batches)
if should_update:
optim_g.step()
optim_g.zero_grad()
self.untoggle_optimizer(optim_g)

# Discriminator
Expand All @@ -417,9 +420,10 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
)

# optimizer
optim_d.zero_grad()
self.manual_backward(loss_disc_all)
optim_d.step()
self.manual_backward(loss_disc_all / accumulate_grad_batches)
if should_update:
optim_d.step()
optim_d.zero_grad()
self.untoggle_optimizer(optim_d)

# end of epoch
Expand Down

0 comments on commit 1172b23

Please sign in to comment.