Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add strict channel priority option #385

Merged
merged 8 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/rattler-bin/src/commands/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
record.depends.push("pip".to_string());
}
}),
true,
)
})?;

Expand Down
128 changes: 98 additions & 30 deletions crates/rattler_repodata_gateway/src/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,15 @@ impl SparseRepoData {
///
/// This will parse the records for the specified packages as well as all the packages these records
/// depend on.
///
/// When strict_channel_priority is true, the channel where a package is found first will be
/// the only channel used for that package. Make it false to search in all channels for all packages.
///
pub fn load_records_recursive<'a>(
repo_data: impl IntoIterator<Item = &'a SparseRepoData>,
package_names: impl IntoIterator<Item = PackageName>,
patch_function: Option<fn(&mut PackageRecord)>,
strict_channel_priority: bool,
baszalmstra marked this conversation as resolved.
Show resolved Hide resolved
) -> io::Result<Vec<Vec<RepoDataRecord>>> {
let repo_data: Vec<_> = repo_data.into_iter().collect();

Expand All @@ -136,7 +141,13 @@ impl SparseRepoData {

// Iterate over the list of packages that still need to be processed.
while let Some(next_package) = pending.pop_front() {
let mut found_in_channel = None;
for (i, repo_data) in repo_data.iter().enumerate() {
// If package was found in other channel, skip this repodata
if found_in_channel.map_or(false, |c| c != &repo_data.channel.base_url) {
continue;
}

let repo_data_packages = repo_data.inner.borrow_repo_data();
let base_url = repo_data_packages
.info
Expand All @@ -162,6 +173,10 @@ impl SparseRepoData {
)?;
records.append(&mut conda_records);

if strict_channel_priority && !records.is_empty() {
found_in_channel = Some(&repo_data.channel.base_url);
}

// Iterate over all packages to find recursive dependencies.
for record in records.iter() {
for dependency in &record.package_record.depends {
Expand Down Expand Up @@ -259,6 +274,7 @@ pub async fn load_repo_data_recursively(
repo_data_paths: impl IntoIterator<Item = (Channel, impl Into<String>, impl AsRef<Path>)>,
package_names: impl IntoIterator<Item = PackageName>,
patch_function: Option<fn(&mut PackageRecord)>,
strict_channel_priority: bool,
) -> Result<Vec<Vec<RepoDataRecord>>, io::Error> {
// Open the different files and memory map them to get access to their bytes. Do this in parallel.
let lazy_repo_data = stream::iter(repo_data_paths)
Expand All @@ -277,7 +293,12 @@ pub async fn load_repo_data_recursively(
.try_collect::<Vec<_>>()
.await?;

SparseRepoData::load_records_recursive(&lazy_repo_data, package_names, patch_function)
SparseRepoData::load_records_recursive(
&lazy_repo_data,
package_names,
patch_function,
strict_channel_priority,
)
}

fn deserialize_filename_and_raw_record<'d, D: Deserializer<'d>>(
Expand Down Expand Up @@ -380,6 +401,7 @@ impl<'de> TryFrom<&'de str> for PackageFilename<'de> {
#[cfg(test)]
mod test {
use super::{load_repo_data_recursively, PackageFilename};
use itertools::Itertools;
use rattler_conda_types::{Channel, ChannelConfig, PackageName, RepoData, RepoDataRecord};
use rstest::rstest;
use std::path::{Path, PathBuf};
Expand All @@ -390,6 +412,7 @@ mod test {

async fn load_sparse(
package_names: impl IntoIterator<Item = impl AsRef<str>>,
strict_channel_priority: bool,
) -> Vec<Vec<RepoDataRecord>> {
load_repo_data_recursively(
[
Expand All @@ -403,25 +426,31 @@ mod test {
"linux-64",
test_dir().join("channels/conda-forge/linux-64/repodata.json"),
),
(
Channel::from_str("pytorch", &ChannelConfig::default()).unwrap(),
"linux-64",
test_dir().join("channels/pytorch/linux-64/repodata.json"),
),
],
package_names
.into_iter()
.map(|name| PackageName::try_from(name.as_ref()).unwrap()),
None,
strict_channel_priority,
)
.await
.unwrap()
}

#[tokio::test]
async fn test_empty_sparse_load() {
let sparse_empty_data = load_sparse(Vec::<String>::new()).await;
assert_eq!(sparse_empty_data, vec![vec![], vec![]]);
let sparse_empty_data = load_sparse(Vec::<String>::new(), false).await;
assert_eq!(sparse_empty_data, vec![vec![], vec![], vec![]]);
}

#[tokio::test]
async fn test_sparse_single() {
let sparse_empty_data = load_sparse(["_libgcc_mutex"]).await;
let sparse_empty_data = load_sparse(["_libgcc_mutex"], false).await;
let total_records = sparse_empty_data
.iter()
.map(|repo| repo.len())
Expand All @@ -430,9 +459,45 @@ mod test {
assert_eq!(total_records, 3);
}

#[tokio::test]
async fn test_sparse_strict() {
// If we load pytorch-cpy from all channels (non-strict) we expect records from both
// conda-forge and the pytorch channels.
let sparse_data = load_sparse(["pytorch-cpu"], false).await;
let channels = sparse_data
.into_iter()
.flatten()
.filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu")
.map(|record| record.channel)
.unique()
.collect_vec();
assert_eq!(
channels,
vec![
String::from("https://conda.anaconda.org/conda-forge/"),
String::from("https://conda.anaconda.org/pytorch/")
]
);

// If we load pytorch-cpy from strict channels we expect records only from the first channel
// that contains the package. In this case we expect records only from conda-forge.
let strict_sparse_data = load_sparse(["pytorch-cpu"], true).await;
let channels = strict_sparse_data
.into_iter()
.flatten()
.filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu")
.map(|record| record.channel)
.unique()
.collect_vec();
assert_eq!(
channels,
vec![String::from("https://conda.anaconda.org/conda-forge/")]
);
}

#[tokio::test]
async fn test_parse_duplicate() {
let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"]).await;
let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"], false).await;
let total_records = sparse_empty_data
.iter()
.map(|repo| repo.len())
Expand All @@ -444,7 +509,7 @@ mod test {

#[tokio::test]
async fn test_sparse_jupyterlab_detectron2() {
let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"]).await;
let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"], true).await;

let total_records = sparse_empty_data
.iter()
Expand All @@ -456,30 +521,33 @@ mod test {

#[tokio::test]
async fn test_sparse_numpy_dev() {
let sparse_empty_data = load_sparse([
"python",
"cython",
"compilers",
"openblas",
"nomkl",
"pytest",
"pytest-cov",
"pytest-xdist",
"hypothesis",
"mypy",
"typing_extensions",
"sphinx",
"numpydoc",
"ipython",
"scipy",
"pandas",
"matplotlib",
"pydata-sphinx-theme",
"pycodestyle",
"gitpython",
"cffi",
"pytz",
])
let sparse_empty_data = load_sparse(
[
"python",
"cython",
"compilers",
"openblas",
"nomkl",
"pytest",
"pytest-cov",
"pytest-xdist",
"hypothesis",
"mypy",
"typing_extensions",
"sphinx",
"numpydoc",
"ipython",
"scipy",
"pandas",
"matplotlib",
"pydata-sphinx-theme",
"pycodestyle",
"gitpython",
"cffi",
"pytz",
],
false,
)
.await;

let total_records = sparse_empty_data
Expand Down
2 changes: 1 addition & 1 deletion crates/rattler_solve/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) {

let names = specs.iter().map(|s| s.name.clone().unwrap());
let available_packages =
SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None).unwrap();
SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None, true).unwrap();

#[cfg(feature = "libsolv_c")]
group.bench_function("libsolv_c", |b| {
Expand Down
4 changes: 2 additions & 2 deletions crates/rattler_solve/tests/backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ fn solve_real_world<T: SolverImpl + Default>(specs: Vec<&str>) -> Vec<String> {

let names = specs.iter().filter_map(|s| s.name.as_ref().cloned());
let available_packages =
SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap();
SparseRepoData::load_records_recursive(sparse_repo_datas, names, None, true).unwrap();

let solver_task = SolverTask {
available_packages: &available_packages,
Expand Down Expand Up @@ -592,7 +592,7 @@ fn compare_solve(specs: Vec<&str>) {

let names = specs.iter().filter_map(|s| s.name.as_ref().cloned());
let available_packages =
SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap();
SparseRepoData::load_records_recursive(sparse_repo_datas, names, None, true).unwrap();

let extract_pkgs = |records: Vec<RepoDataRecord>| {
let mut pkgs = records
Expand Down
20 changes: 10 additions & 10 deletions py-rattler/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions py-rattler/rattler/repo_data/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def subdir(self) -> str:
def load_records_recursive(
repo_data: List[SparseRepoData],
package_names: List[PackageName],
strict_channel_priority: bool = True,
) -> List[List[RepoDataRecord]]:
"""
Given a set of [`SparseRepoData`]s load all the records
Expand Down Expand Up @@ -142,6 +143,7 @@ def load_records_recursive(
for list_of_records in PySparseRepoData.load_records_recursive(
[r._sparse for r in repo_data],
[p._name for p in package_names],
strict_channel_priority,
)
]

Expand Down
5 changes: 5 additions & 0 deletions py-rattler/rattler/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def solve(
locked_packages: Optional[List[RepoDataRecord]] = None,
pinned_packages: Optional[List[RepoDataRecord]] = None,
virtual_packages: Optional[List[GenericVirtualPackage]] = None,
strict_channel_priority: bool = True,
) -> List[RepoDataRecord]:
"""
Resolve the dependencies and return the `RepoDataRecord`s
Expand All @@ -38,6 +39,9 @@ def solve(
will always select that version no matter what even if that
means other packages have to be downgraded.
virtual_packages: A list of virtual packages considered active.
strict_channel_priority: (Default = True) When `True` the channel that the package
is first found in will be used as the only channel for that package.
When `False` it will search for every package in every channel.

Returns:
Resolved list of `RepoDataRecord`s.
Expand All @@ -54,5 +58,6 @@ def solve(
v_package._generic_virtual_package
for v_package in virtual_packages or []
],
strict_channel_priority,
)
]
Loading