Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 3 additions & 102 deletions tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import distutils.sysconfig

import os
import platform
import subprocess
import sys
from pathlib import Path

from setuptools.command.build_ext import build_ext


__all__ = [
"get_ext_modules",
"CMakeBuild",
]
__all__ = ["get_ext_modules"]


_THIS_DIR = Path(__file__).parent.resolve()
Expand All @@ -38,96 +30,5 @@ def _get_build(var, default=False):
return False


_BUILD_S3 = _get_build("BUILD_S3", False)
_USE_SYSTEM_AWS_SDK_CPP = _get_build("USE_SYSTEM_AWS_SDK_CPP", False)
_USE_SYSTEM_PYBIND11 = _get_build("USE_SYSTEM_PYBIND11", False)
_USE_SYSTEM_LIBS = _get_build("USE_SYSTEM_LIBS", False)


try:
# Use the pybind11 from third_party
if not (_USE_SYSTEM_PYBIND11 or _USE_SYSTEM_LIBS):
sys.path.insert(0, str(_ROOT_DIR / "third_party/pybind11/"))
from pybind11.setup_helpers import Pybind11Extension
except ImportError:
from setuptools import Extension as Pybind11Extension


def get_ext_modules():
if _BUILD_S3:
return [Pybind11Extension(name="torchdata._torchdata", sources=[])]
else:
return []


class CMakeBuild(build_ext):
def run(self):
try:
subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake is not available.") from None
super().run()

def build_extension(self, ext):
# Because the following `cmake` command will build all of `ext_modules`` at the same time,
# we would like to prevent multiple calls to `cmake`.
# Therefore, we call `cmake` only for `torchdata._torchdata`,
# in case `ext_modules` contains more than one module.
if ext.name != "torchdata._torchdata":
return

extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))

# required for auto-detection of auxiliary "native" libs
if not extdir.endswith(os.path.sep):
extdir += os.path.sep

debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
cfg = "Debug" if debug else "Release"

cmake_args = [
f"-DCMAKE_BUILD_TYPE={cfg}",
f"-DCMAKE_INSTALL_PREFIX={extdir}",
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY={extdir}", # For Windows
f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
f"-DBUILD_S3:BOOL={'ON' if _BUILD_S3 else 'OFF'}",
f"-DUSE_SYSTEM_AWS_SDK_CPP:BOOL={'ON' if _USE_SYSTEM_AWS_SDK_CPP else 'OFF'}",
f"-DUSE_SYSTEM_PYBIND11:BOOL={'ON' if _USE_SYSTEM_PYBIND11 else 'OFF'}",
f"-DUSE_SYSTEM_LIBS:BOOL={'ON' if _USE_SYSTEM_LIBS else 'OFF'}",
]

build_args = ["--config", cfg]

# Default to Ninja
if "CMAKE_GENERATOR" not in os.environ or platform.system() == "Windows":
cmake_args += ["-GNinja"]
if platform.system() == "Windows":
python_version = sys.version_info
cmake_args += [
"-DCMAKE_C_COMPILER=cl",
"-DCMAKE_CXX_COMPILER=cl",
f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}",
]

# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
# across all generators.
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
# self.parallel is a Python 3 only way to set parallel jobs by hand
# using -j in the build_ext call, not supported by pip or PyPA-build.
if hasattr(self, "parallel") and self.parallel:
# CMake 3.12+ only.
build_args += [f"-j{self.parallel}"]

if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

subprocess.check_call(["cmake", str(_ROOT_DIR)] + cmake_args, cwd=self.build_temp)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)

def get_ext_filename(self, fullname):
ext_filename = super().get_ext_filename(fullname)
ext_filename_parts = ext_filename.split(".")
without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
ext_filename = ".".join(without_abi)
return ext_filename
return []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about this, can we check if we need to keep get_ext_modules ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Thanks!