Skip to content

Commit

Permalink
Auto merge of rust-lang#15012 - lowr:patch/generate-fn-async-ret-ty, …
Browse files Browse the repository at this point in the history
…r=HKalbasi

Infer return type for async function in `generate_function`

Part of rust-lang#10122

In `generate_function` assist, when we infer the return type of async function we're generating, we should retrieve the type of parent await expression rather than the call expression itself.
  • Loading branch information
bors committed Jun 9, 2023
2 parents 9c03aa1 + 32768fe commit 9973b11
Showing 1 changed file with 84 additions and 18 deletions.
102 changes: 84 additions & 18 deletions crates/ide-assists/src/handlers/generate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,9 @@ impl FunctionBuilder {
let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
let is_async = await_expr.is_some();

let (ret_type, should_focus_return_type) = make_return_type(
ctx,
&ast::Expr::CallExpr(call.clone()),
target_module,
&mut necessary_generic_params,
);
let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into());
let (ret_type, should_focus_return_type) =
make_return_type(ctx, &expr_for_ret_ty, target_module, &mut necessary_generic_params);

let (generic_param_list, where_clause) =
fn_generic_params(ctx, necessary_generic_params, &target)?;
Expand Down Expand Up @@ -338,12 +335,9 @@ impl FunctionBuilder {
let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
let is_async = await_expr.is_some();

let (ret_type, should_focus_return_type) = make_return_type(
ctx,
&ast::Expr::MethodCallExpr(call.clone()),
target_module,
&mut necessary_generic_params,
);
let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into());
let (ret_type, should_focus_return_type) =
make_return_type(ctx, &expr_for_ret_ty, target_module, &mut necessary_generic_params);

let (generic_param_list, where_clause) =
fn_generic_params(ctx, necessary_generic_params, &target)?;
Expand Down Expand Up @@ -429,12 +423,12 @@ impl FunctionBuilder {
/// user can change the `todo!` function body.
fn make_return_type(
ctx: &AssistContext<'_>,
call: &ast::Expr,
expr: &ast::Expr,
target_module: Module,
necessary_generic_params: &mut FxHashSet<hir::GenericParam>,
) -> (Option<ast::RetType>, bool) {
let (ret_ty, should_focus_return_type) = {
match ctx.sema.type_of_expr(call).map(TypeInfo::original) {
match ctx.sema.type_of_expr(expr).map(TypeInfo::original) {
Some(ty) if ty.is_unknown() => (Some(make::ty_placeholder()), true),
None => (Some(make::ty_placeholder()), true),
Some(ty) if ty.is_unit() => (None, false),
Expand Down Expand Up @@ -2268,13 +2262,13 @@ impl Foo {
check_assist(
generate_function,
r"
fn foo() {
$0bar(42).await();
async fn foo() {
$0bar(42).await;
}
",
r"
fn foo() {
bar(42).await();
async fn foo() {
bar(42).await;
}
async fn bar(arg: i32) ${0:-> _} {
Expand All @@ -2284,6 +2278,28 @@ async fn bar(arg: i32) ${0:-> _} {
)
}

#[test]
fn return_type_for_async_fn() {
check_assist(
generate_function,
r"
//- minicore: result
async fn foo() {
if Err(()) = $0bar(42).await {}
}
",
r"
async fn foo() {
if Err(()) = bar(42).await {}
}
async fn bar(arg: i32) -> Result<_, ()> {
${0:todo!()}
}
",
);
}

#[test]
fn create_method() {
check_assist(
Expand Down Expand Up @@ -2401,6 +2417,31 @@ fn foo() {S.bar();}
)
}

#[test]
fn create_async_method() {
check_assist(
generate_function,
r"
//- minicore: result
struct S;
async fn foo() {
if let Err(()) = S.$0bar(42).await {}
}
",
r"
struct S;
impl S {
async fn bar(&self, arg: i32) -> Result<_, ()> {
${0:todo!()}
}
}
async fn foo() {
if let Err(()) = S.bar(42).await {}
}
",
)
}

#[test]
fn create_static_method() {
check_assist(
Expand All @@ -2421,6 +2462,31 @@ fn foo() {S::bar();}
)
}

#[test]
fn create_async_static_method() {
check_assist(
generate_function,
r"
//- minicore: result
struct S;
async fn foo() {
if let Err(()) = S::$0bar(42).await {}
}
",
r"
struct S;
impl S {
async fn bar(arg: i32) -> Result<_, ()> {
${0:todo!()}
}
}
async fn foo() {
if let Err(()) = S::bar(42).await {}
}
",
)
}

#[test]
fn create_generic_static_method() {
check_assist(
Expand Down

0 comments on commit 9973b11

Please sign in to comment.