diff --git a/setup.py b/setup.py index c6fefb67caf..920ac4fd1e7 100755 --- a/setup.py +++ b/setup.py @@ -1,6 +1,9 @@ #!/usr/bin/env python +import glob import os +from functools import partial from importlib.util import module_from_spec, spec_from_file_location +from typing import List, Tuple from setuptools import find_packages, setup @@ -24,19 +27,23 @@ def _load_py_module(fname, pkg="torchmetrics"): ) -def _prepare_extras(): - extras = { - "image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="image.txt"), # skipcq: PYL-W0212 - "text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="text.txt"), # skipcq: PYL-W0212 - "detection": setup_tools._load_requirements( # skipcq: PYL-W0212 - path_dir=_PATH_REQUIRE, file_name="detection.txt" - ), - "audio": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="audio.txt"), - } - # create an 'all' keyword that install all possible denpendencies - extras["all"] = [package for extra in extras.values() for package in extra] +BASE_REQUIREMENTS = setup_tools._load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt") + - return extras +def _prepare_extras(base_req: List[str], skip_files: Tuple[str] = ("devel.txt")): + # find all extra requirements + _load_req = partial(setup_tools._load_requirements, path_dir=_PATH_REQUIRE) + found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(_PATH_REQUIRE, "*.txt"))) + # filter unwanted files + found_req_files = [n for n in found_req_files if n not in skip_files] + found_req_names = [os.path.splitext(req)[0].replace("datatype_", "") for req in found_req_files] + # define basic and extra extras + extras_req = {name: base_req + _load_req(file_name=fname) for name, fname in zip(found_req_names, found_req_files)} + # filter the uniques + extras_req = {n: list(set(req)) for n, req in extras_req.items()} + # create an 'all' keyword that install all possible denpendencies + extras_req["all"] = [pkg for reqs in extras_req.values() for pkg in reqs] + return extras_req # https://packaging.python.org/discussions/install-requires-vs-requirements / @@ -61,7 +68,8 @@ def _prepare_extras(): keywords=["deep learning", "machine learning", "pytorch", "metrics", "AI"], python_requires=">=3.6", setup_requires=[], - install_requires=setup_tools._load_requirements(_PATH_ROOT), + install_requires=BASE_REQUIREMENTS, + extras_require=_prepare_extras(BASE_REQUIREMENTS), project_urls={ "Bug Tracker": os.path.join(about.__homepage__, "issues"), "Documentation": "https://torchmetrics.rtfd.io/en/latest/", @@ -89,5 +97,4 @@ def _prepare_extras(): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", ], - extras_require=_prepare_extras(), ) diff --git a/torchmetrics/setup_tools.py b/torchmetrics/setup_tools.py index 696169dc11e..68c361c60be 100644 --- a/torchmetrics/setup_tools.py +++ b/torchmetrics/setup_tools.py @@ -32,7 +32,7 @@ def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comme if comment_char in ln: ln = ln[: ln.index(comment_char)].strip() # skip directly installed dependencies - if ln.startswith("http"): + if ln.startswith("http") or ln.startswith("git") or ln.startswith("-r"): continue if ln: # if requirement is not empty reqs.append(ln)