Skip to content

Commit

Permalink
feat(library): add default python operations
Browse files Browse the repository at this point in the history
This is mainly for debug for now, but it allows to disable the library extensions and
use instead implementations based on pytorch operators only.
  • Loading branch information
dacorvo committed Feb 3, 2024
1 parent 5c4951b commit 8ed7e69
Show file tree
Hide file tree
Showing 20 changed files with 149 additions and 30 deletions.
9 changes: 9 additions & 0 deletions quanto/library/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Quanto operations library

This contains the `quanto::` operations, available in python under `torch.ops.quanto`.

To add a new operation:

- add a definition for the operation in `library/ops.py`,
- provide a default implementation using pytorch operators only under `library/python`,
- provide optimized kernels for all devices under `library/ext`.
9 changes: 2 additions & 7 deletions quanto/library/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import torch

from .cpp import *
from .ext import *
from .ops import *


if torch.backends.mps.is_available():
from .mps import *
from .python import *
Empty file removed quanto/library/builtin/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions quanto/library/ext/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Quanto library extensions

This folder contains the implementations of all `quanto_ext::` operations.

This namespace corresponds to the device-specifc optimized implementations of quanto operations.

Implementations can be provided as part of:

- the generic C++ pytorch extension under `cpp`,
- the CUDA extension under `cuda`,
- the Metal Performance Shader extension under `mps`.

The operations are defined in `library/ops.py`.

To provide an implementation for specific device types, use the following syntax:

```
@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"])
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
return ext().unpack(t, bits)
```

Please refer to each extension folder to see how to add the actual implementation.
7 changes: 7 additions & 0 deletions quanto/library/ext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch

from .cpp import *


if torch.backends.mps.is_available():
from .mps import *
11 changes: 11 additions & 0 deletions quanto/library/ext/cpp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Quanto generic C++ extension

Kernels in this extension must use only the C++ syntax.

They can use any pytorch operation defined under `aten::` or `c10::`.

To add a new implementation for an operation defined in `library./ops.py`:

- add the corresponding `.cpp` file to the list of sources in `__init__.py`,
- add a binding to `pybind_module.cpp`,
- provide an implementation calling the binding in `__init__.py`.
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@
__all__ = []


_backend = None
_ext = None


