Skip to content

Commit 1aa162e

Browse files
authored
Apply torchfix (#15532)
Signed-off-by: cyy <cyyever@outlook.com>
1 parent cf5c8f1 commit 1aa162e

File tree

5 files changed

+15
-11
lines changed

5 files changed

+15
-11
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -884,9 +884,8 @@ def _sdpa_attention(
884884

885885
for i, seq_len in enumerate(seq_lens):
886886
end = start + seq_len
887-
with torch.backends.cuda.sdp_kernel(enable_math=True,
888-
enable_flash=False,
889-
enable_mem_efficient=False):
887+
with torch.nn.attention.sdpa_kernel(
888+
torch.nn.attention.SDPBackend.MATH):
890889
sub_out = torch.nn.functional.scaled_dot_product_attention(
891890
query[:, start:end, :],
892891
key[:, start:end, :],

vllm/lora/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,9 @@ def from_local_checkpoint(
272272
f" target modules in {expected_lora_modules}"
273273
f" but received {unexpected_modules}."
274274
f" Please verify that the loaded LoRA module is correct")
275-
tensors = torch.load(lora_bin_file_path, map_location=device)
275+
tensors = torch.load(lora_bin_file_path,
276+
map_location=device,
277+
weights_only=True)
276278
else:
277279
raise ValueError(f"{lora_dir} doesn't contain tensors")
278280

vllm/model_executor/models/nemotron.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def _cast_if_autocast_enabled(*args):
6363
if not torch.is_autocast_enabled():
6464
return args
6565
else:
66-
return torch.cuda.amp.autocast_mode._cast(
67-
args, torch.get_autocast_gpu_dtype())
66+
return torch.amp.autocast_mode._cast(
67+
args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype())
6868

6969

7070
class NemotronLayerNorm1P(nn.LayerNorm):
@@ -89,7 +89,7 @@ def forward(
8989
residual = x
9090
args = _cast_if_autocast_enabled(x, self.normalized_shape,
9191
self.weight + 1, self.bias, self.eps)
92-
with torch.cuda.amp.autocast(enabled=False):
92+
with torch.amp.autocast("cuda", enabled=False):
9393
x = torch.nn.functional.layer_norm(*args)
9494
return x if residual is None else (x, residual)
9595

vllm/model_executor/models/phi4mm_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,9 +1766,12 @@ def forward(
17661766
if mask.dtype != q.dtype:
17671767
attn_mask = attn_mask.to(q.dtype)
17681768

1769-
with torch.backends.cuda.sdp_kernel(enable_flash=True,
1770-
enable_math=True,
1771-
enable_mem_efficient=True):
1769+
with torch.nn.attention.sdpa_kernel([
1770+
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
1771+
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
1772+
torch.nn.attention.SDPBackend.MATH,
1773+
torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
1774+
]):
17721775
x = torch.nn.functional.scaled_dot_product_attention(
17731776
q,
17741777
k,

vllm/multimodal/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def load_base64(self, media_type: str, data: str) -> torch.Tensor:
149149
return self.load_bytes(base64.b64decode(data))
150150

151151
def load_file(self, filepath: Path) -> torch.Tensor:
152-
return torch.load(filepath)
152+
return torch.load(filepath, weights_only=True)
153153

154154
def encode_base64(self, media: torch.Tensor) -> str:
155155
return base64.b64encode(media.numpy()).decode('utf-8')

0 commit comments

Comments
 (0)