Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add timeout parameter and SolverOptions to return early #499

Merged
merged 6 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions crates/rattler-bin/src/commands/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use rattler_repodata_gateway::fetch::{
CacheResult, DownloadProgress, FetchRepoDataError, FetchRepoDataOptions,
};
use rattler_repodata_gateway::sparse::SparseRepoData;
use rattler_solve::{libsolv_c, resolvo, SolverImpl, SolverTask};
use rattler_solve::{
libsolv_c::{self},
resolvo, SolverImpl, SolverTask,
};
use reqwest::Client;
use std::sync::Arc;
use std::{
Expand Down Expand Up @@ -53,7 +56,10 @@ pub struct Opt {
virtual_package: Option<Vec<String>>,

#[clap(long)]
use_experimental_libsolv_rs: bool,
use_resolvo: bool,

#[clap(long)]
timeout: Option<u64>,
}

pub async fn create(opt: Opt) -> anyhow::Result<()> {
Expand Down Expand Up @@ -221,13 +227,13 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
virtual_packages,
specs,
pinned_packages: Vec::new(),
timeout: opt.timeout.map(Duration::from_millis),
};

// Next, use a solver to solve this specific problem. This provides us with all the operations
// we need to apply to our environment to bring it up to date.
let use_libsolv_rs = opt.use_experimental_libsolv_rs;
let required_packages = wrap_in_progress("solving", move || {
if use_libsolv_rs {
if opt.use_resolvo {
resolvo::Solver.solve(solver_task)
} else {
libsolv_c::Solver.solve(solver_task)
Expand Down
2 changes: 2 additions & 0 deletions crates/rattler_solve/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) {
pinned_packages: vec![],
virtual_packages: vec![],
specs: specs.clone(),
timeout: None,
}))
.unwrap()
});
Expand All @@ -79,6 +80,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) {
pinned_packages: vec![],
virtual_packages: vec![],
specs: specs.clone(),
timeout: None,
}))
.unwrap()
});
Expand Down
3 changes: 3 additions & 0 deletions crates/rattler_solve/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ pub struct SolverTask<TAvailablePackagesIterator> {

/// The specs we want to solve
pub specs: Vec<MatchSpec>,

/// The timeout after which the solver should stop
pub timeout: Option<std::time::Duration>,
}

/// A representation of a collection of [`RepoDataRecord`] usable by a [`SolverImpl`]
Expand Down
6 changes: 6 additions & 0 deletions crates/rattler_solve/src/libsolv_c/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ impl super::SolverImpl for Solver {
&mut self,
task: SolverTask<TAvailablePackagesIterator>,
) -> Result<Vec<RepoDataRecord>, SolveError> {
if task.timeout.is_some() {
return Err(SolveError::UnsupportedOperations(vec![
"timeout".to_string()
]));
}

// Construct a default libsolv pool
let pool = Pool::default();

Expand Down
14 changes: 8 additions & 6 deletions crates/rattler_solve/src/resolvo/conda_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ pub(super) fn compare_candidates<'a>(

// Otherwise, compare the dependencies of the variants. If there are similar
// dependencies select the variant that selects the highest version of the dependency.
let a_dependencies = solver
.get_or_cache_dependencies(a)
.expect("should not get here, resolution process aborted");
let b_dependencies = solver
.get_or_cache_dependencies(b)
.expect("should not get here, resolution process aborted");
let (a_dependencies, b_dependencies) = match (
solver.get_or_cache_dependencies(a),
solver.get_or_cache_dependencies(b),
) {
(Ok(a_deps), Ok(b_deps)) => (a_deps, b_deps),
// If either call fails, it's likely due to solver cancellation; thus, we can't compare dependencies
_ => return Ordering::Equal,
};

// If the MatchSpecs are known use these
// map these into a HashMap<PackageName, VersionSetId>
Expand Down
24 changes: 24 additions & 0 deletions crates/rattler_solve/src/resolvo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ pub(crate) struct CondaDependencyProvider<'a> {
RefCell<HashMap<VersionSetId, Option<(rattler_conda_types::Version, bool)>>>,

parse_match_spec_cache: RefCell<HashMap<&'a str, VersionSetId>>,

stop_time: Option<std::time::SystemTime>,
}

impl<'a> CondaDependencyProvider<'a> {
Expand All @@ -174,6 +176,7 @@ impl<'a> CondaDependencyProvider<'a> {
locked_records: &'a [RepoDataRecord],
virtual_packages: &'a [GenericVirtualPackage],
match_specs: &[MatchSpec],
stop_time: Option<std::time::SystemTime>,
) -> Self {
let pool = Pool::default();
let mut records: HashMap<NameId, Candidates> = HashMap::default();
Expand Down Expand Up @@ -333,10 +336,17 @@ impl<'a> CondaDependencyProvider<'a> {
records,
matchspec_to_highest_version: RefCell::default(),
parse_match_spec_cache: RefCell::default(),
stop_time,
}
}
}

