diff --git a/poetry/installation/installer.py b/poetry/installation/installer.py index afc7b668602..704cc3e133b 100644 --- a/poetry/installation/installer.py +++ b/poetry/installation/installer.py @@ -17,6 +17,7 @@ from poetry.repositories.installed_repository import InstalledRepository from poetry.semver import parse_constraint from poetry.utils.helpers import canonicalize_name +from poetry.utils.extras import get_extra_package_names from .base_installer import BaseInstaller from .pip_installer import PipInstaller @@ -399,7 +400,7 @@ def _get_operations_from_lock( installed_repo = self._installed_repository ops = [] - extra_packages = [p.name for p in self._get_extra_packages(locked_repository)] + extra_packages = self._get_extra_packages(locked_repository) for locked in locked_repository.packages: is_installed = False for installed in installed_repo.packages: @@ -429,7 +430,7 @@ def _get_operations_from_lock( def _filter_operations( self, ops, repo ): # type: (List[Operation], Repository) -> None - extra_packages = [p.name for p in self._get_extra_packages(repo)] + extra_packages = self._get_extra_packages(repo) for op in ops: if isinstance(op, Update): package = op.target_package @@ -468,9 +469,9 @@ def _filter_operations( if package.category == "dev" and not self.is_dev_mode(): op.skip("Dev dependencies not requested") - def _get_extra_packages(self, repo): + def _get_extra_packages(self, repo): # type: (Repository) -> List[str] """ - Returns all packages required by extras. + Returns all package names required by extras. Maybe we just let the solver handle it? """ @@ -479,26 +480,7 @@ def _get_extra_packages(self, repo): else: extras = self._locker.lock_data.get("extras", {}) - extra_packages = [] - for extra_name, packages in extras.items(): - if extra_name not in self._extras: - continue - - extra_packages += [Dependency(p, "*") for p in packages] - - def _extra_packages(packages): - pkgs = [] - for package in packages: - for pkg in repo.packages: - if pkg.name == package.name: - pkgs.append(package) - pkgs += _extra_packages(pkg.requires) - - break - - return pkgs - - return _extra_packages(extra_packages) + return list(get_extra_package_names(repo.packages, extras, self._extras)) def _get_installer(self): # type: () -> BaseInstaller return PipInstaller(self._env, self._io, self._pool) diff --git a/poetry/utils/exporter.py b/poetry/utils/exporter.py index 3c05a4da864..94b3c797111 100644 --- a/poetry/utils/exporter.py +++ b/poetry/utils/exporter.py @@ -9,6 +9,7 @@ from poetry.poetry import Poetry from poetry.utils._compat import Path from poetry.utils._compat import decode +from poetry.utils.extras import get_extra_package_names class Exporter(object): @@ -55,20 +56,16 @@ def _export_requirements_txt( ): # type: (Path, Union[IO, str], bool, bool, bool) -> None indexes = [] content = "" + packages = self._poetry.locker.locked_repository(dev).packages - # Generate a list of package names we have opted into via `extras` - extras_set = frozenset(extras or ()) - extra_package_names = set() - if extras: - for extra_name, extra_packages in self._poetry.locker.lock_data.get( - "extras", {} - ).items(): - if extra_name in extras_set: - extra_package_names.update(extra_packages) - - for package in sorted( - self._poetry.locker.locked_repository(dev).packages, key=lambda p: p.name - ): + # Build a set of all packages required by our selected extras + extra_package_names = set( + get_extra_package_names( + packages, self._poetry.locker.lock_data.get("extras", {}), extras or () + ) + ) + + for package in sorted(packages, key=lambda p: p.name): # If a package is optional and we haven't opted in to it, continue if package.optional and package.name not in extra_package_names: continue diff --git a/poetry/utils/extras.py b/poetry/utils/extras.py new file mode 100644 index 00000000000..7a1b24ad42e --- /dev/null +++ b/poetry/utils/extras.py @@ -0,0 +1,48 @@ +from typing import Iterator, Mapping, Sequence + +from poetry.packages import Package +from poetry.utils.helpers import canonicalize_name + + +def get_extra_package_names( + packages, # type: Sequence[Package] + extras, # type: Mapping[str, Collection[str]] + extra_names, # type: Sequence[str] +): # type: (...) -> Iterator[str] + """ + Returns all package names required by the given extras. + + :param packages: A collection of packages, such as from Repository.packages + :param extras: A mapping of `extras` names to lists of package names, as defined + in the `extras` section of `poetry.lock`. + :param extra_names: A list of strings specifying names of extra groups to resolve. + """ + if not extra_names: + return [] + + # lookup for packages by name, faster than looping over packages repeatedly + packages_by_name = {package.name: package for package in packages} + + # get and flatten names of packages we've opted into as extras + extra_package_names = [ + canonicalize_name(extra_package_name) + for extra_name in extra_names + for extra_package_name in extras.get(extra_name, ()) + ] + + def _extra_packages(package_names): + """Recursively find dependencies for packages names""" + # for each extra pacakge name + for package_name in package_names: + # Find the actual Package object. A missing key indicates an implicit + # dependency (like setuptools), which should be ignored + package = packages_by_name.get(canonicalize_name(package_name)) + if package: + yield package.name + # Recurse for dependencies + for dependency_package_name in _extra_packages( + dependency.name for dependency in package.requires + ): + yield dependency_package_name + + return _extra_packages(extra_package_names) diff --git a/tests/utils/test_exporter.py b/tests/utils/test_exporter.py index ad5be296bf4..ccb8ecf248b 100644 --- a/tests/utils/test_exporter.py +++ b/tests/utils/test_exporter.py @@ -367,7 +367,15 @@ def test_exporter_exports_requirements_txt_with_optional_packages_if_opted_in( { "name": "bar", "version": "4.5.6", - "category": "dev", + "category": "main", + "optional": True, + "python-versions": "*", + "dependencies": {"spam": ">=0.1"}, + }, + { + "name": "spam", + "version": "0.1.0", + "category": "main", "optional": True, "python-versions": "*", }, @@ -375,7 +383,7 @@ def test_exporter_exports_requirements_txt_with_optional_packages_if_opted_in( "metadata": { "python-versions": "*", "content-hash": "123456789", - "hashes": {"foo": ["12345"], "bar": ["67890"]}, + "hashes": {"foo": ["12345"], "bar": ["67890"], "spam": ["abcde"]}, }, "extras": {"feature_bar": ["bar"]}, } @@ -398,6 +406,8 @@ def test_exporter_exports_requirements_txt_with_optional_packages_if_opted_in( --hash=sha256:67890 foo==1.2.3 \\ --hash=sha256:12345 +spam==0.1.0 \\ + --hash=sha256:abcde """ assert expected == content diff --git a/tests/utils/test_extras.py b/tests/utils/test_extras.py new file mode 100644 index 00000000000..4bb03d19035 --- /dev/null +++ b/tests/utils/test_extras.py @@ -0,0 +1,50 @@ +import pytest + +from poetry.utils.extras import get_extra_package_names + +from poetry.packages import Package + +_PACKAGE_FOO = Package("foo", "0.1.0") +_PACKAGE_SPAM = Package("spam", "0.2.0") +_PACKAGE_BAR = Package("bar", "0.3.0") +_PACKAGE_BAR.add_dependency("foo") + + +@pytest.mark.parametrize( + "packages,extras,extra_names,expected_extra_package_names", + [ + # Empty edge case + ([], {}, [], []), + # Selecting no extras is fine + ([_PACKAGE_FOO], {}, [], []), + # An empty extras group should return an empty list + ([_PACKAGE_FOO], {"group0": []}, ["group0"], []), + # Selecting an extras group should return the contained packages + ( + [_PACKAGE_FOO, _PACKAGE_SPAM, _PACKAGE_BAR], + {"group0": ["foo"]}, + ["group0"], + ["foo"], + ), + # If a package has dependencies, we should also get their names + ( + [_PACKAGE_FOO, _PACKAGE_SPAM, _PACKAGE_BAR], + {"group0": ["bar"], "group1": ["spam"]}, + ["group0"], + ["bar", "foo"], + ), + # Selecting multpile extras should get us the union of all package names + ( + [_PACKAGE_FOO, _PACKAGE_SPAM, _PACKAGE_BAR], + {"group0": ["bar"], "group1": ["spam"]}, + ["group0", "group1"], + ["bar", "foo", "spam"], + ), + ], +) +def test_get_extra_package_names( + packages, extras, extra_names, expected_extra_package_names +): + assert expected_extra_package_names == list( + get_extra_package_names(packages, extras, extra_names) + )