Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add N generic everywhere #323

Merged
merged 2 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions crates/rattler_libsolv_rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ pub use transaction::Transaction;

pub use mapping::Mapping;

/// Blanket trait implementation for something that we consider a package name.
pub trait PackageName: Eq + Hash {}
impl<N: Eq + Hash> PackageName for N {}

/// Version is a name and a version specification.
pub trait VersionTrait: Display {
/// The version associated with this record.
Expand All @@ -52,11 +56,11 @@ pub trait VersionSet: Debug + Display + Clone + Eq + Hash {
}

/// Bla
pub trait DependencyProvider<VS: VersionSet> {
pub trait DependencyProvider<VS: VersionSet, N: PackageName = String> {
/// Sort the specified solvables based on which solvable to try first.
fn sort_candidates(
&mut self,
pool: &Pool<VS>,
pool: &Pool<VS, N>,
solvables: &mut [SolvableId],
match_spec_to_candidates: &Mapping<VersionSetId, OnceCell<Vec<SolvableId>>>,
);
Expand Down
31 changes: 17 additions & 14 deletions crates/rattler_libsolv_rs/src/pool.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::hash::Hash;

use crate::arena::Arena;
use crate::id::{NameId, RepoId, SolvableId, VersionSetId};
use crate::mapping::Mapping;
use crate::solvable::{PackageSolvable, Solvable};
use crate::VersionSet;
use crate::{PackageName, VersionSet};

/// A pool that stores data related to the available packages
///
/// Because it stores solvables, it contains references to `PackageRecord`s (the `'a` lifetime comes
/// from the original `PackageRecord`s)
pub struct Pool<VS: VersionSet, Name: Hash + Eq = String> {
pub struct Pool<VS: VersionSet, N: PackageName = String> {
/// All the solvables that have been registered
pub(crate) solvables: Arena<SolvableId, Solvable<VS::V>>,

/// The total amount of registered repos
total_repos: u32,

/// Interned package names
package_names: Arena<NameId, Name>,
package_names: Arena<NameId, N>,

/// Map from package names to the id of their interned counterpart
pub(crate) names_to_ids: HashMap<Name, NameId>,
pub(crate) names_to_ids: HashMap<N, NameId>,

/// Map from interned package names to the solvables that have that name
pub(crate) packages_by_name: Mapping<NameId, Vec<SolvableId>>,
Expand All @@ -42,7 +41,7 @@ pub struct Pool<VS: VersionSet, Name: Hash + Eq = String> {
pub(crate) match_spec_to_forbidden: Mapping<VersionSetId, Vec<SolvableId>>,
}

impl<VS: VersionSet, Name: Hash + Eq> Default for Pool<VS, Name> {
impl<VS: VersionSet, N: PackageName> Default for Pool<VS, N> {
fn default() -> Self {
let mut solvables = Arena::new();
solvables.alloc(Solvable::new_root());
Expand All @@ -63,7 +62,7 @@ impl<VS: VersionSet, Name: Hash + Eq> Default for Pool<VS, Name> {
}
}

impl<VS: VersionSet, Name: Hash + Eq + Clone> Pool<VS, Name> {
impl<VS: VersionSet, N: PackageName> Pool<VS, N> {
/// Creates a new [`Pool`]
pub fn new() -> Self {
Self::default()
Expand Down Expand Up @@ -136,7 +135,8 @@ impl<VS: VersionSet, Name: Hash + Eq + Clone> Pool<VS, Name> {
/// Interns a package name into the `Pool`, returning its `NameId`
pub fn intern_package_name<NValue>(&mut self, name: NValue) -> NameId
where
NValue: Into<Name>,
NValue: Into<N>,
N: Clone,
{
match self.names_to_ids.entry(name.into()) {
Entry::Occupied(e) => *e.get(),
Expand All @@ -153,14 +153,14 @@ impl<VS: VersionSet, Name: Hash + Eq + Clone> Pool<VS, Name> {
}

/// Lookup the package name id associated to the provided name
pub fn lookup_package_name(&self, name: &Name) -> Option<NameId> {
pub fn lookup_package_name(&self, name: &N) -> Option<NameId> {
self.names_to_ids.get(name).copied()
}

/// Returns the package name associated to the provided id
///
/// Panics if the package name is not found in the pool
pub fn resolve_package_name(&self, name_id: NameId) -> &Name {
pub fn resolve_package_name(&self, name_id: NameId) -> &N {
&self.package_names[name_id]
}

Expand Down Expand Up @@ -207,12 +207,12 @@ impl<VS: VersionSet, Name: Hash + Eq + Clone> Pool<VS, Name> {
}

/// A helper struct to visualize a name.
pub struct NameDisplay<'pool, VS: VersionSet> {
pub struct NameDisplay<'pool, VS: VersionSet, N: PackageName> {
id: NameId,
pool: &'pool Pool<VS>,
pool: &'pool Pool<VS, N>,
}

impl<'pool, VS: VersionSet> Display for NameDisplay<'pool, VS> {
impl<'pool, VS: VersionSet, N: PackageName + Display> Display for NameDisplay<'pool, VS, N> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let name = self.pool.resolve_package_name(self.id);
write!(f, "{}", name)
Expand All @@ -221,7 +221,10 @@ impl<'pool, VS: VersionSet> Display for NameDisplay<'pool, VS> {

impl NameId {
/// Returns an object that can be used to format the name.
pub fn display<VS: VersionSet>(self, pool: &Pool<VS>) -> NameDisplay<'_, VS> {
pub fn display<VS: VersionSet, N: PackageName + Display>(
self,
pool: &Pool<VS, N>,
) -> NameDisplay<'_, VS, N> {
NameDisplay { id: self, pool }
}
}
39 changes: 24 additions & 15 deletions crates/rattler_libsolv_rs/src/problem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

use std::collections::{HashMap, HashSet};
use std::fmt;
use std::fmt::Formatter;
use std::fmt::{Display, Formatter};
use std::hash::Hash;
use std::rc::Rc;

use itertools::Itertools;
Expand All @@ -15,7 +16,7 @@ use crate::id::{ClauseId, SolvableId, VersionSetId};
use crate::pool::Pool;
use crate::solver::clause::Clause;
use crate::solver::Solver;
use crate::{DependencyProvider, VersionSet, VersionTrait};
use crate::{DependencyProvider, PackageName, VersionSet, VersionTrait};

/// Represents the cause of the solver being unable to find a solution
#[derive(Debug)]
Expand All @@ -38,9 +39,9 @@ impl Problem {
}

/// Generates a graph representation of the problem (see [`ProblemGraph`] for details)
pub fn graph<VS: VersionSet, D: DependencyProvider<VS>>(
pub fn graph<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>>(
&self,
solver: &Solver<VS, D>,
solver: &Solver<VS, N, D>,
) -> ProblemGraph {
let mut graph = DiGraph::<ProblemNode, ProblemEdge>::default();
let mut nodes: HashMap<SolvableId, NodeIndex> = HashMap::default();
Expand Down Expand Up @@ -142,10 +143,15 @@ impl Problem {
}

/// Display a user-friendly error explaining the problem
pub fn display_user_friendly<'a, VS: VersionSet, D: DependencyProvider<VS>>(
pub fn display_user_friendly<
'a,
VS: VersionSet,
N: PackageName + Display,
D: DependencyProvider<VS, N>,
>(
&self,
solver: &'a Solver<VS, D>,
) -> DisplayUnsat<'a, VS> {
solver: &'a Solver<VS, N, D>,
) -> DisplayUnsat<'a, VS, N> {
let graph = self.graph(solver);
DisplayUnsat::new(graph, solver.pool())
}
Expand Down Expand Up @@ -311,9 +317,9 @@ impl ProblemGraph {
write!(f, "}}")
}

fn simplify<VS: VersionSet>(
fn simplify<VS: VersionSet, N: PackageName>(
&self,
pool: &Pool<VS>,
pool: &Pool<VS, N>,
) -> HashMap<SolvableId, Rc<MergedProblemNode>> {
let graph = &self.graph;

Expand Down Expand Up @@ -471,16 +477,16 @@ impl ProblemGraph {

/// A struct implementing [`fmt::Display`] that generates a user-friendly representation of a
/// problem graph
pub struct DisplayUnsat<'pool, VS: VersionSet> {
pub struct DisplayUnsat<'pool, VS: VersionSet, N: PackageName + Display> {
graph: ProblemGraph,
merged_candidates: HashMap<SolvableId, Rc<MergedProblemNode>>,
installable_set: HashSet<NodeIndex>,
missing_set: HashSet<NodeIndex>,
pool: &'pool Pool<VS>,
pool: &'pool Pool<VS, N>,
}

impl<'pool, VS: VersionSet> DisplayUnsat<'pool, VS> {
pub(crate) fn new(graph: ProblemGraph, pool: &'pool Pool<VS>) -> Self {
impl<'pool, VS: VersionSet, N: PackageName + Display> DisplayUnsat<'pool, VS, N> {
pub(crate) fn new(graph: ProblemGraph, pool: &'pool Pool<VS, N>) -> Self {
let merged_candidates = graph.simplify(pool);
let installable_set = graph.get_installable_set();
let missing_set = graph.get_missing_set();
Expand Down Expand Up @@ -512,7 +518,10 @@ impl<'pool, VS: VersionSet> DisplayUnsat<'pool, VS> {
f: &mut Formatter<'_>,
top_level_edges: &[EdgeReference<ProblemEdge>],
top_level_indent: bool,
) -> fmt::Result {
) -> fmt::Result
where
N: Display,
{
pub enum DisplayOp {
Requirement(VersionSetId, Vec<EdgeIndex>),
Candidate(NodeIndex),
Expand Down Expand Up @@ -695,7 +704,7 @@ impl<'pool, VS: VersionSet> DisplayUnsat<'pool, VS> {
}
}

impl<VS: VersionSet> fmt::Display for DisplayUnsat<'_, VS> {
impl<VS: VersionSet, N: PackageName + Display> fmt::Display for DisplayUnsat<'_, VS, N> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let (top_level_missing, top_level_conflicts): (Vec<_>, _) = self
.graph
Expand Down
24 changes: 14 additions & 10 deletions crates/rattler_libsolv_rs/src/solver/clause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use crate::{
mapping::Mapping,
pool::Pool,
solver::decision_map::DecisionMap,
VersionSet,
PackageName, VersionSet,
};

use std::fmt::{Debug, Formatter};
use std::fmt::{Debug, Display, Formatter};
use std::hash::Hash;

/// Represents a single clause in the SAT problem
///
Expand Down Expand Up @@ -113,10 +114,10 @@ impl Clause {
}

/// Visits each literal in the clause
pub fn visit_literals<VS: VersionSet>(
pub fn visit_literals<VS: VersionSet, N: PackageName>(
&self,
learnt_clauses: &Arena<LearntClauseId, Vec<Literal>>,
pool: &Pool<VS>,
pool: &Pool<VS, N>,
mut visit: impl FnMut(Literal),
) {
match *self {
Expand Down Expand Up @@ -203,7 +204,10 @@ impl ClauseState {
clause
}

pub fn debug<'a, VS: VersionSet>(&self, pool: &'a Pool<VS>) -> ClauseDebug<'a, VS> {
pub fn debug<'a, VS: VersionSet, N: PackageName>(
&self,
pool: &'a Pool<VS, N>,
) -> ClauseDebug<'a, VS, N> {
ClauseDebug {
kind: self.kind,
pool,
Expand Down Expand Up @@ -315,9 +319,9 @@ impl ClauseState {
}
}

pub fn next_unwatched_variable<VS: VersionSet>(
pub fn next_unwatched_variable<VS: VersionSet, N: PackageName>(
&self,
pool: &Pool<VS>,
pool: &Pool<VS, N>,
learnt_clauses: &Arena<LearntClauseId, Vec<Literal>>,
decision_map: &DecisionMap,
) -> Option<SolvableId> {
Expand Down Expand Up @@ -395,12 +399,12 @@ impl Literal {
}

/// A representation of a clause that implements [`Debug`]
pub(crate) struct ClauseDebug<'pool, VS: VersionSet> {
pub(crate) struct ClauseDebug<'pool, VS: VersionSet, N: PackageName> {
kind: Clause,
pool: &'pool Pool<VS>,
pool: &'pool Pool<VS, N>,
}

impl<VS: VersionSet> Debug for ClauseDebug<'_, VS> {
impl<VS: VersionSet, N: PackageName + Display> Debug for ClauseDebug<'_, VS, N> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.kind {
Clause::InstallRoot => write!(f, "install root"),
Expand Down
15 changes: 9 additions & 6 deletions crates/rattler_libsolv_rs/src/solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ use std::cell::OnceCell;

use itertools::Itertools;
use std::collections::{HashMap, HashSet};
use std::fmt::Display;

use crate::{DependencyProvider, VersionSet, VersionSetId};
use crate::{DependencyProvider, PackageName, VersionSet, VersionSetId};
use clause::{Clause, ClauseState, Literal};
use decision::Decision;
use decision_tracker::DecisionTracker;
Expand All @@ -28,9 +29,9 @@ mod watch_map;
///
/// Keeps solvables in a `Pool`, which contains references to `PackageRecord`s (the `'a` lifetime
/// comes from the original `PackageRecord`s)
pub struct Solver<VS: VersionSet, D: DependencyProvider<VS>> {
pub struct Solver<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> {
provider: D,
pool: Pool<VS>,
pool: Pool<VS, N>,

pub(crate) clauses: Vec<ClauseState>,
watches: WatchMap,
Expand All @@ -42,9 +43,9 @@ pub struct Solver<VS: VersionSet, D: DependencyProvider<VS>> {
decision_tracker: DecisionTracker,
}

impl<VS: VersionSet, D: DependencyProvider<VS>> Solver<VS, D> {
impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> Solver<VS, N, D> {
/// Create a solver, using the provided pool
pub fn new(pool: Pool<VS>, provider: D) -> Self {
pub fn new(pool: Pool<VS, N>, provider: D) -> Self {
Self {
clauses: Vec::new(),
watches: WatchMap::new(),
Expand All @@ -58,10 +59,12 @@ impl<VS: VersionSet, D: DependencyProvider<VS>> Solver<VS, D> {
}

/// Returns a reference to the pool used by the solver
pub fn pool(&self) -> &Pool<VS> {
pub fn pool(&self) -> &Pool<VS, N> {
&self.pool
}
}

impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Solver<VS, N, D> {
/// Solves the provided `jobs` and returns a transaction from the found solution
///
/// Returns a [`Problem`] if no solution was found, which provides ways to inspect the causes
Expand Down