From 779bc360ff4f3965a1ac29fdc02c43db7ede08c0 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 28 May 2024 17:07:42 +0500 Subject: [PATCH] Watermark: fix tests (#30961) * fix tests * style * Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/generation/test_utils.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7d654312a3a069..57b6c6d188105d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2148,6 +2148,8 @@ def test_watermark_generation(self): watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash") _ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15) + # We will not check watermarked text, since we check it in `logits_processors` tests + # Checking if generated ids are as expected fails on different hardware args = { "bias": 2.0, "context_width": 1, @@ -2158,19 +2160,11 @@ def test_watermark_generation(self): output = model.generate(**model_inputs, do_sample=False, max_length=15) output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15) - # check that the watermarked text is generating what is should - self.assertListEqual( - output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]] - ) - self.assertListEqual( - output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]] - ) - + # Check that the detector is detecting watermarked text detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args) detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True) detection_out = detector(output[:, input_len:], return_dict=True) - # check that the detector is detecting watermarked text self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True]) self.assertListEqual(detection_out.prediction.tolist(), [False])