Skip to content

Commit

Permalink
Move op definition for custom kernel from C++ to Python (#949)
Browse files Browse the repository at this point in the history
move op def to python
  • Loading branch information
gau-nernst authored Sep 26, 2024
1 parent e83c35d commit b149edb
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 33 deletions.
6 changes: 3 additions & 3 deletions torchao/csrc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ To learn more about custom ops in PyTorch you can refer to the [PyTorch Custom O

## How to add your own kernel in ao

We've integrated a test kernel which implements a non-maximum supression (NMS) op which you can use as a template for your own kernels.
We've integrated several kernels which you can use as a template for your own kernels. `tensor_core_tiled_layout` is the most straight-forward to get started with.

1. Install the cudatoolkit https://anaconda.org/conda-forge/cudatoolkit
2. In `csrc/cuda` author your custom kernel and ensure you expose a `TORCH_LIBRARY_IMPL` which will expose `torchao::your_custom_kernel`
3. In `csrc/` author a `cpp` stub which will include a `TORCH_LIBRARY_FRAGMENT` which will place your custom kernel in the `torchao.ops` namespace and also expose a public function with the right arguments
4. In `torchao/ops.py` is where you'll expose the python API which your new end users will leverage
3. In `torchao/ops.py`, define your op signature at the top of the file. You can refer to [this](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md) on how to write the signature correctly
4. `torchao/ops.py` is also where you'll expose the python API which your new end users will leverage
5. Write a new test in `test/test_ops.py` which most importantly needs to pass `opcheck()`, this ensures that your custom kernel composes out of the box with `torch.compile()`

And that's it! Once CI passes and your code merged you'll be able to point people to `torchao.ops.your_custom_kernel`. If you're working on an interesting kernel and would like someone else to handle the release and package management please feel free to open an issue.
Expand Down
8 changes: 0 additions & 8 deletions torchao/csrc/fp6_llm.cpp

This file was deleted.

8 changes: 0 additions & 8 deletions torchao/csrc/sparse_marlin.cpp

This file was deleted.

10 changes: 0 additions & 10 deletions torchao/csrc/tensor_core_tiled_layout.cpp

This file was deleted.

22 changes: 18 additions & 4 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4


lib = torch.library.Library("torchao", "FRAGMENT")
lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor")
lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor")
lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor")
lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor")


def register_custom_op(name):
def decorator(func):
if TORCH_VERSION_AT_LEAST_2_4:
Expand Down Expand Up @@ -39,7 +46,14 @@ def quant_llm_linear(


@register_custom_op("torchao::quant_llm_linear")
def _(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK = 1):
def _(
EXPONENT: int,
MANTISSA: int,
_in_feats: Tensor,
_weights: Tensor,
_scales: Tensor,
splitK: int = 1,
) -> Tensor:
torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D")
torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}")
torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D")
Expand Down Expand Up @@ -76,7 +90,7 @@ def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Ten
)


@register_custom_op(f"torchao::unpack_tensor_core_tiled_layout")
@register_custom_op("torchao::unpack_tensor_core_tiled_layout")
def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor:
torch._check(
packed_w.dim() == 4,
Expand Down Expand Up @@ -127,7 +141,7 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens
)


@register_custom_op(f"torchao::dequantize_tensor_core_tiled_layout")
@register_custom_op("torchao::dequantize_tensor_core_tiled_layout")
def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor:
# packed_w preconditions
torch._check(
Expand Down Expand Up @@ -192,7 +206,7 @@ def marlin_24_gemm(
)


@register_custom_op(f"torchao::marlin_24_gemm")
@register_custom_op("torchao::marlin_24_gemm")
def _(
x: Tensor,
weight_marlin: Tensor,
Expand Down

0 comments on commit b149edb

Please sign in to comment.