From fa73ff326ec5b9abb4539dcb01d9478d61e281df Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 18 Jun 2024 08:07:00 -0400 Subject: [PATCH 1/4] Calculate do_penalties upfront Signed-off-by: Thomas Parnell --- vllm/model_executor/sampling_metadata.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 7ad84f51b7e4c..63b1d39fc61d8 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -340,6 +340,18 @@ def from_sampling_metadata( get_num_triton_sampler_splits(vocab_size)) assert sampling_metadata.seq_groups is not None + + # first pass through to determine do_penalties + for seq_group in sampling_metadata.seq_groups: + sampling_params = seq_group.sampling_params + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty + if not do_penalties and (abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True + for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids sampling_params = seq_group.sampling_params @@ -366,10 +378,6 @@ def from_sampling_metadata( do_top_p_top_k = True if not do_min_p and min_p > _SAMPLING_EPS: do_min_p = True - if not do_penalties and (abs(p) >= _SAMPLING_EPS - or abs(f) >= _SAMPLING_EPS - or abs(r - 1.0) >= _SAMPLING_EPS): - do_penalties = True is_prompt = seq_group.is_prompt if (seq_group.is_prompt From 176d61123a4892deb7061a676d15cdb660797d47 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 18 Jun 2024 08:30:03 -0400 Subject: [PATCH 2/4] Format Signed-off-by: Thomas Parnell --- vllm/model_executor/sampling_metadata.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 63b1d39fc61d8..485f91da060b7 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -343,14 +343,14 @@ def from_sampling_metadata( # first pass through to determine do_penalties for seq_group in sampling_metadata.seq_groups: - sampling_params = seq_group.sampling_params - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - if not do_penalties and (abs(p) >= _SAMPLING_EPS - or abs(f) >= _SAMPLING_EPS - or abs(r - 1.0) >= _SAMPLING_EPS): - do_penalties = True + sampling_params = seq_group.sampling_params + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty + if not do_penalties and (abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids From 81f3b5bff1fd1ea28a92e5d1e35000967dbbf886 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 18 Jun 2024 14:55:26 -0400 Subject: [PATCH 3/4] Compute prompt_tokens and output_tokens in second pass so that do_penalties is consistently computed across sequence groups. Signed-off-by: Thomas Parnell --- vllm/model_executor/sampling_metadata.py | 40 ++++++++++++------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 485f91da060b7..679931af3664c 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -340,18 +340,6 @@ def from_sampling_metadata( get_num_triton_sampler_splits(vocab_size)) assert sampling_metadata.seq_groups is not None - - # first pass through to determine do_penalties - for seq_group in sampling_metadata.seq_groups: - sampling_params = seq_group.sampling_params - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - if not do_penalties and (abs(p) >= _SAMPLING_EPS - or abs(f) >= _SAMPLING_EPS - or abs(r - 1.0) >= _SAMPLING_EPS): - do_penalties = True - for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids sampling_params = seq_group.sampling_params @@ -378,6 +366,10 @@ def from_sampling_metadata( do_top_p_top_k = True if not do_min_p and min_p > _SAMPLING_EPS: do_min_p = True + if not do_penalties and (abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True is_prompt = seq_group.is_prompt if (seq_group.is_prompt @@ -394,18 +386,10 @@ def from_sampling_metadata( presence_penalties += [0] * prefill_len frequency_penalties += [0] * prefill_len repetition_penalties += [1] * prefill_len - if do_penalties: - prompt_tokens.extend([] for _ in range(prefill_len)) - output_tokens.extend([] for _ in range(prefill_len)) if seq_group.do_sample: sample_lens = len(seq_group.sample_indices) assert sample_lens == len(seq_ids) - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - if do_penalties: - prompt_tokens.append(seq_data.prompt_token_ids) - output_tokens.append(seq_data.output_token_ids) temperatures += [temperature] * len(seq_ids) top_ps += [top_p] * len(seq_ids) top_ks += [top_k] * len(seq_ids) @@ -432,6 +416,22 @@ def from_sampling_metadata( sampling_seeds.append(seq_seeds) sample_indices.extend(seq_group.sample_indices) + if do_penalties: + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + + if (seq_group.is_prompt + and sampling_params.prompt_logprobs is not None): + prefill_len = len(seq_group.prompt_logprob_indices) + prompt_tokens.extend([] for _ in range(prefill_len)) + output_tokens.extend([] for _ in range(prefill_len)) + + if seq_group.do_sample: + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) + sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, frequency_penalties, repetition_penalties, sampling_seeds, From 1166740f9bce044ec6a5f251ccd1e1374f7528a2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 18 Jun 2024 14:56:52 -0400 Subject: [PATCH 4/4] Minor cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/sampling_metadata.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 679931af3664c..f95de56f39b57 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -419,13 +419,11 @@ def from_sampling_metadata( if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids - if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) prompt_tokens.extend([] for _ in range(prefill_len)) output_tokens.extend([] for _ in range(prefill_len)) - if seq_group.do_sample: for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id]