Skip to content

Commit

Permalink
more concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
aochagavia committed Feb 2, 2024
1 parent 0938551 commit 96c8b60
Showing 1 changed file with 40 additions and 27 deletions.
67 changes: 40 additions & 27 deletions src/solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,16 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
&self,
solvable_id: SolvableId,
) -> Result<AddClauseOutput, Box<dyn Any>> {
let mut output = AddClauseOutput::default();
let mut queue = vec![solvable_id];
let mut seen = HashSet::new();
seen.insert(solvable_id);
let output = RefCell::new(AddClauseOutput::default());
let queue = RefCell::new(vec![solvable_id]);
let seen = RefCell::new(HashSet::new());
seen.borrow_mut().insert(solvable_id);

while let Some(solvable_id) = queue.borrow_mut().pop() {

Check failure on line 196 in src/solver/mod.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

this `RefCell` reference is held across an `await` point
let output = &output;
let queue = &queue;
let seen = &seen;

while let Some(solvable_id) = queue.pop() {
let mutex = {
let mut clauses = self.clauses_added_for_solvable.borrow_mut();
let mutex = clauses
Expand Down Expand Up @@ -237,6 +241,7 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
.alloc(ClauseState::exclude(solvable_id, *reason));

// Exclusions are negative assertions, tracked outside of the watcher system
let mut output = output.borrow_mut();
output.negative_assertions.push((solvable_id, clause_id));

// There might be a conflict now
Expand All @@ -251,14 +256,10 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
};

// Add clauses for the requirements
for version_set_id in requirements {
let add_requirements = requirements.into_iter().map(|version_set_id| async move {
let dependency_name = self.pool.resolve_version_set_package_name(version_set_id);
self.add_clauses_for_package(
&mut output.negative_assertions,
&mut output.clauses_to_watch,
dependency_name,
)
.await?;
self.add_clauses_for_package(&output, dependency_name)

Check failure on line 261 in src/solver/mod.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

this expression creates a reference which is immediately dereferenced by the compiler
.await?;

// Find all the solvables that match for the given version set
let candidates = self
Expand All @@ -268,6 +269,8 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol

// Queue requesting the dependencies of the candidates as well if they are cheaply
// available from the dependency provider.
let mut queue = queue.borrow_mut();
let mut seen = seen.borrow_mut();
for &candidate in candidates {
if seen.insert(candidate)
&& self.cache.are_dependencies_available_for(candidate)
Expand All @@ -292,6 +295,7 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
unreachable!();
};

let mut output = output.borrow_mut();
if clause.has_watches() {
output.clauses_to_watch.push(clause_id);
}
Expand All @@ -306,17 +310,14 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
// Add assertions for unit clauses (i.e. those with no matching candidates)
output.negative_assertions.push((solvable_id, clause_id));
}
}

// Add clauses for the constraints
for version_set_id in constrains {
Ok::<_, Box<dyn Any>>(())
});

let add_constrains = constrains.into_iter().map(|version_set_id| async move {
let dependency_name = self.pool.resolve_version_set_package_name(version_set_id);
self.add_clauses_for_package(
&mut output.negative_assertions,
&mut output.clauses_to_watch,
dependency_name,
)
.await?;
self.add_clauses_for_package(&output, dependency_name)

Check failure on line 319 in src/solver/mod.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

this expression creates a reference which is immediately dereferenced by the compiler
.await?;

// Find all the solvables that match for the given version set
let constrained_candidates = self
Expand All @@ -325,6 +326,7 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
.await?;

// Add forbidden clauses for the candidates
let mut output = output.borrow_mut();
for forbidden_candidate in constrained_candidates.iter().copied().collect_vec() {
let (clause, conflict) = ClauseState::constrains(
solvable_id,
Expand All @@ -340,12 +342,22 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
output.conflicting_clauses.push(clause_id);
}
}

Ok::<_, Box<dyn Any>>(())
});

let add_requirements = futures::future::join_all(add_requirements);
let add_constrains = futures::future::join_all(add_constrains);
let (results1, results2) =
futures::future::join(add_requirements, add_constrains).await;
for result in results1.into_iter().chain(results2) {
result?;
}

*clauses_added = true;
}

Ok(output)
Ok(output.into_inner())
}

/// Adds all clauses for a specific package name.
Expand All @@ -366,8 +378,7 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
/// will be returned as an `Err(...)`.
async fn add_clauses_for_package(
&self,
negative_assertions: &mut Vec<(SolvableId, ClauseId)>,
clauses_to_watch: &mut Vec<ClauseId>,
output: &RefCell<AddClauseOutput>,
package_name: NameId,
) -> Result<(), Box<dyn Any>> {
let mutex = {
Expand Down Expand Up @@ -402,6 +413,8 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
);
}

let mut output = output.borrow_mut();

// Each candidate gets a clause to disallow other candidates.
for (i, &candidate) in candidates.iter().enumerate() {
for &other_candidate in &candidates[i + 1..] {
Expand All @@ -411,7 +424,7 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
.alloc(ClauseState::forbid_multiple(candidate, other_candidate));

debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches());
clauses_to_watch.push(clause_id);
output.clauses_to_watch.push(clause_id);
}
}

Expand All @@ -425,7 +438,7 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
.alloc(ClauseState::lock(locked_solvable_id, other_candidate));

debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches());
clauses_to_watch.push(clause_id);
output.clauses_to_watch.push(clause_id);
}
}
}
Expand All @@ -438,7 +451,7 @@ impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Sol
.alloc(ClauseState::exclude(solvable, reason));

// Exclusions are negative assertions, tracked outside of the watcher system
negative_assertions.push((solvable, clause_id));
output.negative_assertions.push((solvable, clause_id));

// Conflicts should be impossible here
debug_assert!(self.decision_tracker.assigned_value(solvable) != Some(true));
Expand Down

0 comments on commit 96c8b60

Please sign in to comment.