@@ -13,15 +13,19 @@ use crate::transform::{simplify, MirPass, MirSource};
13
13
use itertools:: Itertools as _;
14
14
use rustc:: mir:: * ;
15
15
use rustc:: ty:: { Ty , TyCtxt } ;
16
+ use rustc_index:: vec:: IndexVec ;
16
17
use rustc_target:: abi:: VariantIdx ;
18
+ use std:: iter:: { Enumerate , Peekable } ;
19
+ use std:: slice:: Iter ;
17
20
18
21
/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move.
19
22
///
20
23
/// This is done by transforming basic blocks where the statements match:
21
24
///
22
25
/// ```rust
23
26
/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY );
24
- /// ((_LOCAL_0 as Variant).FIELD: TY) = move _LOCAL_TMP;
27
+ /// _TMP_2 = _LOCAL_TMP;
28
+ /// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2;
25
29
/// discriminant(_LOCAL_0) = VAR_IDX;
26
30
/// ```
27
31
///
@@ -32,50 +36,306 @@ use rustc_target::abi::VariantIdx;
32
36
/// ```
33
37
pub struct SimplifyArmIdentity ;
34
38
39
+ #[ derive( Debug ) ]
40
+ struct ArmIdentityInfo < ' tcx > {
41
+ /// Storage location for the variant's field
42
+ local_temp_0 : Local ,
43
+ /// Storage location holding the variant being read from
44
+ local_1 : Local ,
45
+ /// The variant field being read from
46
+ vf_s0 : VarField < ' tcx > ,
47
+ /// Index of the statement which loads the variant being read
48
+ get_variant_field_stmt : usize ,
49
+
50
+ /// Tracks each assignment to a temporary of the variant's field
51
+ field_tmp_assignments : Vec < ( Local , Local ) > ,
52
+
53
+ /// Storage location holding the variant's field that was read from
54
+ local_tmp_s1 : Local ,
55
+ /// Storage location holding the enum that we are writing to
56
+ local_0 : Local ,
57
+ /// The variant field being written to
58
+ vf_s1 : VarField < ' tcx > ,
59
+
60
+ /// Storage location that the discriminant is being written to
61
+ set_discr_local : Local ,
62
+ /// The variant being written
63
+ set_discr_var_idx : VariantIdx ,
64
+
65
+ /// Index of the statement that should be overwritten as a move
66
+ stmt_to_overwrite : usize ,
67
+ /// SourceInfo for the new move
68
+ source_info : SourceInfo ,
69
+
70
+ /// Indices of matching Storage{Live,Dead} statements encountered.
71
+ /// (StorageLive index,, StorageDead index, Local)
72
+ storage_stmts : Vec < ( usize , usize , Local ) > ,
73
+
74
+ /// The statements that should be removed (turned into nops)
75
+ stmts_to_remove : Vec < usize > ,
76
+ }
77
+
78
+ fn get_arm_identity_info < ' a , ' tcx > ( stmts : & ' a [ Statement < ' tcx > ] ) -> Option < ArmIdentityInfo < ' tcx > > {
79
+ // This can't possibly match unless there are at least 3 statements in the block
80
+ // so fail fast on tiny blocks.
81
+ if stmts. len ( ) < 3 {
82
+ return None ;
83
+ }
84
+
85
+ let mut tmp_assigns = Vec :: new ( ) ;
86
+ let mut nop_stmts = Vec :: new ( ) ;
87
+ let mut storage_stmts = Vec :: new ( ) ;
88
+ let mut storage_live_stmts = Vec :: new ( ) ;
89
+ let mut storage_dead_stmts = Vec :: new ( ) ;
90
+
91
+ type StmtIter < ' a , ' tcx > = Peekable < Enumerate < Iter < ' a , Statement < ' tcx > > > > ;
92
+
93
+ fn is_storage_stmt < ' tcx > ( stmt : & Statement < ' tcx > ) -> bool {
94
+ matches ! ( stmt. kind, StatementKind :: StorageLive ( _) | StatementKind :: StorageDead ( _) )
95
+ }
96
+
97
+ fn try_eat_storage_stmts < ' a , ' tcx > (
98
+ stmt_iter : & mut StmtIter < ' a , ' tcx > ,
99
+ storage_live_stmts : & mut Vec < ( usize , Local ) > ,
100
+ storage_dead_stmts : & mut Vec < ( usize , Local ) > ,
101
+ ) {
102
+ while stmt_iter. peek ( ) . map ( |( _, stmt) | is_storage_stmt ( stmt) ) . unwrap_or ( false ) {
103
+ let ( idx, stmt) = stmt_iter. next ( ) . unwrap ( ) ;
104
+
105
+ if let StatementKind :: StorageLive ( l) = stmt. kind {
106
+ storage_live_stmts. push ( ( idx, l) ) ;
107
+ } else if let StatementKind :: StorageDead ( l) = stmt. kind {
108
+ storage_dead_stmts. push ( ( idx, l) ) ;
109
+ }
110
+ }
111
+ }
112
+
113
+ fn is_tmp_storage_stmt < ' tcx > ( stmt : & Statement < ' tcx > ) -> bool {
114
+ if let StatementKind :: Assign ( box ( place, Rvalue :: Use ( op) ) ) = & stmt. kind {
115
+ if let Operand :: Copy ( p) | Operand :: Move ( p) = op {
116
+ return place. as_local ( ) . is_some ( ) && p. as_local ( ) . is_some ( ) ;
117
+ }
118
+ }
119
+
120
+ false
121
+ }
122
+
123
+ fn try_eat_assign_tmp_stmts < ' a , ' tcx > (
124
+ stmt_iter : & mut StmtIter < ' a , ' tcx > ,
125
+ tmp_assigns : & mut Vec < ( Local , Local ) > ,
126
+ nop_stmts : & mut Vec < usize > ,
127
+ ) {
128
+ while stmt_iter. peek ( ) . map ( |( _, stmt) | is_tmp_storage_stmt ( stmt) ) . unwrap_or ( false ) {
129
+ let ( idx, stmt) = stmt_iter. next ( ) . unwrap ( ) ;
130
+
131
+ if let StatementKind :: Assign ( box ( place, Rvalue :: Use ( op) ) ) = & stmt. kind {
132
+ if let Operand :: Copy ( p) | Operand :: Move ( p) = op {
133
+ tmp_assigns. push ( ( place. as_local ( ) . unwrap ( ) , p. as_local ( ) . unwrap ( ) ) ) ;
134
+ nop_stmts. push ( idx) ;
135
+ }
136
+ }
137
+ }
138
+ }
139
+
140
+ fn find_storage_live_dead_stmts_for_local < ' tcx > (
141
+ l : Local ,
142
+ stmts : & [ Statement < ' tcx > ] ,
143
+ ) -> Option < ( usize , usize ) > {
144
+ trace ! ( "looking for {:?}" , l) ;
145
+ let mut storage_live_stmt = None ;
146
+ let mut storage_dead_stmt = None ;
147
+ for ( idx, stmt) in stmts. iter ( ) . enumerate ( ) {
148
+ if stmt. kind == StatementKind :: StorageLive ( l) {
149
+ storage_live_stmt = Some ( idx) ;
150
+ } else if stmt. kind == StatementKind :: StorageDead ( l) {
151
+ storage_dead_stmt = Some ( idx) ;
152
+ }
153
+ }
154
+
155
+ Some ( ( storage_live_stmt?, storage_dead_stmt. unwrap_or ( usize:: MAX ) ) )
156
+ }
157
+
158
+ // Try to match the expected MIR structure with the basic block we're processing.
159
+ // We want to see something that looks like:
160
+ // ```
161
+ // (StorageLive(_) | StorageDead(_));*
162
+ // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
163
+ // (StorageLive(_) | StorageDead(_));*
164
+ // (tmp_n+1 = tmp_n);*
165
+ // (StorageLive(_) | StorageDead(_));*
166
+ // (tmp_n+1 = tmp_n);*
167
+ // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp;
168
+ // discriminant(LOCAL_FROM) = VariantIdx;
169
+ // (StorageLive(_) | StorageDead(_));*
170
+ // ```
171
+ let mut stmt_iter = stmts. iter ( ) . enumerate ( ) . peekable ( ) ;
172
+
173
+ try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
174
+
175
+ let ( get_variant_field_stmt, stmt) = stmt_iter. next ( ) ?;
176
+ let ( local_tmp_s0, local_1, vf_s0) = match_get_variant_field ( stmt) ?;
177
+
178
+ try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
179
+
180
+ try_eat_assign_tmp_stmts ( & mut stmt_iter, & mut tmp_assigns, & mut nop_stmts) ;
181
+
182
+ try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
183
+
184
+ try_eat_assign_tmp_stmts ( & mut stmt_iter, & mut tmp_assigns, & mut nop_stmts) ;
185
+
186
+ let ( idx, stmt) = stmt_iter. next ( ) ?;
187
+ let ( local_tmp_s1, local_0, vf_s1) = match_set_variant_field ( stmt) ?;
188
+ nop_stmts. push ( idx) ;
189
+
190
+ let ( idx, stmt) = stmt_iter. next ( ) ?;
191
+ let ( set_discr_local, set_discr_var_idx) = match_set_discr ( stmt) ?;
192
+ let discr_stmt_source_info = stmt. source_info ;
193
+ nop_stmts. push ( idx) ;
194
+
195
+ try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
196
+
197
+ for ( live_idx, live_local) in storage_live_stmts {
198
+ if let Some ( i) = storage_dead_stmts. iter ( ) . rposition ( |( _, l) | * l == live_local) {
199
+ let ( dead_idx, _) = storage_dead_stmts. swap_remove ( i) ;
200
+ storage_stmts. push ( ( live_idx, dead_idx, live_local) ) ;
201
+
202
+ if live_local == local_tmp_s0 {
203
+ nop_stmts. push ( get_variant_field_stmt) ;
204
+ }
205
+ }
206
+ }
207
+
208
+ nop_stmts. sort ( ) ;
209
+
210
+ // Use one of the statements we're going to discard between the point
211
+ // where the storage location for the variant field becomes live and
212
+ // is killed.
213
+ let ( live_idx, daed_idx) = find_storage_live_dead_stmts_for_local ( local_tmp_s0, stmts) ?;
214
+ let stmt_to_overwrite =
215
+ nop_stmts. iter ( ) . find ( |stmt_idx| live_idx < * * stmt_idx && * * stmt_idx < daed_idx) ;
216
+
217
+ Some ( ArmIdentityInfo {
218
+ local_temp_0 : local_tmp_s0,
219
+ local_1,
220
+ vf_s0,
221
+ get_variant_field_stmt,
222
+ field_tmp_assignments : tmp_assigns,
223
+ local_tmp_s1,
224
+ local_0,
225
+ vf_s1,
226
+ set_discr_local,
227
+ set_discr_var_idx,
228
+ stmt_to_overwrite : * stmt_to_overwrite?,
229
+ source_info : discr_stmt_source_info,
230
+ storage_stmts,
231
+ stmts_to_remove : nop_stmts,
232
+ } )
233
+ }
234
+
235
+ fn optimization_applies < ' tcx > (
236
+ opt_info : & ArmIdentityInfo < ' tcx > ,
237
+ local_decls : & IndexVec < Local , LocalDecl < ' tcx > > ,
238
+ ) -> bool {
239
+ trace ! ( "testing if optimization applies..." ) ;
240
+
241
+ // FIXME(wesleywiser): possible relax this restriction?
242
+ if opt_info. local_0 == opt_info. local_1 {
243
+ trace ! ( "NO: moving into ourselves" ) ;
244
+ return false ;
245
+ } else if opt_info. vf_s0 != opt_info. vf_s1 {
246
+ trace ! ( "NO: the field-and-variant information do not match" ) ;
247
+ return false ;
248
+ } else if local_decls[ opt_info. local_0 ] . ty != local_decls[ opt_info. local_1 ] . ty {
249
+ // FIXME(Centril,oli-obk): possibly relax to same layout?
250
+ trace ! ( "NO: source and target locals have different types" ) ;
251
+ return false ;
252
+ } else if ( opt_info. local_0 , opt_info. vf_s0 . var_idx )
253
+ != ( opt_info. set_discr_local , opt_info. set_discr_var_idx )
254
+ {
255
+ trace ! ( "NO: the discriminants do not match" ) ;
256
+ return false ;
257
+ }
258
+
259
+ // Verify the assigment chain consists of the form b = a; c = b; d = c; etc...
260
+ if opt_info. field_tmp_assignments . len ( ) == 0 {
261
+ trace ! ( "NO: no assignments found" ) ;
262
+ }
263
+ let mut last_assigned_to = opt_info. field_tmp_assignments [ 0 ] . 1 ;
264
+ let source_local = last_assigned_to;
265
+ for ( l, r) in & opt_info. field_tmp_assignments {
266
+ if * r != last_assigned_to {
267
+ trace ! ( "NO: found unexpected assignment {:?} = {:?}" , l, r) ;
268
+ return false ;
269
+ }
270
+
271
+ last_assigned_to = * l;
272
+ }
273
+
274
+ if source_local != opt_info. local_temp_0 {
275
+ trace ! (
276
+ "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}" ,
277
+ source_local,
278
+ opt_info. local_temp_0
279
+ ) ;
280
+ return false ;
281
+ } else if last_assigned_to != opt_info. local_tmp_s1 {
282
+ trace ! (
283
+ "NO: end of assignemnt chain does not match written enum temp: {:?} != {:?}" ,
284
+ last_assigned_to,
285
+ opt_info. local_tmp_s1
286
+ ) ;
287
+ return false ;
288
+ }
289
+
290
+ trace ! ( "SUCCESS: optimization applies!" ) ;
291
+ return true ;
292
+ }
293
+
35
294
impl < ' tcx > MirPass < ' tcx > for SimplifyArmIdentity {
36
- fn run_pass ( & self , _: TyCtxt < ' tcx > , _: MirSource < ' tcx > , body : & mut BodyAndCache < ' tcx > ) {
295
+ fn run_pass ( & self , _: TyCtxt < ' tcx > , source : MirSource < ' tcx > , body : & mut BodyAndCache < ' tcx > ) {
296
+ trace ! ( "running SimplifyArmIdentity on {:?}" , source) ;
37
297
let ( basic_blocks, local_decls) = body. basic_blocks_and_local_decls_mut ( ) ;
38
298
for bb in basic_blocks {
39
- // Need 3 statements:
40
- let ( s0, s1, s2) = match & mut * bb. statements {
41
- [ s0, s1, s2] => ( s0, s1, s2) ,
42
- _ => continue ,
43
- } ;
299
+ if let Some ( opt_info) = get_arm_identity_info ( & bb. statements ) {
300
+ trace ! ( "got opt_info = {:#?}" , opt_info) ;
301
+ if !optimization_applies ( & opt_info, local_decls) {
302
+ debug ! ( "optimization skipped for {:?}" , source) ;
303
+ continue ;
304
+ }
44
305
45
- // Pattern match on the form we want:
46
- let ( local_tmp_s0, local_1, vf_s0) = match match_get_variant_field ( s0) {
47
- None => continue ,
48
- Some ( x) => x,
49
- } ;
50
- let ( local_tmp_s1, local_0, vf_s1) = match match_set_variant_field ( s1) {
51
- None => continue ,
52
- Some ( x) => x,
53
- } ;
54
- if local_tmp_s0 != local_tmp_s1
55
- // Avoid moving into ourselves.
56
- || local_0 == local_1
57
- // The field-and-variant information match up.
58
- || vf_s0 != vf_s1
59
- // Source and target locals have the same type.
60
- // FIXME(Centril | oli-obk): possibly relax to same layout?
61
- || local_decls[ local_0] . ty != local_decls[ local_1] . ty
62
- // We're setting the discriminant of `local_0` to this variant.
63
- || Some ( ( local_0, vf_s0. var_idx ) ) != match_set_discr ( s2)
64
- {
65
- continue ;
66
- }
306
+ // Also remove unused Storage{Live,Dead} statements which correspond
307
+ // to temps used previously.
308
+ for ( live_idx, dead_idx, local) in & opt_info. storage_stmts {
309
+ // The temporary that we've read the variant field into is scoped to this block,
310
+ // so we can remove the assignment.
311
+ if * local == opt_info. local_temp_0 {
312
+ bb. statements [ opt_info. get_variant_field_stmt ] . make_nop ( ) ;
313
+ }
67
314
68
- // Right shape; transform!
69
- s0 . source_info = s2 . source_info ;
70
- match & mut s0 . kind {
71
- StatementKind :: Assign ( box ( place , rvalue ) ) => {
72
- * place = local_0 . into ( ) ;
73
- * rvalue = Rvalue :: Use ( Operand :: Move ( local_1 . into ( ) ) ) ;
315
+ for ( left , right ) in & opt_info . field_tmp_assignments {
316
+ if local == left || local == right {
317
+ bb . statements [ * live_idx ] . make_nop ( ) ;
318
+ bb . statements [ * dead_idx ] . make_nop ( ) ;
319
+ }
320
+ }
74
321
}
75
- _ => unreachable ! ( ) ,
322
+
323
+ // Right shape; transform
324
+ for stmt_idx in opt_info. stmts_to_remove {
325
+ bb. statements [ stmt_idx] . make_nop ( ) ;
326
+ }
327
+
328
+ let stmt = & mut bb. statements [ opt_info. stmt_to_overwrite ] ;
329
+ stmt. source_info = opt_info. source_info ;
330
+ stmt. kind = StatementKind :: Assign ( box (
331
+ opt_info. local_0 . into ( ) ,
332
+ Rvalue :: Use ( Operand :: Move ( opt_info. local_1 . into ( ) ) ) ,
333
+ ) ) ;
334
+
335
+ bb. statements . retain ( |stmt| stmt. kind != StatementKind :: Nop ) ;
336
+
337
+ trace ! ( "block is now {:?}" , bb. statements) ;
76
338
}
77
- s1. make_nop ( ) ;
78
- s2. make_nop ( ) ;
79
339
}
80
340
}
81
341
}
@@ -129,7 +389,7 @@ fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)>
129
389
}
130
390
}
131
391
132
- #[ derive( PartialEq ) ]
392
+ #[ derive( PartialEq , Debug ) ]
133
393
struct VarField < ' tcx > {
134
394
field : Field ,
135
395
field_ty : Ty < ' tcx > ,
0 commit comments