Skip to content

Commit 60ef042

Browse files
committed
add comments for the meta device registration
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
1 parent 6e40e45 commit 60ef042

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

csrc/torch_binding_meta.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,33 @@
55
#include <torch_npu/csrc/framework/OpCommand.h>
66
#include <torch_npu/csrc/npu/Module.h>
77
#include "utils.h"
8+
/*
9+
* How to write a meta implementation for a custom operator (meta kernel):
10+
*
11+
* Meta implementations are used for shape and dtype inference, tracing, and export.
12+
* They do NOT perform any real computation or allocate device memory.
13+
* Instead, they return empty tensors with the correct shapes, dtypes, and device types.
14+
*
15+
* Steps to write a meta implementation:
16+
* 1. The function signature should match the operator's schema, but only use the arguments
17+
* necessary to infer output shapes and dtypes.
18+
* 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes.
19+
* 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype.
20+
* 4. Do NOT perform any real computation or data movement.
21+
* 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar.
22+
*
23+
* Example:
24+
* std::tuple<at::Tensor, at::Tensor> my_op_meta(
25+
* at::Tensor &input, int64_t some_param) {
26+
* // Infer output shape based on input and parameters
27+
* auto out_shape = ...;
28+
* at::Tensor out = at::empty_symint(out_shape, input.options());
29+
* // Return empty tensor(s) with correct shape/dtype
30+
* return {out, ...};
31+
* }
32+
*
33+
* See below for real examples.
34+
*/
835

936
namespace vllm_ascend {
1037
namespace meta {

vllm_ascend/ops/meta_registration.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,48 @@
11
import torch
22
from torch.library import Library
33

4+
# This file provides a template and registration utilities for writing "meta" implementations
5+
# of custom operators in Python for the vllm_ascend project.
6+
#
7+
# We offer two ways to implement meta implementations for custom ops:
8+
# 1. Python meta implementation (as shown in this file): Write a Python function that
9+
# takes the same arguments as your operator and returns empty tensors with the correct
10+
# shapes and dtypes. This is useful for rapid prototyping and for ops that are only
11+
# used in Python.
12+
# 2. C++ meta implementation: You can also implement the meta function in C++ for better
13+
# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp`
14+
# for examples of C++ meta implementations and how to register them.
15+
#
16+
# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which
17+
# is essential for supporting `torch.compile` and aclgraph.
18+
19+
20+
# How to add a new meta implementation in Python:
21+
# -------------------------------------
22+
# 1. Write a Python function that takes the same arguments as your operator, and returns
23+
# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes.
24+
# Do NOT perform any real computation or allocate device memory.
25+
#
26+
# 2. Register your meta function using `register_meta_if_necessary`, providing:
27+
# - The namespace (usually "_C" for custom ops)
28+
# - The operator name (as registered in C++)
29+
# - The Python meta function
30+
# - (Optional) The overload name, if your op has overloads
31+
#
32+
# 3. The registration utility will check if a meta implementation already exists for your op,
33+
# and only register if necessary. This avoids duplicate registrations.
34+
#
35+
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
36+
#
37+
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
38+
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
39+
# and aclgraph.
40+
#
41+
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
42+
443
lib = Library("_C", "IMPL")
544

45+
646
def register_meta_if_necessary(ns:str, op_name: str, fn, overload: str = ""):
747
if overload != "":
848
op_name = op_name + "." + overload

0 commit comments

Comments
 (0)