From 0ba979b64cfac115e563cd0ecc8f6bacd85c0f26 Mon Sep 17 00:00:00 2001 From: YQ Date: Thu, 19 Oct 2023 22:22:06 +0800 Subject: [PATCH 1/3] fix logit to multi-hot converstion --- examples/pytorch/text-classification/run_classification.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/text-classification/run_classification.py b/examples/pytorch/text-classification/run_classification.py index 3033a61404e839..60951830fcdcad 100755 --- a/examples/pytorch/text-classification/run_classification.py +++ b/examples/pytorch/text-classification/run_classification.py @@ -655,7 +655,7 @@ def compute_metrics(p: EvalPrediction): preds = np.squeeze(preds) result = metric.compute(predictions=preds, references=p.label_ids) elif is_multi_label: - preds = np.array([np.where(p > 0.5, 1, 0) for p in preds]) + preds = np.array([np.where(p > 0, 1, 0) for p in preds]) # convert logits to multi-hot encoding # Micro F1 is commonly used in multi-label classification result = metric.compute(predictions=preds, references=p.label_ids, average="micro") else: @@ -721,7 +721,9 @@ def compute_metrics(p: EvalPrediction): if is_regression: predictions = np.squeeze(predictions) elif is_multi_label: - predictions = np.array([np.where(p > 0.5, 1, 0) for p in predictions]) + predictions = np.array( + [np.where(p > 0, 1, 0) for p in predictions] + ) # convert logits to multi-hot encoding else: predictions = np.argmax(predictions, axis=1) output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt") From e49de734557ae3b3f8d1913d2a19675b1a7da9bf Mon Sep 17 00:00:00 2001 From: YQ Date: Fri, 20 Oct 2023 21:00:56 +0800 Subject: [PATCH 2/3] add comments --- examples/pytorch/text-classification/run_classification.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/text-classification/run_classification.py b/examples/pytorch/text-classification/run_classification.py index 60951830fcdcad..9e29e43daf674f 100755 --- a/examples/pytorch/text-classification/run_classification.py +++ b/examples/pytorch/text-classification/run_classification.py @@ -721,9 +721,10 @@ def compute_metrics(p: EvalPrediction): if is_regression: predictions = np.squeeze(predictions) elif is_multi_label: - predictions = np.array( - [np.where(p > 0, 1, 0) for p in predictions] - ) # convert logits to multi-hot encoding + # Convert logits to multi-hot encoding. We compare the logits to 0 instead of 0.5, because the sigmoid is not applied. + # You can aso pass `preprocess_logits_for_metrics=lambda logits, labels: nn.functional.sigmoid(logits)` to the Trainer + # and set p > 0.5 below (less efficient in this case) + predictions = np.array([np.where(p > 0, 1, 0) for p in predictions]) else: predictions = np.argmax(predictions, axis=1) output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt") From 1fd9c617dcdeec95ae4bda21135ad1e6d72c372f Mon Sep 17 00:00:00 2001 From: YQ Date: Fri, 20 Oct 2023 21:15:57 +0800 Subject: [PATCH 3/3] typo --- examples/pytorch/text-classification/run_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/run_classification.py b/examples/pytorch/text-classification/run_classification.py index 9e29e43daf674f..1bc4dbe5fa3712 100755 --- a/examples/pytorch/text-classification/run_classification.py +++ b/examples/pytorch/text-classification/run_classification.py @@ -722,7 +722,7 @@ def compute_metrics(p: EvalPrediction): predictions = np.squeeze(predictions) elif is_multi_label: # Convert logits to multi-hot encoding. We compare the logits to 0 instead of 0.5, because the sigmoid is not applied. - # You can aso pass `preprocess_logits_for_metrics=lambda logits, labels: nn.functional.sigmoid(logits)` to the Trainer + # You can also pass `preprocess_logits_for_metrics=lambda logits, labels: nn.functional.sigmoid(logits)` to the Trainer # and set p > 0.5 below (less efficient in this case) predictions = np.array([np.where(p > 0, 1, 0) for p in predictions]) else: