Skip to content

Commit

Permalink
add GenericProvider for resolving versions
Browse files Browse the repository at this point in the history
Refactor existing resolvers to define base classes
to reduce duplication and make it easier to create
new providers.

Add a GenericProvider class that accepts a
callable as argument and picks versions from the
return values.
  • Loading branch information
dhellmann committed Jul 28, 2024
1 parent 990962c commit 9165f37
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 59 deletions.
23 changes: 22 additions & 1 deletion docs/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,27 @@ def get_resolver_provider(ctx, req, include_sdists, include_wheels, sdist_server
...
```

The `GenericProvider` is a convenient base class, or can be instantiated
directly if given a `version_source` callable that returns an iterator of
version values.

```python
from fromager import resolver

VERSION_MAP = {'1.0': 'first-release', '2.0': 'second-release'}

def _version_source(
identifier: str,
requirements: resolver.RequirementsMap,
incompatibilities: resolver.CandidatesMap,
) -> typing.Iterable[Version]:
return VERSION_MAP.keys()


def get_resolver_provider(ctx, req, include_sdists, include_wheels, sdist_server_url):
return resolver.GenericProvider(version_source=_version_source, constraints=ctx.constraints)
```

### prepare_source

The `prepare_source()` function is responsible for setting up a tree
Expand Down Expand Up @@ -425,7 +446,7 @@ def post_build(
)
```

## Customizations using settings.yaml
## Customizations using settings.yaml

To use predefined urls to download sources from, instead of overriding the entire `download_source` function, a mapping of package to download source url can be provided directly in settings.yaml. Optionally the downloaded sdist can be renamed. Both the url and the destination filename support templating. The only supported template variable is `version` - it is replaced by the version returned by the resolver.

Expand Down
160 changes: 102 additions & 58 deletions src/fromager/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,15 @@ def get_metadata_for_wheel(url: str) -> EmailMessage:
return EmailMessage()


class PyPIProvider(ExtrasProvider):
RequirementsMap: typing.TypeAlias = dict[str, typing.Iterable[Requirement]]
CandidatesMap: typing.TypeAlias = dict[str, typing.Iterable[Candidate]]
VersionSource: typing.TypeAlias = typing.Callable[
[str, RequirementsMap, CandidatesMap],
typing.Iterable[str | Version],
]


class BaseProvider(ExtrasProvider):
def __init__(
self,
include_sdists: bool = True,
Expand Down Expand Up @@ -215,7 +223,47 @@ def get_base_requirement(self, candidate: Candidate) -> Requirement:
def get_preference(self, identifier, resolutions, candidates, information, **kwds):
return sum(1 for _ in candidates[identifier])

def find_matches(self, identifier, requirements, incompatibilities):
def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool:
if canonicalize_name(requirement.name) != candidate.name:
return False
return (
candidate.version in requirement.specifier
and self.constraints.is_satisfied_by(requirement.name, candidate.version)
)

def get_dependencies(self, candidate: Candidate) -> list:
# return candidate.dependencies
return []

def find_matches(
self,
identifier: str,
requirements: RequirementsMap,
incompatibilities: CandidatesMap,
) -> typing.Iterable[Version]:
raise NotImplementedError()


class PyPIProvider(BaseProvider):
def __init__(
self,
include_sdists: bool = True,
include_wheels: bool = True,
sdist_server_url: str = "https://pypi.org/simple/",
constraints: Constraints | None = None,
):
super().__init__()
self.include_sdists = include_sdists
self.include_wheels = include_wheels
self.sdist_server_url = sdist_server_url
self.constraints = constraints or Constraints({})

def find_matches(
self,
identifier: str,
requirements: RequirementsMap,
incompatibilities: CandidatesMap,
) -> typing.Iterable[Version]:
requirements = list(requirements[identifier])
bad_versions = {c.version for c in incompatibilities[identifier]}

Expand Down Expand Up @@ -265,64 +313,41 @@ def find_matches(self, identifier, requirements, incompatibilities):
candidates.append(candidate)
return sorted(candidates, key=attrgetter("version"), reverse=True)

def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool:
if canonicalize_name(requirement.name) != candidate.name:
return False
return (
candidate.version in requirement.specifier
and self.constraints.is_satisfied_by(requirement.name, candidate.version)
)

def get_dependencies(self, candidate: Candidate) -> list:
# return candidate.dependencies
return []


class GitHubTagProvider(ExtrasProvider):
class GenericProvider(BaseProvider):
def __init__(
self, organization: str, repo: str, constraints: Constraints | None = None
self,
version_source: VersionSource,
constraints: Constraints | None = None,
):
super().__init__()
self.organization = organization
self.repo = repo
token = os.getenv("GITHUB_TOKEN")
auth = github.Auth.Token(token) if token else None
self.client = github.Github(auth=auth)
self._version_source = version_source
self.constraints = constraints or Constraints({})

def identify(self, requirement_or_candidate: Requirement | Candidate) -> str:
return canonicalize_name(requirement_or_candidate.name)

def get_extras_for(
self, requirement_or_candidate: Requirement | Candidate
) -> tuple[str]:
# Extras is a set, which is not hashable
return tuple(sorted(requirement_or_candidate.extras))

def get_base_requirement(self, candidate: Candidate) -> Requirement:
return Requirement(f"{candidate.name}=={candidate.version}")

def get_preference(self, identifier, resolutions, candidates, information, **kwds):
return sum(1 for _ in candidates[identifier])

def find_matches(self, identifier, requirements, incompatibilities):
repo = self.client.get_repo(f"{self.organization}/{self.repo}")

def find_matches(
self,
identifier: str,
requirements: RequirementsMap,
incompatibilities: CandidatesMap,
) -> typing.Iterable[Version]:
requirements = list(requirements[identifier])
bad_versions = {c.version for c in incompatibilities[identifier]}

# Need to pass the extras to the search, so they
# are added to the candidate at creation - we
# treat candidates as immutable once created.
candidates = []
for tag in repo.get_tags():
try:
version = Version(tag.name)
except Exception as err:
logger.debug(
f"{identifier}: could not parse version from git tag {tag.name} on {repo.full_name}: {err}"
)
continue
for item in self._version_source(identifier, requirements, incompatibilities):
if isinstance(item, Version):
version = item
else:
try:
version = Version(item)
except Exception as err:
logger.debug(
f"{identifier}: could not parse version from {item}: {err}"
)
continue
# Skip versions that are known bad
if version in bad_versions:
if DEBUG_RESOLVER:
Expand All @@ -345,17 +370,36 @@ def find_matches(self, identifier, requirements, incompatibilities):
f"{identifier}: skipping {version} due to constraint {c}"
)
continue
candidates.append(Candidate(identifier, version, url=tag.name))
candidates.append(Candidate(identifier, version, url=item))
return sorted(candidates, key=attrgetter("version"), reverse=True)

def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool:
if canonicalize_name(requirement.name) != candidate.name:
return False
return (
candidate.version in requirement.specifier
and self.constraints.is_satisfied_by(requirement.name, candidate.version)
)

def get_dependencies(self, candidate: Candidate) -> list[Requirement]:
# return candidate.dependencies
return []
class GitHubTagProvider(GenericProvider):
def __init__(
self, organization: str, repo: str, constraints: Constraints | None = None
):
self.organization = organization
self.repo = repo
token = os.getenv("GITHUB_TOKEN")
auth = github.Auth.Token(token) if token else None
self.client = github.Github(auth=auth)
self.constraints = constraints or Constraints({})
super().__init__(version_source=self._find_tags, constraints=constraints)

def _find_tags(
self,
identifier: str,
requirements: RequirementsMap,
incompatibilities: CandidatesMap,
) -> typing.Iterable[Version]:
repo = self.client.get_repo(f"{self.organization}/{self.repo}")

for tag in repo.get_tags():
try:
version = Version(tag.name)
except Exception as err:
logger.debug(
f"{identifier}: could not parse version from git tag {tag.name} on {repo.full_name}: {err}"
)
continue
yield version
15 changes: 15 additions & 0 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,18 @@ def test_github_constraint_match():
assert str(candidate.version) == "0.8.1"
# check the "URL" in case tag syntax does not match version syntax
assert str(candidate.url) == "0.8.1"


def test_resolve_generic():
def _versions(*args, **kwds):
return ["1.2", "1.3", "1.4.1"]

provider = resolver.GenericProvider(_versions, None)
reporter = resolvelib.BaseReporter()
rslvr = resolvelib.Resolver(provider, reporter)

result = rslvr.resolve([Requirement("fromager")])
assert "fromager" in result.mapping

candidate = result.mapping["fromager"]
assert str(candidate.version) == "1.4.1"

0 comments on commit 9165f37

Please sign in to comment.