From 86a89023e2d712701680cfc007eec2deef10afb0 Mon Sep 17 00:00:00 2001 From: Mahesh Keralapura Date: Wed, 14 Aug 2024 10:29:28 -0700 Subject: [PATCH] Use a float instead of a tensor for model_forward_time & model_execute_time --- vllm/worker/model_runner.py | 6 +++--- vllm/worker/worker_base.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a541831ab4601..5547f382f8b8e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1563,9 +1563,9 @@ def execute_model( orig_model_forward_time = 0.0 if intermediate_tensors is not None: orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() + "model_forward_time", 0.0) hidden_or_intermediate_states.tensors["model_forward_time"] = ( - torch.tensor(model_forward_time + orig_model_forward_time)) + model_forward_time + orig_model_forward_time) return hidden_or_intermediate_states logits = self.model.compute_logits(hidden_or_intermediate_states, @@ -1588,7 +1588,7 @@ def execute_model( orig_model_forward_time = 0.0 if intermediate_tensors is not None: orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() + "model_forward_time", 0.0) # If there are multiple workers, we are still tracking the latency # from the start time of the driver worker to the end time of the # driver worker. The model forward time will then end up covering diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 85ab0d348e03d..d593d6eeb834c 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -307,7 +307,7 @@ def execute_model( if (self.observability_config is not None and self.observability_config.collect_model_execute_time): orig_model_execute_time = intermediate_tensors.tensors.get( - "model_execute_time", torch.tensor(0)).item() + "model_execute_time", 0.0) output = self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] @@ -318,7 +318,7 @@ def execute_model( # output is IntermediateTensors if (self.observability_config is not None and self.observability_config.collect_model_execute_time): - output.tensors["model_execute_time"] = torch.tensor( + output.tensors["model_execute_time"] = ( model_execute_time + orig_model_execute_time) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group())