diff --git a/crates/rattler-bin/src/commands/create.rs b/crates/rattler-bin/src/commands/create.rs index 2afa7f118..6e2263d02 100644 --- a/crates/rattler-bin/src/commands/create.rs +++ b/crates/rattler-bin/src/commands/create.rs @@ -7,8 +7,8 @@ use rattler::{ package_cache::PackageCache, }; use rattler_conda_types::{ - Channel, ChannelConfig, GenericVirtualPackage, MatchSpec, Platform, PrefixRecord, - RepoDataRecord, Version, + Channel, ChannelConfig, GenericVirtualPackage, MatchSpec, PackageRecord, Platform, + PrefixRecord, RepoDataRecord, Version, }; use rattler_networking::{AuthenticatedClient, AuthenticationStorage}; use rattler_repodata_gateway::fetch::{ @@ -158,7 +158,15 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { // Get the package names from the matchspecs so we can only load the package records that we need. let package_names = specs.iter().filter_map(|spec| spec.name.as_ref()); let repodatas = wrap_in_progress("parsing repodata", move || { - SparseRepoData::load_records_recursive(&sparse_repo_datas, package_names) + SparseRepoData::load_records_recursive( + &sparse_repo_datas, + package_names, + Some(|record| { + if record.name == "python" { + record.depends.push("pip".to_string()); + } + }), + ) })?; // Determine virtual packages of the system. These packages define the capabilities of the @@ -220,6 +228,9 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { } })?; + // sort topologically + let required_packages = PackageRecord::sort_topologically(required_packages); + // Construct a transaction to let transaction = Transaction::from_current_and_desired( installed_packages, @@ -604,7 +615,16 @@ async fn fetch_repo_data_records_with_progress( // task. let repo_data_json_path = result.repo_data_json_path.clone(); match tokio::task::spawn_blocking(move || { - SparseRepoData::new(channel, platform.to_string(), repo_data_json_path) + SparseRepoData::new( + channel, + platform.to_string(), + repo_data_json_path, + Some(|record: &mut PackageRecord| { + if record.name == "python" { + record.depends.push("pip".to_string()); + } + }), + ) }) .await { diff --git a/crates/rattler/src/install/transaction.rs b/crates/rattler/src/install/transaction.rs index a233ad5dd..6ebec560c 100644 --- a/crates/rattler/src/install/transaction.rs +++ b/crates/rattler/src/install/transaction.rs @@ -1,7 +1,8 @@ +use std::collections::HashSet; + use crate::install::python::PythonInfoError; use crate::install::PythonInfo; use rattler_conda_types::{PackageRecord, Platform}; -use std::collections::HashMap; #[derive(Debug, thiserror::Error)] pub enum TransactionError { @@ -86,49 +87,60 @@ impl, New: AsRef> Transaction CurIter::IntoIter: Clone, NewIter::IntoIter: Clone, { - let current = current.into_iter(); - let desired = desired.into_iter(); + let current_iter = current.into_iter(); + let desired_iter = desired.into_iter(); // Determine the python version used in the current situation. - let current_python_info = find_python_info(current.clone(), platform)?; - let desired_python_info = find_python_info(desired.clone(), platform)?; + let current_python_info = find_python_info(current_iter.clone(), platform)?; + let desired_python_info = find_python_info(desired_iter.clone(), platform)?; let needs_python_relink = match (¤t_python_info, &desired_python_info) { (Some(current), Some(desired)) => desired.is_relink_required(current), _ => false, }; - // Create a lookup table by name for the desired packages. - let mut desired: HashMap = desired - .into_iter() - .map(|record| (record.as_ref().name.clone(), record)) - .collect(); - let mut operations = Vec::new(); - // Find all the elements that are no longer in the desired set - for record in current { - match desired.remove(&record.as_ref().name) { - None => operations.push(TransactionOperation::Remove(record)), - Some(desired) => { - // If the desired differs from the current it has to be updated. - if !describe_same_content(desired.as_ref(), record.as_ref()) { - operations.push(TransactionOperation::Change { - old: record, - new: desired, - }) - } - // If this is a noarch package and all python packages need to be relinked, - // reinstall the package completely. - else if desired.as_ref().noarch.is_python() && needs_python_relink { - operations.push(TransactionOperation::Reinstall(record)); - } - } + let mut current_map = current_iter + .clone() + .map(|r| (r.as_ref().name.clone(), r)) + .collect::>(); + + let desired_names = desired_iter + .clone() + .map(|r| r.as_ref().name.clone()) + .collect::>(); + + // Remove all current packages that are not in desired (but keep order of current) + for record in current_iter { + if !desired_names.contains(&record.as_ref().name) { + operations.push(TransactionOperation::Remove(record)); } } - // The remaining packages from the desired list need to be explicitly installed. - for record in desired.into_values() { - operations.push(TransactionOperation::Install(record)) + // reverse all removals, last in first out + operations.reverse(); + + // Figure out the operations to perform, but keep the order of the original "desired" iterator + for record in desired_iter { + let name = &record.as_ref().name; + let old_record = current_map.remove(name); + + if let Some(old_record) = old_record { + if !describe_same_content(record.as_ref(), old_record.as_ref()) { + // if the content changed, we need to reinstall (remove and install) + operations.push(TransactionOperation::Change { + old: old_record, + new: record, + }); + } else if needs_python_relink { + // when the python version changed, we need to relink all noarch packages + // to recompile the bytecode + operations.push(TransactionOperation::Reinstall(old_record)); + } + // if the content is the same, we dont need to do anything + } else { + operations.push(TransactionOperation::Install(record)); + } } Ok(Self { diff --git a/crates/rattler_conda_types/src/repo_data/mod.rs b/crates/rattler_conda_types/src/repo_data/mod.rs index a789d0165..f3909e736 100644 --- a/crates/rattler_conda_types/src/repo_data/mod.rs +++ b/crates/rattler_conda_types/src/repo_data/mod.rs @@ -213,7 +213,7 @@ impl PackageRecord { /// the order of `records` and of the `depends` vector inside the records. /// /// Note that this function only works for packages with unique names. - pub fn sort_topologically>(records: Vec) -> Vec { + pub fn sort_topologically + Clone>(records: Vec) -> Vec { topological_sort::sort_topologically(records) } } diff --git a/crates/rattler_conda_types/src/repo_data/topological_sort.rs b/crates/rattler_conda_types/src/repo_data/topological_sort.rs index 2a61d0f91..eb6560373 100644 --- a/crates/rattler_conda_types/src/repo_data/topological_sort.rs +++ b/crates/rattler_conda_types/src/repo_data/topological_sort.rs @@ -6,20 +6,56 @@ use fxhash::{FxHashMap, FxHashSet}; /// This function is deterministic, meaning that it will return the same result regardless of the /// order of `packages` and of the `depends` vector inside the records. /// +/// If cycles are encountered, and one of the packages in the cycle is noarch, the noarch package +/// is sorted _after_ the other packages in the cycle. This is done to ensure that the noarch +/// package is installed last, so that it can be linked correctly (ie. compiled with Python if +/// necessary). +/// /// Note that this function only works for packages with unique names. -pub fn sort_topologically>(packages: Vec) -> Vec { - let roots = get_graph_roots(&packages); +pub fn sort_topologically + Clone>(packages: Vec) -> Vec { + let roots = get_graph_roots(&packages, None); let mut all_packages = packages - .into_iter() + .iter() + .cloned() .map(|p| (p.as_ref().name.clone(), p)) .collect(); - get_topological_order(roots, &mut all_packages) + // Detect cycles + let mut visited = FxHashSet::default(); + let mut stack = Vec::new(); + let mut cycles = Vec::new(); + + for root in &roots { + if !visited.contains(root) { + if let Some(cycle) = find_cycles(root, &all_packages, &mut visited, &mut stack) { + cycles.push(cycle); + } + } + } + + // print all cycles + for cycle in &cycles { + tracing::debug!("Found cycle: {:?}", cycle); + } + + // Break cycles + let cycle_breaks = break_cycles(cycles, &all_packages); + + // obtain the new roots (packages that are not dependencies of any other package) + // this is needed because breaking cycles can create new roots + let roots = get_graph_roots(&packages, Some(&cycle_breaks)); + + get_topological_order(roots, &mut all_packages, &cycle_breaks) } -/// Retrieves the names of the packages that form the roots of the graph -fn get_graph_roots>(records: &[T]) -> Vec { +/// Retrieves the names of the packages that form the roots of the graph and breaks specified +/// cycles (e.g. if there is a cycle between A and B and there is a cycle_break (A, B), the edge +/// A -> B will be removed) +fn get_graph_roots>( + records: &[T], + cycle_breaks: Option<&FxHashSet<(String, String)>>, +) -> Vec { let all_packages: FxHashSet<_> = records.iter().map(|r| r.as_ref().name.as_str()).collect(); let dependencies: FxHashSet<_> = records @@ -29,6 +65,14 @@ fn get_graph_roots>(records: &[T]) -> Vec { .depends .iter() .map(|d| package_name_from_match_spec(d)) + .filter(|d| { + // filter out circular dependencies + if let Some(cycle_breaks) = cycle_breaks { + !cycle_breaks.contains(&(r.as_ref().name.clone(), d.to_string())) + } else { + true + } + }) }) .collect(); @@ -45,11 +89,77 @@ enum Action { Install(String), } +/// Find cycles with DFS +fn find_cycles>( + node: &str, + packages: &FxHashMap, + visited: &mut FxHashSet, + stack: &mut Vec, +) -> Option> { + visited.insert(node.to_string()); + stack.push(node.to_string()); + + if let Some(package) = packages.get(node) { + for dependency in &package.as_ref().depends { + let dep_name = package_name_from_match_spec(dependency); + + if !visited.contains(dep_name) { + if let Some(cycle) = find_cycles(dep_name, packages, visited, stack) { + return Some(cycle); + } + } else if stack.contains(&dep_name.to_string()) { + // Cycle detected. We clone the part of the stack that forms the cycle. + if let Some(pos) = stack.iter().position(|x| *x == dep_name) { + return Some(stack[pos..].to_vec()); + } + } + } + } + + stack.pop(); + None +} + +/// Breaks cycles by removing the edges that form them +/// Edges from arch to noarch packages are removed to break the cycles. +fn break_cycles>( + cycles: Vec>, + packages: &FxHashMap, +) -> FxHashSet<(String, String)> { + // we record the edges that we want to remove + let mut cycle_breaks = FxHashSet::default(); + + for cycle in cycles { + for i in 0..cycle.len() { + let pi1 = &cycle[i]; + // Next package in cycle, wraps around + let pi2 = &cycle[(i + 1) % cycle.len()]; + + let p1 = &packages[pi1]; + let p2 = &packages[pi2]; + + // prefer arch packages over noarch packages + let p1_noarch = p1.as_ref().noarch.is_none(); + let p2_noarch = p2.as_ref().noarch.is_none(); + if p1_noarch && !p2_noarch { + cycle_breaks.insert((pi1.clone(), pi2.clone())); + break; + } else if !p1_noarch && p2_noarch { + cycle_breaks.insert((pi2.clone(), pi1.clone())); + break; + } + } + } + tracing::debug!("Breaking cycle: {:?}", cycle_breaks); + cycle_breaks +} + /// Returns a vector containing the topological ordering of the packages, based on the provided /// roots fn get_topological_order>( mut roots: Vec, packages: &mut FxHashMap, + cycle_breaks: &FxHashSet<(String, String)>, ) -> Vec { // Sorting makes this step deterministic (i.e. the same output is returned, regardless of the // original order of the input) @@ -83,6 +193,9 @@ fn get_topological_order>( } }; + // Remove the edges that form cycles + deps.retain(|dep| !cycle_breaks.contains(&(package_name.clone(), dep.clone()))); + // Sorting makes this step deterministic (i.e. the same output is returned, regardless of the // original order of the input) deps.sort(); @@ -204,13 +317,14 @@ mod tests { #[rstest] #[case(get_resolved_packages_for_python(), &["python"])] + #[case(get_resolved_packages_for_python_pip(), &["pip"])] #[case(get_resolved_packages_for_numpy(), &["numpy"])] #[case(get_resolved_packages_for_two_roots(), &["4ti2", "micromamba"])] fn test_get_graph_roots( #[case] packages: Vec, #[case] expected_roots: &[&str], ) { - let mut roots = get_graph_roots(&packages); + let mut roots = get_graph_roots(&packages, None); roots.sort(); assert_eq!(roots.as_slice(), expected_roots); } @@ -219,6 +333,7 @@ mod tests { #[case(get_resolved_packages_for_python(), "python", &[("libzlib", "libgcc-ng")])] #[case(get_resolved_packages_for_numpy(), "numpy", &[("llvm-openmp", "libzlib")])] #[case(get_resolved_packages_for_two_roots(), "4ti2", &[("libzlib", "libgcc-ng")])] + #[case(get_resolved_packages_for_python_pip(), "pip", &[("pip", "python"), ("libzlib", "libgcc-ng")])] fn test_topological_sort( #[case] packages: Vec, #[case] expected_last_package: &str, @@ -230,7 +345,7 @@ mod tests { sanity_check_topological_sort(&sorted_packages, &packages); simulate_install(&sorted_packages, &circular_deps); - // Sanity check: the last package should be python + // Sanity check: the last package should be python (or pip when it is present) let last_package = &sorted_packages[sorted_packages.len() - 1]; assert_eq!(last_package.package_record.name, expected_last_package) } @@ -1622,4 +1737,91 @@ mod tests { serde_json::from_str(repodata_json).unwrap() } + + fn get_resolved_packages_for_python_pip() -> Vec { + let pip = r#" + [ + { + "arch": null, + "build": "pyhd8ed1ab_0", + "build_number": 0, + "build_string": "pyhd8ed1ab_0", + "channel": "https://conda.anaconda.org/conda-forge/noarch", + "constrains": [], + "depends": [ + "python >=3.7", + "setuptools", + "wheel" + ], + "fn": "pip-23.1.2-pyhd8ed1ab_0.conda", + "license": "MIT", + "license_family": "MIT", + "md5": "7288da0d36821349cf1126e8670292df", + "name": "pip", + "noarch": "python", + "platform": null, + "sha256": "4fe1f47f6eac5b2635a622b6f985640bf835843c1d8d7ccbbae0f7d27cadec92", + "size": 1367644, + "subdir": "noarch", + "timestamp": 1682507713321, + "track_features": "", + "url": "https://conda.anaconda.org/conda-forge/noarch/pip-23.1.2-pyhd8ed1ab_0.conda", + "version": "23.1.2" + }, + { + "arch": null, + "build": "pyhd8ed1ab_0", + "build_number": 0, + "build_string": "pyhd8ed1ab_0", + "channel": "https://conda.anaconda.org/conda-forge/noarch", + "constrains": [], + "depends": [ + "python >=3.7" + ], + "fn": "wheel-0.40.0-pyhd8ed1ab_0.conda", + "license": "MIT", + "md5": "49bb0d9e60ce1db25e151780331bb5f3", + "name": "wheel", + "noarch": "python", + "platform": null, + "sha256": "79b4d29b0c004014a2abd5fc2c9fcd35cc6256222b960c2a317a27c4b0d8884d", + "size": 55729, + "subdir": "noarch", + "timestamp": 1678812153506, + "track_features": "", + "url": "https://conda.anaconda.org/conda-forge/noarch/wheel-0.40.0-pyhd8ed1ab_0.conda", + "version": "0.40.0" + }, + { + "arch": null, + "build": "pyhd8ed1ab_0", + "build_number": 0, + "build_string": "pyhd8ed1ab_0", + "channel": "https://conda.anaconda.org/conda-forge/noarch", + "constrains": [], + "depends": [ + "python >=3.7" + ], + "fn": "setuptools-68.0.0-pyhd8ed1ab_0.conda", + "license": "MIT", + "license_family": "MIT", + "md5": "5a7739d0f57ee64133c9d32e6507c46d", + "name": "setuptools", + "noarch": "python", + "platform": null, + "sha256": "083a0913f5b56644051f31ac40b4eeea762a88c00aa12437817191b85a753cec", + "size": 463712, + "subdir": "noarch", + "timestamp": 1687527994911, + "track_features": "", + "url": "https://conda.anaconda.org/conda-forge/noarch/setuptools-68.0.0-pyhd8ed1ab_0.conda", + "version": "68.0.0" + } + ]"#; + + let mut python = get_resolved_packages_for_python(); + let pip: Vec = serde_json::from_str(pip).unwrap(); + python.extend(pip); + return python; + } } diff --git a/crates/rattler_repodata_gateway/src/sparse/mod.rs b/crates/rattler_repodata_gateway/src/sparse/mod.rs index 37d38adf1..58aa4d318 100644 --- a/crates/rattler_repodata_gateway/src/sparse/mod.rs +++ b/crates/rattler_repodata_gateway/src/sparse/mod.rs @@ -29,6 +29,10 @@ pub struct SparseRepoData { /// The subdirectory from where the repodata is downloaded subdir: String, + + /// A function that can be used to patch the package record after it has been parsed. + /// This is mainly used to add `pip` to `python` if desired + patch_record_fn: Option, } /// A struct that holds a memory map of a `repodata.json` file and also a self-referential field which @@ -47,10 +51,13 @@ struct SparseRepoDataInner { impl SparseRepoData { /// Construct an instance of self from a file on disk and a [`Channel`]. + /// The `patch_function` can be used to patch the package record after it has been parsed + /// (e.g. to add `pip` to `python`). pub fn new( channel: Channel, subdir: impl Into, path: impl AsRef, + patch_function: Option, ) -> Result { let file = std::fs::File::open(path)?; let memory_map = unsafe { memmap2::Mmap::map(&file) }?; @@ -62,6 +69,7 @@ impl SparseRepoData { .try_build()?, subdir: subdir.into(), channel, + patch_record_fn: patch_function, }) } @@ -87,12 +95,14 @@ impl SparseRepoData { &repo_data.packages, &self.channel, &self.subdir, + self.patch_record_fn, )?; let mut conda_records = parse_records( package_name, &repo_data.conda_packages, &self.channel, &self.subdir, + self.patch_record_fn, )?; records.append(&mut conda_records); Ok(records) @@ -106,6 +116,7 @@ impl SparseRepoData { pub fn load_records_recursive<'a>( repo_data: impl IntoIterator, package_names: impl IntoIterator>, + patch_function: Option, ) -> io::Result>> { let repo_data: Vec<_> = repo_data.into_iter().collect(); @@ -130,12 +141,14 @@ impl SparseRepoData { &repo_data_packages.packages, &repo_data.channel, &repo_data.subdir, + patch_function, )?; let mut conda_records = parse_records( &next_package, &repo_data_packages.conda_packages, &repo_data.channel, &repo_data.subdir, + patch_function, )?; records.append(&mut conda_records); @@ -185,6 +198,7 @@ fn parse_records<'i>( packages: &[(PackageFilename<'i>, &'i RawValue)], channel: &Channel, subdir: &str, + patch_function: Option, ) -> io::Result> { let channel_name = channel.canonical_name(); @@ -206,30 +220,44 @@ fn parse_records<'i>( file_name: key.filename.to_owned(), }); } + + // Apply the patch function if one was specified + if let Some(patch_fn) = patch_function { + for record in &mut result { + patch_fn(&mut record.package_record); + } + } + Ok(result) } /// A helper function that immediately loads the records for the given packages (and their dependencies). +/// Records for the specified packages are loaded from the repodata files. +/// The patch_record_fn is applied to each record after it has been parsed and can mutate the record after +/// it has been loaded. pub async fn load_repo_data_recursively( repo_data_paths: impl IntoIterator, impl AsRef)>, package_names: impl IntoIterator>, + patch_function: Option, ) -> Result>, io::Error> { // Open the different files and memory map them to get access to their bytes. Do this in parallel. let lazy_repo_data = stream::iter(repo_data_paths) .map(|(channel, subdir, path)| { let path = path.as_ref().to_path_buf(); let subdir = subdir.into(); - tokio::task::spawn_blocking(move || SparseRepoData::new(channel, subdir, path)) - .unwrap_or_else(|r| match r.try_into_panic() { - Ok(panic) => std::panic::resume_unwind(panic), - Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string())), - }) + tokio::task::spawn_blocking(move || { + SparseRepoData::new(channel, subdir, path, patch_function) + }) + .unwrap_or_else(|r| match r.try_into_panic() { + Ok(panic) => std::panic::resume_unwind(panic), + Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string())), + }) }) .buffered(50) .try_collect::>() .await?; - SparseRepoData::load_records_recursive(&lazy_repo_data, package_names) + SparseRepoData::load_records_recursive(&lazy_repo_data, package_names, patch_function) } fn deserialize_filename_and_raw_record<'d, D: Deserializer<'d>>( @@ -357,6 +385,7 @@ mod test { ), ], package_names, + None, ) .await .unwrap() diff --git a/crates/rattler_solve/tests/backends.rs b/crates/rattler_solve/tests/backends.rs index 312af901c..a3ee03663 100644 --- a/crates/rattler_solve/tests/backends.rs +++ b/crates/rattler_solve/tests/backends.rs @@ -57,6 +57,7 @@ fn read_sparse_repodata(path: &str) -> SparseRepoData { Channel::from_str("dummy", &ChannelConfig::default()).unwrap(), "dummy".to_string(), path, + None, ) .unwrap() } @@ -108,7 +109,7 @@ fn solve_real_world(specs: Vec<&str>) -> Vec { let names = specs.iter().map(|s| s.name.clone().unwrap()); let available_packages = - SparseRepoData::load_records_recursive(sparse_repo_datas, names).unwrap(); + SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap(); let solver_task = SolverTask { available_packages: &available_packages, @@ -511,7 +512,7 @@ fn compare_solve(specs: Vec<&str>) { let names = specs.iter().filter_map(|s| s.name.clone()); let available_packages = - SparseRepoData::load_records_recursive(sparse_repo_datas, names).unwrap(); + SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap(); let extract_pkgs = |records: Vec| { let mut pkgs = records