From a3cf190923742c2b2fe56e4b553d1cf28ee8a34a Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Thu, 23 Jan 2025 13:13:38 +0100 Subject: [PATCH] Refactor constraints module Refactor and modify the API of `fromager.constraints` module to prepare the code for multiple constraints files. The `Constraints` class no longer takes an mapping of package name to `Requirement`. Instead it has a new method `add_constraint()` that takes a requirements strings. The method does not take a `Requirements` object, because we might want to modify and merge constraints in the future. The `load()` module-level function has been replaced by `Constraints.load_constraints_file()` method. Signed-off-by: Christian Heimes --- src/fromager/constraints.py | 48 +++++++++++++++++++-------------- src/fromager/context.py | 4 +-- src/fromager/resolver.py | 2 +- tests/test_constraints.py | 53 +++++++++++++++++++++++++++---------- tests/test_resolver.py | 17 +++++++----- 5 files changed, 80 insertions(+), 44 deletions(-) diff --git a/src/fromager/constraints.py b/src/fromager/constraints.py index 488afb75..81e616e2 100644 --- a/src/fromager/constraints.py +++ b/src/fromager/constraints.py @@ -3,7 +3,7 @@ import typing from packaging.requirements import Requirement -from packaging.utils import canonicalize_name +from packaging.utils import NormalizedName, canonicalize_name from packaging.version import Version from . import requirements_file @@ -12,8 +12,33 @@ class Constraints: - def __init__(self, data: dict[str, Requirement]): - self._data = {canonicalize_name(n): v for n, v in data.items()} + def __init__(self) -> None: + # mapping of canonical names to requirements + # NOTE: Requirement.name is not normalized + self._data: dict[NormalizedName, Requirement] = {} + + def __iter__(self) -> typing.Iterable[NormalizedName]: + yield from self._data + + def add_constraint(self, unparsed: str) -> None: + """Add new constraint, must not conflict with any existing constraints""" + req = Requirement(unparsed) + canon_name = canonicalize_name(req.name) + previous = self._data.get(canon_name) + if previous is not None: + raise KeyError( + f"{canon_name}: new constraint '{req}' conflicts with '{previous}'" + ) + if requirements_file.evaluate_marker(req, req): + logger.debug(f"adding constraint {req}") + self._data[canon_name] = req + + def load_constraints_file(self, constraints_file: str | pathlib.Path) -> None: + """Load constraints from a constraints file""" + logger.info("loading constraints from %s", constraints_file) + content = requirements_file.parse_requirements_file(constraints_file) + for line in content: + self.add_constraint(line) def get_constraint(self, name: str) -> Requirement | None: return self._data.get(canonicalize_name(name)) @@ -29,20 +54,3 @@ def is_satisfied_by(self, pkg_name: str, version: Version) -> bool: if constraint: return constraint.specifier.contains(version, prereleases=True) return True - - -def _parse(content: typing.Iterable[str]) -> Constraints: - constraints = {} - for line in content: - req = Requirement(line) - if requirements_file.evaluate_marker(req, req): - constraints[req.name] = req - return Constraints(constraints) - - -def load(constraints_file: str | pathlib.Path | None) -> Constraints: - if not constraints_file: - return Constraints({}) - logger.info("loading constraints from %s", constraints_file) - parsed_req_file = requirements_file.parse_requirements_file(constraints_file) - return _parse(parsed_req_file) diff --git a/src/fromager/context.py b/src/fromager/context.py index 93154d8f..ddc6f166 100644 --- a/src/fromager/context.py +++ b/src/fromager/context.py @@ -42,12 +42,12 @@ def __init__( ) self.settings = active_settings self.input_constraints_uri: str | None + self.constraints = constraints.Constraints() if constraints_file is not None: self.input_constraints_uri = constraints_file - self.constraints = constraints.load(constraints_file) + self.constraints.load_constraints_file(constraints_file) else: self.input_constraints_uri = None - self.constraints = constraints.Constraints({}) self.sdists_repo = pathlib.Path(sdists_repo).absolute() self.sdists_downloads = self.sdists_repo / "downloads" self.sdists_builds = self.sdists_repo / "builds" diff --git a/src/fromager/resolver.py b/src/fromager/resolver.py index 8a39f1e9..babd6e3e 100644 --- a/src/fromager/resolver.py +++ b/src/fromager/resolver.py @@ -215,7 +215,7 @@ def __init__( self.include_sdists = include_sdists self.include_wheels = include_wheels self.sdist_server_url = sdist_server_url - self.constraints = constraints or Constraints({}) + self.constraints = constraints or Constraints() self.req_type = req_type def identify(self, requirement_or_candidate: Requirement | Candidate) -> str: diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 84378949..6e0c96e0 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1,5 +1,4 @@ import pathlib -from unittest.mock import Mock, patch import pytest from packaging.requirements import Requirement @@ -9,40 +8,66 @@ def test_constraint_is_satisfied_by(): - c = constraints.Constraints({"foo": Requirement("foo<=1.1")}) + c = constraints.Constraints() + c.add_constraint("foo<=1.1") assert c.is_satisfied_by("foo", "1.1") assert c.is_satisfied_by("foo", Version("1.0")) assert c.is_satisfied_by("bar", Version("2.0")) def test_constraint_canonical_name(): - c = constraints.Constraints({"flash_attn": Requirement("flash_attn<=1.1")}) + c = constraints.Constraints() + c.add_constraint("flash_attn<=1.1") assert c.is_satisfied_by("flash_attn", "1.1") assert c.is_satisfied_by("flash-attn", "1.1") assert c.is_satisfied_by("Flash-ATTN", "1.1") + assert list(c) == ["flash-attn"] def test_constraint_not_is_satisfied_by(): - c = constraints.Constraints({"foo": Requirement("foo<=1.1")}) + c = constraints.Constraints() + c.add_constraint("foo<=1.1") + c.add_constraint("bar>=2.0") assert not c.is_satisfied_by("foo", "1.2") assert not c.is_satisfied_by("foo", Version("2.0")) + assert not c.is_satisfied_by("bar", Version("1.0")) -def test_load_empty_constraints_file(): - assert constraints.load(None)._data == {} +def test_add_constraint_conflict(): + c = constraints.Constraints() + c.add_constraint("foo<=1.1") + c.add_constraint("flit_core==2.0rc3") + with pytest.raises(KeyError): + c.add_constraint("foo<=1.1") + with pytest.raises(KeyError): + c.add_constraint("foo>1.1") + with pytest.raises(KeyError): + c.add_constraint("flit_core>2.0.0") + with pytest.raises(KeyError): + c.add_constraint("flit-core>2.0.0") + + +def test_allow_prerelease(): + c = constraints.Constraints() + c.add_constraint("foo>=1.1") + assert not c.allow_prerelease("foo") + c.add_constraint("bar>=1.1a0") + assert c.allow_prerelease("bar") + c.add_constraint("flit_core==2.0rc3") + assert c.allow_prerelease("flit_core") def test_load_non_existant_constraints_file(tmp_path: pathlib.Path): non_existant_file = tmp_path / "non_existant.txt" + c = constraints.Constraints() with pytest.raises(FileNotFoundError): - constraints.load(non_existant_file) + c.load_constraints_file(non_existant_file) -@patch("fromager.requirements_file.parse_requirements_file") -def test_load_constraints_file(parse_requirements_file: Mock, tmp_path: pathlib.Path): +def test_load_constraints_file(tmp_path: pathlib.Path): constraint_file = tmp_path / "constraint.txt" - constraint_file.write_text("a\n") - parse_requirements_file.return_value = ["torch==3.1.0"] - assert constraints.load(constraint_file)._data == { - "torch": Requirement("torch==3.1.0") - } + constraint_file.write_text("egg\ntorch==3.1.0 # comment\n") + c = constraints.Constraints() + c.load_constraints_file(constraint_file) + assert list(c) == ["egg", "torch"] # type: ignore + assert c.get_constraint("torch") == Requirement("torch==3.1.0") diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 5ae87db5..47bfffc5 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -176,9 +176,8 @@ def test_provider_choose_sdist(): def test_provider_choose_either_with_constraint(): - constraint = constraints.Constraints( - {"hydra-core": Requirement("hydra-core==1.3.2")} - ) + constraint = constraints.Constraints() + constraint.add_constraint("hydra-core==1.3.2") with requests_mock.Mocker() as r: r.get( "https://pypi.org/simple/hydra-core/", @@ -204,7 +203,8 @@ def test_provider_choose_either_with_constraint(): def test_provider_constraint_mismatch(): - constraint = constraints.Constraints({"hydra-core": Requirement("hydra-core<=1.1")}) + constraint = constraints.Constraints() + constraint.add_constraint("hydra-core<=1.1") with requests_mock.Mocker() as r: r.get( "https://pypi.org/simple/hydra-core/", @@ -220,7 +220,8 @@ def test_provider_constraint_mismatch(): def test_provider_constraint_match(): - constraint = constraints.Constraints({"hydra-core": Requirement("hydra-core<=1.3")}) + constraint = constraints.Constraints() + constraint.add_constraint("hydra-core<=1.3") with requests_mock.Mocker() as r: r.get( "https://pypi.org/simple/hydra-core/", @@ -525,7 +526,8 @@ def test_resolve_github(): def test_github_constraint_mismatch(): - constraint = constraints.Constraints({"fromager": Requirement("fromager>=1.0")}) + constraint = constraints.Constraints() + constraint.add_constraint("fromager>=1.0") with requests_mock.Mocker() as r: r.get( "https://api.github.com:443/repos/python-wheel-build/fromager", @@ -547,7 +549,8 @@ def test_github_constraint_mismatch(): def test_github_constraint_match(): - constraint = constraints.Constraints({"fromager": Requirement("fromager<0.9")}) + constraint = constraints.Constraints() + constraint.add_constraint("fromager<0.9") with requests_mock.Mocker() as r: r.get( "https://api.github.com:443/repos/python-wheel-build/fromager",