@@ -104,15 +104,30 @@ def forward(self, x) -> torch.Tensor:
104104 return self .fc2 (self .act (self .fc1 (x )))
105105
106106
107+ # copy form vllm
107108class VisionRotaryEmbedding (nn .Module ):
108109 def __init__ (self , dim : int , theta : float = 10000.0 ) -> None :
109110 super ().__init__ ()
111+ self .dim = dim
112+ self .theta = theta
110113 self .inv_freq = 1.0 / (theta ** (torch .arange (0 , dim , 2 , dtype = torch .float ) / dim ))
114+ self ._seq_len_cached = 0
115+ self ._freqs_cached = None
116+
117+ def update_freqs_cache (self , seqlen : int ) -> None :
118+ if seqlen > self ._seq_len_cached :
119+ seqlen *= 2
120+ self ._seq_len_cached = seqlen
121+ self .inv_freq = 1.0 / (
122+ self .theta ** (torch .arange (0 , self .dim , 2 , dtype = torch .float , device = self .inv_freq .device ) / self .dim )
123+ )
124+ seq = torch .arange (seqlen , device = self .inv_freq .device , dtype = self .inv_freq .dtype )
125+ freqs = torch .outer (seq , self .inv_freq )
126+ self ._freqs_cached = freqs
111127
112128 def forward (self , seqlen : int ) -> torch .Tensor :
113- self .seq = torch .arange (seqlen , device = self .inv_freq .device , dtype = self .inv_freq .dtype )
114- self .freqs = torch .outer (self .seq , self .inv_freq )
115- return self .freqs
129+ self .update_freqs_cache (seqlen )
130+ return self ._freqs_cached [:seqlen ]
116131
117132
118133class VisionFlashAttention (nn .Module ):
@@ -130,17 +145,19 @@ def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> t
130145 return output
131146
132147 def forward (
133- self , hidden_states : torch .Tensor , cu_seqlens : torch .Tensor , rotary_pos_emb : torch .Tensor = None
148+ self ,
149+ hidden_states : torch .Tensor ,
150+ cu_seqlens : torch .Tensor ,
151+ max_seqlen : int = 0 ,
152+ rotary_pos_emb : torch .Tensor = None ,
134153 ) -> torch .Tensor :
135154 seq_length = hidden_states .shape [0 ]
136155 q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
137- q = self .apply_rotary_pos_emb_vision (q .unsqueeze (0 ), rotary_pos_emb . cuda () )
138- k = self .apply_rotary_pos_emb_vision (k .unsqueeze (0 ), rotary_pos_emb . cuda () )
156+ q = self .apply_rotary_pos_emb_vision (q .unsqueeze (0 ), rotary_pos_emb )
157+ k = self .apply_rotary_pos_emb_vision (k .unsqueeze (0 ), rotary_pos_emb )
139158 q = q .squeeze (0 )
140159 k = k .squeeze (0 )
141160
142- cu_seqlens = cu_seqlens .to (q .device , torch .int32 , non_blocking = True )
143- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
144161 attn_output = g_cache_manager .alloc_tensor (q .shape , q .dtype , device = q .device )
145162
146163 flash_attention_fwd (q , k , v , attn_output , cu_seqlens , max_seqlen )
@@ -159,9 +176,9 @@ def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None:
159176 self .attn = VisionFlashAttention (embed_dim , num_heads = num_heads )
160177 self .mlp = VisionMlp (dim = embed_dim , hidden_dim = mlp_hidden_dim , hidden_act = hidden_act )
161178
162- def forward (self , hidden_states , cu_seqlens , rotary_pos_emb ) -> torch .Tensor :
179+ def forward (self , hidden_states , cu_seqlens , max_seqlen , rotary_pos_emb ) -> torch .Tensor :
163180 hidden_states = hidden_states + self .attn (
164- self .norm1 (hidden_states ), cu_seqlens = cu_seqlens , rotary_pos_emb = rotary_pos_emb
181+ self .norm1 (hidden_states ), cu_seqlens = cu_seqlens , max_seqlen = max_seqlen , rotary_pos_emb = rotary_pos_emb
165182 )
166183 hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
167184 return hidden_states
@@ -271,9 +288,8 @@ def rot_pos_emb(self, grid_thw):
271288 pos_shape = (h // s , s , w // s , s )
272289 hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
273290 wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
274- hpos_ids , wpos_ids = hpos_ids .reshape (pos_shape ), wpos_ids .reshape (pos_shape )
275- hpos_ids , wpos_ids = hpos_ids .permute (0 , 2 , 1 , 3 ), wpos_ids .permute (0 , 2 , 1 , 3 )
276- hpos_ids , wpos_ids = hpos_ids .flatten (), wpos_ids .flatten ()
291+ hpos_ids = hpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
292+ wpos_ids = wpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
277293
278294 pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ))
279295 pos_ids = torch .cat (pos_ids , dim = 0 )
@@ -284,14 +300,18 @@ def rot_pos_emb(self, grid_thw):
284300
285301 def forward (self , hidden_states : torch .Tensor , grid_thw : torch .Tensor ) -> torch .Tensor :
286302 hidden_states = self .patch_embed (hidden_states )
287- rotary_pos_emb = self .rot_pos_emb (grid_thw )
303+ rotary_pos_emb = self .rot_pos_emb (grid_thw ). to ( "cuda" , non_blocking = True )
288304 cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).cumsum (
289305 dim = 0 , dtype = torch .int32
290306 )
291307 cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
308+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
292309
310+ cu_seqlens = cu_seqlens .to ("cuda" , non_blocking = True )
293311 for blk in self .blocks :
294- hidden_states = blk (hidden_states , cu_seqlens = cu_seqlens , rotary_pos_emb = rotary_pos_emb )
312+ hidden_states = blk (
313+ hidden_states , cu_seqlens = cu_seqlens , max_seqlen = max_seqlen , rotary_pos_emb = rotary_pos_emb
314+ )
295315 return self .merger (hidden_states )
296316
297317 def load_image (self , img : List [ImageItem ]):
0 commit comments