From b522ae52667bc3c42244bebaf8db3ede7a0f4ff6 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 19 Jul 2024 17:56:35 +0200 Subject: [PATCH] Add test --- tests/unittests/text/test_bertscore.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 0812808426f..dfd6d60a0e5 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -170,3 +170,24 @@ def test_bertscore_differentiability( metric_args=metric_args, key=metric_key, ) + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize( + "idf", + [(False,), (True,)], +) +def test_bertscore_sorting(idf: bool): + """Test that BERTScore is invariant to the order of the inputs.""" + short = "Short text" + long = "This is a longer text" + + preds = [long, long] + targets = [long, short] + + metric = BERTScore(idf=idf) + score = metric(preds, targets) + + # First index should be the self-comparison - sorting by length should not shuffle this + assert score["f1"][0] > score["f1"][1]