|
1 | 1 | import torch |
2 | 2 | from torch.library import Library |
3 | 3 |
|
| 4 | +# This file provides a template and registration utilities for writing "meta" implementations |
| 5 | +# of custom operators in Python for the vllm_ascend project. |
| 6 | +# |
| 7 | +# We offer two ways to implement meta implementations for custom ops: |
| 8 | +# 1. Python meta implementation (as shown in this file): Write a Python function that |
| 9 | +# takes the same arguments as your operator and returns empty tensors with the correct |
| 10 | +# shapes and dtypes. This is useful for rapid prototyping and for ops that are only |
| 11 | +# used in Python. |
| 12 | +# 2. C++ meta implementation: You can also implement the meta function in C++ for better |
| 13 | +# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp` |
| 14 | +# for examples of C++ meta implementations and how to register them. |
| 15 | +# |
| 16 | +# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which |
| 17 | +# is essential for supporting `torch.compile` and aclgraph. |
| 18 | + |
| 19 | + |
| 20 | +# How to add a new meta implementation in Python: |
| 21 | +# ------------------------------------- |
| 22 | +# 1. Write a Python function that takes the same arguments as your operator, and returns |
| 23 | +# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes. |
| 24 | +# Do NOT perform any real computation or allocate device memory. |
| 25 | +# |
| 26 | +# 2. Register your meta function using `register_meta_if_necessary`, providing: |
| 27 | +# - The namespace (usually "_C" for custom ops) |
| 28 | +# - The operator name (as registered in C++) |
| 29 | +# - The Python meta function |
| 30 | +# - (Optional) The overload name, if your op has overloads |
| 31 | +# |
| 32 | +# 3. The registration utility will check if a meta implementation already exists for your op, |
| 33 | +# and only register if necessary. This avoids duplicate registrations. |
| 34 | +# |
| 35 | +# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask. |
| 36 | +# |
| 37 | +# 5. When developing new custom ops, always provide a meta implementation to enable tracing, |
| 38 | +# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile` |
| 39 | +# and aclgraph. |
| 40 | +# |
| 41 | +# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors |
| 42 | + |
4 | 43 | lib = Library("_C", "IMPL") |
5 | 44 |
|
| 45 | + |
6 | 46 | def register_meta_if_necessary(ns:str, op_name: str, fn, overload: str = ""): |
7 | 47 | if overload != "": |
8 | 48 | op_name = op_name + "." + overload |
|
0 commit comments