Skip to content

Commit ff3a75c

Browse files
committed
Fix closure kind inference
1 parent b5510f0 commit ff3a75c

File tree

3 files changed

+98
-11
lines changed

3 files changed

+98
-11
lines changed

crates/hir-ty/src/infer/closure.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ impl InferenceContext<'_> {
922922
}
923923
}
924924

925-
fn closure_kind(&self) -> FnTrait {
925+
fn closure_kind_from_capture(&self) -> FnTrait {
926926
let mut r = FnTrait::Fn;
927927
for it in &self.current_captures {
928928
r = cmp::min(
@@ -939,7 +939,7 @@ impl InferenceContext<'_> {
939939
r
940940
}
941941

942-
fn analyze_closure(&mut self, closure: ClosureId) -> FnTrait {
942+
fn analyze_closure(&mut self, closure: ClosureId, predicate: Option<FnTrait>) -> FnTrait {
943943
let (_, root) = self.db.lookup_intern_closure(closure.into());
944944
self.current_closure = Some(closure);
945945
let Expr::Closure { body, capture_by, .. } = &self.body[root] else {
@@ -957,7 +957,13 @@ impl InferenceContext<'_> {
957957
}
958958
self.restrict_precision_for_unsafe();
959959
// closure_kind should be done before adjust_for_move_closure
960-
let closure_kind = self.closure_kind();
960+
let closure_kind = {
961+
let from_capture = self.closure_kind_from_capture();
962+
// if predicate.unwrap_or(FnTrait::Fn) < from_capture {
963+
// // TODO: Diagnostic
964+
// }
965+
predicate.unwrap_or(from_capture)
966+
};
961967
match capture_by {
962968
CaptureBy::Value => self.adjust_for_move_closure(),
963969
CaptureBy::Ref => (),
@@ -973,7 +979,9 @@ impl InferenceContext<'_> {
973979
let deferred_closures = self.sort_closures();
974980
for (closure, exprs) in deferred_closures.into_iter().rev() {
975981
self.current_captures = vec![];
976-
let kind = self.analyze_closure(closure);
982+
983+
let predicate = self.table.get_closure_fn_trait_predicate(closure);
984+
let kind = self.analyze_closure(closure, predicate);
977985

978986
for (derefed_callee, callee_ty, params, expr) in exprs {
979987
if let &Expr::Call { callee, .. } = &self.body[expr] {

crates/hir-ty/src/infer/unify.rs

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
//! Unification and canonicalization logic.
22
3-
use std::{fmt, iter, mem};
3+
use std::{cmp, fmt, iter, mem};
44

55
use chalk_ir::{
66
cast::Cast, fold::TypeFoldable, interner::HasInterner, zip::Zip, CanonicalVarKind, FloatTy,
7-
IntTy, TyVariableKind, UniverseIndex,
7+
IntTy, TyVariableKind, UniverseIndex, WhereClause,
88
};
99
use chalk_solve::infer::ParameterEnaVariableExt;
1010
use either::Either;
@@ -14,11 +14,12 @@ use triomphe::Arc;
1414

1515
use super::{InferOk, InferResult, InferenceContext, TypeError};
1616
use crate::{
17-
consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime,
18-
to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue,
19-
DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar,
20-
Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution,
21-
TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind,
17+
chalk_db::TraitId, consteval::unknown_const, db::HirDatabase, fold_tys_and_consts,
18+
static_lifetime, to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical,
19+
ClosureId, Const, ConstValue, DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal,
20+
GoalData, Guidance, InEnvironment, InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy,
21+
ProjectionTyExt, Scalar, Solution, Substitution, TraitEnvironment, Ty, TyBuilder, TyExt,
22+
TyKind, VariableKind,
2223
};
2324

2425
impl InferenceContext<'_> {
@@ -146,6 +147,8 @@ pub(crate) struct InferenceTable<'a> {
146147
/// Double buffer used in [`Self::resolve_obligations_as_possible`] to cut down on
147148
/// temporary allocations.
148149
resolve_obligations_buffer: Vec<Canonicalized<InEnvironment<Goal>>>,
150+
fn_trait_predicates: Vec<(Ty, FnTrait)>,
151+
cached_fn_trait_ids: Option<CachedFnTraitIds>,
149152
}
150153

151154
pub(crate) struct InferenceTableSnapshot {
@@ -154,15 +157,36 @@ pub(crate) struct InferenceTableSnapshot {
154157
type_variable_table_snapshot: Vec<TypeVariableFlags>,
155158
}
156159

160+
#[derive(Clone)]
161+
struct CachedFnTraitIds {
162+
fn_trait: TraitId,
163+
fn_mut_trait: TraitId,
164+
fn_once_trait: TraitId,
165+
}
166+
167+
impl CachedFnTraitIds {
168+
fn new(db: &dyn HirDatabase, trait_env: &Arc<TraitEnvironment>) -> Option<Self> {
169+
let fn_trait = FnTrait::Fn.get_id(db, trait_env.krate.clone()).map(to_chalk_trait_id)?;
170+
let fn_mut_trait =
171+
FnTrait::FnMut.get_id(db, trait_env.krate.clone()).map(to_chalk_trait_id)?;
172+
let fn_once_trait =
173+
FnTrait::FnOnce.get_id(db, trait_env.krate.clone()).map(to_chalk_trait_id)?;
174+
Some(Self { fn_trait, fn_mut_trait, fn_once_trait })
175+
}
176+
}
177+
157178
impl<'a> InferenceTable<'a> {
158179
pub(crate) fn new(db: &'a dyn HirDatabase, trait_env: Arc<TraitEnvironment>) -> Self {
180+
let cached_fn_trait_ids = CachedFnTraitIds::new(db, &trait_env);
159181
InferenceTable {
160182
db,
161183
trait_env,
162184
var_unification_table: ChalkInferenceTable::new(),
163185
type_variable_table: Vec::new(),
164186
pending_obligations: Vec::new(),
165187
resolve_obligations_buffer: Vec::new(),
188+
fn_trait_predicates: Vec::new(),
189+
cached_fn_trait_ids,
166190
}
167191
}
168192

@@ -498,6 +522,22 @@ impl<'a> InferenceTable<'a> {
498522
}
499523

500524
fn register_obligation_in_env(&mut self, goal: InEnvironment<Goal>) {
525+
if let Some(fn_trait_ids) = &self.cached_fn_trait_ids {
526+
if let GoalData::DomainGoal(DomainGoal::Holds(WhereClause::Implemented(trait_ref))) =
527+
goal.goal.data(Interner)
528+
{
529+
if let Some(ty) = trait_ref.substitution.type_parameters(Interner).next() {
530+
if trait_ref.trait_id == fn_trait_ids.fn_trait {
531+
self.fn_trait_predicates.push((ty, FnTrait::Fn));
532+
} else if trait_ref.trait_id == fn_trait_ids.fn_mut_trait {
533+
self.fn_trait_predicates.push((ty, FnTrait::FnMut));
534+
} else if trait_ref.trait_id == fn_trait_ids.fn_once_trait {
535+
self.fn_trait_predicates.push((ty, FnTrait::FnOnce));
536+
}
537+
}
538+
}
539+
}
540+
501541
let canonicalized = self.canonicalize(goal);
502542
if !self.try_resolve_obligation(&canonicalized) {
503543
self.pending_obligations.push(canonicalized);
@@ -791,6 +831,23 @@ impl<'a> InferenceTable<'a> {
791831
_ => c,
792832
}
793833
}
834+
835+
pub(super) fn get_closure_fn_trait_predicate(
836+
&mut self,
837+
closure_id: ClosureId,
838+
) -> Option<FnTrait> {
839+
let predicates = mem::take(&mut self.fn_trait_predicates);
840+
let res = predicates.iter().filter_map(|(ty, fn_trait)| {
841+
if matches!(self.resolve_completely(ty.clone()).kind(Interner), TyKind::Closure(c, ..) if *c == closure_id) {
842+
Some(*fn_trait)
843+
} else {
844+
None
845+
}
846+
}).fold(None, |acc, x| Some(cmp::max(acc.unwrap_or(FnTrait::FnOnce), x)));
847+
self.fn_trait_predicates = predicates;
848+
849+
return res;
850+
}
794851
}
795852

796853
impl fmt::Debug for InferenceTable<'_> {

crates/hir-ty/src/mir/eval/tests.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,28 @@ fn main() {
636636
);
637637
}
638638

639+
#[test]
640+
fn check_check_check() {
641+
check_pass(
642+
r#"
643+
//- minicore: fn
644+
struct Foo<F: FnOnce(i32)>(F);
645+
fn should_not_reach() {
646+
_ // FIXME: replace this function with panic when that works
647+
}
648+
649+
fn main() {
650+
let mut a = 0;
651+
let foo = Foo(move |val| {
652+
a = val;
653+
});
654+
foo.0(10);
655+
foo.0(20);
656+
}
657+
"#,
658+
)
659+
}
660+
639661
#[test]
640662
fn closure_capture_array_const_generic() {
641663
check_pass(

0 commit comments

Comments
 (0)