From 70b99bf291e89f5ff2ac5a3ab0faeeee3ce1c74e Mon Sep 17 00:00:00 2001 From: Wackyator Date: Thu, 28 Sep 2023 21:50:24 +0530 Subject: [PATCH 01/12] feat: add solver --- py-rattler/Cargo.lock | 84 +++++++++++++++++++++++++++ py-rattler/Cargo.toml | 2 +- py-rattler/rattler/__init__.py | 4 ++ py-rattler/rattler/solver/__init__.py | 3 + py-rattler/rattler/solver/solver.py | 33 +++++++++++ py-rattler/src/error.rs | 5 ++ py-rattler/src/lib.rs | 5 ++ py-rattler/src/solver.rs | 44 ++++++++++++++ 8 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 py-rattler/rattler/solver/__init__.py create mode 100644 py-rattler/rattler/solver/solver.py create mode 100644 py-rattler/src/solver.rs diff --git a/py-rattler/Cargo.lock b/py-rattler/Cargo.lock index 8db4d65bc..ec182bf7d 100644 --- a/py-rattler/Cargo.lock +++ b/py-rattler/Cargo.lock @@ -248,6 +248,18 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + [[package]] name = "blake2" version = "0.10.6" @@ -558,6 +570,15 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "elsa" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714f766f3556b44e7e4776ad133fcc3445a489517c25c704ace411bb14790194" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "encoding_rs" version = "0.8.33" @@ -660,6 +681,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.0.27" @@ -700,6 +727,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "futures" version = "0.1.31" @@ -1627,6 +1660,16 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +[[package]] +name = "petgraph" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" +dependencies = [ + "fixedbitset", + "indexmap 2.0.0", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -1836,6 +1879,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.8.5" @@ -2071,6 +2120,7 @@ dependencies = [ "rattler_conda_types", "rattler_digest", "rattler_libsolv_c", + "resolvo", "serde", "tempfile", "thiserror", @@ -2192,6 +2242,19 @@ dependencies = [ "winreg", ] +[[package]] +name = "resolvo" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dab30801b54723f1949c6453a35db09c89e2ce7e052dc63e715f32fb40e427c" +dependencies = [ + "bitvec", + "elsa", + "itertools", + "petgraph", + "tracing", +] + [[package]] name = "retry-policies" version = "0.2.0" @@ -2502,6 +2565,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "static_assertions" version = "1.1.0" @@ -2570,6 +2639,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "tar" version = "0.4.40" @@ -3116,6 +3191,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "xattr" version = "1.0.1" diff --git a/py-rattler/Cargo.toml b/py-rattler/Cargo.toml index 98c3576c2..0de38e5a3 100644 --- a/py-rattler/Cargo.toml +++ b/py-rattler/Cargo.toml @@ -20,7 +20,7 @@ rattler_conda_types = { path = "../crates/rattler_conda_types", default-features rattler_networking = { path = "../crates/rattler_networking", default-features = false } rattler_shell = { path = "../crates/rattler_shell", default-features = false } rattler_virtual_packages = { path = "../crates/rattler_virtual_packages" } -rattler_solve = { path = "../crates/rattler_solve" } +rattler_solve = { path = "../crates/rattler_solve", features = ["resolvo"] } pyo3 = { version = "0.19", features = [ "abi3-py38", diff --git a/py-rattler/rattler/__init__.py b/py-rattler/rattler/__init__.py index 4104d03cc..5076cf9d8 100644 --- a/py-rattler/rattler/__init__.py +++ b/py-rattler/rattler/__init__.py @@ -12,6 +12,8 @@ from rattler.virtual_package import GenericVirtualPackage, VirtualPackage from rattler.package import PackageName from rattler.prefix import PrefixRecord, PrefixPaths +from rattler.solver import solve +from rattler.platform import Platform __all__ = [ "Version", @@ -32,4 +34,6 @@ "PrefixRecord", "PrefixPaths", "SparseRepoData", + "solve", + "Platform", ] diff --git a/py-rattler/rattler/solver/__init__.py b/py-rattler/rattler/solver/__init__.py new file mode 100644 index 000000000..084e30ab0 --- /dev/null +++ b/py-rattler/rattler/solver/__init__.py @@ -0,0 +1,3 @@ +from rattler.solver.solver import solve + +__all__ = ["solve"] diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py new file mode 100644 index 000000000..82d498383 --- /dev/null +++ b/py-rattler/rattler/solver/solver.py @@ -0,0 +1,33 @@ +from __future__ import annotations +from typing import List +from rattler.match_spec.match_spec import MatchSpec + +from rattler.rattler import py_solve +from rattler.repo_data.record import RepoDataRecord +from rattler.virtual_package.generic import GenericVirtualPackage + + +def solve( + specs: List[MatchSpec], + available_packages: List[List[RepoDataRecord]], + locked_packages: List[RepoDataRecord], + pinned_packages: List[RepoDataRecord], + virtual_packages: List[GenericVirtualPackage], +) -> List[RepoDataRecord]: + """ + Resolve the dependencies and return the [`RepoDataRecord`]s + that should be present in the environment. + """ + return [ + RepoDataRecord._from_py_record(solved_package) + for solved_package in py_solve( + [spec._match_spec for spec in specs], + [ + [package._record for package in list_of_packages] + for list_of_packages in available_packages + ], + [package._record for package in locked_packages], + [package._record for package in pinned_packages], + [v_package._generic_virtual_package for v_package in virtual_packages], + ) + ] diff --git a/py-rattler/src/error.rs b/py-rattler/src/error.rs index b806e7e22..b79ccd787 100644 --- a/py-rattler/src/error.rs +++ b/py-rattler/src/error.rs @@ -8,6 +8,7 @@ use rattler_conda_types::{ }; use rattler_repodata_gateway::fetch::FetchRepoDataError; use rattler_shell::activation::ActivationError; +use rattler_solve::SolveError; use rattler_virtual_packages::DetectVirtualPackageError; use thiserror::Error; @@ -38,6 +39,8 @@ pub enum PyRattlerError { DetectVirtualPackageError(#[from] DetectVirtualPackageError), #[error(transparent)] IoError(#[from] io::Error), + #[error(transparent)] + SolverError(#[from] SolveError), } impl From for PyErr { @@ -69,6 +72,7 @@ impl From for PyErr { DetectVirtualPackageException::new_err(err.to_string()) } PyRattlerError::IoError(err) => IoException::new_err(err.to_string()), + PyRattlerError::SolverError(err) => SolverException::new_err(err.to_string()), } } } @@ -85,3 +89,4 @@ create_exception!(exceptions, FetchRepoDataException, PyException); create_exception!(exceptions, CacheDirException, PyException); create_exception!(exceptions, DetectVirtualPackageException, PyException); create_exception!(exceptions, IoException, PyException); +create_exception!(exceptions, SolverException, PyException); diff --git a/py-rattler/src/lib.rs b/py-rattler/src/lib.rs index e3429efc7..dd0264ffc 100644 --- a/py-rattler/src/lib.rs +++ b/py-rattler/src/lib.rs @@ -9,6 +9,7 @@ mod platform; mod prefix_record; mod repo_data; mod shell; +mod solver; mod version; mod virtual_package; @@ -34,6 +35,7 @@ use pyo3::prelude::*; use platform::{PyArch, PyPlatform}; use shell::{PyActivationResult, PyActivationVariables, PyActivator, PyShellEnum}; +use solver::py_solve; use virtual_package::PyVirtualPackage; #[pymodule] @@ -71,6 +73,9 @@ fn rattler(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::().unwrap(); m.add_class::().unwrap(); + m.add_function(wrap_pyfunction!(py_solve, m).unwrap()) + .unwrap(); + // Exceptions m.add( "InvalidVersionError", diff --git a/py-rattler/src/solver.rs b/py-rattler/src/solver.rs new file mode 100644 index 000000000..37af0d809 --- /dev/null +++ b/py-rattler/src/solver.rs @@ -0,0 +1,44 @@ +use pyo3::{pyfunction, PyResult}; +use rattler_conda_types::RepoDataRecord; +use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; + +use crate::{ + error::PyRattlerError, generic_virtual_package::PyGenericVirtualPackage, + match_spec::PyMatchSpec, repo_data::repo_data_record::PyRepoDataRecord, +}; + +#[pyfunction] +pub fn py_solve( + specs: Vec, + available_packages: Vec>, + locked_packages: Vec, + pinned_packages: Vec, + virtual_packages: Vec, +) -> PyResult> { + let available_packages = available_packages + .into_iter() + .map(|records| { + records + .into_iter() + .map(Into::::into) + .collect::>() + }) + .collect::>(); + + let task = SolverTask { + available_packages: &available_packages, + locked_packages: locked_packages.into_iter().map(Into::into).collect(), + pinned_packages: pinned_packages.into_iter().map(Into::into).collect(), + virtual_packages: virtual_packages.into_iter().map(Into::into).collect(), + specs: specs.into_iter().map(Into::into).collect(), + }; + + Ok(Solver + .solve(task) + .map(|res| { + res.into_iter() + .map(Into::into) + .collect::>() + }) + .map_err(PyRattlerError::from)?) +} From f2bbdfccdda812df580526b6739312cc054e4bc6 Mon Sep 17 00:00:00 2001 From: Wackyator Date: Fri, 29 Sep 2023 13:44:55 +0530 Subject: [PATCH 02/12] test: add solver test --- py-rattler/tests/unit/test_solver.py | 109 +++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 py-rattler/tests/unit/test_solver.py diff --git a/py-rattler/tests/unit/test_solver.py b/py-rattler/tests/unit/test_solver.py new file mode 100644 index 000000000..ead220402 --- /dev/null +++ b/py-rattler/tests/unit/test_solver.py @@ -0,0 +1,109 @@ +# type: ignore +import os.path +import subprocess + +import pytest +from rattler import ( + Channel, + ChannelConfig, + fetch_repo_data, + Platform, + solve, + MatchSpec, + RepoDataRecord, +) + + +@pytest.fixture(scope="session") +def noarch_repo_data() -> None: + port, repo_name = 8912, "test-repo-1" + + test_data_dir = os.path.join( + os.path.dirname(__file__), "../../../test-data/test-server" + ) + + with subprocess.Popen( + [ + "python", + os.path.join(test_data_dir, "reposerver.py"), + "-d", + os.path.join(test_data_dir, "repo"), + "-n", + repo_name, + "-p", + str(port), + ] + ) as proc: + yield port, repo_name + proc.terminate() + + +@pytest.fixture(scope="session") +def linux64_repo_data() -> None: + port, repo_name = 8913, "test-repo-2" + + test_data_dir = os.path.join(os.path.dirname(__file__), "../../../test-data/") + + with subprocess.Popen( + [ + "python", + os.path.join(test_data_dir, "test-server/reposerver.py"), + "-d", + os.path.join(test_data_dir, "channels/dummy/"), + "-n", + repo_name, + "-p", + str(port), + ] + ) as proc: + yield port, repo_name + proc.terminate() + + +@pytest.mark.asyncio +async def test_solve( + tmp_path, + noarch_repo_data, + linux64_repo_data, +): + noarch_port, noarch_repo = noarch_repo_data + linux64_port, linux64_repo = linux64_repo_data + cache_dir = tmp_path / "test_repo_data_download" + noarch_chan = Channel( + noarch_repo, ChannelConfig(f"http://localhost:{noarch_port}/") + ) + plat_noarch = Platform("noarch") + linux64_chan = Channel( + linux64_repo, ChannelConfig(f"http://localhost:{linux64_port}/") + ) + plat_linux64 = Platform("linux-64") + + noarch_data = await fetch_repo_data( + channels=[noarch_chan], + platforms=[plat_noarch], + cache_path=cache_dir, + callback=None, + ) + + linux64_data = await fetch_repo_data( + channels=[linux64_chan], + platforms=[plat_linux64], + cache_path=cache_dir, + callback=None, + ) + + available_packages = [ + package.into_repo_data(noarch_chan) for package in noarch_data + ] + [package.into_repo_data(linux64_chan) for package in linux64_data] + + solved_data = solve( + [MatchSpec("test-package"), MatchSpec("foobar"), MatchSpec("baz")], + available_packages, + [], + [], + [], + ) + + assert isinstance(solved_data, list) + assert isinstance(solved_data[0], RepoDataRecord) + assert len(solved_data) == 4 From 5d75959710da1147ccbd00211f213264002e4bfc Mon Sep 17 00:00:00 2001 From: Wackyator Date: Fri, 29 Sep 2023 13:56:29 +0530 Subject: [PATCH 03/12] fix: try different ports for solver test --- py-rattler/tests/unit/test_solver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py-rattler/tests/unit/test_solver.py b/py-rattler/tests/unit/test_solver.py index ead220402..792ff19dc 100644 --- a/py-rattler/tests/unit/test_solver.py +++ b/py-rattler/tests/unit/test_solver.py @@ -16,7 +16,7 @@ @pytest.fixture(scope="session") def noarch_repo_data() -> None: - port, repo_name = 8912, "test-repo-1" + port, repo_name = 8812, "test-repo-1" test_data_dir = os.path.join( os.path.dirname(__file__), "../../../test-data/test-server" @@ -40,7 +40,7 @@ def noarch_repo_data() -> None: @pytest.fixture(scope="session") def linux64_repo_data() -> None: - port, repo_name = 8913, "test-repo-2" + port, repo_name = 8813, "test-repo-2" test_data_dir = os.path.join(os.path.dirname(__file__), "../../../test-data/") From f1689894cf3f303e86761e23caa48f132233e3e0 Mon Sep 17 00:00:00 2001 From: Wackyator Date: Fri, 29 Sep 2023 14:18:10 +0530 Subject: [PATCH 04/12] fix: solver takes repodata directly --- py-rattler/rattler/solver/solver.py | 10 ++++++---- py-rattler/src/solver.rs | 17 +++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py index 82d498383..b31b4c271 100644 --- a/py-rattler/rattler/solver/solver.py +++ b/py-rattler/rattler/solver/solver.py @@ -1,15 +1,17 @@ from __future__ import annotations -from typing import List +from typing import List, Tuple from rattler.match_spec.match_spec import MatchSpec from rattler.rattler import py_solve from rattler.repo_data.record import RepoDataRecord from rattler.virtual_package.generic import GenericVirtualPackage +from rattler.repo_data.repo_data import RepoData +from rattler.channel import Channel def solve( specs: List[MatchSpec], - available_packages: List[List[RepoDataRecord]], + available_packages: List[Tuple[RepoData, Channel]], locked_packages: List[RepoDataRecord], pinned_packages: List[RepoDataRecord], virtual_packages: List[GenericVirtualPackage], @@ -23,8 +25,8 @@ def solve( for solved_package in py_solve( [spec._match_spec for spec in specs], [ - [package._record for package in list_of_packages] - for list_of_packages in available_packages + (repo_data._repo_data, channel._channel) + for (repo_data, channel) in available_packages ], [package._record for package in locked_packages], [package._record for package in pinned_packages], diff --git a/py-rattler/src/solver.rs b/py-rattler/src/solver.rs index 37af0d809..f88c7f0a4 100644 --- a/py-rattler/src/solver.rs +++ b/py-rattler/src/solver.rs @@ -1,28 +1,25 @@ use pyo3::{pyfunction, PyResult}; -use rattler_conda_types::RepoDataRecord; use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; use crate::{ - error::PyRattlerError, generic_virtual_package::PyGenericVirtualPackage, - match_spec::PyMatchSpec, repo_data::repo_data_record::PyRepoDataRecord, + channel::PyChannel, + error::PyRattlerError, + generic_virtual_package::PyGenericVirtualPackage, + match_spec::PyMatchSpec, + repo_data::{repo_data_record::PyRepoDataRecord, PyRepoData}, }; #[pyfunction] pub fn py_solve( specs: Vec, - available_packages: Vec>, + available_packages: Vec<(PyRepoData, PyChannel)>, locked_packages: Vec, pinned_packages: Vec, virtual_packages: Vec, ) -> PyResult> { let available_packages = available_packages .into_iter() - .map(|records| { - records - .into_iter() - .map(Into::::into) - .collect::>() - }) + .map(|(records, channel)| records.inner.into_repo_data_records(&channel.inner)) .collect::>(); let task = SolverTask { From a2596e45da070b5fc58484203520cc4bc95e3956 Mon Sep 17 00:00:00 2001 From: Wackyator Date: Fri, 29 Sep 2023 14:18:39 +0530 Subject: [PATCH 05/12] test: update test --- py-rattler/tests/unit/test_solver.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/py-rattler/tests/unit/test_solver.py b/py-rattler/tests/unit/test_solver.py index 792ff19dc..864da0f78 100644 --- a/py-rattler/tests/unit/test_solver.py +++ b/py-rattler/tests/unit/test_solver.py @@ -92,10 +92,9 @@ async def test_solve( callback=None, ) - available_packages = [ - package.into_repo_data(noarch_chan) for package in noarch_data - ] + [package.into_repo_data(linux64_chan) for package in linux64_data] - + available_packages = [(data, noarch_chan) for data in noarch_data] + [ + (data, linux64_chan) for data in linux64_data + ] solved_data = solve( [MatchSpec("test-package"), MatchSpec("foobar"), MatchSpec("baz")], available_packages, From d77323108b0cb188c0d618037bd06f2e0d3d0987 Mon Sep 17 00:00:00 2001 From: Wackyator Date: Mon, 2 Oct 2023 12:37:05 +0530 Subject: [PATCH 06/12] fix: move large SparseRepoData extraction move to rust ffi --- py-rattler/rattler/solver/solver.py | 12 +++----- py-rattler/src/error.rs | 6 ++++ py-rattler/src/match_spec.rs | 2 +- py-rattler/src/repo_data/sparse.rs | 33 +++++++++++++++++++-- py-rattler/src/solver.rs | 45 +++++++++++++++++++---------- 5 files changed, 71 insertions(+), 27 deletions(-) diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py index b31b4c271..4a3c4ee80 100644 --- a/py-rattler/rattler/solver/solver.py +++ b/py-rattler/rattler/solver/solver.py @@ -1,17 +1,16 @@ from __future__ import annotations -from typing import List, Tuple +from typing import List from rattler.match_spec.match_spec import MatchSpec from rattler.rattler import py_solve from rattler.repo_data.record import RepoDataRecord +from rattler.repo_data.sparse import SparseRepoData from rattler.virtual_package.generic import GenericVirtualPackage -from rattler.repo_data.repo_data import RepoData -from rattler.channel import Channel def solve( specs: List[MatchSpec], - available_packages: List[Tuple[RepoData, Channel]], + available_packages: List[SparseRepoData], locked_packages: List[RepoDataRecord], pinned_packages: List[RepoDataRecord], virtual_packages: List[GenericVirtualPackage], @@ -24,10 +23,7 @@ def solve( RepoDataRecord._from_py_record(solved_package) for solved_package in py_solve( [spec._match_spec for spec in specs], - [ - (repo_data._repo_data, channel._channel) - for (repo_data, channel) in available_packages - ], + available_packages, [package._record for package in locked_packages], [package._record for package in pinned_packages], [v_package._generic_virtual_package for v_package in virtual_packages], diff --git a/py-rattler/src/error.rs b/py-rattler/src/error.rs index b79ccd787..efdc7cd91 100644 --- a/py-rattler/src/error.rs +++ b/py-rattler/src/error.rs @@ -41,6 +41,8 @@ pub enum PyRattlerError { IoError(#[from] io::Error), #[error(transparent)] SolverError(#[from] SolveError), + #[error("invalid 'SparseRepoData' object found")] + InvalidSparseDataError, } impl From for PyErr { @@ -73,6 +75,9 @@ impl From for PyErr { } PyRattlerError::IoError(err) => IoException::new_err(err.to_string()), PyRattlerError::SolverError(err) => SolverException::new_err(err.to_string()), + PyRattlerError::InvalidSparseDataError => InvalidSparseDataException::new_err( + PyRattlerError::InvalidSparseDataError.to_string(), + ), } } } @@ -90,3 +95,4 @@ create_exception!(exceptions, CacheDirException, PyException); create_exception!(exceptions, DetectVirtualPackageException, PyException); create_exception!(exceptions, IoException, PyException); create_exception!(exceptions, SolverException, PyException); +create_exception!(exceptions, InvalidSparseDataException, PyException); diff --git a/py-rattler/src/match_spec.rs b/py-rattler/src/match_spec.rs index 7c2a3a241..3c52b7f18 100644 --- a/py-rattler/src/match_spec.rs +++ b/py-rattler/src/match_spec.rs @@ -11,7 +11,7 @@ use crate::{ #[repr(transparent)] #[derive(Clone)] pub struct PyMatchSpec { - inner: MatchSpec, + pub(crate) inner: MatchSpec, } impl From for PyMatchSpec { diff --git a/py-rattler/src/repo_data/sparse.rs b/py-rattler/src/repo_data/sparse.rs index e387219ec..890aff80f 100644 --- a/py-rattler/src/repo_data/sparse.rs +++ b/py-rattler/src/repo_data/sparse.rs @@ -1,10 +1,11 @@ use std::{path::PathBuf, sync::Arc}; -use pyo3::{pyclass, pymethods, PyResult, Python}; +use pyo3::{intern, pyclass, pymethods, FromPyObject, PyAny, PyResult, Python}; use rattler_repodata_gateway::sparse::SparseRepoData; use crate::channel::PyChannel; +use crate::error::PyRattlerError; use crate::package_name::PyPackageName; use crate::repo_data::repo_data_record::PyRepoDataRecord; @@ -23,6 +24,34 @@ impl From for PySparseRepoData { } } +impl<'a> From<&'a PySparseRepoData> for &'a SparseRepoData { + fn from(value: &'a PySparseRepoData) -> Self { + value.inner.as_ref() + } +} + +impl<'a> TryFrom<&'a PyAny> for PySparseRepoData { + type Error = pyo3::PyErr; + fn try_from(value: &'a PyAny) -> Result { + let intern_val = intern!(value.py(), "_sparse"); + if !value.hasattr(intern_val)? { + return Err(PyRattlerError::from(anyhow::anyhow!( + "TypeError: Object is not an instance of 'SparseRepoData'" + )) + .into()); + } + + let inner = value.getattr(intern_val)?; + if !inner.is_instance_of::() { + return Err( + PyRattlerError::from(anyhow::anyhow!("TypeError: '_sparse' is invalid!")).into(), + ); + } + + PySparseRepoData::extract(inner) + } +} + #[pymethods] impl PySparseRepoData { #[new] @@ -57,7 +86,7 @@ impl PySparseRepoData { repo_data: Vec, package_names: Vec, ) -> PyResult>> { - let repo_data = repo_data.iter().map(|r| r.inner.as_ref()); + let repo_data = repo_data.iter().map(Into::into); let package_names = package_names.into_iter().map(Into::into); // release gil to allow other threads to progress diff --git a/py-rattler/src/solver.rs b/py-rattler/src/solver.rs index f88c7f0a4..e85a0e97f 100644 --- a/py-rattler/src/solver.rs +++ b/py-rattler/src/solver.rs @@ -1,26 +1,36 @@ -use pyo3::{pyfunction, PyResult}; +use pyo3::{pyfunction, PyAny, PyResult, Python}; +use rattler_repodata_gateway::sparse::SparseRepoData; use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; use crate::{ - channel::PyChannel, error::PyRattlerError, generic_virtual_package::PyGenericVirtualPackage, match_spec::PyMatchSpec, - repo_data::{repo_data_record::PyRepoDataRecord, PyRepoData}, + repo_data::{repo_data_record::PyRepoDataRecord, sparse::PySparseRepoData}, }; #[pyfunction] pub fn py_solve( + py: Python<'_>, specs: Vec, - available_packages: Vec<(PyRepoData, PyChannel)>, + available_packages: Vec<&'_ PyAny>, locked_packages: Vec, pinned_packages: Vec, virtual_packages: Vec, ) -> PyResult> { - let available_packages = available_packages - .into_iter() - .map(|(records, channel)| records.inner.into_repo_data_records(&channel.inner)) - .collect::>(); + let packages = available_packages + .iter() + .map(|&pkg| TryInto::::try_into(pkg)) + .collect::>>()?; + + let package_names = specs + .iter() + .filter_map(|match_spec| match_spec.inner.name.clone()); + + // pure rust operation, release gil to allow python threads to do other work + let available_packages = py.allow_threads(move || { + SparseRepoData::load_records_recursive(packages.iter().map(Into::into), package_names, None) + })?; let task = SolverTask { available_packages: &available_packages, @@ -30,12 +40,15 @@ pub fn py_solve( specs: specs.into_iter().map(Into::into).collect(), }; - Ok(Solver - .solve(task) - .map(|res| { - res.into_iter() - .map(Into::into) - .collect::>() - }) - .map_err(PyRattlerError::from)?) + // pure rust operation, release gil to allow python threads to do other work + py.allow_threads(move || { + Ok(Solver + .solve(task) + .map(|res| { + res.into_iter() + .map(Into::into) + .collect::>() + }) + .map_err(PyRattlerError::from)?) + }) } From 5cb5b4edd3b4ed203b868a3384d71120e59712f6 Mon Sep 17 00:00:00 2001 From: Wackyator Date: Mon, 2 Oct 2023 13:26:41 +0530 Subject: [PATCH 07/12] revert: move large SparseRepoData extraction move to rust ffi --- py-rattler/src/repo_data/sparse.rs | 31 +++------------------- py-rattler/src/solver.rs | 41 +++++++++++++----------------- 2 files changed, 21 insertions(+), 51 deletions(-) diff --git a/py-rattler/src/repo_data/sparse.rs b/py-rattler/src/repo_data/sparse.rs index 890aff80f..b8e5be46f 100644 --- a/py-rattler/src/repo_data/sparse.rs +++ b/py-rattler/src/repo_data/sparse.rs @@ -1,11 +1,10 @@ use std::{path::PathBuf, sync::Arc}; -use pyo3::{intern, pyclass, pymethods, FromPyObject, PyAny, PyResult, Python}; +use pyo3::{pyclass, pymethods, PyResult, Python}; use rattler_repodata_gateway::sparse::SparseRepoData; use crate::channel::PyChannel; -use crate::error::PyRattlerError; use crate::package_name::PyPackageName; use crate::repo_data::repo_data_record::PyRepoDataRecord; @@ -30,28 +29,6 @@ impl<'a> From<&'a PySparseRepoData> for &'a SparseRepoData { } } -impl<'a> TryFrom<&'a PyAny> for PySparseRepoData { - type Error = pyo3::PyErr; - fn try_from(value: &'a PyAny) -> Result { - let intern_val = intern!(value.py(), "_sparse"); - if !value.hasattr(intern_val)? { - return Err(PyRattlerError::from(anyhow::anyhow!( - "TypeError: Object is not an instance of 'SparseRepoData'" - )) - .into()); - } - - let inner = value.getattr(intern_val)?; - if !inner.is_instance_of::() { - return Err( - PyRattlerError::from(anyhow::anyhow!("TypeError: '_sparse' is invalid!")).into(), - ); - } - - PySparseRepoData::extract(inner) - } -} - #[pymethods] impl PySparseRepoData { #[new] @@ -86,11 +63,9 @@ impl PySparseRepoData { repo_data: Vec, package_names: Vec, ) -> PyResult>> { - let repo_data = repo_data.iter().map(Into::into); - let package_names = package_names.into_iter().map(Into::into); - - // release gil to allow other threads to progress py.allow_threads(move || { + let repo_data = repo_data.iter().map(Into::into); + let package_names = package_names.into_iter().map(Into::into); Ok( SparseRepoData::load_records_recursive(repo_data, package_names, None)? .into_iter() diff --git a/py-rattler/src/solver.rs b/py-rattler/src/solver.rs index e85a0e97f..380b4d213 100644 --- a/py-rattler/src/solver.rs +++ b/py-rattler/src/solver.rs @@ -1,4 +1,4 @@ -use pyo3::{pyfunction, PyAny, PyResult, Python}; +use pyo3::{pyfunction, PyResult, Python}; use rattler_repodata_gateway::sparse::SparseRepoData; use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; @@ -13,35 +13,30 @@ use crate::{ pub fn py_solve( py: Python<'_>, specs: Vec, - available_packages: Vec<&'_ PyAny>, + available_packages: Vec, locked_packages: Vec, pinned_packages: Vec, virtual_packages: Vec, ) -> PyResult> { - let packages = available_packages - .iter() - .map(|&pkg| TryInto::::try_into(pkg)) - .collect::>>()?; - - let package_names = specs - .iter() - .filter_map(|match_spec| match_spec.inner.name.clone()); + py.allow_threads(move || { + let package_names = specs + .iter() + .filter_map(|match_spec| match_spec.inner.name.clone()); - // pure rust operation, release gil to allow python threads to do other work - let available_packages = py.allow_threads(move || { - SparseRepoData::load_records_recursive(packages.iter().map(Into::into), package_names, None) - })?; + let available_packages = SparseRepoData::load_records_recursive( + available_packages.iter().map(Into::into), + package_names, + None, + )?; - let task = SolverTask { - available_packages: &available_packages, - locked_packages: locked_packages.into_iter().map(Into::into).collect(), - pinned_packages: pinned_packages.into_iter().map(Into::into).collect(), - virtual_packages: virtual_packages.into_iter().map(Into::into).collect(), - specs: specs.into_iter().map(Into::into).collect(), - }; + let task = SolverTask { + available_packages: &available_packages, + locked_packages: locked_packages.into_iter().map(Into::into).collect(), + pinned_packages: pinned_packages.into_iter().map(Into::into).collect(), + virtual_packages: virtual_packages.into_iter().map(Into::into).collect(), + specs: specs.into_iter().map(Into::into).collect(), + }; - // pure rust operation, release gil to allow python threads to do other work - py.allow_threads(move || { Ok(Solver .solve(task) .map(|res| { From 1701fac885539425f8067a3eeeb7c01a9f5c63b7 Mon Sep 17 00:00:00 2001 From: Wackyator Date: Mon, 2 Oct 2023 14:06:12 +0530 Subject: [PATCH 08/12] fix: add default values for optional values --- py-rattler/rattler/solver/solver.py | 45 +++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py index 4a3c4ee80..b8be17c35 100644 --- a/py-rattler/rattler/solver/solver.py +++ b/py-rattler/rattler/solver/solver.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import List +from typing import List, Optional from rattler.match_spec.match_spec import MatchSpec from rattler.rattler import py_solve @@ -11,19 +11,52 @@ def solve( specs: List[MatchSpec], available_packages: List[SparseRepoData], - locked_packages: List[RepoDataRecord], - pinned_packages: List[RepoDataRecord], - virtual_packages: List[GenericVirtualPackage], + locked_packages: Optional[List[RepoDataRecord]] = None, + pinned_packages: Optional[List[RepoDataRecord]] = None, + virtual_packages: Optional[List[GenericVirtualPackage]] = None, ) -> List[RepoDataRecord]: """ - Resolve the dependencies and return the [`RepoDataRecord`]s + Resolve the dependencies and return the `RepoDataRecord`s that should be present in the environment. + + Arguments: + specs: A list of matchspec to solve. + available_packages: A list of RepoData to use for solving the `specs`. + locked_packages: Records of packages that are previously selected. + If the solver encounters multiple variants of a single + package (identified by its name), it will sort the records + and select the best possible version. However, if there + exists a locked version it will prefer that variant instead. + This is useful to reduce the number of packages that are + updated when installing new packages. Usually you add the + currently installed packages or packages from a lock-file here. + pinned_packages: Records of packages that are previously selected and CANNOT + be changed. If the solver encounters multiple variants of + a single package (identified by its name), it will sort the + records and select the best possible version. However, if + there is a variant available in the `pinned_packages` field it + will always select that version no matter what even if that + means other packages have to be downgraded. + virtual_packages: A list of virtual packages considered active. + + Returns: + Resolved list of `RepoDataRecord`s. """ + + if not locked_packages: + locked_packages = list() + + if not pinned_packages: + pinned_packages = list() + + if not virtual_packages: + virtual_packages = list() + return [ RepoDataRecord._from_py_record(solved_package) for solved_package in py_solve( [spec._match_spec for spec in specs], - available_packages, + [package._sparse for package in available_packages], [package._record for package in locked_packages], [package._record for package in pinned_packages], [v_package._generic_virtual_package for v_package in virtual_packages], From 91a322fc5a66a044ad77382c953d2cd2972caf93 Mon Sep 17 00:00:00 2001 From: Wackyator Date: Mon, 2 Oct 2023 16:42:10 +0530 Subject: [PATCH 09/12] test: fix test --- py-rattler/pixi.lock | 21 ++++++ py-rattler/pixi.toml | 4 ++ py-rattler/tests/unit/test_solver.py | 101 ++++----------------------- 3 files changed, 38 insertions(+), 88 deletions(-) diff --git a/py-rattler/pixi.lock b/py-rattler/pixi.lock index 7e25dd991..bbfa9c170 100644 --- a/py-rattler/pixi.lock +++ b/py-rattler/pixi.lock @@ -1779,6 +1779,27 @@ package: noarch: python size: 46098 timestamp: 1681337144376 +- name: patchelf + version: 0.17.2 + manager: conda + platform: linux-64 + dependencies: + libgcc-ng: '>=7.5.0' + libstdcxx-ng: '>=7.5.0' + url: https://conda.anaconda.org/conda-forge/linux-64/patchelf-0.17.2-h58526e2_0.conda + hash: + md5: ba76a6a448819560b5f8b08a9c74f415 + sha256: eb355ac225be2f698e19dba4dcab7cb0748225677a9799e9cc8e4cadc3cb738f + optional: false + category: main + build: h58526e2_0 + arch: x86_64 + subdir: linux-64 + build_number: 0 + license: GPL-3.0-or-later + license_family: GPL + size: 94048 + timestamp: 1673473024463 - name: pathspec version: 0.11.2 manager: conda diff --git a/py-rattler/pixi.toml b/py-rattler/pixi.toml index dafb3165b..a8d10b66c 100644 --- a/py-rattler/pixi.toml +++ b/py-rattler/pixi.toml @@ -12,6 +12,7 @@ license = "BSD-3-Clause" [tasks] build = "PIP_REQUIRE_VIRTUALENV=false maturin develop" +build-release = "PIP_REQUIRE_VIRTUALENV=false maturin develop --release" test = { cmd = "pytest --doctest-modules", depends_on = ["build"] } fmt-python = "black ." fmt-rust = "cargo fmt --all" @@ -34,3 +35,6 @@ black = "~=23.7.0" ruff = "~=0.0.285" mypy = "~=1.5.1" pytest-asyncio = "0.21.1.*" + +[target.linux-64.dependencies] +patchelf = "~=0.17.2" diff --git a/py-rattler/tests/unit/test_solver.py b/py-rattler/tests/unit/test_solver.py index 864da0f78..e76d83267 100644 --- a/py-rattler/tests/unit/test_solver.py +++ b/py-rattler/tests/unit/test_solver.py @@ -1,108 +1,33 @@ # type: ignore import os.path -import subprocess import pytest from rattler import ( + solve, Channel, ChannelConfig, - fetch_repo_data, - Platform, - solve, MatchSpec, RepoDataRecord, + SparseRepoData, ) -@pytest.fixture(scope="session") -def noarch_repo_data() -> None: - port, repo_name = 8812, "test-repo-1" - - test_data_dir = os.path.join( - os.path.dirname(__file__), "../../../test-data/test-server" - ) - - with subprocess.Popen( - [ - "python", - os.path.join(test_data_dir, "reposerver.py"), - "-d", - os.path.join(test_data_dir, "repo"), - "-n", - repo_name, - "-p", - str(port), - ] - ) as proc: - yield port, repo_name - proc.terminate() - - -@pytest.fixture(scope="session") -def linux64_repo_data() -> None: - port, repo_name = 8813, "test-repo-2" - - test_data_dir = os.path.join(os.path.dirname(__file__), "../../../test-data/") - - with subprocess.Popen( - [ - "python", - os.path.join(test_data_dir, "test-server/reposerver.py"), - "-d", - os.path.join(test_data_dir, "channels/dummy/"), - "-n", - repo_name, - "-p", - str(port), - ] - ) as proc: - yield port, repo_name - proc.terminate() - - @pytest.mark.asyncio -async def test_solve( - tmp_path, - noarch_repo_data, - linux64_repo_data, -): - noarch_port, noarch_repo = noarch_repo_data - linux64_port, linux64_repo = linux64_repo_data - cache_dir = tmp_path / "test_repo_data_download" - noarch_chan = Channel( - noarch_repo, ChannelConfig(f"http://localhost:{noarch_port}/") - ) - plat_noarch = Platform("noarch") - linux64_chan = Channel( - linux64_repo, ChannelConfig(f"http://localhost:{linux64_port}/") - ) - plat_linux64 = Platform("linux-64") - - noarch_data = await fetch_repo_data( - channels=[noarch_chan], - platforms=[plat_noarch], - cache_path=cache_dir, - callback=None, - ) - - linux64_data = await fetch_repo_data( - channels=[linux64_chan], - platforms=[plat_linux64], - cache_path=cache_dir, - callback=None, +async def test_solve(): + linux64_chan = Channel("conda-forge", ChannelConfig(f"http://localhost:{8813}/")) + data_dir = os.path.join(os.path.dirname(__file__), "../../../test-data/") + linux64_path = os.path.join(data_dir, "channels/conda-forge/linux-64/repodata.json") + linux64_data = SparseRepoData( + channel=linux64_chan, + subdir="linux-64", + path=linux64_path, ) - available_packages = [(data, noarch_chan) for data in noarch_data] + [ - (data, linux64_chan) for data in linux64_data - ] solved_data = solve( - [MatchSpec("test-package"), MatchSpec("foobar"), MatchSpec("baz")], - available_packages, - [], - [], - [], + [MatchSpec("python"), MatchSpec("sqlite")], + [linux64_data], ) assert isinstance(solved_data, list) assert isinstance(solved_data[0], RepoDataRecord) - assert len(solved_data) == 4 + assert len(solved_data) == 19 From 31b7cd7e80e834d2178d87da68a9b15d1f541443 Mon Sep 17 00:00:00 2001 From: Wackyator Date: Mon, 2 Oct 2023 18:20:30 +0530 Subject: [PATCH 10/12] fix: Channel uses default ChannelConfig if not defined --- py-rattler/rattler/channel/channel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py-rattler/rattler/channel/channel.py b/py-rattler/rattler/channel/channel.py index 29bd5b533..e22eebc30 100644 --- a/py-rattler/rattler/channel/channel.py +++ b/py-rattler/rattler/channel/channel.py @@ -6,7 +6,9 @@ class Channel: - def __init__(self, name: str, channel_configuration: ChannelConfig) -> None: + def __init__( + self, name: str, channel_configuration: ChannelConfig = ChannelConfig() + ) -> None: """ Create a new channel. From 633e33bbbdf7ebe5fe3f2d3259727c4c10c392a9 Mon Sep 17 00:00:00 2001 From: Wackyator Date: Mon, 2 Oct 2023 18:21:08 +0530 Subject: [PATCH 11/12] test: fix test --- py-rattler/tests/unit/test_solver.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py-rattler/tests/unit/test_solver.py b/py-rattler/tests/unit/test_solver.py index e76d83267..a9354ea95 100644 --- a/py-rattler/tests/unit/test_solver.py +++ b/py-rattler/tests/unit/test_solver.py @@ -5,7 +5,6 @@ from rattler import ( solve, Channel, - ChannelConfig, MatchSpec, RepoDataRecord, SparseRepoData, @@ -14,7 +13,7 @@ @pytest.mark.asyncio async def test_solve(): - linux64_chan = Channel("conda-forge", ChannelConfig(f"http://localhost:{8813}/")) + linux64_chan = Channel("conda-forge") data_dir = os.path.join(os.path.dirname(__file__), "../../../test-data/") linux64_path = os.path.join(data_dir, "channels/conda-forge/linux-64/repodata.json") linux64_data = SparseRepoData( From 76e916947038baa367728177a06d3bddc7d08c0c Mon Sep 17 00:00:00 2001 From: Wackyator Date: Mon, 2 Oct 2023 18:27:48 +0530 Subject: [PATCH 12/12] fix: remove unnecessary if checks --- py-rattler/rattler/solver/solver.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py index b8be17c35..d443c5b28 100644 --- a/py-rattler/rattler/solver/solver.py +++ b/py-rattler/rattler/solver/solver.py @@ -43,22 +43,16 @@ def solve( Resolved list of `RepoDataRecord`s. """ - if not locked_packages: - locked_packages = list() - - if not pinned_packages: - pinned_packages = list() - - if not virtual_packages: - virtual_packages = list() - return [ RepoDataRecord._from_py_record(solved_package) for solved_package in py_solve( [spec._match_spec for spec in specs], [package._sparse for package in available_packages], - [package._record for package in locked_packages], - [package._record for package in pinned_packages], - [v_package._generic_virtual_package for v_package in virtual_packages], + [package._record for package in locked_packages or []], + [package._record for package in pinned_packages or []], + [ + v_package._generic_virtual_package + for v_package in virtual_packages or [] + ], ) ]