1
- use crate :: assert:: expr_if_not;
2
1
use rustc_ast:: {
3
2
attr,
4
3
ptr:: P ,
5
4
token,
6
5
tokenstream:: { DelimSpan , TokenStream , TokenTree } ,
7
- BorrowKind , Expr , ExprKind , ItemKind , MacArgs , MacCall , MacDelimiter , Mutability , Path ,
8
- PathSegment , Stmt , StructRest , UseTree , UseTreeKind , DUMMY_NODE_ID ,
6
+ BinOpKind , BorrowKind , Expr , ExprKind , ItemKind , MacArgs , MacCall , MacDelimiter , Mutability ,
7
+ Path , PathSegment , Stmt , StructRest , UnOp , UseTree , UseTreeKind , DUMMY_NODE_ID ,
9
8
} ;
10
9
use rustc_ast_pretty:: pprust;
11
10
use rustc_data_structures:: fx:: FxHashSet ;
@@ -16,11 +15,19 @@ use rustc_span::{
16
15
} ;
17
16
18
17
pub ( super ) struct Context < ' cx , ' a > {
18
+ // An optimization.
19
+ //
20
+ // Elements that aren't consumed (PartialEq, PartialOrd, ...) can be copied **after** the
21
+ // `assert!` expression fails rather than copied on-the-fly.
22
+ best_case_captures : Vec < Stmt > ,
19
23
// Top-level `let captureN = Capture::new()` statements
20
24
capture_decls : Vec < Capture > ,
21
25
cx : & ' cx ExtCtxt < ' a > ,
22
26
// Formatting string used for debugging
23
27
fmt_string : String ,
28
+ // If the current expression being visited consumes itself. Used to construct
29
+ // `best_case_captures`.
30
+ is_consumed : bool ,
24
31
// Top-level `let __local_bindN = &expr` statements
25
32
local_bind_decls : Vec < Stmt > ,
26
33
// Used to avoid capturing duplicated paths
@@ -36,9 +43,11 @@ pub(super) struct Context<'cx, 'a> {
36
43
impl < ' cx , ' a > Context < ' cx , ' a > {
37
44
pub ( super ) fn new ( cx : & ' cx ExtCtxt < ' a > , span : Span ) -> Self {
38
45
Self {
46
+ best_case_captures : <_ >:: default ( ) ,
39
47
capture_decls : <_ >:: default ( ) ,
40
48
cx,
41
49
fmt_string : <_ >:: default ( ) ,
50
+ is_consumed : true ,
42
51
local_bind_decls : <_ >:: default ( ) ,
43
52
paths : <_ >:: default ( ) ,
44
53
span,
@@ -69,14 +78,22 @@ impl<'cx, 'a> Context<'cx, 'a> {
69
78
self . manage_cond_expr ( & mut cond_expr) ;
70
79
let initial_imports = self . build_initial_imports ( ) ;
71
80
let panic = self . build_panic ( & expr_str, panic_path) ;
81
+ let cond_expr_with_unlikely = self . build_unlikely ( cond_expr) ;
82
+
83
+ let Self { best_case_captures, capture_decls, cx, local_bind_decls, span, .. } = self ;
72
84
73
- let Self { capture_decls, cx, local_bind_decls, span, .. } = self ;
85
+ let mut assert_then_stmts = Vec :: with_capacity ( 2 ) ;
86
+ assert_then_stmts. extend ( best_case_captures) ;
87
+ assert_then_stmts. push ( self . cx . stmt_expr ( panic) ) ;
88
+ let assert_then = self . cx . block ( span, assert_then_stmts) ;
74
89
75
90
let mut stmts = Vec :: with_capacity ( 4 ) ;
76
91
stmts. push ( initial_imports) ;
77
92
stmts. extend ( capture_decls. into_iter ( ) . map ( |c| c. decl ) ) ;
78
93
stmts. extend ( local_bind_decls) ;
79
- stmts. push ( cx. stmt_expr ( expr_if_not ( cx, span, cond_expr, panic, None ) ) ) ;
94
+ stmts. push (
95
+ cx. stmt_expr ( cx. expr ( span, ExprKind :: If ( cond_expr_with_unlikely, assert_then, None ) ) ) ,
96
+ ) ;
80
97
cx. expr_block ( cx. block ( span, stmts) )
81
98
}
82
99
@@ -115,6 +132,16 @@ impl<'cx, 'a> Context<'cx, 'a> {
115
132
)
116
133
}
117
134
135
+ /// Takes the conditional expression of `assert!` and then wraps it inside `unlikely`
136
+ fn build_unlikely ( & self , cond_expr : P < Expr > ) -> P < Expr > {
137
+ let unlikely_path = self . cx . std_path ( & [ sym:: intrinsics, sym:: unlikely] ) ;
138
+ self . cx . expr_call (
139
+ self . span ,
140
+ self . cx . expr_path ( self . cx . path ( self . span , unlikely_path) ) ,
141
+ vec ! [ self . cx. expr( self . span, ExprKind :: Unary ( UnOp :: Not , cond_expr) ) ] ,
142
+ )
143
+ }
144
+
118
145
/// The necessary custom `panic!(...)` expression.
119
146
///
120
147
/// panic!(
@@ -167,17 +194,39 @@ impl<'cx, 'a> Context<'cx, 'a> {
167
194
/// See [Self::manage_initial_capture] and [Self::manage_try_capture]
168
195
fn manage_cond_expr ( & mut self , expr : & mut P < Expr > ) {
169
196
match ( * expr) . kind {
170
- ExprKind :: AddrOf ( _, _, ref mut local_expr) => {
171
- self . manage_cond_expr ( local_expr) ;
197
+ ExprKind :: AddrOf ( _, mutability, ref mut local_expr) => {
198
+ self . with_is_consumed_management (
199
+ matches ! ( mutability, Mutability :: Mut ) ,
200
+ |this| this. manage_cond_expr ( local_expr)
201
+ ) ;
172
202
}
173
203
ExprKind :: Array ( ref mut local_exprs) => {
174
204
for local_expr in local_exprs {
175
205
self . manage_cond_expr ( local_expr) ;
176
206
}
177
207
}
178
- ExprKind :: Binary ( _, ref mut lhs, ref mut rhs) => {
179
- self . manage_cond_expr ( lhs) ;
180
- self . manage_cond_expr ( rhs) ;
208
+ ExprKind :: Binary ( ref op, ref mut lhs, ref mut rhs) => {
209
+ self . with_is_consumed_management (
210
+ matches ! (
211
+ op. node,
212
+ BinOpKind :: Add
213
+ | BinOpKind :: And
214
+ | BinOpKind :: BitAnd
215
+ | BinOpKind :: BitOr
216
+ | BinOpKind :: BitXor
217
+ | BinOpKind :: Div
218
+ | BinOpKind :: Mul
219
+ | BinOpKind :: Or
220
+ | BinOpKind :: Rem
221
+ | BinOpKind :: Shl
222
+ | BinOpKind :: Shr
223
+ | BinOpKind :: Sub
224
+ ) ,
225
+ |this| {
226
+ this. manage_cond_expr ( lhs) ;
227
+ this. manage_cond_expr ( rhs) ;
228
+ }
229
+ ) ;
181
230
}
182
231
ExprKind :: Call ( _, ref mut local_exprs) => {
183
232
for local_expr in local_exprs {
@@ -228,8 +277,11 @@ impl<'cx, 'a> Context<'cx, 'a> {
228
277
self . manage_cond_expr ( local_expr) ;
229
278
}
230
279
}
231
- ExprKind :: Unary ( _, ref mut local_expr) => {
232
- self . manage_cond_expr ( local_expr) ;
280
+ ExprKind :: Unary ( un_op, ref mut local_expr) => {
281
+ self . with_is_consumed_management (
282
+ matches ! ( un_op, UnOp :: Neg | UnOp :: Not ) ,
283
+ |this| this. manage_cond_expr ( local_expr)
284
+ ) ;
233
285
}
234
286
// Expressions that are not worth or can not be captured.
235
287
//
@@ -337,9 +389,23 @@ impl<'cx, 'a> Context<'cx, 'a> {
337
389
) )
338
390
. add_trailing_semicolon ( ) ;
339
391
let local_bind_path = self . cx . expr_path ( Path :: from_ident ( local_bind) ) ;
340
- let ret = self . cx . stmt_expr ( local_bind_path) ;
341
- let block = self . cx . expr_block ( self . cx . block ( self . span , vec ! [ try_capture_call, ret] ) ) ;
342
- * expr = self . cx . expr_deref ( self . span , block) ;
392
+ let rslt = if self . is_consumed {
393
+ let ret = self . cx . stmt_expr ( local_bind_path) ;
394
+ self . cx . expr_block ( self . cx . block ( self . span , vec ! [ try_capture_call, ret] ) )
395
+ } else {
396
+ self . best_case_captures . push ( try_capture_call) ;
397
+ local_bind_path
398
+ } ;
399
+ * expr = self . cx . expr_deref ( self . span , rslt) ;
400
+ }
401
+
402
+ // Calls `f` with the internal `is_consumed` set to `curr_is_consumed` and then
403
+ // sets the internal `is_consumed` back to its original value.
404
+ fn with_is_consumed_management ( & mut self , curr_is_consumed : bool , f : impl FnOnce ( & mut Self ) ) {
405
+ let prev_is_consumed = self . is_consumed ;
406
+ self . is_consumed = curr_is_consumed;
407
+ f ( self ) ;
408
+ self . is_consumed = prev_is_consumed;
343
409
}
344
410
}
345
411
0 commit comments