-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
88 lines (76 loc) · 2.77 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) 2024, DeepLink.
from setuptools import find_packages, setup, Extension
from torch.utils.cpp_extension import BuildExtension, include_paths, library_paths
import glob
import os
import subprocess
def _getenv_or_die(env_name: str):
env = os.getenv(env_name)
if env is None:
raise ValueError(f"{env_name} is not set")
return env
class BuildExtensionWithCompdb(BuildExtension):
def build_extensions(self):
super().build_extensions()
try:
self._gen_compdb()
except Exception as e:
print(f"Failed to generate compile_commands.json: {e}")
def _gen_compdb(self):
assert self.use_ninja
build_ninja_file = glob.glob("./build/**/build.ninja", recursive=True)
assert len(build_ninja_file) == 1
with open("build/compile_commands.json", "w") as f:
subprocess.run(
["ninja", "-f", build_ninja_file[0], "-t", "compdb"],
stdout=f,
check=True,
)
print("Generated build/compile_commands.json")
def get_ext():
ext_name = "deeplink_ext.cpp_extensions"
# 包含所有算子文件
op_files = glob.glob("./csrc/*.cpp")
include_dirs = []
system_include_dirs = include_paths()
define_macros = []
extra_objects = []
library_dirs = library_paths()
libraries = ["c10", "torch", "torch_cpu", "torch_python"]
extra_link_args = []
dipu_root = _getenv_or_die("DIPU_ROOT")
diopi_path = _getenv_or_die("DIOPI_PATH")
vendor_include_dirs = os.getenv("VENDOR_INCLUDE_DIRS")
nccl_include_dirs = os.getenv("NCCL_INCLUDE_DIRS") # nv所需
system_include_dirs += [
dipu_root,
os.path.join(dipu_root, "dist/include"),
os.path.join(diopi_path, "include"),
]
if vendor_include_dirs:
system_include_dirs.append(vendor_include_dirs)
if nccl_include_dirs:
system_include_dirs.append(nccl_include_dirs)
library_dirs += [dipu_root]
libraries += ["torch_dipu"]
extra_compile_args = ["-std=c++17", "-Wno-deprecated-declarations"]
extra_compile_args += ["-isystem" + path for path in system_include_dirs]
ext_ops = Extension(
name=ext_name, # 拓展模块名字
sources=op_files,
include_dirs=include_dirs,
define_macros=define_macros, # 用于定义宏变量
extra_objects=extra_objects, # 传递object文件
extra_compile_args=extra_compile_args,
library_dirs=library_dirs,
libraries=libraries,
extra_link_args=extra_link_args,
)
return [ext_ops]
setup(
name="deeplink_ext",
packages=find_packages(exclude=["build", "csrc", "tests"]),
ext_modules=get_ext(),
cmdclass={"build_ext": BuildExtensionWithCompdb},
install_requires=["einops"],
)