diff --git a/crates/rattler-bin/src/commands/create.rs b/crates/rattler-bin/src/commands/create.rs index 5492602a5..b7f6bc631 100644 --- a/crates/rattler-bin/src/commands/create.rs +++ b/crates/rattler-bin/src/commands/create.rs @@ -167,6 +167,7 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { record.depends.push("pip".to_string()); } }), + true, ) })?; diff --git a/crates/rattler_repodata_gateway/src/sparse/mod.rs b/crates/rattler_repodata_gateway/src/sparse/mod.rs index 150ec4322..f15820932 100644 --- a/crates/rattler_repodata_gateway/src/sparse/mod.rs +++ b/crates/rattler_repodata_gateway/src/sparse/mod.rs @@ -118,10 +118,15 @@ impl SparseRepoData { /// /// This will parse the records for the specified packages as well as all the packages these records /// depend on. + /// + /// When strict_channel_priority is true, the channel where a package is found first will be + /// the only channel used for that package. Make it false to search in all channels for all packages. + /// pub fn load_records_recursive<'a>( repo_data: impl IntoIterator, package_names: impl IntoIterator, patch_function: Option, + strict_channel_priority: bool, ) -> io::Result>> { let repo_data: Vec<_> = repo_data.into_iter().collect(); @@ -136,7 +141,13 @@ impl SparseRepoData { // Iterate over the list of packages that still need to be processed. while let Some(next_package) = pending.pop_front() { + let mut found_in_channel = None; for (i, repo_data) in repo_data.iter().enumerate() { + // If package was found in other channel, skip this repodata + if found_in_channel.map_or(false, |c| c != &repo_data.channel.base_url) { + continue; + } + let repo_data_packages = repo_data.inner.borrow_repo_data(); let base_url = repo_data_packages .info @@ -162,6 +173,10 @@ impl SparseRepoData { )?; records.append(&mut conda_records); + if strict_channel_priority && !records.is_empty() { + found_in_channel = Some(&repo_data.channel.base_url); + } + // Iterate over all packages to find recursive dependencies. for record in records.iter() { for dependency in &record.package_record.depends { @@ -259,6 +274,7 @@ pub async fn load_repo_data_recursively( repo_data_paths: impl IntoIterator, impl AsRef)>, package_names: impl IntoIterator, patch_function: Option, + strict_channel_priority: bool, ) -> Result>, io::Error> { // Open the different files and memory map them to get access to their bytes. Do this in parallel. let lazy_repo_data = stream::iter(repo_data_paths) @@ -277,7 +293,12 @@ pub async fn load_repo_data_recursively( .try_collect::>() .await?; - SparseRepoData::load_records_recursive(&lazy_repo_data, package_names, patch_function) + SparseRepoData::load_records_recursive( + &lazy_repo_data, + package_names, + patch_function, + strict_channel_priority, + ) } fn deserialize_filename_and_raw_record<'d, D: Deserializer<'d>>( @@ -380,6 +401,7 @@ impl<'de> TryFrom<&'de str> for PackageFilename<'de> { #[cfg(test)] mod test { use super::{load_repo_data_recursively, PackageFilename}; + use itertools::Itertools; use rattler_conda_types::{Channel, ChannelConfig, PackageName, RepoData, RepoDataRecord}; use rstest::rstest; use std::path::{Path, PathBuf}; @@ -390,6 +412,7 @@ mod test { async fn load_sparse( package_names: impl IntoIterator>, + strict_channel_priority: bool, ) -> Vec> { load_repo_data_recursively( [ @@ -403,11 +426,17 @@ mod test { "linux-64", test_dir().join("channels/conda-forge/linux-64/repodata.json"), ), + ( + Channel::from_str("pytorch", &ChannelConfig::default()).unwrap(), + "linux-64", + test_dir().join("channels/pytorch/linux-64/repodata.json"), + ), ], package_names .into_iter() .map(|name| PackageName::try_from(name.as_ref()).unwrap()), None, + strict_channel_priority, ) .await .unwrap() @@ -415,13 +444,13 @@ mod test { #[tokio::test] async fn test_empty_sparse_load() { - let sparse_empty_data = load_sparse(Vec::::new()).await; - assert_eq!(sparse_empty_data, vec![vec![], vec![]]); + let sparse_empty_data = load_sparse(Vec::::new(), false).await; + assert_eq!(sparse_empty_data, vec![vec![], vec![], vec![]]); } #[tokio::test] async fn test_sparse_single() { - let sparse_empty_data = load_sparse(["_libgcc_mutex"]).await; + let sparse_empty_data = load_sparse(["_libgcc_mutex"], false).await; let total_records = sparse_empty_data .iter() .map(|repo| repo.len()) @@ -430,9 +459,45 @@ mod test { assert_eq!(total_records, 3); } + #[tokio::test] + async fn test_sparse_strict() { + // If we load pytorch-cpy from all channels (non-strict) we expect records from both + // conda-forge and the pytorch channels. + let sparse_data = load_sparse(["pytorch-cpu"], false).await; + let channels = sparse_data + .into_iter() + .flatten() + .filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu") + .map(|record| record.channel) + .unique() + .collect_vec(); + assert_eq!( + channels, + vec![ + String::from("https://conda.anaconda.org/conda-forge/"), + String::from("https://conda.anaconda.org/pytorch/") + ] + ); + + // If we load pytorch-cpy from strict channels we expect records only from the first channel + // that contains the package. In this case we expect records only from conda-forge. + let strict_sparse_data = load_sparse(["pytorch-cpu"], true).await; + let channels = strict_sparse_data + .into_iter() + .flatten() + .filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu") + .map(|record| record.channel) + .unique() + .collect_vec(); + assert_eq!( + channels, + vec![String::from("https://conda.anaconda.org/conda-forge/")] + ); + } + #[tokio::test] async fn test_parse_duplicate() { - let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"]).await; + let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"], false).await; let total_records = sparse_empty_data .iter() .map(|repo| repo.len()) @@ -444,7 +509,7 @@ mod test { #[tokio::test] async fn test_sparse_jupyterlab_detectron2() { - let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"]).await; + let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"], true).await; let total_records = sparse_empty_data .iter() @@ -456,30 +521,33 @@ mod test { #[tokio::test] async fn test_sparse_numpy_dev() { - let sparse_empty_data = load_sparse([ - "python", - "cython", - "compilers", - "openblas", - "nomkl", - "pytest", - "pytest-cov", - "pytest-xdist", - "hypothesis", - "mypy", - "typing_extensions", - "sphinx", - "numpydoc", - "ipython", - "scipy", - "pandas", - "matplotlib", - "pydata-sphinx-theme", - "pycodestyle", - "gitpython", - "cffi", - "pytz", - ]) + let sparse_empty_data = load_sparse( + [ + "python", + "cython", + "compilers", + "openblas", + "nomkl", + "pytest", + "pytest-cov", + "pytest-xdist", + "hypothesis", + "mypy", + "typing_extensions", + "sphinx", + "numpydoc", + "ipython", + "scipy", + "pandas", + "matplotlib", + "pydata-sphinx-theme", + "pycodestyle", + "gitpython", + "cffi", + "pytz", + ], + false, + ) .await; let total_records = sparse_empty_data diff --git a/crates/rattler_solve/benches/bench.rs b/crates/rattler_solve/benches/bench.rs index fbb23927b..98c08efa6 100644 --- a/crates/rattler_solve/benches/bench.rs +++ b/crates/rattler_solve/benches/bench.rs @@ -52,7 +52,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) { let names = specs.iter().map(|s| s.name.clone().unwrap()); let available_packages = - SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None).unwrap(); + SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None, true).unwrap(); #[cfg(feature = "libsolv_c")] group.bench_function("libsolv_c", |b| { diff --git a/crates/rattler_solve/tests/backends.rs b/crates/rattler_solve/tests/backends.rs index 5c56ab4ab..be97681df 100644 --- a/crates/rattler_solve/tests/backends.rs +++ b/crates/rattler_solve/tests/backends.rs @@ -110,7 +110,7 @@ fn solve_real_world(specs: Vec<&str>) -> Vec { let names = specs.iter().filter_map(|s| s.name.as_ref().cloned()); let available_packages = - SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap(); + SparseRepoData::load_records_recursive(sparse_repo_datas, names, None, true).unwrap(); let solver_task = SolverTask { available_packages: &available_packages, @@ -592,7 +592,7 @@ fn compare_solve(specs: Vec<&str>) { let names = specs.iter().filter_map(|s| s.name.as_ref().cloned()); let available_packages = - SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap(); + SparseRepoData::load_records_recursive(sparse_repo_datas, names, None, true).unwrap(); let extract_pkgs = |records: Vec| { let mut pkgs = records diff --git a/py-rattler/Cargo.lock b/py-rattler/Cargo.lock index 3fa8c4ca7..565ae0511 100644 --- a/py-rattler/Cargo.lock +++ b/py-rattler/Cargo.lock @@ -1922,7 +1922,7 @@ dependencies = [ [[package]] name = "rattler" -version = "0.10.0" +version = "0.11.0" dependencies = [ "anyhow", "async-compression", @@ -1961,7 +1961,7 @@ dependencies = [ [[package]] name = "rattler_conda_types" -version = "0.10.0" +version = "0.11.0" dependencies = [ "chrono", "fxhash", @@ -1988,7 +1988,7 @@ dependencies = [ [[package]] name = "rattler_digest" -version = "0.10.0" +version = "0.11.0" dependencies = [ "blake2", "digest", @@ -2002,7 +2002,7 @@ dependencies = [ [[package]] name = "rattler_macros" -version = "0.10.0" +version = "0.11.0" dependencies = [ "quote", "syn 2.0.37", @@ -2010,7 +2010,7 @@ dependencies = [ [[package]] name = "rattler_networking" -version = "0.10.0" +version = "0.11.0" dependencies = [ "anyhow", "dirs", @@ -2028,7 +2028,7 @@ dependencies = [ [[package]] name = "rattler_package_streaming" -version = "0.10.0" +version = "0.11.0" dependencies = [ "bzip2", "chrono", @@ -2050,7 +2050,7 @@ dependencies = [ [[package]] name = "rattler_repodata_gateway" -version = "0.10.0" +version = "0.11.0" dependencies = [ "anyhow", "async-compression", @@ -2087,7 +2087,7 @@ dependencies = [ [[package]] name = "rattler_shell" -version = "0.10.0" +version = "0.11.0" dependencies = [ "enum_dispatch", "indexmap 2.0.2", @@ -2102,7 +2102,7 @@ dependencies = [ [[package]] name = "rattler_solve" -version = "0.10.0" +version = "0.11.0" dependencies = [ "anyhow", "chrono", @@ -2120,7 +2120,7 @@ dependencies = [ [[package]] name = "rattler_virtual_packages" -version = "0.10.0" +version = "0.11.0" dependencies = [ "cfg-if", "libloading", diff --git a/py-rattler/rattler/repo_data/sparse.py b/py-rattler/rattler/repo_data/sparse.py index bc6432280..b7c40f83d 100644 --- a/py-rattler/rattler/repo_data/sparse.py +++ b/py-rattler/rattler/repo_data/sparse.py @@ -115,6 +115,7 @@ def subdir(self) -> str: def load_records_recursive( repo_data: List[SparseRepoData], package_names: List[PackageName], + strict_channel_priority: bool = True, ) -> List[List[RepoDataRecord]]: """ Given a set of [`SparseRepoData`]s load all the records @@ -142,6 +143,7 @@ def load_records_recursive( for list_of_records in PySparseRepoData.load_records_recursive( [r._sparse for r in repo_data], [p._name for p in package_names], + strict_channel_priority, ) ] diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py index d443c5b28..eacf23ef5 100644 --- a/py-rattler/rattler/solver/solver.py +++ b/py-rattler/rattler/solver/solver.py @@ -14,6 +14,7 @@ def solve( locked_packages: Optional[List[RepoDataRecord]] = None, pinned_packages: Optional[List[RepoDataRecord]] = None, virtual_packages: Optional[List[GenericVirtualPackage]] = None, + strict_channel_priority: bool = True, ) -> List[RepoDataRecord]: """ Resolve the dependencies and return the `RepoDataRecord`s @@ -38,6 +39,9 @@ def solve( will always select that version no matter what even if that means other packages have to be downgraded. virtual_packages: A list of virtual packages considered active. + strict_channel_priority: (Default = True) When `True` the channel that the package + is first found in will be used as the only channel for that package. + When `False` it will search for every package in every channel. Returns: Resolved list of `RepoDataRecord`s. @@ -54,5 +58,6 @@ def solve( v_package._generic_virtual_package for v_package in virtual_packages or [] ], + strict_channel_priority, ) ] diff --git a/py-rattler/src/repo_data/sparse.rs b/py-rattler/src/repo_data/sparse.rs index b8e5be46f..55afaa595 100644 --- a/py-rattler/src/repo_data/sparse.rs +++ b/py-rattler/src/repo_data/sparse.rs @@ -62,16 +62,20 @@ impl PySparseRepoData { py: Python<'_>, repo_data: Vec, package_names: Vec, + strict_channel_priority: bool, ) -> PyResult>> { py.allow_threads(move || { let repo_data = repo_data.iter().map(Into::into); let package_names = package_names.into_iter().map(Into::into); - Ok( - SparseRepoData::load_records_recursive(repo_data, package_names, None)? - .into_iter() - .map(|v| v.into_iter().map(Into::into).collect::>()) - .collect::>(), - ) + Ok(SparseRepoData::load_records_recursive( + repo_data, + package_names, + None, + strict_channel_priority, + )? + .into_iter() + .map(|v| v.into_iter().map(Into::into).collect::>()) + .collect::>()) }) } } diff --git a/py-rattler/src/solver.rs b/py-rattler/src/solver.rs index 380b4d213..1697b5724 100644 --- a/py-rattler/src/solver.rs +++ b/py-rattler/src/solver.rs @@ -17,6 +17,7 @@ pub fn py_solve( locked_packages: Vec, pinned_packages: Vec, virtual_packages: Vec, + strict_channel_priority: bool, ) -> PyResult> { py.allow_threads(move || { let package_names = specs @@ -27,6 +28,7 @@ pub fn py_solve( available_packages.iter().map(Into::into), package_names, None, + strict_channel_priority, )?; let task = SolverTask { diff --git a/test-data/channels/pytorch/linux-64/repodata.json b/test-data/channels/pytorch/linux-64/repodata.json new file mode 100644 index 000000000..a856bdc0a --- /dev/null +++ b/test-data/channels/pytorch/linux-64/repodata.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9aee421aedb89c060bc81ef89ff9f5c05fff41be0ef19d55c75966781875524 +size 1273854