Skip to content

Commit 4b2dd14

Browse files
ShangmingCaiAkshat-Tripathi
authored andcommitted
[Bugfix] Fix device ordinal for multi-node spec decode (vllm-project#13269)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
1 parent b6e3057 commit 4b2dd14

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

vllm/spec_decode/spec_decode_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
1212
from vllm.distributed.communication_op import (broadcast_tensor_dict,
13+
get_tp_group,
1314
tensor_model_parallel_gather)
1415
from vllm.logger import init_logger
1516
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
@@ -365,7 +366,7 @@ def init_device(self) -> None:
365366
target_lm_head_weight)
366367

367368
self._metrics.init_tensors(self.rank, device_type=self.device)
368-
self.spec_decode_sampler.init_tensors(self.rank,
369+
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
369370
device_type=self.device)
370371

371372
scorer_cls: Type[SpeculativeScorer]

0 commit comments

Comments
 (0)