Skip to content

Commit

Permalink
trying to build diptest and cuda extension with a single setup.py call
Browse files Browse the repository at this point in the history
  • Loading branch information
edublancas committed Dec 17, 2020
1 parent 5c9cb90 commit 98547ad
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.vscode/
tmp/
.pytest_cache/
*.mat
Expand Down
63 changes: 35 additions & 28 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from os.path import splitext
from setuptools import find_packages, setup
from distutils.extension import Extension
from torch.utils.cpp_extension import CUDAExtension, BuildExtension

NAME = 'yass-algorithm'
DESCRIPTION = 'YASS: Yet Another Spike Sorter'
Expand All @@ -27,15 +28,29 @@
# autodoc_mock_imports list in conf.py
INSTALL_REQUIRES = [
# these first two are only required for Python 2
'pathlib2;python_version<"3"', 'funcsigs;python_version<"3"',
'pathlib2;python_version<"3"',
'funcsigs;python_version<"3"',
# dependencies...
'numpy', 'scipy', 'scikit-learn', 'pyyaml', 'python-dateutil', 'click',
'tqdm', 'multiprocess', 'coloredlogs', 'cerberus',
'numpy',
'scipy',
'scikit-learn',
'pyyaml',
'python-dateutil',
'click',
'tqdm',
'multiprocess',
'coloredlogs',
'cerberus',
# 'torch',
# from experimental pipeline (nnet and clustering)
# TODO: consider reducing the number of dependencies: parmap, matplotlib
# and progressbar2 are not necessary
'parmap', 'statsmodels', 'matplotlib', 'networkx', 'Cython', 'progressbar2',
'parmap',
'statsmodels',
'matplotlib',
'networkx',
'Cython',
'progressbar2',
'h5py'
]

Expand All @@ -51,35 +66,27 @@
_version_re = re.compile(r'__version__\s+=\s+(.*)')

with open('src/yass/__init__.py', 'rb') as f:
VERSION = str(ast.literal_eval(_version_re.search(
f.read().decode('utf-8')).group(1)))
VERSION = str(
ast.literal_eval(
_version_re.search(f.read().decode('utf-8')).group(1)))

# Cython and numpy installation based on this:
# https://stackoverflow.com/a/42163080/709975

ext_modules = [
Extension(name="diptest._diptest",
sources=["src/diptest/_dip.c", "src/diptest/_diptest.c"],
extra_compile_args=['-O3', '-std=c99']),
CUDAExtension('rowshift', [
'src/gpu_rowshift/rowshift.cpp', 'src/gpu_rowshift/rowshift_kernels.cu'
]),
]

try:
from Cython.setuptools import build_ext
except Exception:
# If we couldn't import Cython, use the normal setuptools
# and look for a pre-compiled .c file instead of a .pyx file
from setuptools.command.build_ext import build_ext
ext_modules = [Extension(name="diptest._diptest",
sources=["src/diptest/_dip.c",
"src/diptest/_diptest.c"],
extra_compile_args=['-O3', '-std=c99'])]
else:
# If we successfully imported Cython, look for a .pyx file
ext_modules = [Extension(name="diptest._diptest",
sources=["src/diptest/_dip.c",
"src/diptest/_diptest.pyx"],
extra_compile_args=['-O3', '-std=c99'])]


class CustomBuildExtCommand(build_ext):
"""build_ext command for use when numpy headers are needed
"""

class CustomBuildExtCommand(BuildExtension):
"""Custom build_ext command to use when numpy headers are needed
(for diptest) and also
"""
def run(self):

# Import numpy here, only when headers are needed
Expand All @@ -89,7 +96,7 @@ def run(self):
self.include_dirs.append(numpy.get_include())

# Call original build_ext command
build_ext.run(self)
BuildExtension.run(self)


setup(
Expand Down

0 comments on commit 98547ad

Please sign in to comment.