Skip to content

Commit 645b6cf

Browse files
committed
fix lint
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
1 parent e47df50 commit 645b6cf

File tree

2 files changed

+34
-38
lines changed

2 files changed

+34
-38
lines changed

vllm_ascend/ops/meta_registration.py renamed to vllm_ascend/meta_registration.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which
1717
# is essential for supporting `torch.compile` and aclgraph.
1818

19-
2019
# How to add a new meta implementation in Python:
2120
# -------------------------------------
2221
# 1. Write a Python function that takes the same arguments as your operator, and returns
@@ -35,56 +34,53 @@
3534
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
3635
#
3736
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
38-
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
37+
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
3938
# and aclgraph.
4039
#
4140
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
4241

4342
lib = Library("_C", "IMPL")
4443

4544

46-
def register_meta_if_necessary(ns:str, op_name: str, fn, overload: str = ""):
47-
if overload != "":
48-
op_name = op_name + "." + overload
49-
schema_to_find = ns + "::" + op_name
50-
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key("Meta")
51-
if schema_to_find in meta_impl_list:
52-
return
53-
lib.impl(op_name, fn, "Meta")
45+
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
46+
if overload != "":
47+
op_name = op_name + "." + overload
48+
schema_to_find = ns + "::" + op_name
49+
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key(
50+
"Meta")
51+
if schema_to_find in meta_impl_list:
52+
return
53+
lib.impl(op_name, fn, "Meta")
5454

55-
def rotary_embedding_meta(
56-
positions: torch.Tensor,
57-
query: torch.Tensor,
58-
key: torch.Tensor,
59-
head_size: int,
60-
cos_sin_cache: torch.Tensor,
61-
is_neox: bool):
6255

63-
num_tokens = positions.numel()
64-
query_hidden_size = query.numel() / num_tokens
65-
key_hidden_size = key.numel() / num_tokens
66-
num_heads = query_hidden_size / head_size
67-
num_kv_heads = key_hidden_size / head_size
56+
def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
57+
key: torch.Tensor, head_size: int,
58+
cos_sin_cache: torch.Tensor, is_neox: bool):
6859

69-
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
70-
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
71-
return query_dst, key_dst
60+
num_tokens = positions.numel()
61+
query_hidden_size = query.numel() // num_tokens
62+
key_hidden_size = key.numel() // num_tokens
63+
num_heads = query_hidden_size // head_size
64+
num_kv_heads = key_hidden_size // head_size
7265

66+
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
67+
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
68+
return query_dst, key_dst
7369

74-
def get_masked_input_and_mask_meta(
75-
input: torch.Tensor,
76-
org_vocab_start_index: int,
77-
org_vocab_end_index: int,
78-
num_org_vocab_padding: int,
79-
added_vocab_start_index: int,
80-
added_vocab_end_index: int):
8170

82-
masked_input = torch.empty_like(input)
83-
mask = torch.empty_like(input).to(torch.bool)
71+
def get_masked_input_and_mask_meta(input: torch.Tensor,
72+
org_vocab_start_index: int,
73+
org_vocab_end_index: int,
74+
num_org_vocab_padding: int,
75+
added_vocab_start_index: int,
76+
added_vocab_end_index: int):
8477

85-
return masked_input, mask
78+
masked_input = torch.empty_like(input)
79+
mask = torch.empty_like(input).to(torch.bool)
8680

81+
return masked_input, mask
8782

8883

8984
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
90-
register_meta_if_necessary("_C", "get_masked_input_and_mask", get_masked_input_and_mask_meta)
85+
register_meta_if_necessary("_C", "get_masked_input_and_mask",
86+
get_masked_input_and_mask_meta)

vllm_ascend/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,9 @@ def enable_custom_op():
215215
return _CUSTOM_OP_ENABLED
216216
try:
217217
# register custom ops into torch_library here
218-
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
219218
# register the meta implementation for custom kernel if necessary
220-
import vllm_ascend.ops.meta_registration # type: ignore # noqa: F401
219+
import vllm_ascend.meta_registration # type: ignore # noqa: F401
220+
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
221221
_CUSTOM_OP_ENABLED = True
222222
except ImportError:
223223
_CUSTOM_OP_ENABLED = False

0 commit comments

Comments
 (0)