|
1 | 1 | from setuptools import setup, find_packages |
2 | | -import torch |
3 | | -from torch.utils.cpp_extension import ( |
4 | | - BuildExtension, |
5 | | - CUDAExtension, |
6 | | - CUDA_HOME, |
7 | | - CppExtension, |
8 | | -) |
| 2 | + |
| 3 | +try: |
| 4 | + import torch |
| 5 | + from torch.utils.cpp_extension import ( |
| 6 | + BuildExtension, |
| 7 | + CUDAExtension, |
| 8 | + CUDA_HOME, |
| 9 | + CppExtension, |
| 10 | + ) |
| 11 | + HAS_TORCH=True |
| 12 | +except: |
| 13 | + HAS_TORCH=False |
| 14 | + |
9 | 15 | import glob |
10 | 16 |
|
11 | 17 | from os import path |
12 | 18 | this_directory = path.abspath(path.dirname(__file__)) |
13 | 19 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: |
14 | 20 | long_description = f.read() |
15 | 21 |
|
16 | | -TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
17 | | -TORCH_MINOR = int(torch.__version__.split(".")[1]) |
18 | | -extra_compile_args = ["-O3"] |
19 | | -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): |
20 | | - extra_compile_args += ["-DVERSION_GE_1_3"] |
21 | 22 |
|
22 | | -ext_src_root = "cuda" |
23 | | -ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root)) |
| 23 | +def get_ext_modules(): |
| 24 | + if not HAS_TORCH: |
| 25 | + return [] |
| 26 | + TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
| 27 | + TORCH_MINOR = int(torch.__version__.split(".")[1]) |
| 28 | + extra_compile_args = ["-O3"] |
| 29 | + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): |
| 30 | + extra_compile_args += ["-DVERSION_GE_1_3"] |
24 | 31 |
|
25 | | -ext_modules = [] |
26 | | -if CUDA_HOME: |
27 | | - ext_modules.append( |
28 | | - CUDAExtension( |
29 | | - name="torch_points_kernels.points_cuda", |
30 | | - sources=ext_sources, |
31 | | - include_dirs=["{}/include".format(ext_src_root)], |
32 | | - extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,}, |
| 32 | + ext_src_root = "cuda" |
| 33 | + ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root)) |
| 34 | + |
| 35 | + ext_modules = [] |
| 36 | + if CUDA_HOME: |
| 37 | + ext_modules.append( |
| 38 | + CUDAExtension( |
| 39 | + name="torch_points_kernels.points_cuda", |
| 40 | + sources=ext_sources, |
| 41 | + include_dirs=["{}/include".format(ext_src_root)], |
| 42 | + extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,}, |
| 43 | + ) |
33 | 44 | ) |
34 | | - ) |
35 | 45 |
|
36 | | -cpu_ext_src_root = "cpu" |
37 | | -cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root)) |
| 46 | + cpu_ext_src_root = "cpu" |
| 47 | + cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root)) |
38 | 48 |
|
39 | | -ext_modules.append( |
40 | | - CppExtension( |
41 | | - name="torch_points_kernels.points_cpu", |
42 | | - sources=cpu_ext_sources, |
43 | | - include_dirs=["{}/include".format(cpu_ext_src_root)], |
44 | | - extra_compile_args={"cxx": extra_compile_args,}, |
| 49 | + ext_modules.append( |
| 50 | + CppExtension( |
| 51 | + name="torch_points_kernels.points_cpu", |
| 52 | + sources=cpu_ext_sources, |
| 53 | + include_dirs=["{}/include".format(cpu_ext_src_root)], |
| 54 | + extra_compile_args={"cxx": extra_compile_args,}, |
| 55 | + ) |
45 | 56 | ) |
46 | | -) |
| 57 | + return ext_modules |
| 58 | + |
| 59 | +def get_cmdclass(): |
| 60 | + if HAS_TORCH: |
| 61 | + return {"build_ext": BuildExtension} |
| 62 | + else: |
| 63 | + return {} |
47 | 64 |
|
48 | 65 | requirements = ["torch>=1.1.0"] |
49 | 66 |
|
50 | 67 | url = 'https://github.com/nicolas-chaulet/torch-points-kernels' |
51 | | -__version__="0.5.3" |
| 68 | +__version__="0.6.0" |
52 | 69 | setup( |
53 | 70 | name="torch-points-kernels", |
54 | 71 | version=__version__, |
|
57 | 74 | url=url, |
58 | 75 | download_url='{}/archive/{}.tar.gz'.format(url, __version__), |
59 | 76 | install_requires=requirements, |
60 | | - ext_modules=ext_modules, |
61 | | - cmdclass={"build_ext": BuildExtension}, |
| 77 | + ext_modules=get_ext_modules(), |
| 78 | + cmdclass=get_cmdclass(), |
62 | 79 | long_description=long_description, |
63 | 80 | long_description_content_type='text/markdown', |
64 | 81 | classifiers=[ |
|
0 commit comments