def backend():
"""Helper to load the CPU backend only when it is required"""
global _backend
if _backend is None:
def ext():
"""Helper to load the CPU ext only when it is required"""
global _ext
if _ext is None:
module_path = os.path.dirname(__file__)
_backend = load(
_ext = load(
name="quanto_cpp",
sources=[
f"{module_path}/unpack.cpp",
f"{module_path}/pybind_module.cpp",
],
extra_cflags=["-O3"],
)
return _backend
return _ext


@impl("quanto::unpack", ["CPU", "CUDA"])
@impl("quanto_ext::unpack", ["CPU", "CUDA"])
def unpack_cpp(t: torch.Tensor, bits: int):
return backend().unpack(t, bits)
return ext().unpack(t, bits)
File renamed without changes.
File renamed without changes.
File renamed without changes.
7 changes: 7 additions & 0 deletions quanto/library/ext/mps/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Quanto Metal Performance Shaders extension

To add a new implementation for an operation defined in `library./ops.py`:

- add the corresponding `.mm` file to the list of sources in `__init__.py`,
- add a binding to `pybind_module.cpp`,
- provide an implementation calling the binding in `__init__.py`.
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@
__all__ = []


_backend = None
_ext = None


def backend():
"""Helper to load the MPS backend only when it is required"""
global _backend
if _backend is None:
def ext():
"""Helper to load the MPS extension only when it is required"""
global _ext
if _ext is None:
module_path = os.path.dirname(__file__)
_backend = load(
_ext = load(
name="quanto_mps",
sources=[f"{module_path}/unpack.mm", f"{module_path}/pybind_module.cpp"],
extra_cflags=["-std=c++17"],
)
return _backend
return _ext


@impl("quanto::unpack", "MPS")
@impl("quanto_ext::unpack", "MPS")
def unpack_mps(t: torch.Tensor, bits: int):
return backend().unpack(t, bits)
return ext().unpack(t, bits)
File renamed without changes.
File renamed without changes.
File renamed without changes.
46 changes: 44 additions & 2 deletions quanto/library/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,48 @@
from torch.library import define
from contextlib import contextmanager

import torch


# This file contains the definitions of all operations under torch.ops.quanto

define("quanto::unpack", "(Tensor self, int bits) -> Tensor")

_ext_enabled = True


@contextmanager
def disable_extensions():
"""Disable quanto extensions (debug)"""
try:
global _ext_enabled
_ext_enabled = False
yield
finally:
_ext_enabled = True


def define(name, schema):
"""Define a new quanto operation.
The operation will actually be defined in three libraries:
- the top-level quanto library as quanto::<op>,
- the quanto python library as quanto_py::<op>,
- the quanto extension library as quanto_ext::<op>.
Only the implementations for the python and extension library need
to be provided: the top-level implementation for the operation is
provided when calling this method and simply routes the calls towards
either the python or extension implementations based on the selected
mode.
"""
for libname in ["quanto", "quanto_py", "quanto_ext"]:

This comment has been minimized.

Copy link
@fxmarty

fxmarty Jul 17, 2024

Contributor

As discussed in the other thread, having this double level of dispatching adds quite a bit of overhead.

Couldn't we have a single @torch.library.impl (or better yet when possible TORCH_LIBRARY_IMPL in C++), with a default being python, and other implem for cuda, cpu, etc. if available? single level?

This comment has been minimized.

Copy link
@dacorvo

dacorvo Jul 17, 2024

Author Collaborator

Can you elaborate on the overhead ? This is not crystal clear to me.

This comment has been minimized.

Copy link
@fxmarty

fxmarty Jul 17, 2024

Contributor

For sure. What I meant is that currently there is two levels of dispatching, following one each other: torch.ops.quanto.whatever and then torch.ops.quanto_ext.whatever. I assume this adds overhead.
image

This comment has been minimized.

Copy link
@dacorvo

dacorvo Jul 17, 2024

Author Collaborator

That's clearer, thanks ! The dual library comes from the early days of torch.library (a few months back only), and serves to:

  • provide a fallback if a kernel is flaky,
  • allow to disable extensions.

The first one is not really needed (better to have an actual error than to very very slowly falling back to python after having cauqht an exception), and the second is to be replaced by this.

Anyway, I already implemented one-level ops for tinygemm, which has support for multiple devices.

It should be even simpler for kernel-specific op like those required for AWQ and Marlin kernels, as they can only run on CUDA: they could be defined and declared directly in ext/cuda/.

This comment has been minimized.

Copy link
@fxmarty

fxmarty Jul 17, 2024

Contributor

Great, sounds good

torch.library.define(f"{libname}::{name}", schema)

# Provide the inplementation for all dispatch key in the main library
@torch.library.impl("quanto::unpack", "default")
def impl(*args, **kwargs):
if _ext_enabled:
return getattr(torch.ops.quanto_ext, name)(*args, **kwargs)
return getattr(torch.ops.quanto_py, name)(*args, **kwargs)


define("unpack", "(Tensor self, int bits) -> Tensor")
18 changes: 18 additions & 0 deletions quanto/library/python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Quanto library python/pytorch operations

This folder contains the implementations of all `quanto_py::` operations.

This namespace corresponds to the default, python-only implementations of quanto operations.

The operations are defined in `library/ops.py`.

To provide an implementation for an operation, use the following syntax:

```
@torch.library.impl("quanto_py::unpack", "default")
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
...
```

The implementation **must** support all device types. This is true if it
is a composition of built-in PyTorch operators.
1 change: 1 addition & 0 deletions quanto/library/python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .unpack import *
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch


@torch.libary.impl("quanto::unpack", "default")
@torch.library.impl("quanto_py::unpack", "default")
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
"""
Un-Pack int4 / int2 weights (packed in a uint8) into a torch.int8 tensor
Expand Down
10 changes: 8 additions & 2 deletions test/library/test_unpack.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from contextlib import nullcontext

import pytest
import torch

from quanto.library import disable_extensions
from quanto.tensor.core import int2, int4, pack_weights


@pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"])
@pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"])
def test_unpack(bits, shape, device):
@pytest.mark.parametrize("use_ext", [True, False], ids=["ext", "no-ext"])
def test_unpack(bits, shape, use_ext, device):
qmax = 2**bits
a = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)
bitsdtype = int2 if bits == 2 else int4
packed_a = pack_weights(a, bitsdtype)
unpacked_a = torch.ops.quanto.unpack(packed_a, bits)
context = nullcontext() if use_ext else disable_extensions()
with context:
unpacked_a = torch.ops.quanto.unpack(packed_a, bits)
assert torch.equal(unpacked_a, a)

0 comments on commit 8ed7e69

Please sign in to comment.