Skip to content

Commit

Permalink
fix: add N generic everywhere (#323)
Browse files Browse the repository at this point in the history
* fix: add N generic everywhere

* Update crates/rattler_libsolv_rs/src/lib.rs

Co-authored-by: Tim de Jager <tdejager89@gmail.com>

---------

Co-authored-by: Tim de Jager <tdejager89@gmail.com>
  • Loading branch information
baszalmstra and tdejager authored Sep 8, 2023
1 parent 570b4ba commit 0b384d0
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 47 deletions.
8 changes: 6 additions & 2 deletions crates/rattler_libsolv_rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ pub use transaction::Transaction;

pub use mapping::Mapping;

/// Blanket trait implementation for something that we consider a package name.
pub trait PackageName: Eq + Hash {}
impl<N: Eq + Hash> PackageName for N {}

/// Version is a name and a version specification.
pub trait VersionTrait: Display {
/// The version associated with this record.
Expand All @@ -52,11 +56,11 @@ pub trait VersionSet: Debug + Display + Clone + Eq + Hash {
}

/// Bla
pub trait DependencyProvider<VS: VersionSet> {
pub trait DependencyProvider<VS: VersionSet, N: PackageName = String> {
/// Sort the specified solvables based on which solvable to try first.
fn sort_candidates(
&mut self,
pool: &Pool<VS>,
pool: &Pool<VS, N>,
solvables: &mut [SolvableId],
match_spec_to_candidates: &Mapping<VersionSetId, OnceCell<Vec<SolvableId>>>,
);
Expand Down
31 changes: 17 additions & 14 deletions crates/rattler_libsolv_rs/src/pool.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::hash::Hash;

use crate::arena::Arena;
use crate::id::{NameId, RepoId, SolvableId, VersionSetId};
use crate::mapping::Mapping;
use crate::solvable::{PackageSolvable, Solvable};
use crate::VersionSet;
use crate::{PackageName, VersionSet};

/// A pool that stores data related to the available packages
///
/// Because it stores solvables, it contains references to `PackageRecord`s (the `'a` lifetime comes
/// from the original `PackageRecord`s)
pub struct Pool<VS: VersionSet, Name: Hash + Eq = String> {
pub struct Pool<VS: VersionSet, N: PackageName = String> {
/// All the solvables that have been registered
pub(crate) solvables: Arena<SolvableId, Solvable<VS::V>>,

/// The total amount of registered repos
total_repos: u32,

/// Interned package names
package_names: Arena<NameId, Name>,
package_names: Arena<NameId, N>,

/// Map from package names to the id of their interned counterpart
pub(crate) names_to_ids: HashMap<Name, NameId>,
pub(crate) names_to_ids: HashMap<N, NameId>,

/// Map from interned package names to the solvables that have that name
pub(crate) packages_by_name: Mapping<NameId, Vec<SolvableId>>,
Expand All @@ -42,7 +41,7 @@ pub struct Pool<VS: VersionSet, Name: Hash + Eq = String> {
pub(crate) match_spec_to_forbidden: Mapping<VersionSetId, Vec<SolvableId>>,
}

impl<VS: VersionSet, Name: Hash + Eq> Default for Pool<VS, Name> {
impl<VS: VersionSet, N: PackageName> Default for Pool<VS, N> {
fn default() -> Self {
let mut solvables = Arena::new();
solvables.alloc(Solvable::new_root());
Expand All @@ -63,7 +62,7 @@ impl<VS: VersionSet, Name: Hash + Eq> Default for Pool<VS, Name> {
}
}

impl<VS: VersionSet, Name: Hash + Eq + Clone> Pool<VS, Name> {
impl<VS: VersionSet, N: PackageName> Pool<VS, N> {
/// Creates a new [`Pool`]
pub fn new() -> Self {
Self::default()
Expand Down Expand Up @@ -136,7 +135,8 @@ impl<VS: VersionSet, Name: Hash + Eq + Clone> Pool<VS, Name> {
/// Interns a package name into the `Pool`, returning its `NameId`
pub fn intern_package_name<NValue>(&mut self, name: NValue) -> NameId
where
NValue: Into<Name>,
NValue: Into<N>,
N: Clone,
{
match self.names_to_ids.entry(name.into()) {
Entry::Occupied(e) => *e.get(),
Expand All @@ -153,14 +153,14 @@ impl<VS: VersionSet, Name: Hash + Eq + Clone> Pool<VS, Name> {
}

/// Lookup the package name id associated to the provided name
pub fn lookup_package_name(&self, name: &Name) -> Option<NameId> {
pub fn lookup_package_name(&self, name: &N) -> Option<NameId> {
self.names_to_ids.get(name).copied()
}

/// Returns the package name associated to the provided id
///
/// Panics if the package name is not found in the pool
pub fn resolve_package_name(&self, name_id: NameId) -> &Name {
pub fn resolve_package_name(&self, name_id: NameId) -> &N {
&self.package_names[name_id]
}

Expand Down Expand Up @@ -207,12 +207,12 @@ impl<VS: VersionSet, Name: Hash + Eq + Clone> Pool<VS, Name> {
}

/// A helper struct to visualize a name.
pub struct NameDisplay<'pool, VS: VersionSet> {
pub struct NameDisplay<'pool, VS: VersionSet, N: PackageName> {
id: NameId,
pool: &'pool Pool<VS>,
pool: &'pool Pool<VS, N>,
}

impl<'pool, VS: VersionSet> Display for NameDisplay<'pool, VS> {
impl<'pool, VS: VersionSet, N: PackageName + Display> Display for NameDisplay<'pool, VS, N> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let name = self.pool.resolve_package_name(self.id);
write!(f, "{}", name)
Expand All @@ -221,7 +221,10 @@ impl<'pool, VS: VersionSet> Display for NameDisplay<'pool, VS> {

impl NameId {
/// Returns an object that can be used to format the name.
pub fn display<VS: VersionSet>(self, pool: &Pool<VS>) -> NameDisplay<'_, VS> {
pub fn display<VS: VersionSet, N: PackageName + Display>(
self,
pool: &Pool<VS, N>,
) -> NameDisplay<'_, VS, N> {
NameDisplay { id: self, pool }
}
}
39 changes: 24 additions & 15 deletions crates/rattler_libsolv_rs/src/problem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::fmt::Formatter;
use std::fmt::{Display, Formatter};
use std::hash::Hash;
use std::rc::Rc;

use itertools::Itertools;
Expand All @@ -15,7 +16,7 @@ use crate::id::{ClauseId, SolvableId, VersionSetId};
use crate::pool::Pool;
use crate::solver::clause::Clause;
use crate::solver::Solver;
use crate::{DependencyProvider, VersionSet, VersionTrait};
use crate::{DependencyProvider, PackageName, VersionSet, VersionTrait};

/// Represents the cause of the solver being unable to find a solution
#[derive(Debug)]
Expand All @@ -38,9 +39,9 @@ impl Problem {
}

/// Generates a graph representation of the problem (see [`ProblemGraph`] for details)
pub fn graph<VS: VersionSet, D: DependencyProvider<VS>>(
pub fn graph<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>>(
&self,
solver: &Solver<VS, D>,
solver: &Solver<VS, N, D>,
) -> ProblemGraph {
let mut graph = DiGraph::<ProblemNode, ProblemEdge>::default();
let mut nodes: HashMap<SolvableId, NodeIndex> = HashMap::default();
Expand Down Expand Up @@ -142,10 +143,15 @@ impl Problem {
}

/// Display a user-friendly error explaining the problem
pub fn display_user_friendly<'a, VS: VersionSet, D: DependencyProvider<VS>>(
pub fn display_user_friendly<
'a,
VS: VersionSet,
N: PackageName + Display,
D: DependencyProvider<VS, N>,
>(
&self,
solver: &'a Solver<VS, D>,
) -> DisplayUnsat<'a, VS> {
solver: &'a Solver<VS, N, D>,
) -> DisplayUnsat<'a, VS, N> {
let graph = self.graph(solver);
DisplayUnsat::new(graph, solver.pool())
}
Expand Down Expand Up @@ -311,9 +317,9 @@ impl ProblemGraph {
write!(f, "}}")
}

fn simplify<VS: VersionSet>(
fn simplify<VS: VersionSet, N: PackageName>(
&self,
pool: &Pool<VS>,
pool: &Pool<VS, N>,
) -> HashMap<SolvableId, Rc<MergedProblemNode>> {
let graph = &self.graph;

Expand Down Expand Up @@ -471,16 +477,16 @@ impl ProblemGraph {

/// A struct implementing [`fmt::Display`] that generates a user-friendly representation of a
/// problem graph
pub struct DisplayUnsat<'pool, VS: VersionSet> {
pub struct DisplayUnsat<'pool, VS: VersionSet, N: PackageName + Display> {
graph: ProblemGraph,
merged_candidates: HashMap<SolvableId, Rc<MergedProblemNode>>,
installable_set: HashSet<NodeIndex>,
missing_set: HashSet<NodeIndex>,
pool: &'pool Pool<VS>,
pool: &'pool Pool<VS, N>,
}

impl<'pool, VS: VersionSet> DisplayUnsat<'pool, VS> {
pub(crate) fn new(graph: ProblemGraph, pool: &'pool Pool<VS>) -> Self {
impl<'pool, VS: VersionSet, N: PackageName + Display> DisplayUnsat<'pool, VS, N> {
pub(crate) fn new(graph: ProblemGraph, pool: &'pool Pool<VS, N>) -> Self {
let merged_candidates = graph.simplify(pool);
let installable_set = graph.get_installable_set();
let missing_set = graph.get_missing_set();
Expand Down Expand Up @@ -512,7 +518,10 @@ impl<'pool, VS: VersionSet> DisplayUnsat<'pool, VS> {
f: &mut Formatter<'_>,
top_level_edges: &[EdgeReference<ProblemEdge>],
top_level_indent: bool,
) -> fmt::Result {
) -> fmt::Result
where
N: Display,
{
pub enum DisplayOp {
Requirement(VersionSetId, Vec<EdgeIndex>),
Candidate(NodeIndex),
Expand Down Expand Up @@ -695,7 +704,7 @@ impl<'pool, VS: VersionSet> DisplayUnsat<'pool, VS> {
}
}

impl<VS: VersionSet> fmt::Display for DisplayUnsat<'_, VS> {
impl<VS: VersionSet, N: PackageName + Display> fmt::Display for DisplayUnsat<'_, VS, N> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let (top_level_missing, top_level_conflicts): (Vec<_>, _) = self
.graph
Expand Down
24 changes: 14 additions & 10 deletions crates/rattler_libsolv_rs/src/solver/clause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use crate::{
mapping::Mapping,
pool::Pool,
solver::decision_map::DecisionMap,
VersionSet,
PackageName, VersionSet,
};

use std::fmt::{Debug, Formatter};
use std::fmt::{Debug, Display, Formatter};
use std::hash::Hash;

/// Represents a single clause in the SAT problem
///
Expand Down Expand Up @@ -113,10 +114,10 @@ impl Clause {
}

/// Visits each literal in the clause
pub fn visit_literals<VS: VersionSet>(
pub fn visit_literals<VS: VersionSet, N: PackageName>(
&self,
learnt_clauses: &Arena<LearntClauseId, Vec<Literal>>,
pool: &Pool<VS>,
pool: &Pool<VS, N>,
mut visit: impl FnMut(Literal),
) {
match *self {
Expand Down Expand Up @@ -203,7 +204,10 @@ impl ClauseState {
clause
}

pub fn debug<'a, VS: VersionSet>(&self, pool: &'a Pool<VS>) -> ClauseDebug<'a, VS> {
pub fn debug<'a, VS: VersionSet, N: PackageName>(
&self,
pool: &'a Pool<VS, N>,
) -> ClauseDebug<'a, VS, N> {
ClauseDebug {
kind: self.kind,
pool,
Expand Down Expand Up @@ -315,9 +319,9 @@ impl ClauseState {
}
}

pub fn next_unwatched_variable<VS: VersionSet>(
pub fn next_unwatched_variable<VS: VersionSet, N: PackageName>(
&self,
pool: &Pool<VS>,
pool: &Pool<VS, N>,
learnt_clauses: &Arena<LearntClauseId, Vec<Literal>>,
decision_map: &DecisionMap,
) -> Option<SolvableId> {
Expand Down Expand Up @@ -395,12 +399,12 @@ impl Literal {
}

/// A representation of a clause that implements [`Debug`]
pub(crate) struct ClauseDebug<'pool, VS: VersionSet> {
pub(crate) struct ClauseDebug<'pool, VS: VersionSet, N: PackageName> {
kind: Clause,
pool: &'pool Pool<VS>,
pool: &'pool Pool<VS, N>,
}

impl<VS: VersionSet> Debug for ClauseDebug<'_, VS> {
impl<VS: VersionSet, N: PackageName + Display> Debug for ClauseDebug<'_, VS, N> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.kind {
Clause::InstallRoot => write!(f, "install root"),
Expand Down
15 changes: 9 additions & 6 deletions crates/rattler_libsolv_rs/src/solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ use std::cell::OnceCell;

use itertools::Itertools;
use std::collections::{HashMap, HashSet};
use std::fmt::Display;

use crate::{DependencyProvider, VersionSet, VersionSetId};
use crate::{DependencyProvider, PackageName, VersionSet, VersionSetId};
use clause::{Clause, ClauseState, Literal};
use decision::Decision;
use decision_tracker::DecisionTracker;
Expand All @@ -28,9 +29,9 @@ mod watch_map;
///
/// Keeps solvables in a `Pool`, which contains references to `PackageRecord`s (the `'a` lifetime
/// comes from the original `PackageRecord`s)
pub struct Solver<VS: VersionSet, D: DependencyProvider<VS>> {
pub struct Solver<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> {
provider: D,
pool: Pool<VS>,
pool: Pool<VS, N>,

pub(crate) clauses: Vec<ClauseState>,
watches: WatchMap,
Expand All @@ -42,9 +43,9 @@ pub struct Solver<VS: VersionSet, D: DependencyProvider<VS>> {
decision_tracker: DecisionTracker,
}

impl<VS: VersionSet, D: DependencyProvider<VS>> Solver<VS, D> {
impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> Solver<VS, N, D> {
/// Create a solver, using the provided pool
pub fn new(pool: Pool<VS>, provider: D) -> Self {
pub fn new(pool: Pool<VS, N>, provider: D) -> Self {
Self {
clauses: Vec::new(),
watches: WatchMap::new(),
Expand All @@ -58,10 +59,12 @@ impl<VS: VersionSet, D: DependencyProvider<VS>> Solver<VS, D> {
}

/// Returns a reference to the pool used by the solver
pub fn pool(&self) -> &Pool<VS> {
pub fn pool(&self) -> &Pool<VS, N> {
&self.pool
}
}

impl<VS: VersionSet, N: PackageName + Display, D: DependencyProvider<VS, N>> Solver<VS, N, D> {
/// Solves the provided `jobs` and returns a transaction from the found solution
///
/// Returns a [`Problem`] if no solution was found, which provides ways to inspect the causes
Expand Down

0 comments on commit 0b384d0

Please sign in to comment.