Skip to content

Commit b614e0f

Browse files
DarkLight1337yewentao256
authored andcommitted
[Misc] Improve type annotations for jsontree (#25577)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 44d6701 commit b614e0f

File tree

5 files changed

+88
-39
lines changed

5 files changed

+88
-39
lines changed

vllm/model_executor/models/aya_vision.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
44
from collections.abc import Iterable, Mapping, Sequence
5-
from typing import Annotated, Literal, Optional, Union, cast
5+
from typing import Annotated, Literal, Optional, Union
66

77
import torch
88
from torch import nn
@@ -347,23 +347,24 @@ def load_weights(self, weights: Iterable[tuple[str,
347347
loader = AutoWeightsLoader(self)
348348
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
349349

350-
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
351-
pixel_values: torch.Tensor,
352-
**kwargs) -> torch.Tensor:
353-
target_dtype = vision_tower.get_input_embeddings().weight.dtype
354-
image_features = vision_tower(pixel_values.to(dtype=target_dtype),
355-
**kwargs)
350+
def _image_pixels_to_features(
351+
self,
352+
vision_tower: SiglipVisionModel,
353+
pixel_values: torch.Tensor,
354+
**kwargs,
355+
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
356+
target_dtype: torch.dtype = \
357+
vision_tower.get_input_embeddings().weight.dtype
358+
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
359+
vision_tower(pixel_values.to(dtype=target_dtype), **kwargs)
356360

357361
def select_features(leaf: torch.Tensor):
358362
return self._select_image_features(
359363
leaf,
360364
strategy=self.config.vision_feature_select_strategy,
361365
)
362366

363-
return cast(
364-
Union[torch.Tensor, tuple[torch.Tensor, ...]],
365-
json_map_leaves(select_features, image_features),
366-
)
367+
return json_map_leaves(select_features, image_features)
367368

368369
def _select_image_features(self, image_features: torch.Tensor, *,
369370
strategy: str) -> torch.Tensor:

vllm/model_executor/models/llava.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import abstractmethod
55
from collections.abc import Iterable, Mapping, Sequence
66
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
7-
Union, cast)
7+
Union)
88

99
import torch
1010
import torch.nn as nn
@@ -623,18 +623,16 @@ def _image_pixels_to_features(
623623
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
624624
# NOTE: we skip the step to select the vision feature layer since
625625
# this is already done inside the vision tower
626-
image_features = vision_tower(pixel_values)
626+
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
627+
vision_tower(pixel_values)
627628

628629
def select_features(leaf: torch.Tensor):
629630
return self._select_image_features(
630631
leaf,
631632
strategy=self.config.vision_feature_select_strategy,
632633
)
633634

634-
return cast(
635-
Union[torch.Tensor, tuple[torch.Tensor, ...]],
636-
json_map_leaves(select_features, image_features),
637-
)
635+
return json_map_leaves(select_features, image_features)
638636

639637
def _process_image_pixels(
640638
self,

vllm/model_executor/models/minimax_vl_01.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Iterable, Mapping
4-
from typing import Annotated, Literal, Optional, Union, cast
4+
from typing import Annotated, Literal, Optional, Union
55

66
import torch
77
import torch.nn as nn
@@ -254,18 +254,16 @@ def _image_pixels_to_features(
254254
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
255255
# NOTE: we skip the step to select the vision feature layer since
256256
# this is already done inside the vision tower
257-
image_features = tuple(vision_tower(p) for p in pixel_values)
257+
image_features: tuple[torch.Tensor, ...] = \
258+
tuple(vision_tower(p) for p in pixel_values)
258259

259260
def select_features(leaf: torch.Tensor):
260261
return self._select_image_features(
261262
leaf,
262263
strategy=self.config.vision_feature_select_strategy,
263264
)
264265

265-
return cast(
266-
Union[torch.Tensor, tuple[torch.Tensor, ...]],
267-
json_map_leaves(select_features, image_features),
268-
)
266+
return json_map_leaves(select_features, image_features)
269267

270268
# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
271269
def pack_image_features(self, image_features: list[torch.Tensor],

vllm/model_executor/models/tarsier.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
from collections.abc import Iterable, Mapping, Sequence
66
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
7-
Union, cast)
7+
Union)
88

99
import torch
1010
import torch.nn as nn
@@ -490,23 +490,16 @@ def _image_pixels_to_features(
490490
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
491491
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
492492
# From vLLM LLaVA, vision tower output handling
493-
image_hidden_states = vision_tower(pixel_values)
494-
if not isinstance(image_hidden_states, torch.Tensor):
495-
raise TypeError(
496-
f"image_hidden_states type: {type(image_hidden_states)}"
497-
" is not supported")
493+
image_hidden_states: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
494+
vision_tower(pixel_values)
498495

499496
def select_features_fn(leaf: torch.Tensor):
500497
return self._select_image_features(
501498
leaf,
502499
strategy=self.config.vision_feature_select_strategy,
503500
)
504501

505-
selected_features = cast(
506-
Union[torch.Tensor, tuple[torch.Tensor, ...]],
507-
json_map_leaves(select_features_fn, image_hidden_states),
508-
)
509-
return selected_features
502+
return json_map_leaves(select_features_fn, image_hidden_states)
510503

511504
def _add_tarsier_split_tokens(
512505
self, projected_image_features: torch.Tensor) -> torch.Tensor:

vllm/utils/jsontree.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from collections.abc import Iterable
66
from functools import reduce
7-
from typing import Callable, TypeVar, Union, overload
7+
from typing import Callable, TypeVar, Union, cast, overload
88

99
_T = TypeVar("_T")
1010
_U = TypeVar("_U")
@@ -30,10 +30,42 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
3030
yield value
3131

3232

33+
@overload
34+
def json_map_leaves(
35+
func: Callable[[_T], _U],
36+
value: Union[_T, dict[str, _T]],
37+
) -> Union[_U, dict[str, _U]]:
38+
...
39+
40+
41+
@overload
42+
def json_map_leaves(
43+
func: Callable[[_T], _U],
44+
value: Union[_T, list[_T]],
45+
) -> Union[_U, list[_U]]:
46+
...
47+
48+
49+
@overload
50+
def json_map_leaves(
51+
func: Callable[[_T], _U],
52+
value: Union[_T, tuple[_T, ...]],
53+
) -> Union[_U, tuple[_U, ...]]:
54+
...
55+
56+
57+
@overload
3358
def json_map_leaves(
3459
func: Callable[[_T], _U],
3560
value: JSONTree[_T],
3661
) -> JSONTree[_U]:
62+
...
63+
64+
65+
def json_map_leaves(
66+
func: Callable[[_T], _U],
67+
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
68+
) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]:
3769
"""Apply a function to each leaf in a nested JSON structure."""
3870
if isinstance(value, dict):
3971
return {k: json_map_leaves(func, v) for k, v in value.items()}
@@ -45,6 +77,33 @@ def json_map_leaves(
4577
return func(value)
4678

4779

80+
@overload
81+
def json_reduce_leaves(
82+
func: Callable[[_T, _T], _T],
83+
value: Union[_T, dict[str, _T]],
84+
/,
85+
) -> _T:
86+
...
87+
88+
89+
@overload
90+
def json_reduce_leaves(
91+
func: Callable[[_T, _T], _T],
92+
value: Union[_T, list[_T]],
93+
/,
94+
) -> _T:
95+
...
96+
97+
98+
@overload
99+
def json_reduce_leaves(
100+
func: Callable[[_T, _T], _T],
101+
value: Union[_T, tuple[_T, ...]],
102+
/,
103+
) -> _T:
104+
...
105+
106+
48107
@overload
49108
def json_reduce_leaves(
50109
func: Callable[[_T, _T], _T],
@@ -65,10 +124,10 @@ def json_reduce_leaves(
65124

66125

67126
def json_reduce_leaves(
68-
func: Callable[..., Union[_T, _U]],
69-
value: JSONTree[_T],
70-
initial: _U = ..., # type: ignore[assignment]
71-
/,
127+
func: Callable[..., Union[_T, _U]],
128+
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
129+
initial: _U = cast(_U, ...), # noqa: B008
130+
/,
72131
) -> Union[_T, _U]:
73132
"""
74133
Apply a function of two arguments cumulatively to each leaf in a

0 commit comments

Comments
 (0)