diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 5173683d25a27..c72ec11b5cb0d 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -637,3 +637,72 @@ def mock_sample(probs, *args, **kwargs): hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_repetition_penalty_mixed(device: str): + + vocab_size = 8 + + def test_sampling_params(sampling_params: List[SamplingParams]): + + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + seq_lens: List[int] = [] + for i in range(2): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=sampling_params[i], + block_tables={0: [1]}, + )) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=device, + pin_memory=is_pin_memory_available()) + + fake_logits = torch.full((2, vocab_size), + 1e-2, + device=device, + dtype=torch.float16) + + fake_logits[:, 5] = 1.1e-2 + fake_logits[:, 1] = 1.2e-2 + + sampler = MockLogitsSampler(fake_logits) + + sampler_output = sampler(logits=fake_logits, + sampling_metadata=sampling_metadata) + + generated_tokens = [] + for output in sampler_output: + generated_tokens.append(output.samples[0].output_token) + + return generated_tokens + + # one configuration is greedy with repetition_penalty + sampling_params_rep = SamplingParams( + temperature=0.0, + repetition_penalty=2.0, + ) + + # other configuration is sampling w/o repetition_penalty + sampling_params_sample = SamplingParams( + temperature=1.0, + top_k=1, + seed=42, + ) + + tokens1 = test_sampling_params( + [sampling_params_rep, sampling_params_sample]) + + tokens2 = test_sampling_params( + [sampling_params_sample, sampling_params_rep]) + + assert tokens1[0] == tokens2[1] + assert tokens1[1] == tokens2[0]