Skip to content

Commit f39d18a

Browse files
authored
fix(GOLDTrainer): Resolve incorrect attribute access and VLLMClient.generate() output type (#4526)
1 parent d45eaab commit f39d18a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

trl/experimental/gold/gold_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,7 +1662,7 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_
16621662
# prompts_text = [p.replace(target_system_prompt, system_prompt) for p in prompts_text]
16631663
# Add system prompt to prompts
16641664

1665-
max_completion_length = generation_config.max_completion_length
1665+
max_completion_length = generation_config.max_new_tokens
16661666
temperature = generation_config.temperature
16671667
# vLLM uses top_k=-1 for no top_k, transformers uses 0 or None.
16681668
top_k = generation_config.top_k if generation_config.top_k and generation_config.top_k > 0 else -1
@@ -1684,7 +1684,7 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_
16841684
min_p=min_p,
16851685
max_tokens=max_completion_length,
16861686
guided_decoding_regex=self.vllm_guided_decoding_regex,
1687-
)
1687+
)["completion_ids"]
16881688
else:
16891689
completion_ids = [None] * len(all_prompts_text)
16901690
completion_ids = broadcast_object_list(completion_ids, from_process=0)

0 commit comments

Comments
 (0)