@@ -180,11 +180,14 @@ def _get_dummy_mm_inputs(
180180 def _get_mm_num_tokens (
181181 self ,
182182 mm_inputs : MultiModalInputs ,
183+ mm_embeddings_only : bool = True ,
183184 ) -> Mapping [str , int ]:
184185 placeholders_by_modality = mm_inputs ["mm_placeholders" ]
185186
186187 return {
187- modality : sum (item .get_num_embeds () for item in placeholders )
188+ modality :
189+ sum (item .get_num_embeds () if mm_embeddings_only else item .length
190+ for item in placeholders )
188191 for modality , placeholders in placeholders_by_modality .items ()
189192 }
190193
@@ -253,10 +256,11 @@ def get_decoder_dummy_data(
253256 multi_modal_placeholders = mm_inputs ["mm_placeholders" ],
254257 )
255258
256- def get_mm_max_tokens (
259+ def _get_mm_max_tokens (
257260 self ,
258261 seq_len : int ,
259262 mm_counts : Optional [Mapping [str , int ]] = None ,
263+ mm_embeddings_only : bool = True ,
260264 ) -> Mapping [str , int ]:
261265 if mm_counts is None :
262266 mm_counts = self .get_mm_limits ()
@@ -285,4 +289,25 @@ def get_mm_max_tokens(
285289 return max_tokens_per_item
286290
287291 mm_inputs = self ._get_dummy_mm_inputs (seq_len , mm_counts )
288- return self ._get_mm_num_tokens (mm_inputs )
292+ return self ._get_mm_num_tokens (mm_inputs ,
293+ mm_embeddings_only = mm_embeddings_only )
294+
295+ def get_mm_max_contiguous_tokens (
296+ self ,
297+ seq_len : int ,
298+ mm_counts : Optional [Mapping [str , int ]] = None ,
299+ ):
300+ """
301+ Returns the maximum length of the multimodal (image placeholders+text)
302+ tokens, including any break/text tokens in-between image embeddings.
303+
304+ <im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>
305+ Returns 9, even when the number of image embeddings is 6.
306+
307+ This is important to take into account when profiling and
308+ initializing the encoder cache size.
309+ """
310+
311+ return self ._get_mm_max_tokens (seq_len ,
312+ mm_counts ,
313+ mm_embeddings_only = False )
0 commit comments