/// The reason why the solver was cancelled
pub enum CancelReason {
/// The solver was cancelled because the timeout was reached
Timeout,
}

impl<'a> DependencyProvider<SolverMatchSpec<'a>> for CondaDependencyProvider<'a> {
fn pool(&self) -> &Pool<SolverMatchSpec<'a>, String> {
&self.pool
Expand Down Expand Up @@ -378,6 +388,15 @@ impl<'a> DependencyProvider<SolverMatchSpec<'a>> for CondaDependencyProvider<'a>

Dependencies::Known(dependencies)
}

fn should_cancel_with_value(&self) -> Option<Box<dyn std::any::Any>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for being slow - why the return type is not Option<Box<dyn CancelReason>> but dyn Any?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's defined like that in the trait (in resolvo) so I think we need to honor that contract and return an Box<Any>

if let Some(stop_time) = self.stop_time {
if std::time::SystemTime::now() > stop_time {
return Some(Box::new(CancelReason::Timeout));
}
}
None
}
}

/// Displays the different candidates by their version and sorted by their version
Expand Down Expand Up @@ -414,13 +433,18 @@ impl super::SolverImpl for Solver {
&mut self,
task: SolverTask<TAvailablePackagesIterator>,
) -> Result<Vec<RepoDataRecord>, SolveError> {
let stop_time = task
.timeout
.map(|timeout| std::time::SystemTime::now() + timeout);

// Construct a provider that can serve the data.
let provider = CondaDependencyProvider::from_solver_task(
task.available_packages.into_iter().map(|r| r.into()),
&task.locked_packages,
&task.pinned_packages,
&task.virtual_packages,
task.specs.clone().as_ref(),
stop_time,
);

// Construct the requirements that the solver needs to satisfy.
Expand Down
6 changes: 6 additions & 0 deletions crates/rattler_solve/tests/backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ fn solve_real_world<T: SolverImpl + Default>(specs: Vec<&str>) -> Vec<String> {
locked_packages: Vec::default(),
pinned_packages: Vec::default(),
virtual_packages: Vec::default(),
timeout: None,
};

let pkgs1 = match T::default().solve(solver_task) {
Expand Down Expand Up @@ -527,6 +528,7 @@ mod libsolv_c {
available_packages: [libsolv_repodata],
specs,
pinned_packages: Vec::new(),
timeout: None,
})
.unwrap();

Expand Down Expand Up @@ -618,6 +620,7 @@ fn solve<T: SolverImpl + Default>(
available_packages: [&repo_data],
specs,
pinned_packages,
timeout: None,
};

let pkgs = T::default().solve(task)?;
Expand Down Expand Up @@ -675,6 +678,7 @@ fn compare_solve(specs: Vec<&str>) {
locked_packages: Vec::default(),
pinned_packages: Vec::default(),
virtual_packages: Vec::default(),
timeout: None,
})
.unwrap(),
),
Expand All @@ -696,6 +700,7 @@ fn compare_solve(specs: Vec<&str>) {
locked_packages: Vec::default(),
pinned_packages: Vec::default(),
virtual_packages: Vec::default(),
timeout: None,
})
.unwrap(),
),
Expand Down Expand Up @@ -766,6 +771,7 @@ fn solve_to_get_channel_of_spec(
locked_packages: Vec::default(),
pinned_packages: Vec::default(),
virtual_packages: Vec::default(),
timeout: None,
})
.unwrap();

Expand Down
3 changes: 3 additions & 0 deletions py-rattler/rattler/solver/solver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
import datetime
from typing import List, Optional
from rattler.match_spec.match_spec import MatchSpec

Expand All @@ -14,6 +15,7 @@ def solve(
locked_packages: Optional[List[RepoDataRecord]] = None,
pinned_packages: Optional[List[RepoDataRecord]] = None,
virtual_packages: Optional[List[GenericVirtualPackage]] = None,
timeout: Optional[datetime.timedelta] = None,
) -> List[RepoDataRecord]:
"""
Resolve the dependencies and return the `RepoDataRecord`s
Expand Down Expand Up @@ -54,5 +56,6 @@ def solve(
v_package._generic_virtual_package
for v_package in virtual_packages or []
],
timeout.microseconds if timeout else None,
)
]
2 changes: 2 additions & 0 deletions py-rattler/src/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub fn py_solve(
locked_packages: Vec<PyRecord>,
pinned_packages: Vec<PyRecord>,
virtual_packages: Vec<PyGenericVirtualPackage>,
timeout: Option<u64>,
) -> PyResult<Vec<PyRecord>> {
py.allow_threads(move || {
let package_names = specs
Expand All @@ -39,6 +40,7 @@ pub fn py_solve(
.collect::<PyResult<Vec<_>>>()?,
virtual_packages: virtual_packages.into_iter().map(Into::into).collect(),
specs: specs.into_iter().map(Into::into).collect(),
timeout: timeout.map(std::time::Duration::from_micros),
};

Ok(Solver
Expand Down