File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -3351,6 +3351,23 @@ def profile_run(self) -> None:
33513351 expected_num_items = max_mm_items_per_batch ,
33523352 )
33533353
3354+ # NOTE: This happens when encoder cache needs to store
3355+ # the embeddings that encoder outputs are scattered onto.
3356+ # In this case we create dummy embeddings of size
3357+ # (encode_budget, hidden_size) and scatter encoder
3358+ # output into it.
3359+ encoder_output_shape = dummy_encoder_outputs [0 ].shape
3360+ if encoder_output_shape [0 ] < encoder_budget :
3361+ expanded_outputs = []
3362+ for output in dummy_encoder_outputs :
3363+ expanded = output .new_zeros (
3364+ (encoder_budget , encoder_output_shape [- 1 ]))
3365+ num_tokens = output .shape [0 ]
3366+ expanded [:num_tokens ].copy_ (output )
3367+ expanded_outputs .append (expanded )
3368+
3369+ dummy_encoder_outputs = expanded_outputs
3370+
33543371 # Cache the dummy encoder outputs.
33553372 self .encoder_cache ["tmp" ] = dict (
33563373 enumerate (dummy_encoder_outputs ))
You can’t perform that action at this time.
0 commit comments