Skip to content

Commit

Permalink
Fix equality checks of equivalent requirements
Browse files Browse the repository at this point in the history
Signed-off-by: Juan Luis Cano Rodríguez <juan_luis_cano@mckinsey.com>
  • Loading branch information
astrojuanlu committed Jun 5, 2023
1 parent d1279d0 commit 3dbb218
Showing 1 changed file with 56 additions and 8 deletions.
64 changes: 56 additions & 8 deletions kedro/framework/cli/micropkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import tempfile
from importlib import import_module
from pathlib import Path
from typing import Iterable, List, Tuple, Union
from typing import Any, Iterable, List, Tuple, Union

import click
from build.util import project_wheel_metadata
from packaging.requirements import InvalidRequirement, Requirement
from packaging.utils import canonicalize_name
from rope.base.project import Project
from rope.contrib import generate
from rope.refactor.move import MoveModule
Expand Down Expand Up @@ -49,6 +50,53 @@
"""


class _EquivalentRequirement(Requirement):
# See https://github.com/pypa/packaging/issues/644#issuecomment-1567982812

@property
def canonical_name(self) -> str:
"""Canonicalized name according to the rules of PEP 503."""
return canonicalize_name(self.name)

def _to_str(self, name: str) -> str:
parts: list[str] = [name]

if self.extras:
formatted_extras = ",".join(sorted(self.extras))
parts.append(f"[{formatted_extras}]")

if self.specifier:
parts.append(str(self.specifier))

if self.url:
parts.append(f"@ {self.url}")
if self.marker:
parts.append(" ")

if self.marker:
parts.append(f"; {self.marker}")

return "".join(parts)

def __str__(self) -> str:
return self._to_str(self.name)

def __hash__(self) -> int:
return hash((self.__class__.__name__, self._to_str(self.canonical_name)))

def __eq__(self, other: Any) -> bool:
if not isinstance(other, (Requirement, _EquivalentRequirement)):
return NotImplemented

return (
self.canonical_name == other.canonical_name
and self.extras == other.extras
and self.specifier == other.specifier
and self.url == other.url
and self.marker == other.marker
)


def _check_module_path(ctx, param, value): # pylint: disable=unused-argument
if value and not re.match(r"^[\w.]+$", value):
message = (
Expand Down Expand Up @@ -620,7 +668,7 @@ def _make_install_requires(requirements_txt: Path) -> list[str]:
if not requirements_txt.exists():
return []
return [
str(Requirement(_drop_comment(requirement_line)))
str(_EquivalentRequirement(_drop_comment(requirement_line)))
for requirement_line in requirements_txt.read_text().splitlines()
if requirement_line and not requirement_line.startswith("#")
]
Expand Down Expand Up @@ -868,9 +916,6 @@ def _append_package_reqs(
requirements_txt: Path, package_reqs: list[str], package_name: str
) -> None:
"""Appends micro-package requirements to project level requirements.txt"""
# NOTE: packaging.requirements.Requirement equality check
# does not normalize names, and as such is not equivalent to pkg_resources.Requirement,
# see https://github.com/pypa/packaging/issues/644#issuecomment-1567982812
incoming_reqs = _safe_parse_requirements(package_reqs)
if requirements_txt.is_file():
existing_reqs = _safe_parse_requirements(requirements_txt.read_text())
Expand Down Expand Up @@ -909,13 +954,14 @@ def _get_all_library_reqs(metadata):
# See https://discuss.python.org/t/\
# programmatically-getting-non-optional-requirements-of-current-directory/26963/2
return [
str(Requirement(dep_str)) for dep_str in metadata.get_all("Requires-Dist", [])
str(_EquivalentRequirement(dep_str))
for dep_str in metadata.get_all("Requires-Dist", [])
]


def _safe_parse_requirements(
requirements: str | Iterable[str],
) -> set[Requirement]:
) -> set[_EquivalentRequirement]:
"""Safely parse a requirement or set of requirements. This avoids blowing up when it
encounters a requirement it cannot parse (e.g. `-r requirements.txt`). This way
we can still extract all the parseable requirements out of a set containing some
Expand All @@ -933,7 +979,9 @@ def _safe_parse_requirements(
and not requirement_line.startswith("-e")
):
try:
parseable_requirements.add(Requirement(_drop_comment(requirement_line)))
parseable_requirements.add(
_EquivalentRequirement(_drop_comment(requirement_line))
)
except InvalidRequirement:
continue
return parseable_requirements

0 comments on commit 3dbb218

Please sign in to comment.