Skip to content

Commit 88577b3

Browse files
committed
0808-temp
1 parent eba4b00 commit 88577b3

File tree

2 files changed

+36
-16
lines changed

2 files changed

+36
-16
lines changed

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,30 @@ def forward(self, x) -> torch.Tensor:
104104
return self.fc2(self.act(self.fc1(x)))
105105

106106

107+
# copy form vllm
107108
class VisionRotaryEmbedding(nn.Module):
108109
def __init__(self, dim: int, theta: float = 10000.0) -> None:
109110
super().__init__()
111+
self.dim = dim
112+
self.theta = theta
110113
self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
114+
self._seq_len_cached = 0
115+
self._freqs_cached = None
116+
117+
def update_freqs_cache(self, seqlen: int) -> None:
118+
if seqlen > self._seq_len_cached:
119+
seqlen *= 2
120+
self._seq_len_cached = seqlen
121+
self.inv_freq = 1.0 / (
122+
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim)
123+
)
124+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
125+
freqs = torch.outer(seq, self.inv_freq)
126+
self._freqs_cached = freqs
111127

112128
def forward(self, seqlen: int) -> torch.Tensor:
113-
self.seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
114-
self.freqs = torch.outer(self.seq, self.inv_freq)
115-
return self.freqs
129+
self.update_freqs_cache(seqlen)
130+
return self._freqs_cached[:seqlen]
116131

117132

118133
class VisionFlashAttention(nn.Module):
@@ -130,17 +145,19 @@ def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> t
130145
return output
131146

132147
def forward(
133-
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
148+
self,
149+
hidden_states: torch.Tensor,
150+
cu_seqlens: torch.Tensor,
151+
max_seqlen: int = 0,
152+
rotary_pos_emb: torch.Tensor = None,
134153
) -> torch.Tensor:
135154
seq_length = hidden_states.shape[0]
136155
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
137-
q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb.cuda())
138-
k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb.cuda())
156+
q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
157+
k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)
139158
q = q.squeeze(0)
140159
k = k.squeeze(0)
141160

142-
cu_seqlens = cu_seqlens.to(q.device, torch.int32, non_blocking=True)
143-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
144161
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
145162

146163
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
@@ -159,9 +176,9 @@ def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None:
159176
self.attn = VisionFlashAttention(embed_dim, num_heads=num_heads)
160177
self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act)
161178

162-
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
179+
def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_pos_emb) -> torch.Tensor:
163180
hidden_states = hidden_states + self.attn(
164-
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
181+
self.norm1(hidden_states), cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, rotary_pos_emb=rotary_pos_emb
165182
)
166183
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
167184
return hidden_states
@@ -271,9 +288,8 @@ def rot_pos_emb(self, grid_thw):
271288
pos_shape = (h // s, s, w // s, s)
272289
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
273290
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
274-
hpos_ids, wpos_ids = hpos_ids.reshape(pos_shape), wpos_ids.reshape(pos_shape)
275-
hpos_ids, wpos_ids = hpos_ids.permute(0, 2, 1, 3), wpos_ids.permute(0, 2, 1, 3)
276-
hpos_ids, wpos_ids = hpos_ids.flatten(), wpos_ids.flatten()
291+
hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
292+
wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
277293

278294
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1))
279295
pos_ids = torch.cat(pos_ids, dim=0)
@@ -284,14 +300,18 @@ def rot_pos_emb(self, grid_thw):
284300

285301
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
286302
hidden_states = self.patch_embed(hidden_states)
287-
rotary_pos_emb = self.rot_pos_emb(grid_thw)
303+
rotary_pos_emb = self.rot_pos_emb(grid_thw).to("cuda", non_blocking=True)
288304
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
289305
dim=0, dtype=torch.int32
290306
)
291307
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
308+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
292309

310+
cu_seqlens = cu_seqlens.to("cuda", non_blocking=True)
293311
for blk in self.blocks:
294-
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
312+
hidden_states = blk(
313+
hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, rotary_pos_emb=rotary_pos_emb
314+
)
295315
return self.merger(hidden_states)
296316

297317
def load_image(self, img: List[ImageItem]):

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def flash_attention_fwd(q, k, v, o, cu_seqlens, max_seqlen):
215215
统一的 Flash Attention 接口。如果 sgl_kernel 存在,
216216
则使用 sgl_kernel里的接口,否则使用 Triton 版本。
217217
"""
218-
if _flash_attn_v3_available and is_hopper() and False:
218+
if _flash_attn_v3_available and is_hopper():
219219
flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen)
220220
else:
221221
_flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlen)

0 commit comments

Comments
 (0)