Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/hir-def/src/lang_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ language_item_table! {
Fn, sym::fn_, fn_trait, Target::Trait, GenericRequirement::Exact(1);
FnMut, sym::fn_mut, fn_mut_trait, Target::Trait, GenericRequirement::Exact(1);
FnOnce, sym::fn_once, fn_once_trait, Target::Trait, GenericRequirement::Exact(1);
AsyncFn, sym::async_fn, async_fn_trait, Target::Trait, GenericRequirement::Exact(1);
AsyncFnMut, sym::async_fn_mut, async_fn_mut_trait, Target::Trait, GenericRequirement::Exact(1);
AsyncFnOnce, sym::async_fn_once, async_fn_once_trait, Target::Trait, GenericRequirement::Exact(1);

FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None;

Expand Down
6 changes: 3 additions & 3 deletions crates/hir-ty/src/infer/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1287,8 +1287,8 @@ impl InferenceContext<'_> {
tgt_expr: ExprId,
) {
match fn_x {
FnTrait::FnOnce => (),
FnTrait::FnMut => {
FnTrait::FnOnce | FnTrait::AsyncFnOnce => (),
FnTrait::FnMut | FnTrait::AsyncFnMut => {
if let TyKind::Ref(Mutability::Mut, lt, inner) = derefed_callee.kind(Interner) {
if adjustments
.last()
Expand All @@ -1312,7 +1312,7 @@ impl InferenceContext<'_> {
));
}
}
FnTrait::Fn => {
FnTrait::Fn | FnTrait::AsyncFn => {
if !matches!(derefed_callee.kind(Interner), TyKind::Ref(Mutability::Not, _, _)) {
adjustments.push(Adjustment::borrow(
Mutability::Not,
Expand Down
114 changes: 60 additions & 54 deletions crates/hir-ty/src/infer/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,69 +794,75 @@ impl<'a> InferenceTable<'a> {
ty: &Ty,
num_args: usize,
) -> Option<(FnTrait, Vec<Ty>, Ty)> {
let krate = self.trait_env.krate;
let fn_once_trait = FnTrait::FnOnce.get_id(self.db, krate)?;
let trait_data = self.db.trait_data(fn_once_trait);
let output_assoc_type =
trait_data.associated_type_by_name(&Name::new_symbol_root(sym::Output.clone()))?;

let mut arg_tys = Vec::with_capacity(num_args);
let arg_ty = TyBuilder::tuple(num_args)
.fill(|it| {
let arg = match it {
ParamKind::Type => self.new_type_var(),
ParamKind::Lifetime => unreachable!("Tuple with lifetime parameter"),
ParamKind::Const(_) => unreachable!("Tuple with const parameter"),
};
arg_tys.push(arg.clone());
arg.cast(Interner)
})
.build();

let b = TyBuilder::trait_ref(self.db, fn_once_trait);
if b.remaining() != 2 {
return None;
}
let mut trait_ref = b.push(ty.clone()).push(arg_ty).build();
for (fn_trait_name, output_assoc_name, subtraits) in [
(FnTrait::FnOnce, sym::Output.clone(), &[FnTrait::Fn, FnTrait::FnMut][..]),
(FnTrait::AsyncFnMut, sym::CallRefFuture.clone(), &[FnTrait::AsyncFn]),
(FnTrait::AsyncFnOnce, sym::CallOnceFuture.clone(), &[]),
] {
let krate = self.trait_env.krate;
let fn_trait = fn_trait_name.get_id(self.db, krate)?;
let trait_data = self.db.trait_data(fn_trait);
let output_assoc_type =
trait_data.associated_type_by_name(&Name::new_symbol_root(output_assoc_name))?;

let mut arg_tys = Vec::with_capacity(num_args);
let arg_ty = TyBuilder::tuple(num_args)
.fill(|it| {
let arg = match it {
ParamKind::Type => self.new_type_var(),
ParamKind::Lifetime => unreachable!("Tuple with lifetime parameter"),
ParamKind::Const(_) => unreachable!("Tuple with const parameter"),
};
arg_tys.push(arg.clone());
arg.cast(Interner)
})
.build();

let b = TyBuilder::trait_ref(self.db, fn_trait);
if b.remaining() != 2 {
return None;
}
let mut trait_ref = b.push(ty.clone()).push(arg_ty).build();

let projection = {
TyBuilder::assoc_type_projection(
let projection = TyBuilder::assoc_type_projection(
self.db,
output_assoc_type,
Some(trait_ref.substitution.clone()),
)
.build()
};
.fill_with_unknown()
.build();

let trait_env = self.trait_env.env.clone();
let obligation = InEnvironment {
goal: trait_ref.clone().cast(Interner),
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some() {
self.register_obligation(obligation.goal);
let return_ty = self.normalize_projection_ty(projection);
for fn_x in [FnTrait::Fn, FnTrait::FnMut, FnTrait::FnOnce] {
let fn_x_trait = fn_x.get_id(self.db, krate)?;
trait_ref.trait_id = to_chalk_trait_id(fn_x_trait);
let obligation: chalk_ir::InEnvironment<chalk_ir::Goal<Interner>> = InEnvironment {
goal: trait_ref.clone().cast(Interner),
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self
.db
.trait_solve(krate, self.trait_env.block, canonical.cast(Interner))
.is_some()
{
return Some((fn_x, arg_tys, return_ty));
let trait_env = self.trait_env.env.clone();
let obligation = InEnvironment {
goal: trait_ref.clone().cast(Interner),
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some()
{
self.register_obligation(obligation.goal);
let return_ty = self.normalize_projection_ty(projection);
for &fn_x in subtraits {
let fn_x_trait = fn_x.get_id(self.db, krate)?;
trait_ref.trait_id = to_chalk_trait_id(fn_x_trait);
let obligation: chalk_ir::InEnvironment<chalk_ir::Goal<Interner>> =
InEnvironment {
goal: trait_ref.clone().cast(Interner),
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self
.db
.trait_solve(krate, self.trait_env.block, canonical.cast(Interner))
.is_some()
{
return Some((fn_x, arg_tys, return_ty));
}
}
return Some((fn_trait_name, arg_tys, return_ty));
}
unreachable!("It should at least implement FnOnce at this point");
} else {
None
}
None
}

pub(super) fn insert_type_vars<T>(&mut self, ty: T) -> T
Expand Down
12 changes: 7 additions & 5 deletions crates/hir-ty/src/mir/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2023,11 +2023,11 @@ pub fn mir_body_for_closure_query(
ctx.result.locals.alloc(Local { ty: infer[*root].clone() });
let closure_local = ctx.result.locals.alloc(Local {
ty: match kind {
FnTrait::FnOnce => infer[expr].clone(),
FnTrait::FnMut => {
FnTrait::FnOnce | FnTrait::AsyncFnOnce => infer[expr].clone(),
FnTrait::FnMut | FnTrait::AsyncFnMut => {
TyKind::Ref(Mutability::Mut, error_lifetime(), infer[expr].clone()).intern(Interner)
}
FnTrait::Fn => {
FnTrait::Fn | FnTrait::AsyncFn => {
TyKind::Ref(Mutability::Not, error_lifetime(), infer[expr].clone()).intern(Interner)
}
},
Expand Down Expand Up @@ -2055,8 +2055,10 @@ pub fn mir_body_for_closure_query(
let mut err = None;
let closure_local = ctx.result.locals.iter().nth(1).unwrap().0;
let closure_projection = match kind {
FnTrait::FnOnce => vec![],
FnTrait::FnMut | FnTrait::Fn => vec![ProjectionElem::Deref],
FnTrait::FnOnce | FnTrait::AsyncFnOnce => vec![],
FnTrait::FnMut | FnTrait::Fn | FnTrait::AsyncFnMut | FnTrait::AsyncFn => {
vec![ProjectionElem::Deref]
}
};
ctx.result.walk_places(|p, store| {
if let Some(it) = upvar_map.get(&p.local) {
Expand Down
50 changes: 50 additions & 0 deletions crates/hir-ty/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4811,3 +4811,53 @@ fn bar(v: *const ()) {
"#]],
);
}

#[test]
fn async_fn_traits() {
check_infer(
r#"
//- minicore: async_fn
async fn foo<T: AsyncFn(u32) -> i32>(a: T) {
let fut1 = a(0);
fut1.await;
}
async fn bar<T: AsyncFnMut(u32) -> i32>(mut b: T) {
let fut2 = b(0);
fut2.await;
}
async fn baz<T: AsyncFnOnce(u32) -> i32>(c: T) {
let fut3 = c(0);
fut3.await;
}
"#,
expect![[r#"
37..38 'a': T
43..83 '{ ...ait; }': ()
43..83 '{ ...ait; }': impl Future<Output = ()>
53..57 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
60..61 'a': T
60..64 'a(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
62..63 '0': u32
70..74 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
70..80 'fut1.await': i32
124..129 'mut b': T
134..174 '{ ...ait; }': ()
134..174 '{ ...ait; }': impl Future<Output = ()>
144..148 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
151..152 'b': T
151..155 'b(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
153..154 '0': u32
161..165 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
161..171 'fut2.await': i32
216..217 'c': T
222..262 '{ ...ait; }': ()
222..262 '{ ...ait; }': impl Future<Output = ()>
232..236 'fut3': AsyncFnOnce::CallOnceFuture<T, (u32,)>
239..240 'c': T
239..243 'c(0)': AsyncFnOnce::CallOnceFuture<T, (u32,)>
241..242 '0': u32
249..253 'fut3': AsyncFnOnce::CallOnceFuture<T, (u32,)>
249..259 'fut3.await': i32
"#]],
);
}
26 changes: 23 additions & 3 deletions crates/hir-ty/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ pub enum FnTrait {
FnOnce,
FnMut,
Fn,

AsyncFnOnce,
AsyncFnMut,
AsyncFn,
}

impl fmt::Display for FnTrait {
Expand All @@ -228,6 +232,9 @@ impl fmt::Display for FnTrait {
FnTrait::FnOnce => write!(f, "FnOnce"),
FnTrait::FnMut => write!(f, "FnMut"),
FnTrait::Fn => write!(f, "Fn"),
FnTrait::AsyncFnOnce => write!(f, "AsyncFnOnce"),
FnTrait::AsyncFnMut => write!(f, "AsyncFnMut"),
FnTrait::AsyncFn => write!(f, "AsyncFn"),
}
}
}
Expand All @@ -238,6 +245,9 @@ impl FnTrait {
FnTrait::FnOnce => "call_once",
FnTrait::FnMut => "call_mut",
FnTrait::Fn => "call",
FnTrait::AsyncFnOnce => "async_call_once",
FnTrait::AsyncFnMut => "async_call_mut",
FnTrait::AsyncFn => "async_call",
}
}

Expand All @@ -246,6 +256,9 @@ impl FnTrait {
FnTrait::FnOnce => LangItem::FnOnce,
FnTrait::FnMut => LangItem::FnMut,
FnTrait::Fn => LangItem::Fn,
FnTrait::AsyncFnOnce => LangItem::AsyncFnOnce,
FnTrait::AsyncFnMut => LangItem::AsyncFnMut,
FnTrait::AsyncFn => LangItem::AsyncFn,
}
}

Expand All @@ -254,15 +267,19 @@ impl FnTrait {
LangItem::FnOnce => Some(FnTrait::FnOnce),
LangItem::FnMut => Some(FnTrait::FnMut),
LangItem::Fn => Some(FnTrait::Fn),
LangItem::AsyncFnOnce => Some(FnTrait::AsyncFnOnce),
LangItem::AsyncFnMut => Some(FnTrait::AsyncFnMut),
LangItem::AsyncFn => Some(FnTrait::AsyncFn),
_ => None,
}
}

pub const fn to_chalk_ir(self) -> rust_ir::ClosureKind {
// Chalk doesn't support async fn traits.
match self {
FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce,
FnTrait::FnMut => rust_ir::ClosureKind::FnMut,
FnTrait::Fn => rust_ir::ClosureKind::Fn,
FnTrait::AsyncFnOnce | FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce,
FnTrait::AsyncFnMut | FnTrait::FnMut => rust_ir::ClosureKind::FnMut,
FnTrait::AsyncFn | FnTrait::Fn => rust_ir::ClosureKind::Fn,
}
}

Expand All @@ -271,6 +288,9 @@ impl FnTrait {
FnTrait::FnOnce => Name::new_symbol_root(sym::call_once.clone()),
FnTrait::FnMut => Name::new_symbol_root(sym::call_mut.clone()),
FnTrait::Fn => Name::new_symbol_root(sym::call.clone()),
FnTrait::AsyncFnOnce => Name::new_symbol_root(sym::async_call_once.clone()),
FnTrait::AsyncFnMut => Name::new_symbol_root(sym::async_call_mut.clone()),
FnTrait::AsyncFn => Name::new_symbol_root(sym::async_call.clone()),
}
}

Expand Down
21 changes: 21 additions & 0 deletions crates/ide-diagnostics/src/handlers/expected_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,25 @@ fn foo() {
"#,
);
}

#[test]
fn no_error_for_async_fn_traits() {
check_diagnostics(
r#"
//- minicore: async_fn
async fn f(it: impl AsyncFn(u32) -> i32) {
let fut = it(0);
let _: i32 = fut.await;
}
async fn g(mut it: impl AsyncFnMut(u32) -> i32) {
let fut = it(0);
let _: i32 = fut.await;
}
async fn h(it: impl AsyncFnOnce(u32) -> i32) {
let fut = it(0);
let _: i32 = fut.await;
}
"#,
);
}
}
8 changes: 8 additions & 0 deletions crates/intern/src/symbol/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ define_symbols! {
C,
call_mut,
call_once,
async_call_once,
async_call_mut,
async_call,
call,
cdecl,
Center,
Expand Down Expand Up @@ -221,6 +224,9 @@ define_symbols! {
fn_mut,
fn_once_output,
fn_once,
async_fn_once,
async_fn_mut,
async_fn,
fn_ptr_addr,
fn_ptr_trait,
format_alignment,
Expand Down Expand Up @@ -334,6 +340,8 @@ define_symbols! {
Option,
Ord,
Output,
CallRefFuture,
CallOnceFuture,
owned_box,
packed,
panic_2015,
Expand Down
Loading
Loading