Skip to content

Commit

Permalink
generalize req. extras (#590)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Oct 28, 2021
1 parent 6f02550 commit 61e9c9c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
35 changes: 21 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#!/usr/bin/env python
import glob
import os
from functools import partial
from importlib.util import module_from_spec, spec_from_file_location
from typing import List, Tuple

from setuptools import find_packages, setup

Expand All @@ -24,19 +27,23 @@ def _load_py_module(fname, pkg="torchmetrics"):
)


def _prepare_extras():
extras = {
"image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="image.txt"), # skipcq: PYL-W0212
"text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="text.txt"), # skipcq: PYL-W0212
"detection": setup_tools._load_requirements( # skipcq: PYL-W0212
path_dir=_PATH_REQUIRE, file_name="detection.txt"
),
"audio": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="audio.txt"),
}
# create an 'all' keyword that install all possible denpendencies
extras["all"] = [package for extra in extras.values() for package in extra]
BASE_REQUIREMENTS = setup_tools._load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt")


return extras
def _prepare_extras(base_req: List[str], skip_files: Tuple[str] = ("devel.txt")):
# find all extra requirements
_load_req = partial(setup_tools._load_requirements, path_dir=_PATH_REQUIRE)
found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(_PATH_REQUIRE, "*.txt")))
# filter unwanted files
found_req_files = [n for n in found_req_files if n not in skip_files]
found_req_names = [os.path.splitext(req)[0].replace("datatype_", "") for req in found_req_files]
# define basic and extra extras
extras_req = {name: base_req + _load_req(file_name=fname) for name, fname in zip(found_req_names, found_req_files)}
# filter the uniques
extras_req = {n: list(set(req)) for n, req in extras_req.items()}
# create an 'all' keyword that install all possible denpendencies
extras_req["all"] = [pkg for reqs in extras_req.values() for pkg in reqs]
return extras_req


# https://packaging.python.org/discussions/install-requires-vs-requirements /
Expand All @@ -61,7 +68,8 @@ def _prepare_extras():
keywords=["deep learning", "machine learning", "pytorch", "metrics", "AI"],
python_requires=">=3.6",
setup_requires=[],
install_requires=setup_tools._load_requirements(_PATH_ROOT),
install_requires=BASE_REQUIREMENTS,
extras_require=_prepare_extras(BASE_REQUIREMENTS),
project_urls={
"Bug Tracker": os.path.join(about.__homepage__, "issues"),
"Documentation": "https://torchmetrics.rtfd.io/en/latest/",
Expand Down Expand Up @@ -89,5 +97,4 @@ def _prepare_extras():
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
extras_require=_prepare_extras(),
)
2 changes: 1 addition & 1 deletion torchmetrics/setup_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comme
if comment_char in ln:
ln = ln[: ln.index(comment_char)].strip()
# skip directly installed dependencies
if ln.startswith("http"):
if ln.startswith("http") or ln.startswith("git") or ln.startswith("-r"):
continue
if ln: # if requirement is not empty
reqs.append(ln)
Expand Down

0 comments on commit 61e9c9c

Please sign in to comment.