Skip to content

Commit

Permalink
Merge pull request #209 from dhellmann/match-constraints-by-normalize…
Browse files Browse the repository at this point in the history
…d-name

look up constraints using their canonical name
  • Loading branch information
mergify[bot] authored Jul 22, 2024
2 parents a83f4c8 + d114fac commit a733327
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
9 changes: 5 additions & 4 deletions src/fromager/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing

from packaging.requirements import Requirement
from packaging.utils import canonicalize_name
from packaging.version import Version

from . import requirements_file
Expand All @@ -13,13 +14,13 @@

class Constraints:
def __init__(self, data: dict[str, Requirement]):
self._data = data
self._data = {canonicalize_name(n): v for n, v in data.items()}

def get_constraint(self, req: Requirement):
return self._data.get(req.name)
def get_constraint(self, name: str):
return self._data.get(canonicalize_name(name))

def is_satisfied_by(self, pkg_name: str, version: Version):
constraint = self._data.get(pkg_name)
constraint = self.get_constraint(pkg_name)
if constraint:
return version in constraint.specifier
return True
Expand Down
2 changes: 1 addition & 1 deletion src/fromager/sdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def handle_requirement(
f'{req.name}: {"*" * (len(why) + 1)} handling {req_type} requirement {req} {why}'
)

constraint = ctx.constraints.get_constraint(req)
constraint = ctx.constraints.get_constraint(req.name)
if constraint:
logger.info(
f"{req.name}: incoming requirement {req} matches constraint {constraint}. Will apply both."
Expand Down
2 changes: 1 addition & 1 deletion src/fromager/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def resolve_dist(
include_wheels: bool = True,
) -> tuple[str, str]:
"Return URL to source and its version."
constraint = ctx.constraints.get_constraint(req)
constraint = ctx.constraints.get_constraint(req.name)
logger.debug(
f"{req.name}: resolving requirement {req} using {sdist_server_url} with constraint {constraint}"
)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def test_constraint_is_satisfied_by():
assert c.is_satisfied_by("bar", Version("2.0"))


def test_constraint_canonical_name():
c = constraints.Constraints({"flash_attn": Requirement("flash_attn<=1.1")})
assert c.is_satisfied_by("flash_attn", "1.1")
assert c.is_satisfied_by("flash-attn", "1.1")
assert c.is_satisfied_by("Flash-ATTN", "1.1")


def test_constraint_not_is_satisfied_by():
c = constraints.Constraints({"foo": Requirement("foo<=1.1")})
assert not c.is_satisfied_by("foo", "1.2")
Expand Down

0 comments on commit a733327

Please sign in to comment.