@@ -488,7 +488,9 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
488488
489489 indices = torch .stack ([idx00 , idx01 , idx10 , idx11 ], dim = 0 ).reshape (4 , - 1 )
490490 weights = torch .stack ([w00 , w01 , w10 , w11 ], dim = 0 ).reshape (4 , - 1 , 1 )
491- weights = weights .to (dtype = self .dtype , device = self .device )
491+ weights = weights .to (
492+ dtype = self .dtype , device = self .device , non_blocking = True
493+ )
492494
493495 embeds = self .pos_embed (indices )
494496 weighted_embeds = embeds * weights
@@ -524,14 +526,15 @@ def forward(
524526 x : torch .Tensor ,
525527 grid_thw : list [list [int ]],
526528 ) -> torch .Tensor :
527- hidden_states = x .to (device = self .device , dtype = self .dtype )
529+ hidden_states = x .to (device = self .device , dtype = self .dtype , non_blocking = True )
528530 hidden_states = self .patch_embed (hidden_states )
529531
530532 pos_embeds = self .fast_pos_embed_interpolate (grid_thw )
531533 hidden_states = hidden_states + pos_embeds
532534 rotary_pos_emb = self .rot_pos_emb (grid_thw )
535+ rotary_pos_emb = rotary_pos_emb .to (hidden_states .device , non_blocking = True )
533536
534- grid_thw_tensor = torch .tensor (grid_thw , device = self . device , dtype = torch .int32 )
537+ grid_thw_tensor = torch .tensor (grid_thw , dtype = torch .int32 )
535538
536539 cu_seqlens = torch .repeat_interleave (
537540 grid_thw_tensor [:, 1 ] * grid_thw_tensor [:, 2 ], grid_thw_tensor [:, 0 ]
@@ -542,8 +545,8 @@ def forward(
542545 cu_seqlens = torch .cat ([cu_seqlens .new_zeros (1 ), cu_seqlens ])
543546
544547 hidden_states = hidden_states .unsqueeze (1 )
545- rotary_pos_emb = rotary_pos_emb .to (hidden_states .device )
546548 max_seqlen , seqlens = self .compute_attn_mask_seqlen (cu_seqlens )
549+ cu_seqlens = cu_seqlens .to (self .device , non_blocking = True )
547550
548551 deepstack_feature_lists = []
549552 for layer_num , blk in enumerate (self .blocks ):
0 commit comments