-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This is mainly for debug for now, but it allows to disable the library extensions and use instead implementations based on pytorch operators only.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Quanto operations library | ||
|
||
This contains the `quanto::` operations, available in python under `torch.ops.quanto`. | ||
|
||
To add a new operation: | ||
|
||
- add a definition for the operation in `library/ops.py`, | ||
- provide a default implementation using pytorch operators only under `library/python`, | ||
- provide optimized kernels for all devices under `library/ext`. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,3 @@ | ||
import torch | ||
|
||
from .cpp import * | ||
from .ext import * | ||
from .ops import * | ||
|
||
|
||
if torch.backends.mps.is_available(): | ||
from .mps import * | ||
from .python import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Quanto library extensions | ||
|
||
This folder contains the implementations of all `quanto_ext::` operations. | ||
|
||
This namespace corresponds to the device-specifc optimized implementations of quanto operations. | ||
|
||
Implementations can be provided as part of: | ||
|
||
- the generic C++ pytorch extension under `cpp`, | ||
- the CUDA extension under `cuda`, | ||
- the Metal Performance Shader extension under `mps`. | ||
|
||
The operations are defined in `library/ops.py`. | ||
|
||
To provide an implementation for specific device types, use the following syntax: | ||
|
||
``` | ||
@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"]) | ||
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor: | ||
return ext().unpack(t, bits) | ||
``` | ||
|
||
Please refer to each extension folder to see how to add the actual implementation. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import torch | ||
|
||
from .cpp import * | ||
|
||
|
||
if torch.backends.mps.is_available(): | ||
from .mps import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Quanto generic C++ extension | ||
|
||
Kernels in this extension must use only the C++ syntax. | ||
|
||
They can use any pytorch operation defined under `aten::` or `c10::`. | ||
|
||
To add a new implementation for an operation defined in `library./ops.py`: | ||
|
||
- add the corresponding `.cpp` file to the list of sources in `__init__.py`, | ||
- add a binding to `pybind_module.cpp`, | ||
- provide an implementation calling the binding in `__init__.py`. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Quanto Metal Performance Shaders extension | ||
|
||
To add a new implementation for an operation defined in `library./ops.py`: | ||
|
||
- add the corresponding `.mm` file to the list of sources in `__init__.py`, | ||
- add a binding to `pybind_module.cpp`, | ||
- provide an implementation calling the binding in `__init__.py`. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,48 @@ | ||
from torch.library import define | ||
from contextlib import contextmanager | ||
|
||
import torch | ||
|
||
|
||
# This file contains the definitions of all operations under torch.ops.quanto | ||
|
||
define("quanto::unpack", "(Tensor self, int bits) -> Tensor") | ||
|
||
_ext_enabled = True | ||
|
||
|
||
@contextmanager | ||
def disable_extensions(): | ||
"""Disable quanto extensions (debug)""" | ||
try: | ||
global _ext_enabled | ||
_ext_enabled = False | ||
yield | ||
finally: | ||
_ext_enabled = True | ||
|
||
|
||
def define(name, schema): | ||
"""Define a new quanto operation. | ||
The operation will actually be defined in three libraries: | ||
- the top-level quanto library as quanto::<op>, | ||
- the quanto python library as quanto_py::<op>, | ||
- the quanto extension library as quanto_ext::<op>. | ||
Only the implementations for the python and extension library need | ||
to be provided: the top-level implementation for the operation is | ||
provided when calling this method and simply routes the calls towards | ||
either the python or extension implementations based on the selected | ||
mode. | ||
""" | ||
for libname in ["quanto", "quanto_py", "quanto_ext"]: | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
dacorvo
Author
Collaborator
|
||
torch.library.define(f"{libname}::{name}", schema) | ||
|
||
# Provide the inplementation for all dispatch key in the main library | ||
@torch.library.impl("quanto::unpack", "default") | ||
def impl(*args, **kwargs): | ||
if _ext_enabled: | ||
return getattr(torch.ops.quanto_ext, name)(*args, **kwargs) | ||
return getattr(torch.ops.quanto_py, name)(*args, **kwargs) | ||
|
||
|
||
define("unpack", "(Tensor self, int bits) -> Tensor") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Quanto library python/pytorch operations | ||
|
||
This folder contains the implementations of all `quanto_py::` operations. | ||
|
||
This namespace corresponds to the default, python-only implementations of quanto operations. | ||
|
||
The operations are defined in `library/ops.py`. | ||
|
||
To provide an implementation for an operation, use the following syntax: | ||
|
||
``` | ||
@torch.library.impl("quanto_py::unpack", "default") | ||
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor: | ||
... | ||
``` | ||
|
||
The implementation **must** support all device types. This is true if it | ||
is a composition of built-in PyTorch operators. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .unpack import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,21 @@ | ||
from contextlib import nullcontext | ||
|
||
import pytest | ||
import torch | ||
|
||
from quanto.library import disable_extensions | ||
from quanto.tensor.core import int2, int4, pack_weights | ||
|
||
|
||
@pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"]) | ||
@pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"]) | ||
def test_unpack(bits, shape, device): | ||
@pytest.mark.parametrize("use_ext", [True, False], ids=["ext", "no-ext"]) | ||
def test_unpack(bits, shape, use_ext, device): | ||
qmax = 2**bits | ||
a = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) | ||
bitsdtype = int2 if bits == 2 else int4 | ||
packed_a = pack_weights(a, bitsdtype) | ||
unpacked_a = torch.ops.quanto.unpack(packed_a, bits) | ||
context = nullcontext() if use_ext else disable_extensions() | ||
with context: | ||
unpacked_a = torch.ops.quanto.unpack(packed_a, bits) | ||
assert torch.equal(unpacked_a, a) |
As discussed in the other thread, having this double level of dispatching adds quite a bit of overhead.
Couldn't we have a single
@torch.library.impl
(or better yet when possibleTORCH_LIBRARY_IMPL
in C++), with adefault
being python, and other implem forcuda
,cpu
, etc. if available? single level?