Skip to content

Commit

Permalink
feat: add strict channel priority option (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
ruben-arts authored Oct 21, 2023
1 parent a70f195 commit a39a35d
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 49 deletions.
1 change: 1 addition & 0 deletions crates/rattler-bin/src/commands/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
record.depends.push("pip".to_string());
}
}),
true,
)
})?;

Expand Down
128 changes: 98 additions & 30 deletions crates/rattler_repodata_gateway/src/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,15 @@ impl SparseRepoData {
///
/// This will parse the records for the specified packages as well as all the packages these records
/// depend on.
///
/// When strict_channel_priority is true, the channel where a package is found first will be
/// the only channel used for that package. Make it false to search in all channels for all packages.
///
pub fn load_records_recursive<'a>(
repo_data: impl IntoIterator<Item = &'a SparseRepoData>,
package_names: impl IntoIterator<Item = PackageName>,
patch_function: Option<fn(&mut PackageRecord)>,
strict_channel_priority: bool,
) -> io::Result<Vec<Vec<RepoDataRecord>>> {
let repo_data: Vec<_> = repo_data.into_iter().collect();

Expand All @@ -136,7 +141,13 @@ impl SparseRepoData {

// Iterate over the list of packages that still need to be processed.
while let Some(next_package) = pending.pop_front() {
let mut found_in_channel = None;
for (i, repo_data) in repo_data.iter().enumerate() {
// If package was found in other channel, skip this repodata
if found_in_channel.map_or(false, |c| c != &repo_data.channel.base_url) {
continue;
}

let repo_data_packages = repo_data.inner.borrow_repo_data();
let base_url = repo_data_packages
.info
Expand All @@ -162,6 +173,10 @@ impl SparseRepoData {
)?;
records.append(&mut conda_records);

if strict_channel_priority && !records.is_empty() {
found_in_channel = Some(&repo_data.channel.base_url);
}

// Iterate over all packages to find recursive dependencies.
for record in records.iter() {
for dependency in &record.package_record.depends {
Expand Down Expand Up @@ -259,6 +274,7 @@ pub async fn load_repo_data_recursively(
repo_data_paths: impl IntoIterator<Item = (Channel, impl Into<String>, impl AsRef<Path>)>,
package_names: impl IntoIterator<Item = PackageName>,
patch_function: Option<fn(&mut PackageRecord)>,
strict_channel_priority: bool,
) -> Result<Vec<Vec<RepoDataRecord>>, io::Error> {
// Open the different files and memory map them to get access to their bytes. Do this in parallel.
let lazy_repo_data = stream::iter(repo_data_paths)
Expand All @@ -277,7 +293,12 @@ pub async fn load_repo_data_recursively(
.try_collect::<Vec<_>>()
.await?;

SparseRepoData::load_records_recursive(&lazy_repo_data, package_names, patch_function)
SparseRepoData::load_records_recursive(
&lazy_repo_data,
package_names,
patch_function,
strict_channel_priority,
)
}

fn deserialize_filename_and_raw_record<'d, D: Deserializer<'d>>(
Expand Down Expand Up @@ -380,6 +401,7 @@ impl<'de> TryFrom<&'de str> for PackageFilename<'de> {
#[cfg(test)]
mod test {
use super::{load_repo_data_recursively, PackageFilename};
use itertools::Itertools;
use rattler_conda_types::{Channel, ChannelConfig, PackageName, RepoData, RepoDataRecord};
use rstest::rstest;
use std::path::{Path, PathBuf};
Expand All @@ -390,6 +412,7 @@ mod test {

async fn load_sparse(
package_names: impl IntoIterator<Item = impl AsRef<str>>,
strict_channel_priority: bool,
) -> Vec<Vec<RepoDataRecord>> {
load_repo_data_recursively(
[
Expand All @@ -403,25 +426,31 @@ mod test {
"linux-64",
test_dir().join("channels/conda-forge/linux-64/repodata.json"),
),
(
Channel::from_str("pytorch", &ChannelConfig::default()).unwrap(),
"linux-64",
test_dir().join("channels/pytorch/linux-64/repodata.json"),
),
],
package_names
.into_iter()
.map(|name| PackageName::try_from(name.as_ref()).unwrap()),
None,
strict_channel_priority,
)
.await
.unwrap()
}

#[tokio::test]
async fn test_empty_sparse_load() {
let sparse_empty_data = load_sparse(Vec::<String>::new()).await;
assert_eq!(sparse_empty_data, vec![vec![], vec![]]);
let sparse_empty_data = load_sparse(Vec::<String>::new(), false).await;
assert_eq!(sparse_empty_data, vec![vec![], vec![], vec![]]);
}

#[tokio::test]
async fn test_sparse_single() {
let sparse_empty_data = load_sparse(["_libgcc_mutex"]).await;
let sparse_empty_data = load_sparse(["_libgcc_mutex"], false).await;
let total_records = sparse_empty_data
.iter()
.map(|repo| repo.len())
Expand All @@ -430,9 +459,45 @@ mod test {
assert_eq!(total_records, 3);
}

#[tokio::test]
async fn test_sparse_strict() {
// If we load pytorch-cpy from all channels (non-strict) we expect records from both
// conda-forge and the pytorch channels.
let sparse_data = load_sparse(["pytorch-cpu"], false).await;
let channels = sparse_data
.into_iter()
.flatten()
.filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu")
.map(|record| record.channel)
.unique()
.collect_vec();
assert_eq!(
channels,
vec![
String::from("https://conda.anaconda.org/conda-forge/"),
String::from("https://conda.anaconda.org/pytorch/")
]
);

// If we load pytorch-cpy from strict channels we expect records only from the first channel
// that contains the package. In this case we expect records only from conda-forge.
let strict_sparse_data = load_sparse(["pytorch-cpu"], true).await;
let channels = strict_sparse_data
.into_iter()
.flatten()
.filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu")
.map(|record| record.channel)
.unique()
.collect_vec();
assert_eq!(
channels,
vec![String::from("https://conda.anaconda.org/conda-forge/")]
);
}

#[tokio::test]
async fn test_parse_duplicate() {
let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"]).await;
let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"], false).await;
let total_records = sparse_empty_data
.iter()
.map(|repo| repo.len())
Expand All @@ -444,7 +509,7 @@ mod test {

#[tokio::test]
async fn test_sparse_jupyterlab_detectron2() {
let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"]).await;
let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"], true).await;

let total_records = sparse_empty_data
.iter()
Expand All @@ -456,30 +521,33 @@ mod test {

#[tokio::test]
async fn test_sparse_numpy_dev() {
let sparse_empty_data = load_sparse([
"python",
"cython",
"compilers",
"openblas",
"nomkl",
"pytest",
"pytest-cov",
"pytest-xdist",
"hypothesis",
"mypy",
"typing_extensions",
"sphinx",
"numpydoc",
"ipython",
"scipy",
"pandas",
"matplotlib",
"pydata-sphinx-theme",
"pycodestyle",
"gitpython",
"cffi",
"pytz",
])
let sparse_empty_data = load_sparse(
[
"python",
"cython",
"compilers",
"openblas",
"nomkl",
"pytest",
"pytest-cov",
"pytest-xdist",
"hypothesis",
"mypy",
"typing_extensions",
"sphinx",
"numpydoc",
"ipython",
"scipy",
"pandas",
"matplotlib",
"pydata-sphinx-theme",
"pycodestyle",
"gitpython",
"cffi",
"pytz",
],
false,
)
.await;

let total_records = sparse_empty_data
Expand Down
2 changes: 1 addition & 1 deletion crates/rattler_solve/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) {

let names = specs.iter().map(|s| s.name.clone().unwrap());
let available_packages =
SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None).unwrap();
SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None, true).unwrap();

#[cfg(feature = "libsolv_c")]
group.bench_function("libsolv_c", |b| {
Expand Down
4 changes: 2 additions & 2 deletions crates/rattler_solve/tests/backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ fn solve_real_world<T: SolverImpl + Default>(specs: Vec<&str>) -> Vec<String> {

let names = specs.iter().filter_map(|s| s.name.as_ref().cloned());
let available_packages =
SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap();
SparseRepoData::load_records_recursive(sparse_repo_datas, names, None, true).unwrap();

let solver_task = SolverTask {
available_packages: &available_packages,
Expand Down Expand Up @@ -592,7 +592,7 @@ fn compare_solve(specs: Vec<&str>) {

let names = specs.iter().filter_map(|s| s.name.as_ref().cloned());
let available_packages =
SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap();
SparseRepoData::load_records_recursive(sparse_repo_datas, names, None, true).unwrap();

let extract_pkgs = |records: Vec<RepoDataRecord>| {
let mut pkgs = records
Expand Down
20 changes: 10 additions & 10 deletions py-rattler/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions py-rattler/rattler/repo_data/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def subdir(self) -> str:
def load_records_recursive(
repo_data: List[SparseRepoData],
package_names: List[PackageName],
strict_channel_priority: bool = True,
) -> List[List[RepoDataRecord]]:
"""
Given a set of [`SparseRepoData`]s load all the records
Expand Down Expand Up @@ -142,6 +143,7 @@ def load_records_recursive(
for list_of_records in PySparseRepoData.load_records_recursive(
[r._sparse for r in repo_data],
[p._name for p in package_names],
strict_channel_priority,
)
]

Expand Down
5 changes: 5 additions & 0 deletions py-rattler/rattler/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def solve(
locked_packages: Optional[List[RepoDataRecord]] = None,
pinned_packages: Optional[List[RepoDataRecord]] = None,
virtual_packages: Optional[List[GenericVirtualPackage]] = None,
strict_channel_priority: bool = True,
) -> List[RepoDataRecord]:
"""
Resolve the dependencies and return the `RepoDataRecord`s
Expand All @@ -38,6 +39,9 @@ def solve(
will always select that version no matter what even if that
means other packages have to be downgraded.
virtual_packages: A list of virtual packages considered active.
strict_channel_priority: (Default = True) When `True` the channel that the package
is first found in will be used as the only channel for that package.
When `False` it will search for every package in every channel.
Returns:
Resolved list of `RepoDataRecord`s.
Expand All @@ -54,5 +58,6 @@ def solve(
v_package._generic_virtual_package
for v_package in virtual_packages or []
],
strict_channel_priority,
)
]
Loading

0 comments on commit a39a35d

Please sign in to comment.