From 2ab5b59c6ebeb9bb555ed53bf66a40600d436a7e Mon Sep 17 00:00:00 2001 From: Justin Date: Fri, 25 Aug 2023 19:04:15 -0400 Subject: [PATCH] Fix sequence generator test. (#3546) --- ludwig/features/text_feature.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ludwig/features/text_feature.py b/ludwig/features/text_feature.py index 56bedfa07b5..ec5c6e5c6c9 100644 --- a/ludwig/features/text_feature.py +++ b/ludwig/features/text_feature.py @@ -302,11 +302,12 @@ def update_metrics( decoded_targets, decoded_predictions = get_decoded_targets_and_predictions(targets, predictions, tokenizer) for metric_name, metric_fn in self._metric_functions.items(): prediction_key = get_metric_tensor_input(metric_name) - if prediction_key == RESPONSE and tokenizer is not None: - # RESPONSE metrics cannot be computed if decoded texts are not provided. - # Decoded texts are only provided using the LLM model type. - if decoded_targets is not None and decoded_predictions is not None: - metric_fn.update(decoded_predictions, decoded_targets) + if prediction_key == RESPONSE: + if tokenizer is not None: + # RESPONSE metrics cannot be computed if decoded texts are not provided. + # Decoded texts are only provided using the LLM model type. + if decoded_targets is not None and decoded_predictions is not None: + metric_fn.update(decoded_predictions, decoded_targets) else: metric_fn = metric_fn.to(predictions[prediction_key].device) metric_fn.update(predictions[prediction_key].detach(), targets)