diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 4bbd09694..1690c6275 100644 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -9,4 +9,12 @@ set -eux echo "This script is run before building torchao binaries" +python -m pip install --upgrade pip +if [ -z "$PYTORCH_VERSION" ]; then + PYTORCH_DEP="torch" +else + PYTORCH_DEP="torch==$PYTORCH_VERSION" +fi +pip install $PYTORCH_DEP + pip install setuptools wheel twine auditwheel diff --git a/setup.py b/setup.py index a6a2da8a7..fc85dcea9 100644 --- a/setup.py +++ b/setup.py @@ -110,15 +110,6 @@ def get_extensions(): return ext_modules -# Mimic code from torchvision https://github.com/pytorch/vision/blob/143d078b28f00471156a4e562dd3836370acc9ee/setup.py#L58 -pytorch_dep = "torch" -if os.getenv("PYTORCH_VERSION"): - pytorch_dep += "==" + os.getenv("PYTORCH_VERSION") - -requirements = [ - pytorch_dep, -] - setup( name=package_name, version=version+version_suffix, @@ -128,7 +119,6 @@ def get_extensions(): "torchao.kernel.configs": ["*.pkl"], }, ext_modules=get_extensions() if use_cpp != "0" else None, - install_requires=requirements, extras_require={"dev": read_requirements("dev-requirements.txt")}, description="Package for applying ao techniques to GPU models", long_description=open("README.md").read(), diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 00082acad..381f91c4a 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -15,7 +15,7 @@ # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert # at the callsite prevents usage of this on unsupported versions. -if TORCH_VERSION_AFTER_2_4: +if TORCH_VERSION_AFTER_2_4 and has_triton(): from torch._inductor.runtime.triton_helpers import libdevice from torchao.prototype.mx_formats.constants import ( diff --git a/version.txt b/version.txt index 0d91a54c7..9e11b32fc 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.0 +0.3.1