11# SPDX-License-Identifier: Apache-2.0
22
3+ from collections .abc import Iterable
34from dataclasses import dataclass
45from functools import cached_property
5- from typing import (TYPE_CHECKING , Any , Dict , Generic , Iterable , List , Literal ,
6- Optional , Tuple , Union , cast )
6+ from typing import TYPE_CHECKING , Any , Generic , Literal , Optional , Union , cast
77
88import torch
99from typing_extensions import NotRequired , TypedDict , TypeVar , assert_never
@@ -26,7 +26,7 @@ class TextPrompt(TypedDict):
2626 if the model supports it.
2727 """
2828
29- mm_processor_kwargs : NotRequired [Dict [str , Any ]]
29+ mm_processor_kwargs : NotRequired [dict [str , Any ]]
3030 """
3131 Optional multi-modal processor kwargs to be forwarded to the
3232 multimodal input mapper & processor. Note that if multiple modalities
@@ -38,10 +38,10 @@ class TextPrompt(TypedDict):
3838class TokensPrompt (TypedDict ):
3939 """Schema for a tokenized prompt."""
4040
41- prompt_token_ids : List [int ]
41+ prompt_token_ids : list [int ]
4242 """A list of token IDs to pass to the model."""
4343
44- token_type_ids : NotRequired [List [int ]]
44+ token_type_ids : NotRequired [list [int ]]
4545 """A list of token type IDs to pass to the cross encoder model."""
4646
4747 multi_modal_data : NotRequired ["MultiModalDataDict" ]
@@ -50,7 +50,7 @@ class TokensPrompt(TypedDict):
5050 if the model supports it.
5151 """
5252
53- mm_processor_kwargs : NotRequired [Dict [str , Any ]]
53+ mm_processor_kwargs : NotRequired [dict [str , Any ]]
5454 """
5555 Optional multi-modal processor kwargs to be forwarded to the
5656 multimodal input mapper & processor. Note that if multiple modalities
@@ -115,7 +115,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
115115
116116 decoder_prompt : Optional [_T2_co ]
117117
118- mm_processor_kwargs : NotRequired [Dict [str , Any ]]
118+ mm_processor_kwargs : NotRequired [dict [str , Any ]]
119119
120120
121121PromptType = Union [SingletonPrompt , ExplicitEncoderDecoderPrompt ]
@@ -136,10 +136,10 @@ class TokenInputs(TypedDict):
136136 type : Literal ["token" ]
137137 """The type of inputs."""
138138
139- prompt_token_ids : List [int ]
139+ prompt_token_ids : list [int ]
140140 """The token IDs of the prompt."""
141141
142- token_type_ids : NotRequired [List [int ]]
142+ token_type_ids : NotRequired [list [int ]]
143143 """The token type IDs of the prompt."""
144144
145145 prompt : NotRequired [str ]
@@ -164,12 +164,12 @@ class TokenInputs(TypedDict):
164164 Placeholder ranges for the multi-modal data.
165165 """
166166
167- multi_modal_hashes : NotRequired [List [str ]]
167+ multi_modal_hashes : NotRequired [list [str ]]
168168 """
169169 The hashes of the multi-modal data.
170170 """
171171
172- mm_processor_kwargs : NotRequired [Dict [str , Any ]]
172+ mm_processor_kwargs : NotRequired [dict [str , Any ]]
173173 """
174174 Optional multi-modal processor kwargs to be forwarded to the
175175 multimodal input mapper & processor. Note that if multiple modalities
@@ -179,14 +179,14 @@ class TokenInputs(TypedDict):
179179
180180
181181def token_inputs (
182- prompt_token_ids : List [int ],
183- token_type_ids : Optional [List [int ]] = None ,
182+ prompt_token_ids : list [int ],
183+ token_type_ids : Optional [list [int ]] = None ,
184184 prompt : Optional [str ] = None ,
185185 multi_modal_data : Optional ["MultiModalDataDict" ] = None ,
186186 multi_modal_inputs : Optional ["MultiModalKwargs" ] = None ,
187- multi_modal_hashes : Optional [List [str ]] = None ,
187+ multi_modal_hashes : Optional [list [str ]] = None ,
188188 multi_modal_placeholders : Optional ["MultiModalPlaceholderDict" ] = None ,
189- mm_processor_kwargs : Optional [Dict [str , Any ]] = None ,
189+ mm_processor_kwargs : Optional [dict [str , Any ]] = None ,
190190) -> TokenInputs :
191191 """Construct :class:`TokenInputs` from optional values."""
192192 inputs = TokenInputs (type = "token" , prompt_token_ids = prompt_token_ids )
@@ -255,7 +255,7 @@ def prompt(self) -> Optional[str]:
255255 assert_never (inputs ) # type: ignore[arg-type]
256256
257257 @cached_property
258- def prompt_token_ids (self ) -> List [int ]:
258+ def prompt_token_ids (self ) -> list [int ]:
259259 inputs = self .inputs
260260
261261 if inputs ["type" ] == "token" or inputs ["type" ] == "multimodal" :
@@ -264,7 +264,7 @@ def prompt_token_ids(self) -> List[int]:
264264 assert_never (inputs ) # type: ignore[arg-type]
265265
266266 @cached_property
267- def token_type_ids (self ) -> List [int ]:
267+ def token_type_ids (self ) -> list [int ]:
268268 inputs = self .inputs
269269
270270 if inputs ["type" ] == "token" or inputs ["type" ] == "multimodal" :
@@ -294,7 +294,7 @@ def multi_modal_data(self) -> "MultiModalDataDict":
294294 assert_never (inputs ) # type: ignore[arg-type]
295295
296296 @cached_property
297- def multi_modal_inputs (self ) -> Union [Dict , "MultiModalKwargs" ]:
297+ def multi_modal_inputs (self ) -> Union [dict , "MultiModalKwargs" ]:
298298 inputs = self .inputs
299299
300300 if inputs ["type" ] == "token" :
@@ -306,7 +306,7 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
306306 assert_never (inputs ) # type: ignore[arg-type]
307307
308308 @cached_property
309- def multi_modal_hashes (self ) -> List [str ]:
309+ def multi_modal_hashes (self ) -> list [str ]:
310310 inputs = self .inputs
311311
312312 if inputs ["type" ] == "token" :
@@ -331,7 +331,7 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
331331 assert_never (inputs ) # type: ignore[arg-type]
332332
333333 @cached_property
334- def mm_processor_kwargs (self ) -> Dict [str , Any ]:
334+ def mm_processor_kwargs (self ) -> dict [str , Any ]:
335335 inputs = self .inputs
336336
337337 if inputs ["type" ] == "token" :
@@ -355,7 +355,7 @@ def mm_processor_kwargs(self) -> Dict[str, Any]:
355355def build_explicit_enc_dec_prompt (
356356 encoder_prompt : _T1 ,
357357 decoder_prompt : Optional [_T2 ],
358- mm_processor_kwargs : Optional [Dict [str , Any ]] = None ,
358+ mm_processor_kwargs : Optional [dict [str , Any ]] = None ,
359359) -> ExplicitEncoderDecoderPrompt [_T1 , _T2 ]:
360360 if mm_processor_kwargs is None :
361361 mm_processor_kwargs = {}
@@ -368,9 +368,9 @@ def build_explicit_enc_dec_prompt(
368368def zip_enc_dec_prompts (
369369 enc_prompts : Iterable [_T1 ],
370370 dec_prompts : Iterable [Optional [_T2 ]],
371- mm_processor_kwargs : Optional [Union [Iterable [Dict [str , Any ]],
372- Dict [str , Any ]]] = None ,
373- ) -> List [ExplicitEncoderDecoderPrompt [_T1 , _T2 ]]:
371+ mm_processor_kwargs : Optional [Union [Iterable [dict [str , Any ]],
372+ dict [str , Any ]]] = None ,
373+ ) -> list [ExplicitEncoderDecoderPrompt [_T1 , _T2 ]]:
374374 """
375375 Zip encoder and decoder prompts together into a list of
376376 :class:`ExplicitEncoderDecoderPrompt` instances.
@@ -380,12 +380,12 @@ def zip_enc_dec_prompts(
380380 provided, it will be zipped with the encoder/decoder prompts.
381381 """
382382 if mm_processor_kwargs is None :
383- mm_processor_kwargs = cast (Dict [str , Any ], {})
383+ mm_processor_kwargs = cast (dict [str , Any ], {})
384384 if isinstance (mm_processor_kwargs , dict ):
385385 return [
386386 build_explicit_enc_dec_prompt (
387387 encoder_prompt , decoder_prompt ,
388- cast (Dict [str , Any ], mm_processor_kwargs ))
388+ cast (dict [str , Any ], mm_processor_kwargs ))
389389 for (encoder_prompt ,
390390 decoder_prompt ) in zip (enc_prompts , dec_prompts )
391391 ]
@@ -399,7 +399,7 @@ def zip_enc_dec_prompts(
399399
400400def to_enc_dec_tuple_list (
401401 enc_dec_prompts : Iterable [ExplicitEncoderDecoderPrompt [_T1 , _T2 ]],
402- ) -> List [ Tuple [_T1 , Optional [_T2 ]]]:
402+ ) -> list [ tuple [_T1 , Optional [_T2 ]]]:
403403 return [(enc_dec_prompt ["encoder_prompt" ],
404404 enc_dec_prompt ["decoder_prompt" ])
405405 for enc_dec_prompt in enc_dec_prompts ]
0 commit comments