Skip to content

Commit 93f2c0a

Browse files
authored
[Models] Improve iteration over layers (#26425)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
1 parent 4ebc910 commit 93f2c0a

File tree

8 files changed

+23
-22
lines changed

8 files changed

+23
-22
lines changed

vllm/model_executor/models/apertus.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"""Inference-only Apertus model compatible with HuggingFace weights."""
2727

2828
from collections.abc import Iterable
29+
from itertools import islice
2930
from typing import Any, Optional, Union
3031

3132
import torch
@@ -412,7 +413,9 @@ def forward(
412413
residual = intermediate_tensors["residual"]
413414

414415
aux_hidden_states = []
415-
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
416+
for idx, layer in enumerate(
417+
islice(self.layers, self.start_layer, self.end_layer)
418+
):
416419
if idx in self.aux_hidden_state_layers:
417420
aux_hidden_states.append(hidden_states + residual)
418421
hidden_states, residual = layer(positions, hidden_states, residual)

vllm/model_executor/models/falcon_h1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""Inference-only FalconH1 model."""
44

55
from collections.abc import Iterable
6+
from itertools import islice
67
from typing import Optional
78

89
import torch
@@ -480,8 +481,7 @@ def forward(
480481
assert intermediate_tensors is not None
481482
hidden_states = intermediate_tensors["hidden_states"]
482483

483-
for i in range(self.start_layer, self.end_layer):
484-
layer = self.layers[i]
484+
for layer in islice(self.layers, self.start_layer, self.end_layer):
485485
hidden_states = layer(
486486
positions=positions,
487487
hidden_states=hidden_states,

vllm/model_executor/models/hunyuan_v1.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import typing
2828
from collections.abc import Callable, Iterable
29+
from itertools import islice
2930
from typing import Any, Optional, Union
3031

3132
import regex as re
@@ -672,19 +673,17 @@ def forward(
672673

673674
cla_factor = _get_cla_factor(self.config)
674675
prev_kv_states = None
675-
for i in range(self.start_layer, self.end_layer):
676-
layer = self.layers[i]
676+
for i, layer in enumerate(
677+
islice(self.layers, self.start_layer, self.end_layer)
678+
):
677679
hidden_states, residual, kv_states = layer(
678680
positions,
679681
hidden_states,
680682
residual,
681683
prev_kv_states,
682684
)
683685

684-
if (
685-
getattr(self.config, "use_cla", False)
686-
and (i - self.start_layer) % cla_factor == 0
687-
):
686+
if getattr(self.config, "use_cla", False) and i % cla_factor == 0:
688687
prev_kv_states = kv_states
689688
else:
690689
prev_kv_states = None

vllm/model_executor/models/lfm2_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Iterable
4+
from itertools import islice
45
from typing import Any, Optional
56

67
import torch
@@ -492,7 +493,7 @@ def forward(
492493
hidden_states = intermediate_tensors["hidden_states"]
493494
residual = intermediate_tensors["residual"]
494495

495-
for layer in self.layers[self.start_layer : self.end_layer]:
496+
for layer in islice(self.layers, self.start_layer, self.end_layer):
496497
hidden_states, residual = layer(
497498
positions=positions,
498499
hidden_states=hidden_states,

vllm/model_executor/models/longcat_flash.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import typing
3737
from collections.abc import Callable, Iterable
38+
from itertools import islice
3839
from typing import Optional, Union
3940

4041
import torch
@@ -519,8 +520,7 @@ def forward(
519520
hidden_states = intermediate_tensors["hidden_states"]
520521
residual = intermediate_tensors["residual"]
521522

522-
for i in range(self.start_layer, self.end_layer):
523-
layer = self.layers[i]
523+
for layer in islice(self.layers, self.start_layer, self.end_layer):
524524
hidden_states, residual = layer(
525525
positions,
526526
hidden_states,

vllm/model_executor/models/mamba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""PyTorch MAMBA model."""
44

55
from collections.abc import Iterable
6+
from itertools import islice
67
from typing import Optional
78

89
import torch
@@ -162,8 +163,7 @@ def forward(
162163
hidden_states = intermediate_tensors["hidden_states"]
163164
residual = intermediate_tensors["residual"]
164165

165-
for i in range(self.start_layer, self.end_layer):
166-
layer = self.layers[i]
166+
for layer in islice(self.layers, self.start_layer, self.end_layer):
167167
hidden_states, residual = layer(
168168
positions=positions, hidden_states=hidden_states, residual=residual
169169
)

vllm/model_executor/models/qwen3_vl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from collections.abc import Iterable, Mapping, Sequence
2828
from functools import partial
29+
from itertools import islice
2930
from typing import Any, Callable, Optional, Union
3031

3132
import numpy as np
@@ -1106,11 +1107,9 @@ def forward(
11061107
assert intermediate_tensors is not None
11071108
hidden_states = intermediate_tensors["hidden_states"]
11081109
residual = intermediate_tensors["residual"]
1109-
for layer_idx, layer in enumerate(
1110-
self.layers[self.start_layer : self.end_layer]
1110+
for layer_idx, layer in islice(
1111+
enumerate(self.layers), self.start_layer, self.end_layer
11111112
):
1112-
layer_idx = layer_idx + self.start_layer
1113-
11141113
hidden_states, residual = layer(
11151114
positions,
11161115
hidden_states,

vllm/model_executor/models/qwen3_vl_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import typing
2828
from collections.abc import Iterable
29+
from itertools import islice
2930
from typing import Callable, Optional, Union
3031

3132
import torch
@@ -103,11 +104,9 @@ def forward(
103104
assert intermediate_tensors is not None
104105
hidden_states = intermediate_tensors["hidden_states"]
105106
residual = intermediate_tensors["residual"]
106-
for layer_idx, layer in enumerate(
107-
self.layers[self.start_layer : self.end_layer]
107+
for layer_idx, layer in islice(
108+
enumerate(self.layers), self.start_layer, self.end_layer
108109
):
109-
layer_idx = layer_idx + self.start_layer
110-
111110
hidden_states, residual = layer(
112111
positions,
113112
hidden_states,

0 commit comments

Comments
 (0)