Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions tests/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,14 @@ def test_training_with_judge(self, config_name):
@require_vllm
@pytest.mark.slow
def test_training_with_vllm(self, config_name):
def cleanup_vllm_communicator(trainer):
"""Clean up vLLM communicator to avoid conflicts between test runs"""
try:
if hasattr(trainer, "vllm_client") and trainer.vllm_client is not None:
trainer.vllm_client.close_communicator()
except Exception:
pass # Continue if cleanup fails

model_id = "trl-internal-testing/small-Qwen2ForCausalLM-2.5" # We need a bigger model
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand All @@ -295,10 +303,14 @@ def test_training_with_vllm(self, config_name):
processing_class=tokenizer,
reward_processing_classes=self.reward_tokenizer,
)
trainer.train()

# Check if training loss is available
assert "train_loss" in trainer.state.log_history[-1]
# Ensure cleanup of vLLM communicator after the test
try:
trainer.train()
# Check if training loss is available
assert "train_loss" in trainer.state.log_history[-1]
finally:
cleanup_vllm_communicator(trainer)

@require_vllm
def test_training_with_vllm_colocate(self):
Expand Down
3 changes: 3 additions & 0 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,9 @@ def close_communicator(self):
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

if self.communicator is not None:
self.communicator = None


# Example usage
if __name__ == "__main__":
Expand Down
8 changes: 6 additions & 2 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,11 @@ def __init__(
else:
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
self.vllm_client.init_communicator(device=torch.cuda.current_device())

# Determine device type (supports cuda, xpu, etc.)
accelerator_type = torch.accelerator.current_accelerator().type
current_device = getattr(torch, accelerator_type).current_device()
self.vllm_client.init_communicator(device=current_device)
else:
self.vllm_client = None
elif self.vllm_mode == "colocate":
Expand Down Expand Up @@ -755,7 +759,7 @@ def _generate_vllm_server(self, prompts, images=None):
max_tokens=self.generation_config.max_tokens,
guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None,
generation_kwargs=self.args.generation_kwargs,
)
)["completion_ids"]
# Flatten: each prompt generates 2 completions
completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions]
else:
Expand Down
Loading