Skip to content

Commit c5cbbae

Browse files
authored
Merge pull request #30 from nicolas-chaulet/patch
Handle missing torch better
2 parents 17d4f8a + eeddc5d commit c5cbbae

File tree

3 files changed

+59
-37
lines changed

3 files changed

+59
-37
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# Unreleased
22

3+
# 0.6.0
4+
5+
## Bug fix
6+
- Require pytorch implicitely and log nice message when missing
7+
38
# 0.5.3
49

510
## Update
611
- ball query returns squared distance instead of distance
712
- leaner Point Cloud struct that avoids copying data
813

914
## Bug fix
10-
- Pcakage would not install if pytorch is not already installed
15+
- Package would not install if pytorch is not already installed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ exclude = '''
2525
'''
2626

2727
[build-system]
28-
requires = ["setuptools>=41.0", "setuptools-scm", "wheel", "torch"]
28+
requires = ["setuptools>=41.0", "setuptools-scm", "wheel"]
2929
build-backend = "setuptools.build_meta"

setup.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,71 @@
11
from setuptools import setup, find_packages
2-
import torch
3-
from torch.utils.cpp_extension import (
4-
BuildExtension,
5-
CUDAExtension,
6-
CUDA_HOME,
7-
CppExtension,
8-
)
2+
3+
try:
4+
import torch
5+
from torch.utils.cpp_extension import (
6+
BuildExtension,
7+
CUDAExtension,
8+
CUDA_HOME,
9+
CppExtension,
10+
)
11+
HAS_TORCH=True
12+
except:
13+
HAS_TORCH=False
14+
915
import glob
1016

1117
from os import path
1218
this_directory = path.abspath(path.dirname(__file__))
1319
with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
1420
long_description = f.read()
1521

16-
TORCH_MAJOR = int(torch.__version__.split(".")[0])
17-
TORCH_MINOR = int(torch.__version__.split(".")[1])
18-
extra_compile_args = ["-O3"]
19-
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
20-
extra_compile_args += ["-DVERSION_GE_1_3"]
2122

22-
ext_src_root = "cuda"
23-
ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root))
23+
def get_ext_modules():
24+
if not HAS_TORCH:
25+
return []
26+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
27+
TORCH_MINOR = int(torch.__version__.split(".")[1])
28+
extra_compile_args = ["-O3"]
29+
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
30+
extra_compile_args += ["-DVERSION_GE_1_3"]
2431

25-
ext_modules = []
26-
if CUDA_HOME:
27-
ext_modules.append(
28-
CUDAExtension(
29-
name="torch_points_kernels.points_cuda",
30-
sources=ext_sources,
31-
include_dirs=["{}/include".format(ext_src_root)],
32-
extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,},
32+
ext_src_root = "cuda"
33+
ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root))
34+
35+
ext_modules = []
36+
if CUDA_HOME:
37+
ext_modules.append(
38+
CUDAExtension(
39+
name="torch_points_kernels.points_cuda",
40+
sources=ext_sources,
41+
include_dirs=["{}/include".format(ext_src_root)],
42+
extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,},
43+
)
3344
)
34-
)
3545

36-
cpu_ext_src_root = "cpu"
37-
cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root))
46+
cpu_ext_src_root = "cpu"
47+
cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root))
3848

39-
ext_modules.append(
40-
CppExtension(
41-
name="torch_points_kernels.points_cpu",
42-
sources=cpu_ext_sources,
43-
include_dirs=["{}/include".format(cpu_ext_src_root)],
44-
extra_compile_args={"cxx": extra_compile_args,},
49+
ext_modules.append(
50+
CppExtension(
51+
name="torch_points_kernels.points_cpu",
52+
sources=cpu_ext_sources,
53+
include_dirs=["{}/include".format(cpu_ext_src_root)],
54+
extra_compile_args={"cxx": extra_compile_args,},
55+
)
4556
)
46-
)
57+
return ext_modules
58+
59+
def get_cmdclass():
60+
if HAS_TORCH:
61+
return {"build_ext": BuildExtension}
62+
else:
63+
return {}
4764

4865
requirements = ["torch>=1.1.0"]
4966

5067
url = 'https://github.com/nicolas-chaulet/torch-points-kernels'
51-
__version__="0.5.3"
68+
__version__="0.6.0"
5269
setup(
5370
name="torch-points-kernels",
5471
version=__version__,
@@ -57,8 +74,8 @@
5774
url=url,
5875
download_url='{}/archive/{}.tar.gz'.format(url, __version__),
5976
install_requires=requirements,
60-
ext_modules=ext_modules,
61-
cmdclass={"build_ext": BuildExtension},
77+
ext_modules=get_ext_modules(),
78+
cmdclass=get_cmdclass(),
6279
long_description=long_description,
6380
long_description_content_type='text/markdown',
6481
classifiers=[

0 commit comments

Comments
 (0)