diff --git a/crates/rattler-bin/src/commands/create.rs b/crates/rattler-bin/src/commands/create.rs index fda3e2580..28f363bc5 100644 --- a/crates/rattler-bin/src/commands/create.rs +++ b/crates/rattler-bin/src/commands/create.rs @@ -21,7 +21,10 @@ use rattler_repodata_gateway::fetch::{ CacheResult, DownloadProgress, FetchRepoDataError, FetchRepoDataOptions, }; use rattler_repodata_gateway::sparse::SparseRepoData; -use rattler_solve::{libsolv_c, resolvo, SolverImpl, SolverTask}; +use rattler_solve::{ + libsolv_c::{self}, + resolvo, SolverImpl, SolverTask, +}; use reqwest::Client; use std::sync::Arc; use std::{ @@ -53,7 +56,10 @@ pub struct Opt { virtual_package: Option>, #[clap(long)] - use_experimental_libsolv_rs: bool, + use_resolvo: bool, + + #[clap(long)] + timeout: Option, } pub async fn create(opt: Opt) -> anyhow::Result<()> { @@ -221,13 +227,13 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { virtual_packages, specs, pinned_packages: Vec::new(), + timeout: opt.timeout.map(Duration::from_millis), }; // Next, use a solver to solve this specific problem. This provides us with all the operations // we need to apply to our environment to bring it up to date. - let use_libsolv_rs = opt.use_experimental_libsolv_rs; let required_packages = wrap_in_progress("solving", move || { - if use_libsolv_rs { + if opt.use_resolvo { resolvo::Solver.solve(solver_task) } else { libsolv_c::Solver.solve(solver_task) diff --git a/crates/rattler_solve/benches/bench.rs b/crates/rattler_solve/benches/bench.rs index 2ce89d258..3962dac0f 100644 --- a/crates/rattler_solve/benches/bench.rs +++ b/crates/rattler_solve/benches/bench.rs @@ -64,6 +64,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) { pinned_packages: vec![], virtual_packages: vec![], specs: specs.clone(), + timeout: None, })) .unwrap() }); @@ -79,6 +80,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) { pinned_packages: vec![], virtual_packages: vec![], specs: specs.clone(), + timeout: None, })) .unwrap() }); diff --git a/crates/rattler_solve/src/lib.rs b/crates/rattler_solve/src/lib.rs index 3c4d7f26a..591c6626c 100644 --- a/crates/rattler_solve/src/lib.rs +++ b/crates/rattler_solve/src/lib.rs @@ -99,6 +99,9 @@ pub struct SolverTask { /// The specs we want to solve pub specs: Vec, + + /// The timeout after which the solver should stop + pub timeout: Option, } /// A representation of a collection of [`RepoDataRecord`] usable by a [`SolverImpl`] diff --git a/crates/rattler_solve/src/libsolv_c/mod.rs b/crates/rattler_solve/src/libsolv_c/mod.rs index 902b8cfcc..40ef05d90 100644 --- a/crates/rattler_solve/src/libsolv_c/mod.rs +++ b/crates/rattler_solve/src/libsolv_c/mod.rs @@ -91,6 +91,12 @@ impl super::SolverImpl for Solver { &mut self, task: SolverTask, ) -> Result, SolveError> { + if task.timeout.is_some() { + return Err(SolveError::UnsupportedOperations(vec![ + "timeout".to_string() + ])); + } + // Construct a default libsolv pool let pool = Pool::default(); diff --git a/crates/rattler_solve/src/resolvo/conda_util.rs b/crates/rattler_solve/src/resolvo/conda_util.rs index 9244e3310..9e0c425b2 100644 --- a/crates/rattler_solve/src/resolvo/conda_util.rs +++ b/crates/rattler_solve/src/resolvo/conda_util.rs @@ -49,12 +49,14 @@ pub(super) fn compare_candidates<'a>( // Otherwise, compare the dependencies of the variants. If there are similar // dependencies select the variant that selects the highest version of the dependency. - let a_dependencies = solver - .get_or_cache_dependencies(a) - .expect("should not get here, resolution process aborted"); - let b_dependencies = solver - .get_or_cache_dependencies(b) - .expect("should not get here, resolution process aborted"); + let (a_dependencies, b_dependencies) = match ( + solver.get_or_cache_dependencies(a), + solver.get_or_cache_dependencies(b), + ) { + (Ok(a_deps), Ok(b_deps)) => (a_deps, b_deps), + // If either call fails, it's likely due to solver cancellation; thus, we can't compare dependencies + _ => return Ordering::Equal, + }; // If the MatchSpecs are known use these // map these into a HashMap diff --git a/crates/rattler_solve/src/resolvo/mod.rs b/crates/rattler_solve/src/resolvo/mod.rs index 0520389bd..6db1055a1 100644 --- a/crates/rattler_solve/src/resolvo/mod.rs +++ b/crates/rattler_solve/src/resolvo/mod.rs @@ -165,6 +165,8 @@ pub(crate) struct CondaDependencyProvider<'a> { RefCell>>, parse_match_spec_cache: RefCell>, + + stop_time: Option, } impl<'a> CondaDependencyProvider<'a> { @@ -174,6 +176,7 @@ impl<'a> CondaDependencyProvider<'a> { locked_records: &'a [RepoDataRecord], virtual_packages: &'a [GenericVirtualPackage], match_specs: &[MatchSpec], + stop_time: Option, ) -> Self { let pool = Pool::default(); let mut records: HashMap = HashMap::default(); @@ -333,10 +336,17 @@ impl<'a> CondaDependencyProvider<'a> { records, matchspec_to_highest_version: RefCell::default(), parse_match_spec_cache: RefCell::default(), + stop_time, } } } +/// The reason why the solver was cancelled +pub enum CancelReason { + /// The solver was cancelled because the timeout was reached + Timeout, +} + impl<'a> DependencyProvider> for CondaDependencyProvider<'a> { fn pool(&self) -> &Pool, String> { &self.pool @@ -378,6 +388,15 @@ impl<'a> DependencyProvider> for CondaDependencyProvider<'a> Dependencies::Known(dependencies) } + + fn should_cancel_with_value(&self) -> Option> { + if let Some(stop_time) = self.stop_time { + if std::time::SystemTime::now() > stop_time { + return Some(Box::new(CancelReason::Timeout)); + } + } + None + } } /// Displays the different candidates by their version and sorted by their version @@ -414,6 +433,10 @@ impl super::SolverImpl for Solver { &mut self, task: SolverTask, ) -> Result, SolveError> { + let stop_time = task + .timeout + .map(|timeout| std::time::SystemTime::now() + timeout); + // Construct a provider that can serve the data. let provider = CondaDependencyProvider::from_solver_task( task.available_packages.into_iter().map(|r| r.into()), @@ -421,6 +444,7 @@ impl super::SolverImpl for Solver { &task.pinned_packages, &task.virtual_packages, task.specs.clone().as_ref(), + stop_time, ); // Construct the requirements that the solver needs to satisfy. diff --git a/crates/rattler_solve/tests/backends.rs b/crates/rattler_solve/tests/backends.rs index 294ea655c..5f148eecc 100644 --- a/crates/rattler_solve/tests/backends.rs +++ b/crates/rattler_solve/tests/backends.rs @@ -127,6 +127,7 @@ fn solve_real_world(specs: Vec<&str>) -> Vec { locked_packages: Vec::default(), pinned_packages: Vec::default(), virtual_packages: Vec::default(), + timeout: None, }; let pkgs1 = match T::default().solve(solver_task) { @@ -527,6 +528,7 @@ mod libsolv_c { available_packages: [libsolv_repodata], specs, pinned_packages: Vec::new(), + timeout: None, }) .unwrap(); @@ -618,6 +620,7 @@ fn solve( available_packages: [&repo_data], specs, pinned_packages, + timeout: None, }; let pkgs = T::default().solve(task)?; @@ -675,6 +678,7 @@ fn compare_solve(specs: Vec<&str>) { locked_packages: Vec::default(), pinned_packages: Vec::default(), virtual_packages: Vec::default(), + timeout: None, }) .unwrap(), ), @@ -696,6 +700,7 @@ fn compare_solve(specs: Vec<&str>) { locked_packages: Vec::default(), pinned_packages: Vec::default(), virtual_packages: Vec::default(), + timeout: None, }) .unwrap(), ), @@ -766,6 +771,7 @@ fn solve_to_get_channel_of_spec( locked_packages: Vec::default(), pinned_packages: Vec::default(), virtual_packages: Vec::default(), + timeout: None, }) .unwrap(); diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py index d443c5b28..0934c2bb0 100644 --- a/py-rattler/rattler/solver/solver.py +++ b/py-rattler/rattler/solver/solver.py @@ -1,4 +1,5 @@ from __future__ import annotations +import datetime from typing import List, Optional from rattler.match_spec.match_spec import MatchSpec @@ -14,6 +15,7 @@ def solve( locked_packages: Optional[List[RepoDataRecord]] = None, pinned_packages: Optional[List[RepoDataRecord]] = None, virtual_packages: Optional[List[GenericVirtualPackage]] = None, + timeout: Optional[datetime.timedelta] = None, ) -> List[RepoDataRecord]: """ Resolve the dependencies and return the `RepoDataRecord`s @@ -54,5 +56,6 @@ def solve( v_package._generic_virtual_package for v_package in virtual_packages or [] ], + timeout.microseconds if timeout else None, ) ] diff --git a/py-rattler/src/solver.rs b/py-rattler/src/solver.rs index 644a57269..e11fd257d 100644 --- a/py-rattler/src/solver.rs +++ b/py-rattler/src/solver.rs @@ -15,6 +15,7 @@ pub fn py_solve( locked_packages: Vec, pinned_packages: Vec, virtual_packages: Vec, + timeout: Option, ) -> PyResult> { py.allow_threads(move || { let package_names = specs @@ -39,6 +40,7 @@ pub fn py_solve( .collect::>>()?, virtual_packages: virtual_packages.into_iter().map(Into::into).collect(), specs: specs.into_iter().map(Into::into).collect(), + timeout: timeout.map(std::time::Duration::from_micros), }; Ok(Solver