From 3e71064019bb9431af8e618723890ee79182446c Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Wed, 5 Jul 2023 14:11:49 -0700 Subject: [PATCH] Remove `compute_on_step` from metrics (#6979) * Remove `compute_on_step` from metrics Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove confusing log message Signed-off-by: smajumdar * Update tests Signed-off-by: smajumdar --------- Signed-off-by: smajumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nemo/collections/asr/metrics/rnnt_wer.py | 2 +- nemo/collections/asr/metrics/rnnt_wer_bpe.py | 2 +- nemo/collections/asr/metrics/wer.py | 2 +- nemo/collections/asr/metrics/wer_bpe.py | 2 +- .../common/metrics/global_average_loss_metric.py | 9 ++------- nemo/collections/common/metrics/perplexity.py | 8 ++------ nemo/collections/nlp/metrics/sequence_perplexity.py | 9 ++------- .../nlp/models/language_modeling/bert_lm_model.py | 2 +- .../nlp/models/text2sparql/text2sparql_model.py | 2 +- nemo/core/optim/optimizers.py | 1 - tests/collections/common/pl_utils.py | 8 +++----- 11 files changed, 15 insertions(+), 32 deletions(-) diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 55f9f4b5ea9f..7e5636191a1d 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -1224,7 +1224,7 @@ def validation_epoch_end(self, outputs): def __init__( self, decoding: RNNTDecoding, batch_dim_index=0, use_cer=False, log_prediction=True, dist_sync_on_step=False ): - super(RNNTWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) + super(RNNTWER, self).__init__(dist_sync_on_step=dist_sync_on_step) self.decoding = decoding self.batch_dim_index = batch_dim_index self.use_cer = use_cer diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index 0870eb180776..d2e2c3cc5923 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -359,7 +359,7 @@ def __init__( log_prediction: bool = True, dist_sync_on_step=False, ): - super(RNNTBPEWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) + super(RNNTBPEWER, self).__init__(dist_sync_on_step=dist_sync_on_step) self.decoding = decoding self.batch_dim_index = batch_dim_index self.use_cer = use_cer diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index 7f7f853d307d..4d90810cc3df 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -1125,7 +1125,7 @@ def __init__( fold_consecutive=True, dist_sync_on_step=False, ): - super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) + super().__init__(dist_sync_on_step=dist_sync_on_step) self.decoding = decoding self.use_cer = use_cer diff --git a/nemo/collections/asr/metrics/wer_bpe.py b/nemo/collections/asr/metrics/wer_bpe.py index 762acf172a16..8a92e4745a1b 100644 --- a/nemo/collections/asr/metrics/wer_bpe.py +++ b/nemo/collections/asr/metrics/wer_bpe.py @@ -247,7 +247,7 @@ def __init__( fold_consecutive=True, dist_sync_on_step=False, ): - super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) + super().__init__(dist_sync_on_step=dist_sync_on_step) self.decoding = decoding self.tokenizer = self.decoding.tokenizer self.blank_id = self.decoding.tokenizer.tokenizer.vocab_size diff --git a/nemo/collections/common/metrics/global_average_loss_metric.py b/nemo/collections/common/metrics/global_average_loss_metric.py index fae1dbfea5e8..3bbd4d13abf4 100644 --- a/nemo/collections/common/metrics/global_average_loss_metric.py +++ b/nemo/collections/common/metrics/global_average_loss_metric.py @@ -28,9 +28,6 @@ class GlobalAverageLossMetric(Metric): See :doc:`PyTorch Lightning Metrics` for the metric usage instruction. Args: - compute_on_step: - The method :meth:`forward` only calls ``update()`` and returns ``None`` if this is set to ``False``. - default: ``True`` dist_sync_on_step: Synchronize metric state across processes at each method :meth:`forward` call before returning the value at the step @@ -44,10 +41,8 @@ class GlobalAverageLossMetric(Metric): full_state_update = True - def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True): - super().__init__( - compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group - ) + def __init__(self, dist_sync_on_step=False, process_group=None, take_avg_loss=True): + super().__init__(dist_sync_on_step=dist_sync_on_step, process_group=process_group) self.add_state("loss_sum", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum') self.add_state("num_measurements", torch.tensor(0, dtype=torch.int64), dist_reduce_fx='sum') self.take_avg_loss = take_avg_loss diff --git a/nemo/collections/common/metrics/perplexity.py b/nemo/collections/common/metrics/perplexity.py index 1158e3408611..9e1c21737ec8 100644 --- a/nemo/collections/common/metrics/perplexity.py +++ b/nemo/collections/common/metrics/perplexity.py @@ -29,8 +29,6 @@ class Perplexity(Metric): See `PyTorch Lightning Metrics `_ for the metric usage instructions. Args: - compute_on_step: - Forward only calls ``update()`` and returns ``None`` if this is set to ``False``. default: ``True`` dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. @@ -44,10 +42,8 @@ class Perplexity(Metric): full_state_update = True - def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, validate_args=True): - super().__init__( - compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group - ) + def __init__(self, dist_sync_on_step=False, process_group=None, validate_args=True): + super().__init__(dist_sync_on_step=dist_sync_on_step, process_group=process_group) self.validate_args = validate_args self.add_state('perplexities_sum', torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum') # Total number of distributions seen since last reset diff --git a/nemo/collections/nlp/metrics/sequence_perplexity.py b/nemo/collections/nlp/metrics/sequence_perplexity.py index 688f9db87ea6..339f062f7cc1 100644 --- a/nemo/collections/nlp/metrics/sequence_perplexity.py +++ b/nemo/collections/nlp/metrics/sequence_perplexity.py @@ -31,8 +31,6 @@ class SequencePerplexity(Metric): See :doc:`PyTorch Lightning Metrics` for the metric usage instructions. Args: - compute_on_step: - Forward only calls ``update()`` and returns ``None`` if this is set to ``False``. default: ``True`` dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. process_group: @@ -43,12 +41,9 @@ class SequencePerplexity(Metric): to perform the allgather. """ - def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None): + def __init__(self, dist_sync_on_step=False, process_group=None, dist_sync_fn=None): super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, + dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) # Total sum of exponentiated average negative log likelihoods diff --git a/nemo/collections/nlp/models/language_modeling/bert_lm_model.py b/nemo/collections/nlp/models/language_modeling/bert_lm_model.py index 4c9d43c20d54..5cf509e77846 100644 --- a/nemo/collections/nlp/models/language_modeling/bert_lm_model.py +++ b/nemo/collections/nlp/models/language_modeling/bert_lm_model.py @@ -116,7 +116,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # create extra bias # setup to track metrics - self.validation_perplexity = Perplexity(compute_on_step=False) + self.validation_perplexity = Perplexity() self.setup_optimization(cfg.optim) diff --git a/nemo/collections/nlp/models/text2sparql/text2sparql_model.py b/nemo/collections/nlp/models/text2sparql/text2sparql_model.py index 5290209b0c95..50046aef0344 100644 --- a/nemo/collections/nlp/models/text2sparql/text2sparql_model.py +++ b/nemo/collections/nlp/models/text2sparql/text2sparql_model.py @@ -100,7 +100,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): decoder=cfg.language_model.pretrained_decoder_model_name, ) - self.validation_perplexity = Perplexity(compute_on_step=False) + self.validation_perplexity = Perplexity() self.setup_optimization(cfg.optim) diff --git a/nemo/core/optim/optimizers.py b/nemo/core/optim/optimizers.py index 76e47e20e0cc..9473ef0af969 100644 --- a/nemo/core/optim/optimizers.py +++ b/nemo/core/optim/optimizers.py @@ -51,7 +51,6 @@ AVAILABLE_OPTIMIZERS['fused_adam'] = FusedAdam except ModuleNotFoundError: HAVE_APEX = False - logging.warning("Apex was not found. Using the lamb or fused_adam optimizer will error out.") HAVE_APEX_DISTRIBUTED_ADAM = False if HAVE_APEX: diff --git a/tests/collections/common/pl_utils.py b/tests/collections/common/pl_utils.py index 395c8cef5969..a2e9609c8492 100644 --- a/tests/collections/common/pl_utils.py +++ b/tests/collections/common/pl_utils.py @@ -90,7 +90,7 @@ def _class_test( calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric - metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) + metric = metric_class(dist_sync_on_step=dist_sync_on_step, **metric_args) # verify metrics work after being loaded from pickled state pickled_metric = pickle.dumps(metric) @@ -303,7 +303,7 @@ def _perplexity_class_test( calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric - perplexity = Perplexity(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) + perplexity = Perplexity(dist_sync_on_step=dist_sync_on_step, **metric_args) if (probs is None) == (logits is None): with pytest.raises(ValueError): perplexity(probs, logits) @@ -464,9 +464,7 @@ def _loss_class_test( calculated across devices for each batch (and not just at the end) """ # Instantiate lightning metric - loss_metric = GlobalAverageLossMetric( - compute_on_step=True, dist_sync_on_step=dist_sync_on_step, take_avg_loss=take_avg_loss - ) + loss_metric = GlobalAverageLossMetric(dist_sync_on_step=dist_sync_on_step, take_avg_loss=take_avg_loss) # verify loss works after being loaded from pickled state pickled_metric = pickle.dumps(loss_metric)