Skip to content

Commit ed5c7bb

Browse files
authored
[Bug Fix] OnlineDPOTrainer with vLLM Server Mode (#4500)
1 parent ded9bc6 commit ed5c7bb

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

tests/test_online_dpo_trainer.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,14 @@ def test_training_with_judge(self, config_name):
274274
@require_vllm
275275
@pytest.mark.slow
276276
def test_training_with_vllm(self, config_name):
277+
def cleanup_vllm_communicator(trainer):
278+
"""Clean up vLLM communicator to avoid conflicts between test runs"""
279+
try:
280+
if hasattr(trainer, "vllm_client") and trainer.vllm_client is not None:
281+
trainer.vllm_client.close_communicator()
282+
except Exception:
283+
pass # Continue if cleanup fails
284+
277285
model_id = "trl-internal-testing/small-Qwen2ForCausalLM-2.5" # We need a bigger model
278286
model = AutoModelForCausalLM.from_pretrained(model_id)
279287
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -295,10 +303,14 @@ def test_training_with_vllm(self, config_name):
295303
processing_class=tokenizer,
296304
reward_processing_classes=self.reward_tokenizer,
297305
)
298-
trainer.train()
299306

300-
# Check if training loss is available
301-
assert "train_loss" in trainer.state.log_history[-1]
307+
# Ensure cleanup of vLLM communicator after the test
308+
try:
309+
trainer.train()
310+
# Check if training loss is available
311+
assert "train_loss" in trainer.state.log_history[-1]
312+
finally:
313+
cleanup_vllm_communicator(trainer)
302314

303315
@require_vllm
304316
def test_training_with_vllm_colocate(self):

trl/extras/vllm_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,9 @@ def close_communicator(self):
495495
if response.status_code != 200:
496496
raise Exception(f"Request failed: {response.status_code}, {response.text}")
497497

498+
if self.communicator is not None:
499+
self.communicator = None
500+
498501

499502
# Example usage
500503
if __name__ == "__main__":

trl/trainer/online_dpo_trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,11 @@ def __init__(
463463
else:
464464
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
465465
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
466-
self.vllm_client.init_communicator(device=torch.cuda.current_device())
466+
467+
# Determine device type (supports cuda, xpu, etc.)
468+
accelerator_type = torch.accelerator.current_accelerator().type
469+
current_device = getattr(torch, accelerator_type).current_device()
470+
self.vllm_client.init_communicator(device=current_device)
467471
else:
468472
self.vllm_client = None
469473
elif self.vllm_mode == "colocate":
@@ -755,7 +759,7 @@ def _generate_vllm_server(self, prompts, images=None):
755759
max_tokens=self.generation_config.max_tokens,
756760
guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None,
757761
generation_kwargs=self.args.generation_kwargs,
758-
)
762+
)["completion_ids"]
759763
# Flatten: each prompt generates 2 completions
760764
completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions]
761765
else:

0 commit comments

Comments
 (0)