Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove compute_on_step from metrics #6981

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ def validation_epoch_end(self, outputs):
def __init__(
self, decoding: RNNTDecoding, batch_dim_index=0, use_cer=False, log_prediction=True, dist_sync_on_step=False
):
super(RNNTWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
super(RNNTWER, self).__init__(dist_sync_on_step=dist_sync_on_step)
self.decoding = decoding
self.batch_dim_index = batch_dim_index
self.use_cer = use_cer
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def __init__(
log_prediction: bool = True,
dist_sync_on_step=False,
):
super(RNNTBPEWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
super(RNNTBPEWER, self).__init__(dist_sync_on_step=dist_sync_on_step)
self.decoding = decoding
self.batch_dim_index = batch_dim_index
self.use_cer = use_cer
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ def __init__(
fold_consecutive=True,
dist_sync_on_step=False,
):
super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.decoding = decoding
self.use_cer = use_cer
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __init__(
fold_consecutive=True,
dist_sync_on_step=False,
):
super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.decoding = decoding
self.tokenizer = self.decoding.tokenizer
self.blank_id = self.decoding.tokenizer.tokenizer.vocab_size
Expand Down
9 changes: 2 additions & 7 deletions nemo/collections/common/metrics/global_average_loss_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ class GlobalAverageLossMetric(Metric):
See :doc:`PyTorch Lightning Metrics<pytorch-lightning:metrics>` for the metric usage instruction.

Args:
compute_on_step:
The method :meth:`forward` only calls ``update()`` and returns ``None`` if this is set to ``False``.
default: ``True``
dist_sync_on_step:
Synchronize metric state across processes at each method :meth:`forward` call before returning the
value at the step
Expand All @@ -44,10 +41,8 @@ class GlobalAverageLossMetric(Metric):

full_state_update = True

def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True):
super().__init__(
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group
)
def __init__(self, dist_sync_on_step=False, process_group=None, take_avg_loss=True):
super().__init__(dist_sync_on_step=dist_sync_on_step, process_group=process_group)
self.add_state("loss_sum", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum')
self.add_state("num_measurements", torch.tensor(0, dtype=torch.int64), dist_reduce_fx='sum')
self.take_avg_loss = take_avg_loss
Expand Down
8 changes: 2 additions & 6 deletions nemo/collections/common/metrics/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class Perplexity(Metric):
See `PyTorch Lightning Metrics <https://pytorch-lightning.readthedocs.io/en/stable/ecosystem/metrics.html>`_ for the metric usage instructions.

Args:
compute_on_step:
Forward only calls ``update()`` and returns ``None`` if this is set to ``False``. default: ``True``
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -44,10 +42,8 @@ class Perplexity(Metric):

full_state_update = True

def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, validate_args=True):
super().__init__(
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group
)
def __init__(self, dist_sync_on_step=False, process_group=None, validate_args=True):
super().__init__(dist_sync_on_step=dist_sync_on_step, process_group=process_group)
self.validate_args = validate_args
self.add_state('perplexities_sum', torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum')
# Total number of distributions seen since last reset
Expand Down
9 changes: 2 additions & 7 deletions nemo/collections/nlp/metrics/sequence_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ class SequencePerplexity(Metric):
See :doc:`PyTorch Lightning Metrics<pytorch-lightning:metrics>` for the metric usage instructions.

Args:
compute_on_step:
Forward only calls ``update()`` and returns ``None`` if this is set to ``False``. default: ``True``
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()`` before returning the value at the step.
process_group:
Expand All @@ -43,12 +41,9 @@ class SequencePerplexity(Metric):
to perform the allgather.
"""

def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None):
def __init__(self, dist_sync_on_step=False, process_group=None, dist_sync_fn=None):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn,
)

# Total sum of exponentiated average negative log likelihoods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# create extra bias

# setup to track metrics
self.validation_perplexity = Perplexity(compute_on_step=False)
self.validation_perplexity = Perplexity()

self.setup_optimization(cfg.optim)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
decoder=cfg.language_model.pretrained_decoder_model_name,
)

self.validation_perplexity = Perplexity(compute_on_step=False)
self.validation_perplexity = Perplexity()

self.setup_optimization(cfg.optim)

Expand Down
1 change: 0 additions & 1 deletion nemo/core/optim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
AVAILABLE_OPTIMIZERS['fused_adam'] = FusedAdam
except ModuleNotFoundError:
HAVE_APEX = False
logging.warning("Apex was not found. Using the lamb or fused_adam optimizer will error out.")

HAVE_APEX_DISTRIBUTED_ADAM = False
if HAVE_APEX:
Expand Down
8 changes: 3 additions & 5 deletions tests/collections/common/pl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _class_test(
calculated across devices for each batch (and not just at the end)
"""
# Instanciate lightning metric
metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args)
metric = metric_class(dist_sync_on_step=dist_sync_on_step, **metric_args)

# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
Expand Down Expand Up @@ -303,7 +303,7 @@ def _perplexity_class_test(
calculated across devices for each batch (and not just at the end)
"""
# Instanciate lightning metric
perplexity = Perplexity(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args)
perplexity = Perplexity(dist_sync_on_step=dist_sync_on_step, **metric_args)
if (probs is None) == (logits is None):
with pytest.raises(ValueError):
perplexity(probs, logits)
Expand Down Expand Up @@ -464,9 +464,7 @@ def _loss_class_test(
calculated across devices for each batch (and not just at the end)
"""
# Instantiate lightning metric
loss_metric = GlobalAverageLossMetric(
compute_on_step=True, dist_sync_on_step=dist_sync_on_step, take_avg_loss=take_avg_loss
)
loss_metric = GlobalAverageLossMetric(dist_sync_on_step=dist_sync_on_step, take_avg_loss=take_avg_loss)

# verify loss works after being loaded from pickled state
pickled_metric = pickle.dumps(loss_metric)
Expand Down