File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -3429,6 +3429,23 @@ def profile_run(self) -> None:
34293429 expected_num_items = max_mm_items_per_batch ,
34303430 )
34313431
3432+ # NOTE: This happens when encoder cache needs to store
3433+ # the embeddings that encoder outputs are scattered onto.
3434+ # In this case we create dummy embeddings of size
3435+ # (encode_budget, hidden_size) and scatter encoder
3436+ # output into it.
3437+ encoder_output_shape = dummy_encoder_outputs [0 ].shape
3438+ if encoder_output_shape [0 ] < encoder_budget :
3439+ expanded_outputs = []
3440+ for output in dummy_encoder_outputs :
3441+ expanded = output .new_zeros (
3442+ (encoder_budget , encoder_output_shape [- 1 ]))
3443+ num_tokens = output .shape [0 ]
3444+ expanded [:num_tokens ].copy_ (output )
3445+ expanded_outputs .append (expanded )
3446+
3447+ dummy_encoder_outputs = expanded_outputs
3448+
34323449 # Cache the dummy encoder outputs.
34333450 self .encoder_cache ["tmp" ] = dict (
34343451 enumerate (dummy_encoder_outputs ))
You can’t perform that action at this time.
0 commit comments