diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index 21913ad5450..16970e6aca7 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -292,7 +292,8 @@ def from_mask( cls, metric_per_token: torch.Tensor, mask: torch.Tensor ) -> List[Metric]: """ - From token-level metrics, returns an aggregate MyMetric per example in the batch. + From token-level metrics, returns an aggregate MyMetric per example in the + batch. :param metric_per_token: a (batchsize x num_tokens) Tensor @@ -1097,6 +1098,8 @@ def _consume_user_metrics(self, observation): # User-reported metrics if 'metrics' in observation: for uk, v in observation['metrics'].items(): + if v is None: + continue if uk in ALL_METRICS: # don't let the user override our metrics uk = f'USER_{uk}'