Skip to content

Commit

Permalink
[1/2] Wean off of PYBIND in favor of torch.ops.load_library (pytorch#…
Browse files Browse the repository at this point in the history
…1276)

Wean ao off of PYBIND [part 1]
  • Loading branch information
janeyx99 authored Nov 13, 2024
1 parent 01dc7da commit c546c5c
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
24 changes: 12 additions & 12 deletions packaging/post_build_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

set -eux

WHEEL_NAME=$(ls dist/)
# Prepare manywheel, only for CUDA.
# The wheel is a pure python wheel for other platforms.
if [[ "$CU_VERSION" == cu* ]]; then
WHEEL_NAME=$(ls dist/)

pushd dist
# Prepare manywheel
manylinux_plat=manylinux2014_x86_64
if [[ "$CU_VERSION" == "xpu" ]]; then
manylinux_plat=manylinux_2_28_x86_64
fi
auditwheel repair --plat "$manylinux_plat" -w . \
pushd dist
manylinux_plat=manylinux2014_x86_64
auditwheel repair --plat "$manylinux_plat" -w . \
--exclude libtorch.so \
--exclude libtorch_python.so \
--exclude libtorch_cuda.so \
Expand All @@ -26,10 +25,11 @@ auditwheel repair --plat "$manylinux_plat" -w . \
--exclude libcudart.so.11.0 \
"${WHEEL_NAME}"

ls -lah .
# Clean up the linux_x86_64 wheel
rm "${WHEEL_NAME}"
popd
ls -lah .
# Clean up the linux_x86_64 wheel
rm "${WHEEL_NAME}"
popd
fi

MANYWHEEL_NAME=$(ls dist/)
# Try to install the new wheel
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def get_extensions():
if use_cuda:
sources += cuda_sources

if len(sources) == 0:
return None

ext_modules = [
extension(
"torchao._C",
Expand Down
6 changes: 4 additions & 2 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
)
if not _IS_FBCODE:
try:
from . import _C
from pathlib import Path
so_files = list(Path(__file__).parent.glob("_C*.so"))
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
from . import ops
except:
_C = None
logging.info("Skipping import of cpp extensions")

from torchao.quantization import (
Expand Down
3 changes: 0 additions & 3 deletions torchao/csrc/init.cpp

This file was deleted.

0 comments on commit c546c5c

Please sign in to comment.