diff --git a/crates/hir_ty/src/db.rs b/crates/hir_ty/src/db.rs index e144dd43f44d..a42053396478 100644 --- a/crates/hir_ty/src/db.rs +++ b/crates/hir_ty/src/db.rs @@ -13,6 +13,7 @@ use la_arena::ArenaMap; use crate::{ chalk_db, method_resolution::{InherentImpls, TraitImpls}, + traits::{ChalkCache, ChalkCacheKey}, Binders, CallableDefId, FnDefId, ImplTraitId, InferenceResult, Interner, PolyFnSig, QuantifiedWhereClause, ReturnTypeImplTraits, TraitRef, Ty, TyDefId, ValueTyDefId, }; @@ -159,6 +160,9 @@ pub trait HirDatabase: DefDatabase + Upcast { krate: CrateId, env: chalk_ir::Environment, ) -> chalk_ir::ProgramClauses; + + #[salsa::invoke(crate::traits::chalk_cache)] + fn chalk_cache(&self, cache_key: ChalkCacheKey) -> ChalkCache; } fn infer_wait(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc { diff --git a/crates/hir_ty/src/traits.rs b/crates/hir_ty/src/traits.rs index b139edbee945..bd17e966b813 100644 --- a/crates/hir_ty/src/traits.rs +++ b/crates/hir_ty/src/traits.rs @@ -1,16 +1,19 @@ //! Trait solving using Chalk. -use std::env::var; +use std::fmt::Debug; +use std::hash::Hash; +use std::{env::var, sync::Arc}; -use chalk_ir::GoalData; -use chalk_recursive::Cache; +use chalk_ir::{Fallible, GoalData}; +use chalk_recursive::{Cache, UCanonicalGoal}; use chalk_solve::{logging_db::LoggingRustIrDatabase, Solver}; -use base_db::CrateId; +use base_db::{CrateGraph, CrateId}; use hir_def::{lang_item::LangItemTarget, TraitId}; use stdx::panic_context; use syntax::SmolStr; +use crate::method_resolution::TraitImpls; use crate::{ db::HirDatabase, AliasEq, AliasTy, Canonical, DomainGoal, Goal, Guidance, InEnvironment, Interner, Solution, TraitRefExt, Ty, TyKind, WhereClause, @@ -25,11 +28,63 @@ pub(crate) struct ChalkContext<'a> { pub(crate) krate: CrateId, } -fn create_chalk_solver() -> chalk_recursive::RecursiveSolver { +#[derive(Clone)] +pub struct ChalkCache { + cache: Arc, Fallible>>, +} + +impl PartialEq for ChalkCache { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.cache, &other.cache) + } +} + +impl Eq for ChalkCache {} + +impl Debug for ChalkCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ChalkCache").finish() + } +} + +#[derive(Debug, Clone)] +pub struct ChalkCacheKey { + trait_impls_in_crate: Arc, + crate_graph: Arc, +} + +impl PartialEq for ChalkCacheKey { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.trait_impls_in_crate, &other.trait_impls_in_crate) + && Arc::ptr_eq(&self.crate_graph, &other.crate_graph) + } +} + +impl Eq for ChalkCacheKey {} + +impl Hash for ChalkCacheKey { + fn hash(&self, state: &mut H) { + Arc::as_ptr(&self.trait_impls_in_crate).hash(state); + Arc::as_ptr(&self.crate_graph).hash(state); + } +} + +pub fn chalk_cache(_: &dyn HirDatabase, _: ChalkCacheKey) -> ChalkCache { + ChalkCache { cache: Arc::new(Cache::new()) } +} + +fn create_chalk_solver( + db: &dyn HirDatabase, + cache_key: ChalkCacheKey, +) -> chalk_recursive::RecursiveSolver { let overflow_depth = var("CHALK_OVERFLOW_DEPTH").ok().and_then(|s| s.parse().ok()).unwrap_or(300); let max_size = var("CHALK_SOLVER_MAX_SIZE").ok().and_then(|s| s.parse().ok()).unwrap_or(150); - chalk_recursive::RecursiveSolver::new(overflow_depth, max_size, Some(Cache::new())) + chalk_recursive::RecursiveSolver::new( + overflow_depth, + max_size, + Some(Cache::clone(&db.chalk_cache(cache_key).cache)), + ) } /// A set of clauses that we assume to be true. E.g. if we are inside this function: @@ -103,7 +158,13 @@ fn solve( ) -> Option> { let context = ChalkContext { db, krate }; tracing::debug!("solve goal: {:?}", goal); - let mut solver = create_chalk_solver(); + let mut solver = create_chalk_solver( + db, + ChalkCacheKey { + trait_impls_in_crate: db.trait_impls_in_crate(krate), + crate_graph: db.crate_graph(), + }, + ); let fuel = std::cell::Cell::new(CHALK_SOLVER_FUEL);