diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7d654312a3a069..57b6c6d188105d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2148,6 +2148,8 @@ def test_watermark_generation(self): watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash") _ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15) + # We will not check watermarked text, since we check it in `logits_processors` tests + # Checking if generated ids are as expected fails on different hardware args = { "bias": 2.0, "context_width": 1, @@ -2158,19 +2160,11 @@ def test_watermark_generation(self): output = model.generate(**model_inputs, do_sample=False, max_length=15) output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15) - # check that the watermarked text is generating what is should - self.assertListEqual( - output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]] - ) - self.assertListEqual( - output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]] - ) - + # Check that the detector is detecting watermarked text detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args) detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True) detection_out = detector(output[:, input_len:], return_dict=True) - # check that the detector is detecting watermarked text self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True]) self.assertListEqual(detection_out.prediction.tolist(), [False])