@@ -13,19 +13,31 @@ use super::simple_passes::outgoing_edges;
1313use super :: { apply_rewrite_rules, id} ;
1414use rspirv:: dr:: { Block , Function , Instruction , ModuleHeader , Operand } ;
1515use rspirv:: spirv:: { Op , Word } ;
16- use rustc_data_structures:: fx:: { FxHashMap , FxHashSet } ;
16+ use rustc_data_structures:: fx:: { FxHashMap , FxHashSet , FxIndexMap } ;
1717use rustc_middle:: bug;
1818use std:: collections:: hash_map;
1919
20+ // HACK(eddyb) newtype instead of type alias to avoid mistakes.
21+ #[ derive( Copy , Clone , PartialEq , Eq , Hash ) ]
22+ struct LabelId ( Word ) ;
23+
2024pub fn mem2reg (
2125 header : & mut ModuleHeader ,
2226 types_global_values : & mut Vec < Instruction > ,
2327 pointer_to_pointee : & FxHashMap < Word , Word > ,
2428 constants : & FxHashMap < Word , u32 > ,
2529 func : & mut Function ,
2630) {
27- let reachable = compute_reachable ( & func. blocks ) ;
28- let preds = compute_preds ( & func. blocks , & reachable) ;
31+ // HACK(eddyb) this ad-hoc indexing might be useful elsewhere as well, but
32+ // it's made completely irrelevant by SPIR-T so only applies to legacy code.
33+ let mut blocks: FxIndexMap < _ , _ > = func
34+ . blocks
35+ . iter_mut ( )
36+ . map ( |block| ( LabelId ( block. label_id ( ) . unwrap ( ) ) , block) )
37+ . collect ( ) ;
38+
39+ let reachable = compute_reachable ( & blocks) ;
40+ let preds = compute_preds ( & blocks, & reachable) ;
2941 let idom = compute_idom ( & preds, & reachable) ;
3042 let dominance_frontier = compute_dominance_frontier ( & preds, & idom) ;
3143 loop {
@@ -34,31 +46,27 @@ pub fn mem2reg(
3446 types_global_values,
3547 pointer_to_pointee,
3648 constants,
37- & mut func . blocks ,
49+ & mut blocks,
3850 & dominance_frontier,
3951 ) ;
4052 if !changed {
4153 break ;
4254 }
4355 // mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
44- super :: dce:: dce_phi ( func ) ;
56+ super :: dce:: dce_phi ( & mut blocks ) ;
4557 }
4658}
4759
48- fn label_to_index ( blocks : & [ Block ] , id : Word ) -> usize {
49- blocks
50- . iter ( )
51- . position ( |b| b. label_id ( ) . unwrap ( ) == id)
52- . unwrap ( )
53- }
54-
55- fn compute_reachable ( blocks : & [ Block ] ) -> Vec < bool > {
56- fn recurse ( blocks : & [ Block ] , reachable : & mut [ bool ] , block : usize ) {
60+ fn compute_reachable ( blocks : & FxIndexMap < LabelId , & mut Block > ) -> Vec < bool > {
61+ fn recurse ( blocks : & FxIndexMap < LabelId , & mut Block > , reachable : & mut [ bool ] , block : usize ) {
5762 if !reachable[ block] {
5863 reachable[ block] = true ;
59- for dest_id in outgoing_edges ( & blocks[ block] ) {
60- let dest_idx = label_to_index ( blocks, dest_id) ;
61- recurse ( blocks, reachable, dest_idx) ;
64+ for dest_id in outgoing_edges ( blocks[ block] ) {
65+ recurse (
66+ blocks,
67+ reachable,
68+ blocks. get_index_of ( & LabelId ( dest_id) ) . unwrap ( ) ,
69+ ) ;
6270 }
6371 }
6472 }
@@ -67,17 +75,19 @@ fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
6775 reachable
6876}
6977
70- fn compute_preds ( blocks : & [ Block ] , reachable_blocks : & [ bool ] ) -> Vec < Vec < usize > > {
78+ fn compute_preds (
79+ blocks : & FxIndexMap < LabelId , & mut Block > ,
80+ reachable_blocks : & [ bool ] ,
81+ ) -> Vec < Vec < usize > > {
7182 let mut result = vec ! [ vec![ ] ; blocks. len( ) ] ;
7283 // Do not count unreachable blocks as valid preds of blocks
7384 for ( source_idx, source) in blocks
74- . iter ( )
85+ . values ( )
7586 . enumerate ( )
7687 . filter ( |& ( b, _) | reachable_blocks[ b] )
7788 {
7889 for dest_id in outgoing_edges ( source) {
79- let dest_idx = label_to_index ( blocks, dest_id) ;
80- result[ dest_idx] . push ( source_idx) ;
90+ result[ blocks. get_index_of ( & LabelId ( dest_id) ) . unwrap ( ) ] . push ( source_idx) ;
8191 }
8292 }
8393 result
@@ -161,7 +171,7 @@ fn insert_phis_all(
161171 types_global_values : & mut Vec < Instruction > ,
162172 pointer_to_pointee : & FxHashMap < Word , Word > ,
163173 constants : & FxHashMap < Word , u32 > ,
164- blocks : & mut [ Block ] ,
174+ blocks : & mut FxIndexMap < LabelId , & mut Block > ,
165175 dominance_frontier : & [ FxHashSet < usize > ] ,
166176) -> bool {
167177 let var_maps_and_types = blocks[ 0 ]
@@ -198,7 +208,11 @@ fn insert_phis_all(
198208 rewrite_rules : FxHashMap :: default ( ) ,
199209 } ;
200210 renamer. rename ( 0 , None ) ;
201- apply_rewrite_rules ( & renamer. rewrite_rules , blocks) ;
211+ // FIXME(eddyb) shouldn't this full rescan of the function be done once?
212+ apply_rewrite_rules (
213+ & renamer. rewrite_rules ,
214+ blocks. values_mut ( ) . map ( |block| & mut * * block) ,
215+ ) ;
202216 remove_nops ( blocks) ;
203217 }
204218 remove_old_variables ( blocks, & var_maps_and_types) ;
@@ -216,7 +230,7 @@ struct VarInfo {
216230fn collect_access_chains (
217231 pointer_to_pointee : & FxHashMap < Word , Word > ,
218232 constants : & FxHashMap < Word , u32 > ,
219- blocks : & [ Block ] ,
233+ blocks : & FxIndexMap < LabelId , & mut Block > ,
220234 base_var : Word ,
221235 base_var_ty : Word ,
222236) -> Option < FxHashMap < Word , VarInfo > > {
@@ -246,7 +260,7 @@ fn collect_access_chains(
246260 // Loop in case a previous block references a later AccessChain
247261 loop {
248262 let mut changed = false ;
249- for inst in blocks. iter ( ) . flat_map ( |b| & b. instructions ) {
263+ for inst in blocks. values ( ) . flat_map ( |b| & b. instructions ) {
250264 for ( index, op) in inst. operands . iter ( ) . enumerate ( ) {
251265 if let Operand :: IdRef ( id) = op {
252266 if variables. contains_key ( id) {
@@ -304,10 +318,10 @@ fn collect_access_chains(
304318// same var map (e.g. `s.x = s.y;`).
305319fn split_copy_memory (
306320 header : & mut ModuleHeader ,
307- blocks : & mut [ Block ] ,
321+ blocks : & mut FxIndexMap < LabelId , & mut Block > ,
308322 var_map : & FxHashMap < Word , VarInfo > ,
309323) {
310- for block in blocks {
324+ for block in blocks. values_mut ( ) {
311325 let mut inst_index = 0 ;
312326 while inst_index < block. instructions . len ( ) {
313327 let inst = & block. instructions [ inst_index] ;
@@ -362,7 +376,7 @@ fn has_store(block: &Block, var_map: &FxHashMap<Word, VarInfo>) -> bool {
362376}
363377
364378fn insert_phis (
365- blocks : & [ Block ] ,
379+ blocks : & FxIndexMap < LabelId , & mut Block > ,
366380 dominance_frontier : & [ FxHashSet < usize > ] ,
367381 var_map : & FxHashMap < Word , VarInfo > ,
368382) -> FxHashSet < usize > {
@@ -371,7 +385,7 @@ fn insert_phis(
371385 let mut ever_on_work_list = FxHashSet :: default ( ) ;
372386 let mut work_list = Vec :: new ( ) ;
373387 let mut blocks_with_phi = FxHashSet :: default ( ) ;
374- for ( block_idx, block) in blocks. iter ( ) . enumerate ( ) {
388+ for ( block_idx, block) in blocks. values ( ) . enumerate ( ) {
375389 if has_store ( block, var_map) {
376390 ever_on_work_list. insert ( block_idx) ;
377391 work_list. push ( block_idx) ;
@@ -416,10 +430,10 @@ fn top_stack_or_undef(
416430 }
417431}
418432
419- struct Renamer < ' a > {
433+ struct Renamer < ' a , ' b > {
420434 header : & ' a mut ModuleHeader ,
421435 types_global_values : & ' a mut Vec < Instruction > ,
422- blocks : & ' a mut [ Block ] ,
436+ blocks : & ' a mut FxIndexMap < LabelId , & ' b mut Block > ,
423437 blocks_with_phi : FxHashSet < usize > ,
424438 base_var_type : Word ,
425439 var_map : & ' a FxHashMap < Word , VarInfo > ,
@@ -429,7 +443,7 @@ struct Renamer<'a> {
429443 rewrite_rules : FxHashMap < Word , Word > ,
430444}
431445
432- impl Renamer < ' _ > {
446+ impl Renamer < ' _ , ' _ > {
433447 // Returns the phi definition.
434448 fn insert_phi_value ( & mut self , block : usize , from_block : usize ) -> Word {
435449 let from_block_label = self . blocks [ from_block] . label_id ( ) . unwrap ( ) ;
@@ -549,9 +563,8 @@ impl Renamer<'_> {
549563 }
550564 }
551565
552- for dest_id in outgoing_edges ( & self . blocks [ block] ) . collect :: < Vec < _ > > ( ) {
553- // TODO: Don't do this find
554- let dest_idx = label_to_index ( self . blocks , dest_id) ;
566+ for dest_id in outgoing_edges ( self . blocks [ block] ) . collect :: < Vec < _ > > ( ) {
567+ let dest_idx = self . blocks . get_index_of ( & LabelId ( dest_id) ) . unwrap ( ) ;
555568 self . rename ( dest_idx, Some ( block) ) ;
556569 }
557570
@@ -561,16 +574,16 @@ impl Renamer<'_> {
561574 }
562575}
563576
564- fn remove_nops ( blocks : & mut [ Block ] ) {
565- for block in blocks {
577+ fn remove_nops ( blocks : & mut FxIndexMap < LabelId , & mut Block > ) {
578+ for block in blocks. values_mut ( ) {
566579 block
567580 . instructions
568581 . retain ( |inst| inst. class . opcode != Op :: Nop ) ;
569582 }
570583}
571584
572585fn remove_old_variables (
573- blocks : & mut [ Block ] ,
586+ blocks : & mut FxIndexMap < LabelId , & mut Block > ,
574587 var_maps_and_types : & [ ( FxHashMap < u32 , VarInfo > , u32 ) ] ,
575588) {
576589 blocks[ 0 ] . instructions . retain ( |inst| {
@@ -581,7 +594,7 @@ fn remove_old_variables(
581594 . all ( |( var_map, _) | !var_map. contains_key ( & result_id) )
582595 }
583596 } ) ;
584- for block in blocks {
597+ for block in blocks. values_mut ( ) {
585598 block. instructions . retain ( |inst| {
586599 !matches ! ( inst. class. opcode, Op :: AccessChain | Op :: InBoundsAccessChain )
587600 || inst. operands . iter ( ) . all ( |op| {
0 commit comments