Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Romeyn <marcromeyn@gmail.com>
  • Loading branch information
marcromeyn committed Apr 22, 2024
1 parent 323b071 commit ef2dd78
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def forward(
if microbatch_outputs:
self.callbacks.event("on_megatron_reduce_microbatches_start", **context)

# TODO: Can this lead to issues?
if isinstance(_loss_reduction, _ModuleStepFunction):
_loss_reduction = _loss_reduction(self[0])

Expand All @@ -251,8 +250,6 @@ def forward(
else:
loss_mean = torch.tensor(0.0).cuda()

# TODO: How to handle all-reduce callbacks?

self.callbacks.event("on_megatron_log_step_end", **context)
self.callbacks.event("on_megatron_step_end", **context)

Expand Down Expand Up @@ -310,8 +307,6 @@ def wrapped_forward_step_func(dataloader_iter, model):

output_tensor = _forward_step(model, batch)

# TODO: handle forward_post

# callback
self._setup_module(
forward_callback, batch=batch, model=self, forward_module=model, tensor=output_tensor,
Expand Down

0 comments on commit ef2dd78

Please sign in to comment.