diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 17ac85d5..b2fc2dd3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,7 @@ Changelog * Enforce that the entire marker string is parsed (:issue:`687`) * Requirement parsing no longer automatically validates the URL (:issue:`120`) +* Canonicalize names for requirements comparison (:issue:`644`) 23.1 - 2023-04-12 ~~~~~~~~~~~~~~~~~ diff --git a/src/packaging/requirements.py b/src/packaging/requirements.py index e828c61f..0c00eba3 100644 --- a/src/packaging/requirements.py +++ b/src/packaging/requirements.py @@ -2,12 +2,13 @@ # 2.0, and the BSD License. See the LICENSE file in the root of this repository # for complete details. -from typing import Any, List, Optional, Set +from typing import Any, Iterator, Optional, Set from ._parser import parse_requirement as _parse_requirement from ._tokenizer import ParserSyntaxError from .markers import Marker, _normalize_extra_values from .specifiers import SpecifierSet +from .utils import canonicalize_name class InvalidRequirement(ValueError): @@ -44,38 +45,44 @@ def __init__(self, requirement_string: str) -> None: self.marker = Marker.__new__(Marker) self.marker._markers = _normalize_extra_values(parsed.marker) - def __str__(self) -> str: - parts: List[str] = [self.name] + def _iter_parts(self, name: str) -> Iterator[str]: + yield name if self.extras: formatted_extras = ",".join(sorted(self.extras)) - parts.append(f"[{formatted_extras}]") + yield f"[{formatted_extras}]" if self.specifier: - parts.append(str(self.specifier)) + yield str(self.specifier) if self.url: - parts.append(f"@ {self.url}") + yield f"@ {self.url}" if self.marker: - parts.append(" ") + yield " " if self.marker: - parts.append(f"; {self.marker}") + yield f"; {self.marker}" - return "".join(parts) + def __str__(self) -> str: + return "".join(self._iter_parts(self.name)) def __repr__(self) -> str: return f"" def __hash__(self) -> int: - return hash((self.__class__.__name__, str(self))) + return hash( + ( + self.__class__.__name__, + *self._iter_parts(canonicalize_name(self.name)), + ) + ) def __eq__(self, other: Any) -> bool: if not isinstance(other, Requirement): return NotImplemented return ( - self.name == other.name + canonicalize_name(self.name) == canonicalize_name(other.name) and self.extras == other.extras and self.specifier == other.specifier and self.url == other.url diff --git a/tests/test_requirements.py b/tests/test_requirements.py index 45d3937e..491b3e03 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -26,6 +26,10 @@ ), ] +EQUIVALENT_DEPENDENCIES = [ + ("scikit-learn==1.0.1", "scikit_learn==1.0.1"), +] + DIFFERENT_DEPENDENCIES = [ ("package_one", "package_two"), ("packaging>20.1", "packaging>=20.1"), @@ -632,12 +636,25 @@ def test_str_and_repr( @pytest.mark.parametrize("dep1, dep2", EQUAL_DEPENDENCIES) def test_equal_reqs_equal_hashes(self, dep1: str, dep2: str) -> None: - """Requirement objects created from equivalent strings should be equal.""" + """Requirement objects created from equal strings should be equal.""" + # GIVEN / WHEN + req1, req2 = Requirement(dep1), Requirement(dep2) + + assert req1 == req2 + assert hash(req1) == hash(req2) + + @pytest.mark.parametrize("dep1, dep2", EQUIVALENT_DEPENDENCIES) + def test_equivalent_reqs_equal_hashes_unequal_strings( + self, dep1: str, dep2: str + ) -> None: + """Requirement objects created from equivalent strings should be equal, + even though their string representation will not.""" # GIVEN / WHEN req1, req2 = Requirement(dep1), Requirement(dep2) assert req1 == req2 assert hash(req1) == hash(req2) + assert str(req1) != str(req2) @pytest.mark.parametrize("dep1, dep2", DIFFERENT_DEPENDENCIES) def test_different_reqs_different_hashes(self, dep1: str, dep2: str) -> None: