Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graceful handling of cpp extensions #296

Merged
merged 8 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ python setup.py develop

If you want to install from source run
```Shell
python setup.py install
python setup.py install
```

** Note:
Since we are building pytorch c++/cuda extensions by default, running `pip install .` will
not work.
If you are running into any issues while building `ao` cpp extensions you can instead build using

```shell
USE_CPP=0 python setup.py install
```

### Quantization

Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def read_requirements(file_path):
# Determine the package name based on the presence of an environment variable
package_name = "torchao-nightly" if os.environ.get("TORCHAO_NIGHTLY") else "torchao"
version_suffix = os.getenv("VERSION_SUFFIX", "")
use_cpp = os.getenv('USE_CPP')


# Version is year.month.date if using nightlies
version = current_date if package_name == "torchao-nightly" else "0.2.0"
Expand Down Expand Up @@ -92,7 +94,7 @@ def get_extensions():
package_data={
"torchao.kernel.configs": ["*.pkl"],
},
ext_modules=get_extensions(),
ext_modules=get_extensions() if use_cpp != "0" else None,
install_requires=read_requirements("requirements.txt"),
extras_require={"dev": read_requirements("dev-requirements.txt")},
description="Package for applying ao techniques to GPU models",
Expand Down
7 changes: 7 additions & 0 deletions test/dtypes/test_float6_e3m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
parametrize,
run_tests,
)

try:
import torchao.ops
except RuntimeError:
pytest.skip("torchao.ops not available")


from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2


Expand Down
5 changes: 5 additions & 0 deletions test/test_ops.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the standard should be, but I know other tests (ex. quantization) import some boolean like "TORCH_VERSION_AFTER_2_4" and use that instead of a try except block

Copy link
Contributor

@jerryzh168 jerryzh168 May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main problem is we are not sure why/when the import fails, so the boolean probably won't help here..

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from parameterized import parameterized
import pytest

try:
import torchao.ops
except RuntimeError:
pytest.skip("torchao.ops not available")


# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace)
Expand Down
10 changes: 8 additions & 2 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import torch
import logging

_IS_FBCODE = (
hasattr(torch._utils_internal, "IS_FBSOURCE") and
torch._utils_internal.IS_FBSOURCE
)

if not _IS_FBCODE:
from . import _C
from . import ops
try:
from . import _C
from . import ops
except:
_C = None
logging.info("Skipping import of cpp extensions")

from torchao.quantization import (
apply_weight_only_int8_quant,
Expand Down
10 changes: 7 additions & 3 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from .nf4tensor import NF4Tensor, to_nf4
from .uint4 import UInt4Tensor
from .aqt import AffineQuantizedTensor, to_aq
from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2

__all__ = [
"NF4Tensor",
"to_nf4",
"UInt4Tensor"
"AffineQuantizedTensor",
"to_aq",
"to_float6_e3m2",
"from_float6_e3m2",
]

# CPP extensions
try:
from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2
__all__.extend(["to_float6_e3m2", "from_float6_e3m2"])
except RuntimeError:
pass
Loading