-
Notifications
You must be signed in to change notification settings - Fork 37
/
setup.py
128 lines (105 loc) · 3.78 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import glob
import os
import re
import shutil
import subprocess
from typing import List
from setuptools import find_packages, setup
import distutils.command.clean # isort:skip
_PACKAGE_NAME: str = "torchrecipes"
_VERSION_FILE: str = "version.py"
_README: str = "README.md"
_REQUIREMENTS: str = "requirements.txt"
_DEV_REQUIREMENTS: str = "dev-requirements.txt"
_GITIGNORE: str = ".gitignore"
def get_version() -> str:
"""Retrieves the version of the library."""
version = os.getenv("BUILD_VERSION")
if version:
return version
cwd = os.path.dirname(os.path.abspath(__file__))
version_file_path = os.path.join(_PACKAGE_NAME, _VERSION_FILE)
version_regex = r"__version__: str = ['\"]([^'\"]*)['\"]"
with open(version_file_path, "r") as f:
search = re.search(version_regex, f.read(), re.M)
assert search
version = search.group(1)
try:
sha = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd)
.decode("ascii")
.strip()
)
version += "+" + sha[:7]
except Exception:
pass
return version
def get_long_description() -> str:
"""Fetch project description as Markdown."""
with open(_README, mode="r") as f:
return f.read()
def get_requirements() -> List[str]:
"""Fetch requirements."""
with open(_REQUIREMENTS, mode="r") as f:
return f.readlines()
def get_dev_requirements() -> List[str]:
"""Fetch requirements for library development."""
with open(_DEV_REQUIREMENTS, mode="r") as f:
return f.readlines()
class clean(distutils.command.clean.clean):
def run(self) -> None:
with open(_GITIGNORE, "r") as f:
ignores = f.readlines()
for wildcard in filter(None, ignores):
for filename in glob.glob(wildcard):
try:
os.remove(filename)
except OSError:
shutil.rmtree(filename, ignore_errors=True)
# It's an old-style class in Python 2.7...
distutils.command.clean.clean.run(self)
def main() -> None:
global version
version = get_version()
print("Building wheel {}-{}".format(_PACKAGE_NAME, version))
setup(
# Metadata
name=_PACKAGE_NAME,
version=version,
author="PyTorch Ecosystem Foundations Team",
author_email="luispe@fb.com",
description="Prototype of training recipes for PyTorch",
long_description=get_long_description(),
long_description_content_type="text/markdown",
url="https://github.com/facebookresearch/recipes",
license="BSD-3",
keywords=["pytorch", "machine learning"],
python_requires=">=3.7",
install_requires=get_requirements(),
include_package_data=True,
# Package info
packages=find_packages(),
# pyre-fixme[6]: For 15th argument expected `Mapping[str, Type[Command]]`
# but got `Mapping[str, Type[clean]]`.
cmdclass={
"clean": clean,
},
extras_require={"dev": get_dev_requirements()},
# PyPI package information.
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
version: str
if __name__ == "__main__":
main() # pragma: no cover