Skip to content

Commit

Permalink
add GenericProvider for resolving versions
Browse files Browse the repository at this point in the history
Refactor existing resolvers to define base classes
to reduce duplication and make it easier to create
new providers.

Add a GenericProvider class that accepts a
callable as argument and picks versions from the
return values.
  • Loading branch information
dhellmann committed Jul 28, 2024
1 parent 990962c commit d807d77
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 58 deletions.
160 changes: 102 additions & 58 deletions src/fromager/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,15 @@ def get_metadata_for_wheel(url: str) -> EmailMessage:
return EmailMessage()


class PyPIProvider(ExtrasProvider):
RequirementsMap: typing.TypeAlias = dict[str, typing.Iterable[Requirement]]
CandidatesMap: typing.TypeAlias = dict[str, typing.Iterable[Candidate]]
VersionSource: typing.TypeAlias = typing.Callable[
[str, RequirementsMap, CandidatesMap],
typing.Iterable[str | Version],
]


class BaseProvider(ExtrasProvider):
def __init__(
self,
include_sdists: bool = True,
Expand Down Expand Up @@ -215,7 +223,47 @@ def get_base_requirement(self, candidate: Candidate) -> Requirement:
def get_preference(self, identifier, resolutions, candidates, information, **kwds):
return sum(1 for _ in candidates[identifier])

def find_matches(self, identifier, requirements, incompatibilities):
def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool:
if canonicalize_name(requirement.name) != candidate.name:
return False
return (
candidate.version in requirement.specifier
and self.constraints.is_satisfied_by(requirement.name, candidate.version)
)

def get_dependencies(self, candidate: Candidate) -> list:
# return candidate.dependencies
return []

def find_matches(
self,
identifier: str,
requirements: RequirementsMap,
incompatibilities: CandidatesMap,
) -> typing.Iterable[Version]:
raise NotImplementedError()


class PyPIProvider(BaseProvider):
def __init__(
self,
include_sdists: bool = True,
include_wheels: bool = True,
sdist_server_url: str = "https://pypi.org/simple/",
constraints: Constraints | None = None,
):
super().__init__()
self.include_sdists = include_sdists
self.include_wheels = include_wheels
self.sdist_server_url = sdist_server_url
self.constraints = constraints or Constraints({})

def find_matches(
self,
identifier: str,
requirements: RequirementsMap,
incompatibilities: CandidatesMap,
) -> typing.Iterable[Version]:
requirements = list(requirements[identifier])
bad_versions = {c.version for c in incompatibilities[identifier]}

Expand Down Expand Up @@ -265,64 +313,41 @@ def find_matches(self, identifier, requirements, incompatibilities):
candidates.append(candidate)
return sorted(candidates, key=attrgetter("version"), reverse=True)

def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool:
if canonicalize_name(requirement.name) != candidate.name:
return False
return (
candidate.version in requirement.specifier
and self.constraints.is_satisfied_by(requirement.name, candidate.version)
)

def get_dependencies(self, candidate: Candidate) -> list:
# return candidate.dependencies
return []


class GitHubTagProvider(ExtrasProvider):
class GenericProvider(BaseProvider):
def __init__(
self, organization: str, repo: str, constraints: Constraints | None = None
self,
version_source: VersionSource,
constraints: Constraints | None = None,
):
super().__init__()
self.organization = organization
self.repo = repo
token = os.getenv("GITHUB_TOKEN")
auth = github.Auth.Token(token) if token else None
self.client = github.Github(auth=auth)
self._version_source = version_source
self.constraints = constraints or Constraints({})

def identify(self, requirement_or_candidate: Requirement | Candidate) -> str:
return canonicalize_name(requirement_or_candidate.name)

def get_extras_for(
self, requirement_or_candidate: Requirement | Candidate
) -> tuple[str]:
# Extras is a set, which is not hashable
return tuple(sorted(requirement_or_candidate.extras))

def get_base_requirement(self, candidate: Candidate) -> Requirement:
return Requirement(f"{candidate.name}=={candidate.version}")

def get_preference(self, identifier, resolutions, candidates, information, **kwds):
return sum(1 for _ in candidates[identifier])

def find_matches(self, identifier, requirements, incompatibilities):
repo = self.client.get_repo(f"{self.organization}/{self.repo}")

def find_matches(
self,
identifier: str,
requirements: RequirementsMap,
incompatibilities: CandidatesMap,
) -> typing.Iterable[Version]:
requirements = list(requirements[identifier])
bad_versions = {c.version for c in incompatibilities[identifier]}

# Need to pass the extras to the search, so they
# are added to the candidate at creation - we
# treat candidates as immutable once created.
candidates = []
for tag in repo.get_tags():
try:
version = Version(tag.name)
except Exception as err:
logger.debug(
f"{identifier}: could not parse version from git tag {tag.name} on {repo.full_name}: {err}"
)
continue
for item in self._version_source(identifier, requirements, incompatibilities):
if isinstance(item, Version):
version = item
else:
try:
version = Version(item)
except Exception as err:
logger.debug(
f"{identifier}: could not parse version from {item}: {err}"
)
continue
# Skip versions that are known bad
if version in bad_versions:
if DEBUG_RESOLVER:
Expand All @@ -345,17 +370,36 @@ def find_matches(self, identifier, requirements, incompatibilities):
f"{identifier}: skipping {version} due to constraint {c}"
)
continue
candidates.append(Candidate(identifier, version, url=tag.name))
candidates.append(Candidate(identifier, version, url=item))
return sorted(candidates, key=attrgetter("version"), reverse=True)

def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool:
if canonicalize_name(requirement.name) != candidate.name:
return False
return (
candidate.version in requirement.specifier
and self.constraints.is_satisfied_by(requirement.name, candidate.version)
)

def get_dependencies(self, candidate: Candidate) -> list[Requirement]:
# return candidate.dependencies
return []
class GitHubTagProvider(GenericProvider):
def __init__(
self, organization: str, repo: str, constraints: Constraints | None = None
):
self.organization = organization
self.repo = repo
token = os.getenv("GITHUB_TOKEN")
auth = github.Auth.Token(token) if token else None
self.client = github.Github(auth=auth)
self.constraints = constraints or Constraints({})
super().__init__(version_source=self._find_tags, constraints=constraints)

def _find_tags(
self,
identifier: str,
requirements: RequirementsMap,
incompatibilities: CandidatesMap,
) -> typing.Iterable[Version]:
repo = self.client.get_repo(f"{self.organization}/{self.repo}")

for tag in repo.get_tags():
try:
version = Version(tag.name)
except Exception as err:
logger.debug(
f"{identifier}: could not parse version from git tag {tag.name} on {repo.full_name}: {err}"
)
continue
yield version
15 changes: 15 additions & 0 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,18 @@ def test_github_constraint_match():
assert str(candidate.version) == "0.8.1"
# check the "URL" in case tag syntax does not match version syntax
assert str(candidate.url) == "0.8.1"


def test_resolve_generic():
def _versions(*args, **kwds):
return ["1.2", "1.3", "1.4.1"]

provider = resolver.GenericProvider(_versions, None)
reporter = resolvelib.BaseReporter()
rslvr = resolvelib.Resolver(provider, reporter)

result = rslvr.resolve([Requirement("fromager")])
assert "fromager" in result.mapping

candidate = result.mapping["fromager"]
assert str(candidate.version) == "1.4.1"

0 comments on commit d807d77

Please sign in to comment.