Skip to content

Commit

Permalink
Add Support for Multiple Categories per LockedDependency
Browse files Browse the repository at this point in the history
Also change default categories value from `{'main'}` to empty set
because we add instead of overwrite when we apply the categories.
  • Loading branch information
srilman authored and maresb committed Sep 13, 2024
1 parent a061142 commit 1ac5f7c
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 12 deletions.
53 changes: 42 additions & 11 deletions conda_lock/lockfile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@

from collections import defaultdict
from textwrap import dedent
from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Union
from typing import (
Collection,
DefaultDict,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Union,
)

import yaml

Expand Down Expand Up @@ -38,6 +48,23 @@ def _seperator_munge_get(
return d[key.replace("_", "-")]


def _truncate_main_category(
planned: Mapping[str, Union[List[LockedDependency], LockedDependency]],
) -> None:
"""
Given the package dependencies with their respective categories
for any package that is in the main category, remove all other associated categories
"""
# Packages in the main category are always installed
# so other categories are not necessary
for targets in planned.values():
if not isinstance(targets, list):
targets = [targets]
for target in targets:
if "main" in target.categories:
target.categories = {"main"}


def apply_categories(
requested: Dict[str, Dependency],
planned: Mapping[str, Union[List[LockedDependency], LockedDependency]],
Expand Down Expand Up @@ -111,27 +138,31 @@ def dep_name(manager: str, dep: str) -> str:

by_category[request.category].append(request.name)

# now, map each package to its root request preferring the ones earlier in the
# list
# now, map each package to every root request that requires it
categories = [*categories, *(k for k in by_category if k not in categories)]
root_requests = {}
root_requests: DefaultDict[str, List[str]] = defaultdict(list)
for category in categories:
for root in by_category.get(category, []):
for transitive_dep in dependents[root]:
if transitive_dep not in root_requests:
root_requests[transitive_dep] = root
root_requests[transitive_dep].append(root)
# include root requests themselves
for name in requested:
root_requests[name] = name
root_requests[name].append(name)

for dep, root in root_requests.items():
source = requested[root]
for dep, roots in root_requests.items():
# try a conda target first
targets = _seperator_munge_get(planned, dep)
if not isinstance(targets, list):
targets = [targets]
for target in targets:
target.categories = {source.category}

for root in roots:
source = requested[root]
for target in targets:
target.categories.add(source.category)

# For any dep that is part of the 'main' category
# we should remove all other categories
_truncate_main_category(planned)


def parse_conda_lock_file(path: pathlib.Path) -> Lockfile:
Expand Down
2 changes: 1 addition & 1 deletion conda_lock/lockfile/v2prelim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class LockedDependency(BaseLockedDependency):
categories: Set[str] = {"main"}
categories: Set[str] = set()

def to_v1(self) -> List[LockedDependencyV1]:
"""Convert a v2 dependency into a list of v1 dependencies.
Expand Down
57 changes: 57 additions & 0 deletions tests/test_conda_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_add_auth_to_line,
_add_auth_to_lockfile,
_extract_domain,
_solve_for_arch,
_strip_auth_from_line,
_strip_auth_from_lockfile,
create_lockfile_from_spec,
Expand Down Expand Up @@ -606,6 +607,7 @@ def test_choose_wheel() -> None:
platform="linux-64",
)
assert len(solution) == 1
assert solution["fastavro"].categories == {"main"}
assert solution["fastavro"].hash == HashModel(
sha256="a111a384a786b7f1fd6a8a8307da07ccf4d4c425084e2d61bae33ecfb60de405"
)
Expand Down Expand Up @@ -1820,6 +1822,61 @@ def test_aggregate_lock_specs_invalid_pip_repos():
aggregate_lock_specs([base_spec, spec_a, spec_a_b], platforms=[])


def test_solve_arch_multiple_categories():
_conda_exe = determine_conda_executable(None, mamba=False, micromamba=False)
channels = [Channel.from_string("conda-forge")]

with tempfile.NamedTemporaryFile(dir=".") as tf:
spec = LockSpecification(
dependencies={
"linux-64": [
VersionedDependency(
name="python",
version="=3.10.9",
manager="conda",
category="main",
extras=[],
),
VersionedDependency(
name="pandas",
version="=1.5.3",
manager="conda",
category="test",
extras=[],
),
VersionedDependency(
name="pyarrow",
version="=9.0.0",
manager="conda",
category="dev",
extras=[],
),
],
},
channels=channels,
# NB: this file must exist for relative path resolution to work
# in create_lockfile_from_spec
sources=[Path(tf.name)],
)

vpr = default_virtual_package_repodata()
with vpr:
locked_deps = _solve_for_arch(
conda=_conda_exe,
spec=spec,
platform="linux-64",
channels=channels,
pip_repositories=[],
virtual_package_repo=vpr,
)
python_deps = [dep for dep in locked_deps if dep.name == "python"]
assert len(python_deps) == 1
assert python_deps[0].categories == {"main"}
numpy_deps = [dep for dep in locked_deps if dep.name == "numpy"]
assert len(numpy_deps) == 1
assert numpy_deps[0].categories == {"test", "dev"}


def _check_package_installed(package: str, prefix: str):
import glob

Expand Down

0 comments on commit 1ac5f7c

Please sign in to comment.