diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 09f986ad558..53c6524db52 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -274,6 +274,14 @@ def test_training_with_judge(self, config_name): @require_vllm @pytest.mark.slow def test_training_with_vllm(self, config_name): + def cleanup_vllm_communicator(trainer): + """Clean up vLLM communicator to avoid conflicts between test runs""" + try: + if hasattr(trainer, "vllm_client") and trainer.vllm_client is not None: + trainer.vllm_client.close_communicator() + except Exception: + pass # Continue if cleanup fails + model_id = "trl-internal-testing/small-Qwen2ForCausalLM-2.5" # We need a bigger model model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -295,10 +303,14 @@ def test_training_with_vllm(self, config_name): processing_class=tokenizer, reward_processing_classes=self.reward_tokenizer, ) - trainer.train() - # Check if training loss is available - assert "train_loss" in trainer.state.log_history[-1] + # Ensure cleanup of vLLM communicator after the test + try: + trainer.train() + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + finally: + cleanup_vllm_communicator(trainer) @require_vllm def test_training_with_vllm_colocate(self): diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 5d76d422c16..5ddc6150a59 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -495,6 +495,9 @@ def close_communicator(self): if response.status_code != 200: raise Exception(f"Request failed: {response.status_code}, {response.text}") + if self.communicator is not None: + self.communicator = None + # Example usage if __name__ == "__main__": diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 70190eddff5..cdbd5458fe8 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -463,7 +463,11 @@ def __init__( else: base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) - self.vllm_client.init_communicator(device=torch.cuda.current_device()) + + # Determine device type (supports cuda, xpu, etc.) + accelerator_type = torch.accelerator.current_accelerator().type + current_device = getattr(torch, accelerator_type).current_device() + self.vllm_client.init_communicator(device=current_device) else: self.vllm_client = None elif self.vllm_mode == "colocate": @@ -755,7 +759,7 @@ def _generate_vllm_server(self, prompts, images=None): max_tokens=self.generation_config.max_tokens, guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None, generation_kwargs=self.args.generation_kwargs, - ) + )["completion_ids"] # Flatten: each prompt generates 2 completions completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions] else: