forked from rusty1s/pytorch_cluster
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
53 lines (48 loc) · 1.99 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
ext_modules = [
CppExtension('torch_cluster.graclus_cpu', ['cpu/graclus.cpp']),
CppExtension('torch_cluster.grid_cpu', ['cpu/grid.cpp']),
CppExtension('torch_cluster.fps_cpu', ['cpu/fps.cpp']),
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if CUDA_HOME is not None:
ext_modules += [
CUDAExtension('torch_cluster.graclus_cuda',
['cuda/graclus.cpp', 'cuda/graclus_kernel.cu']),
CUDAExtension('torch_cluster.grid_cuda',
['cuda/grid.cpp', 'cuda/grid_kernel.cu']),
CUDAExtension('torch_cluster.fps_cuda',
['cuda/fps.cpp', 'cuda/fps_kernel.cu']),
CUDAExtension('torch_cluster.nearest_cuda',
['cuda/nearest.cpp', 'cuda/nearest_kernel.cu']),
CUDAExtension('torch_cluster.knn_cuda',
['cuda/knn.cpp', 'cuda/knn_kernel.cu']),
CUDAExtension('torch_cluster.radius_cuda',
['cuda/radius.cpp', 'cuda/radius_kernel.cu']),
CUDAExtension('torch_cluster.rw_cuda',
['cuda/rw.cpp', 'cuda/rw_kernel.cu']),
]
__version__ = '1.2.4'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy']
setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_cluster',
version=__version__,
description=('PyTorch Extension Library of Optimized Graph Cluster '
'Algorithms'),
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url=url,
download_url='{}/archive/{}.tar.gz'.format(url, __version__),
keywords=['pytorch', 'cluster', 'geometric-deep-learning', 'graph'],
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
ext_modules=ext_modules,
cmdclass=cmdclass,
packages=find_packages(),
)