Skip to content

Commit

Permalink
Use a float instead of a tensor for model_forward_time & model_execut…
Browse files Browse the repository at this point in the history
…e_time
  • Loading branch information
sfc-gh-mkeralapura committed Aug 14, 2024
1 parent ac3db8d commit 86a8902
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,9 +1563,9 @@ def execute_model(
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
"model_forward_time", 0.0)
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
model_forward_time + orig_model_forward_time)
return hidden_or_intermediate_states

logits = self.model.compute_logits(hidden_or_intermediate_states,
Expand All @@ -1588,7 +1588,7 @@ def execute_model(
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
"model_forward_time", 0.0)
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def execute_model(
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item()
"model_execute_time", 0.0)

output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
Expand All @@ -318,7 +318,7 @@ def execute_model(
# output is IntermediateTensors
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(
output.tensors["model_execute_time"] = (
model_execute_time + orig_model_execute_time)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
Expand Down

0 comments on commit 86a8902

Please sign in to comment.