diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index 22a70f951ea7..53883aeb7186 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -924,7 +924,7 @@ impl InferenceContext<'_> { } } - fn closure_kind(&self) -> FnTrait { + fn closure_kind_from_capture(&self) -> FnTrait { let mut r = FnTrait::Fn; for it in &self.current_captures { r = cmp::min( @@ -941,7 +941,7 @@ impl InferenceContext<'_> { r } - fn analyze_closure(&mut self, closure: ClosureId) -> FnTrait { + fn analyze_closure(&mut self, closure: ClosureId, predicate: Option) -> FnTrait { let InternedClosure(_, root) = self.db.lookup_intern_closure(closure.into()); self.current_closure = Some(closure); let Expr::Closure { body, capture_by, .. } = &self.body[root] else { @@ -959,7 +959,14 @@ impl InferenceContext<'_> { } self.restrict_precision_for_unsafe(); // closure_kind should be done before adjust_for_move_closure - let closure_kind = self.closure_kind(); + let closure_kind = { + let from_capture = self.closure_kind_from_capture(); + // if predicate.unwrap_or(FnTrait::Fn) < from_capture { + // // Diagnostics here, like compiler does in + // // https://github.com/rust-lang/rust/blob/11f32b73e0dc9287e305b5b9980d24aecdc8c17f/compiler/rustc_hir_typeck/src/upvar.rs#L264 + // } + predicate.unwrap_or(from_capture) + }; match capture_by { CaptureBy::Value => self.adjust_for_move_closure(), CaptureBy::Ref => (), @@ -975,7 +982,9 @@ impl InferenceContext<'_> { let deferred_closures = self.sort_closures(); for (closure, exprs) in deferred_closures.into_iter().rev() { self.current_captures = vec![]; - let kind = self.analyze_closure(closure); + + let predicate = self.table.get_closure_fn_trait_predicate(closure); + let kind = self.analyze_closure(closure, predicate); for (derefed_callee, callee_ty, params, expr) in exprs { if let &Expr::Call { callee, .. } = &self.body[expr] { diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index 709760b64fd3..f5d5be57024e 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -1,10 +1,10 @@ //! Unification and canonicalization logic. -use std::{fmt, iter, mem}; +use std::{cmp, fmt, iter, mem}; use chalk_ir::{ cast::Cast, fold::TypeFoldable, interner::HasInterner, zip::Zip, CanonicalVarKind, FloatTy, - IntTy, TyVariableKind, UniverseIndex, + IntTy, TyVariableKind, UniverseIndex, WhereClause, }; use chalk_solve::infer::ParameterEnaVariableExt; use either::Either; @@ -14,11 +14,12 @@ use triomphe::Arc; use super::{InferOk, InferResult, InferenceContext, TypeError}; use crate::{ - consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime, - to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue, - DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar, - Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution, - TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind, + chalk_db::TraitId, consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, + static_lifetime, to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, + ClosureId, Const, ConstValue, DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal, + GoalData, Guidance, InEnvironment, InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy, + ProjectionTyExt, Scalar, Solution, Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, + TyKind, VariableKind, }; impl InferenceContext<'_> { @@ -181,6 +182,8 @@ pub(crate) struct InferenceTable<'a> { /// Double buffer used in [`Self::resolve_obligations_as_possible`] to cut down on /// temporary allocations. resolve_obligations_buffer: Vec>>, + fn_trait_predicates: Vec<(Ty, FnTrait)>, + cached_fn_trait_ids: Option, } pub(crate) struct InferenceTableSnapshot { @@ -189,8 +192,25 @@ pub(crate) struct InferenceTableSnapshot { type_variable_table_snapshot: Vec, } +#[derive(Clone)] +struct CachedFnTraitIds { + fn_trait: TraitId, + fn_mut_trait: TraitId, + fn_once_trait: TraitId, +} + +impl CachedFnTraitIds { + fn new(db: &dyn HirDatabase, trait_env: &Arc) -> Option { + let fn_trait = FnTrait::Fn.get_id(db, trait_env.krate).map(to_chalk_trait_id)?; + let fn_mut_trait = FnTrait::FnMut.get_id(db, trait_env.krate).map(to_chalk_trait_id)?; + let fn_once_trait = FnTrait::FnOnce.get_id(db, trait_env.krate).map(to_chalk_trait_id)?; + Some(Self { fn_trait, fn_mut_trait, fn_once_trait }) + } +} + impl<'a> InferenceTable<'a> { pub(crate) fn new(db: &'a dyn HirDatabase, trait_env: Arc) -> Self { + let cached_fn_trait_ids = CachedFnTraitIds::new(db, &trait_env); InferenceTable { db, trait_env, @@ -198,6 +218,8 @@ impl<'a> InferenceTable<'a> { type_variable_table: Vec::new(), pending_obligations: Vec::new(), resolve_obligations_buffer: Vec::new(), + fn_trait_predicates: Vec::new(), + cached_fn_trait_ids, } } @@ -547,6 +569,22 @@ impl<'a> InferenceTable<'a> { } fn register_obligation_in_env(&mut self, goal: InEnvironment) { + if let Some(fn_trait_ids) = &self.cached_fn_trait_ids { + if let GoalData::DomainGoal(DomainGoal::Holds(WhereClause::Implemented(trait_ref))) = + goal.goal.data(Interner) + { + if let Some(ty) = trait_ref.substitution.type_parameters(Interner).next() { + if trait_ref.trait_id == fn_trait_ids.fn_trait { + self.fn_trait_predicates.push((ty, FnTrait::Fn)); + } else if trait_ref.trait_id == fn_trait_ids.fn_mut_trait { + self.fn_trait_predicates.push((ty, FnTrait::FnMut)); + } else if trait_ref.trait_id == fn_trait_ids.fn_once_trait { + self.fn_trait_predicates.push((ty, FnTrait::FnOnce)); + } + } + } + } + let canonicalized = self.canonicalize(goal); let solution = self.try_resolve_obligation(&canonicalized); if matches!(solution, Some(Solution::Ambig(_))) { @@ -838,6 +876,23 @@ impl<'a> InferenceTable<'a> { _ => c, } } + + pub(super) fn get_closure_fn_trait_predicate( + &mut self, + closure_id: ClosureId, + ) -> Option { + let predicates = mem::take(&mut self.fn_trait_predicates); + let res = predicates.iter().filter_map(|(ty, fn_trait)| { + if matches!(self.resolve_completely(ty.clone()).kind(Interner), TyKind::Closure(c, ..) if *c == closure_id) { + Some(*fn_trait) + } else { + None + } + }).fold(None, |acc, x| Some(cmp::max(acc.unwrap_or(FnTrait::FnOnce), x))); + self.fn_trait_predicates = predicates; + + res + } } impl fmt::Debug for InferenceTable<'_> { diff --git a/crates/hir-ty/src/tests/patterns.rs b/crates/hir-ty/src/tests/patterns.rs index 069007308225..963b4a2aba05 100644 --- a/crates/hir-ty/src/tests/patterns.rs +++ b/crates/hir-ty/src/tests/patterns.rs @@ -702,25 +702,25 @@ fn test() { 51..58 'loop {}': ! 56..58 '{}': () 72..171 '{ ... x); }': () - 78..81 'foo': fn foo<&(i32, &str), i32, impl Fn(&(i32, &str)) -> i32>(&(i32, &str), impl Fn(&(i32, &str)) -> i32) -> i32 + 78..81 'foo': fn foo<&(i32, &str), i32, impl FnOnce(&(i32, &str)) -> i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> i32) -> i32 78..105 'foo(&(...y)| x)': i32 82..91 '&(1, "a")': &(i32, &str) 83..91 '(1, "a")': (i32, &str) 84..85 '1': i32 87..90 '"a"': &str - 93..104 '|&(x, y)| x': impl Fn(&(i32, &str)) -> i32 + 93..104 '|&(x, y)| x': impl FnOnce(&(i32, &str)) -> i32 94..101 '&(x, y)': &(i32, &str) 95..101 '(x, y)': (i32, &str) 96..97 'x': i32 99..100 'y': &str 103..104 'x': i32 - 142..145 'foo': fn foo<&(i32, &str), &i32, impl Fn(&(i32, &str)) -> &i32>(&(i32, &str), impl Fn(&(i32, &str)) -> &i32) -> &i32 + 142..145 'foo': fn foo<&(i32, &str), &i32, impl FnOnce(&(i32, &str)) -> &i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> &i32) -> &i32 142..168 'foo(&(...y)| x)': &i32 146..155 '&(1, "a")': &(i32, &str) 147..155 '(1, "a")': (i32, &str) 148..149 '1': i32 151..154 '"a"': &str - 157..167 '|(x, y)| x': impl Fn(&(i32, &str)) -> &i32 + 157..167 '|(x, y)| x': impl FnOnce(&(i32, &str)) -> &i32 158..164 '(x, y)': (i32, &str) 159..160 'x': &i32 162..163 'y': &&str diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs index 2ad9a7fe525f..9a8ebd07d015 100644 --- a/crates/hir-ty/src/tests/regression.rs +++ b/crates/hir-ty/src/tests/regression.rs @@ -862,7 +862,7 @@ fn main() { 123..126 'S()': S 132..133 's': S 132..144 's.g(|_x| {})': () - 136..143 '|_x| {}': impl Fn(&i32) + 136..143 '|_x| {}': impl FnOnce(&i32) 137..139 '_x': &i32 141..143 '{}': () 150..151 's': S diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs index 6c7dbe1db6ff..4034d3c69284 100644 --- a/crates/hir-ty/src/tests/simple.rs +++ b/crates/hir-ty/src/tests/simple.rs @@ -2190,9 +2190,9 @@ fn main() { 149..151 'Ok': extern "rust-call" Ok<(), ()>(()) -> Result<(), ()> 149..155 'Ok(())': Result<(), ()> 152..154 '()': () - 167..171 'test': fn test<(), (), impl Fn() -> impl Future>, impl Future>>(impl Fn() -> impl Future>) + 167..171 'test': fn test<(), (), impl FnMut() -> impl Future>, impl Future>>(impl FnMut() -> impl Future>) 167..228 'test(|... })': () - 172..227 '|| asy... }': impl Fn() -> impl Future> + 172..227 '|| asy... }': impl FnMut() -> impl Future> 175..227 'async ... }': impl Future> 191..205 'return Err(())': ! 198..201 'Err': extern "rust-call" Err<(), ()>(()) -> Result<(), ()> @@ -2743,6 +2743,29 @@ impl B for Astruct {} ) } +#[test] +fn closures_kinds_with_predicates() { + check_types( + r#" +//- minicore: fn +struct A(F); +struct B<'a, F: FnMut()>(&'a F); + +fn f() { + let c1 = || {}; + //^^ impl Fn() + let a1 = A(|| {}); + let c2 = a1.0; + //^^ impl FnOnce() + let c3 = || {}; + //^^ impl FnMut() + let a2 = A(c3); + let b1 = B(&a2.0); +} + "#, + ) +} + #[test] fn capture_kinds_simple() { check_types( diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index db14addaf185..35f73c28d057 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -1333,9 +1333,9 @@ fn foo() -> (impl FnOnce(&str, T), impl Trait) { } "#, expect![[r#" - 134..165 '{ ...(C)) }': (impl Fn(&str, T), Bar) - 140..163 '(|inpu...ar(C))': (impl Fn(&str, T), Bar) - 141..154 '|input, t| {}': impl Fn(&str, T) + 134..165 '{ ...(C)) }': (impl FnOnce(&str, T), Bar) + 140..163 '(|inpu...ar(C))': (impl FnOnce(&str, T), Bar) + 141..154 '|input, t| {}': impl FnOnce(&str, T) 142..147 'input': &str 149..150 't': T 152..154 '{}': () @@ -1963,20 +1963,20 @@ fn test() { 163..167 '1u32': u32 174..175 'x': Option 174..190 'x.map(...v + 1)': Option - 180..189 '|v| v + 1': impl Fn(u32) -> u32 + 180..189 '|v| v + 1': impl FnOnce(u32) -> u32 181..182 'v': u32 184..185 'v': u32 184..189 'v + 1': u32 188..189 '1': u32 196..197 'x': Option 196..212 'x.map(... 1u64)': Option - 202..211 '|_v| 1u64': impl Fn(u32) -> u64 + 202..211 '|_v| 1u64': impl FnOnce(u32) -> u64 203..205 '_v': u32 207..211 '1u64': u64 222..223 'y': Option 239..240 'x': Option 239..252 'x.map(|_v| 1)': Option - 245..251 '|_v| 1': impl Fn(u32) -> i64 + 245..251 '|_v| 1': impl FnOnce(u32) -> i64 246..248 '_v': u32 250..251 '1': i64 "#]], @@ -2062,17 +2062,17 @@ fn test() { 312..314 '{}': () 330..489 '{ ... S); }': () 340..342 'x1': u64 - 345..349 'foo1': fn foo1 u64>(S, impl Fn(S) -> u64) -> u64 + 345..349 'foo1': fn foo1 u64>(S, impl FnOnce(S) -> u64) -> u64 345..368 'foo1(S...hod())': u64 350..351 'S': S - 353..367 '|s| s.method()': impl Fn(S) -> u64 + 353..367 '|s| s.method()': impl FnOnce(S) -> u64 354..355 's': S 357..358 's': S 357..367 's.method()': u64 378..380 'x2': u64 - 383..387 'foo2': fn foo2 u64>(impl Fn(S) -> u64, S) -> u64 + 383..387 'foo2': fn foo2 u64>(impl FnOnce(S) -> u64, S) -> u64 383..406 'foo2(|...(), S)': u64 - 388..402 '|s| s.method()': impl Fn(S) -> u64 + 388..402 '|s| s.method()': impl FnOnce(S) -> u64 389..390 's': S 392..393 's': S 392..402 's.method()': u64 @@ -2081,14 +2081,14 @@ fn test() { 421..422 'S': S 421..446 'S.foo1...hod())': u64 428..429 'S': S - 431..445 '|s| s.method()': impl Fn(S) -> u64 + 431..445 '|s| s.method()': impl FnOnce(S) -> u64 432..433 's': S 435..436 's': S 435..445 's.method()': u64 456..458 'x4': u64 461..462 'S': S 461..486 'S.foo2...(), S)': u64 - 468..482 '|s| s.method()': impl Fn(S) -> u64 + 468..482 '|s| s.method()': impl FnOnce(S) -> u64 469..470 's': S 472..473 's': S 472..482 's.method()': u64 @@ -2562,9 +2562,9 @@ fn main() { 72..74 '_v': F 117..120 '{ }': () 132..163 '{ ... }); }': () - 138..148 'f::<(), _>': fn f<(), impl Fn(&())>(impl Fn(&())) + 138..148 'f::<(), _>': fn f<(), impl FnOnce(&())>(impl FnOnce(&())) 138..160 'f::<()... z; })': () - 149..159 '|z| { z; }': impl Fn(&()) + 149..159 '|z| { z; }': impl FnOnce(&()) 150..151 'z': &() 153..159 '{ z; }': () 155..156 'z': &() @@ -2749,9 +2749,9 @@ fn main() { 983..998 'Vec::::new': fn new() -> Vec 983..1000 'Vec::<...:new()': Vec 983..1012 'Vec::<...iter()': IntoIter - 983..1075 'Vec::<...one })': FilterMap, impl Fn(i32) -> Option> + 983..1075 'Vec::<...one })': FilterMap, impl FnMut(i32) -> Option> 983..1101 'Vec::<... y; })': () - 1029..1074 '|x| if...None }': impl Fn(i32) -> Option + 1029..1074 '|x| if...None }': impl FnMut(i32) -> Option 1030..1031 'x': i32 1033..1074 'if x >...None }': Option 1036..1037 'x': i32 @@ -2764,7 +2764,7 @@ fn main() { 1049..1057 'x as u32': u32 1066..1074 '{ None }': Option 1068..1072 'None': Option - 1090..1100 '|y| { y; }': impl Fn(u32) + 1090..1100 '|y| { y; }': impl FnMut(u32) 1091..1092 'y': u32 1094..1100 '{ y; }': () 1096..1097 'y': u32 diff --git a/crates/ide/src/hover/tests.rs b/crates/ide/src/hover/tests.rs index 69ddc1e45efb..54c8437086a4 100644 --- a/crates/ide/src/hover/tests.rs +++ b/crates/ide/src/hover/tests.rs @@ -353,9 +353,9 @@ fn main() { expect![[r#" ```rust {closure#0} // size = 8, align = 8, niches = 1 - impl FnOnce() -> S2 + impl Fn() -> S2 ``` - Coerced to: &impl FnOnce() -> S2 + Coerced to: &impl Fn() -> S2 ## Captures * `x` by move"#]], @@ -401,17 +401,17 @@ fn main() { }, }, HoverGotoTypeData { - mod_path: "core::ops::function::FnOnce", + mod_path: "core::ops::function::Fn", nav: NavigationTarget { file_id: FileId( 1, ), - full_range: 632..867, - focus_range: 693..699, - name: "FnOnce", + full_range: 254..425, + focus_range: 310..312, + name: "Fn", kind: Trait, container_name: "function", - description: "pub trait FnOnce\nwhere\n Args: Tuple,", + description: "pub trait Fn\nwhere\n Self: FnMut,\n Args: Tuple,", }, }, ],