Skip to content

Commit 6f709ca

Browse files
PotabkAngazenn
authored andcommitted
[Bugfix] Fix broken CI (vllm-project#2825)
### What this PR does / why we need it? 1. Initial support disable tp for integrating with [vllm-commit](vllm-project/vllm#23024) 2. [vllm@commit](vllm-project/vllm#23673) now use `bytes` to save the `BlockHash` to reduce GC overhead, this pr add the integration - vLLM version: main - vLLM main: vllm-project/vllm@e408272 --------- Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent f7a9c42 commit 6f709ca

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

tests/ut/core/test_scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
SchedulerConfig, SpeculativeConfig, VllmConfig)
99
from vllm.multimodal.inputs import PlaceholderRange
1010
from vllm.sampling_params import SamplingParams
11+
from vllm.utils import sha256
1112
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
1213
init_none_hash)
1314
from vllm.v1.core.sched.output import SchedulerOutput
@@ -38,7 +39,7 @@ def create_requests(
3839
max_tokens: int = 16,
3940
stop_token_ids: Optional[list[int]] = None,
4041
block_size: int = 3,
41-
hash_fn=hash,
42+
hash_fn=sha256,
4243
):
4344
init_none_hash(hash_fn)
4445
prompt_logprobs = PROMPT_LOGPROBS

tests/ut/kv_connector/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm import SamplingParams
1111
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
1212
ModelConfig, SchedulerConfig, VllmConfig)
13+
from vllm.utils import sha256
1314
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
1415
init_none_hash)
1516
from vllm.v1.core.sched.scheduler import Scheduler
@@ -129,10 +130,10 @@ def create_request(
129130
"""Make dummy request for testing."""
130131
global _none_hash_initialized
131132
if not _none_hash_initialized:
132-
init_none_hash(hash)
133+
init_none_hash(sha256)
133134
_none_hash_initialized = True
134135

135-
block_hasher = get_request_block_hasher(block_size, hash)
136+
block_hasher = get_request_block_hasher(block_size, sha256)
136137

137138
kv_transfer_params: Optional[dict[str, Any]] = None
138139

vllm_ascend/ops/linear.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
prefix: str = "",
6363
*,
6464
return_bias: bool = True,
65+
disable_tp: bool = False,
6566
):
6667
self.comm_group = None
6768
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
@@ -88,7 +89,8 @@ def __init__(
8889
params_dtype,
8990
quant_config,
9091
prefix,
91-
return_bias=return_bias)
92+
return_bias=return_bias,
93+
disable_tp=disable_tp)
9294

9395
self.gather_output = gather_output
9496

@@ -137,6 +139,7 @@ def __init__(
137139
prefix: str = "",
138140
*,
139141
return_bias: bool = True,
142+
disable_tp: bool = False,
140143
):
141144
if prefix.find("down_proj") != -1 and mlp_tp_enable():
142145
comm_group = get_mlp_tp_group()
@@ -156,6 +159,7 @@ def __init__(
156159
self.forward_type = "normal"
157160
self.comm_group = comm_group
158161

162+
# TODO: check for disable_tp
159163
self.tp_size = self.comm_group.world_size
160164
self.tp_rank = self.comm_group.rank_in_group
161165

@@ -171,7 +175,8 @@ def __init__(
171175
params_dtype,
172176
quant_config,
173177
prefix,
174-
return_bias=return_bias)
178+
return_bias=return_bias,
179+
disable_tp=disable_tp)
175180

176181
self.input_is_parallel = input_is_parallel
177182
self.reduce_results = reduce_results
@@ -392,6 +397,7 @@ def __init__(
392397
prefix: str = "",
393398
*,
394399
return_bias: bool = True,
400+
disable_tp: bool = False,
395401
):
396402
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
397403
comm_group = get_mlp_tp_group()
@@ -403,6 +409,7 @@ def __init__(
403409
comm_group = get_tp_group()
404410
self.forward_type = "normal_tp"
405411
self.comm_group = comm_group
412+
# TODO: check for disable_tp
406413
self.tp_rank = comm_group.rank_in_group
407414
self.tp_size = comm_group.world_size
408415

@@ -418,7 +425,8 @@ def __init__(
418425
params_dtype=params_dtype,
419426
quant_config=quant_config,
420427
prefix=prefix,
421-
return_bias=return_bias)
428+
return_bias=return_bias,
429+
disable_tp=disable_tp)
422430

423431
def forward(
424432
self,
@@ -498,6 +506,7 @@ def __init__(
498506
prefix: str = "",
499507
*,
500508
return_bias: bool = True,
509+
disable_tp: bool = False,
501510
):
502511
if dense_optim_enable():
503512
self.forward_type = "dense_optim"
@@ -511,6 +520,7 @@ def __init__(
511520
total_num_kv_heads = total_num_heads
512521
self.total_num_kv_heads = total_num_kv_heads
513522
# Divide the weight matrix along the last dimension.
523+
# TODO: check for disable_tp
514524
tp_size = self.comm_group.world_size
515525
self.num_heads = divide(self.total_num_heads, tp_size)
516526
if tp_size >= self.total_num_kv_heads:
@@ -537,7 +547,8 @@ def __init__(
537547
params_dtype=params_dtype,
538548
quant_config=quant_config,
539549
prefix=prefix,
540-
return_bias=return_bias)
550+
return_bias=return_bias,
551+
disable_tp=disable_tp)
541552

542553
def forward(
543554
self,
@@ -611,4 +622,4 @@ def __init__(
611622
self.quant_method = quant_config.get_quant_method(self,
612623
prefix=prefix)
613624
self.return_bias = return_bias
614-
self.disable_tp = disable_tp
625+
self.disable_tp = disable_tp

0 commit comments

Comments
 (0)