Skip to content

Commit d88918e

Browse files
[Core] Enable sharded state loader for V1 engine and enhance test coverage (#25308)
Signed-off-by: pengdrumli <pengdrumli@tencent.com>
1 parent 3c713a9 commit d88918e

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

tests/test_sharded_state_loader.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,19 @@ def llama_3p2_1b_files():
5757

5858
def _run_writer(input_dir, output_dir, weights_patterns, **kwargs):
5959
llm_sharded_writer = LLM(model=input_dir, **kwargs)
60-
60+
# Check which engine version is being used
61+
is_v1_engine = hasattr(llm_sharded_writer.llm_engine, "engine_core")
6162
# Dump worker states to output directory
62-
llm_sharded_writer.llm_engine.model_executor.save_sharded_state(
63-
path=output_dir)
63+
if is_v1_engine:
64+
# For V1 engine, we need to use engine_core.save_sharded_state
65+
print("Using V1 engine save path")
66+
llm_sharded_writer.llm_engine.engine_core.save_sharded_state(
67+
path=output_dir)
68+
else:
69+
# For V0 engine
70+
print("Using V0 engine save path")
71+
model_executor = llm_sharded_writer.llm_engine.model_executor
72+
model_executor.save_sharded_state(path=output_dir)
6473

6574
# Copy metadata files to output directory
6675
for file in os.listdir(input_dir):
@@ -91,16 +100,13 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
91100
gpu_memory_utilization = 0.8
92101
input_dir = llama_3p2_1b_files
93102
ctx = mp.get_context("spawn")
94-
# The interface in v1 engine has changed, run in v1 engine will hang.
95-
monkeypatch.setenv("VLLM_USE_V1", "0")
96103

97104
# Run in separate processes for memory & CUDA isolation
98105
with TemporaryDirectory() as output_dir:
99106
p = ctx.Process(target=_run_writer,
100107
args=(input_dir, output_dir, weights_patterns),
101108
kwargs=dict(
102109
tensor_parallel_size=tp_size,
103-
distributed_executor_backend="mp",
104110
gpu_memory_utilization=gpu_memory_utilization,
105111
enforce_eager=True,
106112
))
@@ -112,7 +118,6 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
112118
p = ctx.Process(target=_run_generate,
113119
args=(input_dir, queue),
114120
kwargs=dict(
115-
distributed_executor_backend="mp",
116121
enable_lora=enable_lora,
117122
gpu_memory_utilization=gpu_memory_utilization,
118123
tensor_parallel_size=tp_size,
@@ -133,7 +138,6 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
133138
p = ctx.Process(target=_run_generate,
134139
args=(output_dir, queue),
135140
kwargs=dict(
136-
distributed_executor_backend="mp",
137141
enable_lora=enable_lora,
138142
gpu_memory_utilization=gpu_memory_utilization,
139143
tensor_parallel_size=tp_size,

vllm/engine/arg_utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,12 +1486,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14861486
#############################################################
14871487
# Unsupported Feature Flags on V1.
14881488

1489-
if self.load_format == "sharded_state":
1490-
_raise_or_fallback(
1491-
feature_name=f"--load_format {self.load_format}",
1492-
recommend_to_remove=False)
1493-
return False
1494-
14951489
if (self.logits_processor_pattern
14961490
!= EngineArgs.logits_processor_pattern):
14971491
_raise_or_fallback(feature_name="--logits-processor-pattern",

0 commit comments

Comments
 (0)