1
1
use std:: iter;
2
2
3
+ use hir:: HasSource ;
3
4
use ide_db:: {
4
5
famous_defs:: FamousDefs ,
5
6
syntax_helpers:: node_ext:: { for_each_tail_expr, walk_expr} ,
6
7
} ;
8
+ use itertools:: Itertools ;
7
9
use syntax:: {
8
- ast:: { self , make, Expr } ,
9
- match_ast, ted, AstNode ,
10
+ ast:: { self , make, Expr , HasGenericParams } ,
11
+ match_ast, ted, AstNode , ToSmolStr ,
10
12
} ;
11
13
12
14
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -39,25 +41,22 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<
39
41
} ;
40
42
41
43
let type_ref = & ret_type. ty ( ) ?;
42
- let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
43
- let result_enum =
44
+ let core_result =
44
45
FamousDefs ( & ctx. sema , ctx. sema . scope ( type_ref. syntax ( ) ) ?. krate ( ) ) . core_result_Result ( ) ?;
45
46
46
- if matches ! ( ty, Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == result_enum) {
47
+ let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
48
+ if matches ! ( ty, Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == core_result) {
49
+ // The return type is already wrapped in a Result
47
50
cov_mark:: hit!( wrap_return_type_in_result_simple_return_type_already_result) ;
48
51
return None ;
49
52
}
50
53
51
- let new_result_ty =
52
- make:: ext:: ty_result ( type_ref. clone ( ) , make:: ty_placeholder ( ) ) . clone_for_update ( ) ;
53
- let generic_args = new_result_ty. syntax ( ) . descendants ( ) . find_map ( ast:: GenericArgList :: cast) ?;
54
- let last_genarg = generic_args. generic_args ( ) . last ( ) ?;
55
-
56
54
acc. add (
57
55
AssistId ( "wrap_return_type_in_result" , AssistKind :: RefactorRewrite ) ,
58
56
"Wrap return type in Result" ,
59
57
type_ref. syntax ( ) . text_range ( ) ,
60
58
|edit| {
59
+ let new_result_ty = result_type ( ctx, & core_result, type_ref) . clone_for_update ( ) ;
61
60
let body = edit. make_mut ( ast:: Expr :: BlockExpr ( body) ) ;
62
61
63
62
let mut exprs_to_wrap = Vec :: new ( ) ;
@@ -81,16 +80,72 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<
81
80
}
82
81
83
82
let old_result_ty = edit. make_mut ( type_ref. clone ( ) ) ;
84
-
85
83
ted:: replace ( old_result_ty. syntax ( ) , new_result_ty. syntax ( ) ) ;
86
84
87
- if let Some ( cap) = ctx. config . snippet_cap {
88
- edit. add_placeholder_snippet ( cap, last_genarg) ;
85
+ // Add a placeholder snippet at the first generic argument that doesn't equal the return type.
86
+ // This is normally the error type, but that may not be the case when we inserted a type alias.
87
+ let args = new_result_ty. syntax ( ) . descendants ( ) . find_map ( ast:: GenericArgList :: cast) ;
88
+ let error_type_arg = args. and_then ( |list| {
89
+ list. generic_args ( ) . find ( |arg| match arg {
90
+ ast:: GenericArg :: TypeArg ( _) => arg. syntax ( ) . text ( ) != type_ref. syntax ( ) . text ( ) ,
91
+ ast:: GenericArg :: LifetimeArg ( _) => false ,
92
+ _ => true ,
93
+ } )
94
+ } ) ;
95
+ if let Some ( error_type_arg) = error_type_arg {
96
+ if let Some ( cap) = ctx. config . snippet_cap {
97
+ edit. add_placeholder_snippet ( cap, error_type_arg) ;
98
+ }
89
99
}
90
100
} ,
91
101
)
92
102
}
93
103
104
+ fn result_type (
105
+ ctx : & AssistContext < ' _ > ,
106
+ core_result : & hir:: Enum ,
107
+ ret_type : & ast:: Type ,
108
+ ) -> ast:: Type {
109
+ // Try to find a Result<T, ...> type alias in the current scope (shadowing the default).
110
+ let result_path = hir:: ModPath :: from_segments (
111
+ hir:: PathKind :: Plain ,
112
+ iter:: once ( hir:: Name :: new_symbol_root ( hir:: sym:: Result . clone ( ) ) ) ,
113
+ ) ;
114
+ let alias = ctx. sema . resolve_mod_path ( ret_type. syntax ( ) , & result_path) . and_then ( |def| {
115
+ def. filter_map ( |def| match def. as_module_def ( ) ? {
116
+ hir:: ModuleDef :: TypeAlias ( alias) => {
117
+ let enum_ty = alias. ty ( ctx. db ( ) ) . as_adt ( ) ?. as_enum ( ) ?;
118
+ ( & enum_ty == core_result) . then_some ( alias)
119
+ }
120
+ _ => None ,
121
+ } )
122
+ . find_map ( |alias| {
123
+ let mut inserted_ret_type = false ;
124
+ let generic_params = alias
125
+ . source ( ctx. db ( ) ) ?
126
+ . value
127
+ . generic_param_list ( ) ?
128
+ . generic_params ( )
129
+ . map ( |param| match param {
130
+ // Replace the very first type parameter with the functions return type.
131
+ ast:: GenericParam :: TypeParam ( _) if !inserted_ret_type => {
132
+ inserted_ret_type = true ;
133
+ ret_type. to_smolstr ( )
134
+ }
135
+ ast:: GenericParam :: LifetimeParam ( _) => make:: lifetime ( "'_" ) . to_smolstr ( ) ,
136
+ _ => make:: ty_placeholder ( ) . to_smolstr ( ) ,
137
+ } )
138
+ . join ( ", " ) ;
139
+
140
+ let name = alias. name ( ctx. db ( ) ) ;
141
+ let name = name. as_str ( ) ;
142
+ Some ( make:: ty ( & format ! ( "{name}<{generic_params}>" ) ) )
143
+ } )
144
+ } ) ;
145
+ // If there is no applicable alias in scope use the default Result type.
146
+ alias. unwrap_or_else ( || make:: ext:: ty_result ( ret_type. clone ( ) , make:: ty_placeholder ( ) ) )
147
+ }
148
+
94
149
fn tail_cb_impl ( acc : & mut Vec < ast:: Expr > , e : & ast:: Expr ) {
95
150
match e {
96
151
Expr :: BreakExpr ( break_expr) => {
@@ -998,4 +1053,216 @@ fn foo(the_field: u32) -> Result<u32, ${0:_}> {
998
1053
"# ,
999
1054
) ;
1000
1055
}
1056
+
1057
+ #[ test]
1058
+ fn wrap_return_type_in_local_result_type ( ) {
1059
+ check_assist (
1060
+ wrap_return_type_in_result,
1061
+ r#"
1062
+ //- minicore: result
1063
+ type Result<T> = core::result::Result<T, ()>;
1064
+
1065
+ fn foo() -> i3$02 {
1066
+ return 42i32;
1067
+ }
1068
+ "# ,
1069
+ r#"
1070
+ type Result<T> = core::result::Result<T, ()>;
1071
+
1072
+ fn foo() -> Result<i32> {
1073
+ return Ok(42i32);
1074
+ }
1075
+ "# ,
1076
+ ) ;
1077
+
1078
+ check_assist (
1079
+ wrap_return_type_in_result,
1080
+ r#"
1081
+ //- minicore: result
1082
+ type Result2<T> = core::result::Result<T, ()>;
1083
+
1084
+ fn foo() -> i3$02 {
1085
+ return 42i32;
1086
+ }
1087
+ "# ,
1088
+ r#"
1089
+ type Result2<T> = core::result::Result<T, ()>;
1090
+
1091
+ fn foo() -> Result<i32, ${0:_}> {
1092
+ return Ok(42i32);
1093
+ }
1094
+ "# ,
1095
+ ) ;
1096
+ }
1097
+
1098
+ #[ test]
1099
+ fn wrap_return_type_in_imported_local_result_type ( ) {
1100
+ check_assist (
1101
+ wrap_return_type_in_result,
1102
+ r#"
1103
+ //- minicore: result
1104
+ mod some_module {
1105
+ pub type Result<T> = core::result::Result<T, ()>;
1106
+ }
1107
+
1108
+ use some_module::Result;
1109
+
1110
+ fn foo() -> i3$02 {
1111
+ return 42i32;
1112
+ }
1113
+ "# ,
1114
+ r#"
1115
+ mod some_module {
1116
+ pub type Result<T> = core::result::Result<T, ()>;
1117
+ }
1118
+
1119
+ use some_module::Result;
1120
+
1121
+ fn foo() -> Result<i32> {
1122
+ return Ok(42i32);
1123
+ }
1124
+ "# ,
1125
+ ) ;
1126
+
1127
+ check_assist (
1128
+ wrap_return_type_in_result,
1129
+ r#"
1130
+ //- minicore: result
1131
+ mod some_module {
1132
+ pub type Result<T> = core::result::Result<T, ()>;
1133
+ }
1134
+
1135
+ use some_module::*;
1136
+
1137
+ fn foo() -> i3$02 {
1138
+ return 42i32;
1139
+ }
1140
+ "# ,
1141
+ r#"
1142
+ mod some_module {
1143
+ pub type Result<T> = core::result::Result<T, ()>;
1144
+ }
1145
+
1146
+ use some_module::*;
1147
+
1148
+ fn foo() -> Result<i32> {
1149
+ return Ok(42i32);
1150
+ }
1151
+ "# ,
1152
+ ) ;
1153
+ }
1154
+
1155
+ #[ test]
1156
+ fn wrap_return_type_in_local_result_type_from_function_body ( ) {
1157
+ check_assist (
1158
+ wrap_return_type_in_result,
1159
+ r#"
1160
+ //- minicore: result
1161
+ fn foo() -> i3$02 {
1162
+ type Result<T> = core::result::Result<T, ()>;
1163
+ 0
1164
+ }
1165
+ "# ,
1166
+ r#"
1167
+ fn foo() -> Result<i32, ${0:_}> {
1168
+ type Result<T> = core::result::Result<T, ()>;
1169
+ Ok(0)
1170
+ }
1171
+ "# ,
1172
+ ) ;
1173
+ }
1174
+
1175
+ #[ test]
1176
+ fn wrap_return_type_in_local_result_type_already_using_alias ( ) {
1177
+ check_assist_not_applicable (
1178
+ wrap_return_type_in_result,
1179
+ r#"
1180
+ //- minicore: result
1181
+ pub type Result<T> = core::result::Result<T, ()>;
1182
+
1183
+ fn foo() -> Result<i3$02> {
1184
+ return Ok(42i32);
1185
+ }
1186
+ "# ,
1187
+ ) ;
1188
+ }
1189
+
1190
+ #[ test]
1191
+ fn wrap_return_type_in_local_result_type_multiple_generics ( ) {
1192
+ check_assist (
1193
+ wrap_return_type_in_result,
1194
+ r#"
1195
+ //- minicore: result
1196
+ type Result<T, E> = core::result::Result<T, E>;
1197
+
1198
+ fn foo() -> i3$02 {
1199
+ 0
1200
+ }
1201
+ "# ,
1202
+ r#"
1203
+ type Result<T, E> = core::result::Result<T, E>;
1204
+
1205
+ fn foo() -> Result<i32, ${0:_}> {
1206
+ Ok(0)
1207
+ }
1208
+ "# ,
1209
+ ) ;
1210
+
1211
+ check_assist (
1212
+ wrap_return_type_in_result,
1213
+ r#"
1214
+ //- minicore: result
1215
+ type Result<T, E> = core::result::Result<Foo<T, E>, ()>;
1216
+
1217
+ fn foo() -> i3$02 {
1218
+ 0
1219
+ }
1220
+ "# ,
1221
+ r#"
1222
+ type Result<T, E> = core::result::Result<Foo<T, E>, ()>;
1223
+
1224
+ fn foo() -> Result<i32, ${0:_}> {
1225
+ Ok(0)
1226
+ }
1227
+ "# ,
1228
+ ) ;
1229
+
1230
+ check_assist (
1231
+ wrap_return_type_in_result,
1232
+ r#"
1233
+ //- minicore: result
1234
+ type Result<'a, T, E> = core::result::Result<Foo<T, E>, &'a ()>;
1235
+
1236
+ fn foo() -> i3$02 {
1237
+ 0
1238
+ }
1239
+ "# ,
1240
+ r#"
1241
+ type Result<'a, T, E> = core::result::Result<Foo<T, E>, &'a ()>;
1242
+
1243
+ fn foo() -> Result<'_, i32, ${0:_}> {
1244
+ Ok(0)
1245
+ }
1246
+ "# ,
1247
+ ) ;
1248
+
1249
+ check_assist (
1250
+ wrap_return_type_in_result,
1251
+ r#"
1252
+ //- minicore: result
1253
+ type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
1254
+
1255
+ fn foo() -> i3$02 {
1256
+ 0
1257
+ }
1258
+ "# ,
1259
+ r#"
1260
+ type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
1261
+
1262
+ fn foo() -> Result<i32, ${0:_}> {
1263
+ Ok(0)
1264
+ }
1265
+ "# ,
1266
+ ) ;
1267
+ }
1001
1268
}
0 commit comments