@@ -18,6 +18,7 @@ extern crate stable_mir;
18
18
19
19
use rustc_smir:: rustc_internal;
20
20
use stable_mir:: mir:: MirVisitor ;
21
+ use stable_mir:: mir:: MutMirVisitor ;
21
22
use stable_mir:: * ;
22
23
use std:: collections:: HashSet ;
23
24
use std:: io:: Write ;
@@ -99,6 +100,83 @@ impl<'a> mir::MirVisitor for TestVisitor<'a> {
99
100
}
100
101
}
101
102
103
+ fn test_mut_visitor ( ) -> ControlFlow < ( ) > {
104
+ let main_fn = stable_mir:: entry_fn ( ) ;
105
+ let mut main_body = main_fn. unwrap ( ) . expect_body ( ) ;
106
+ let locals = main_body. locals ( ) . to_vec ( ) ;
107
+ let mut main_visitor = TestMutVisitor :: collect ( locals) ;
108
+ main_visitor. visit_body ( & mut main_body) ;
109
+ assert ! ( main_visitor. ret_val. is_some( ) ) ;
110
+ assert ! ( main_visitor. args. is_empty( ) ) ;
111
+ assert ! ( main_visitor. tys. contains( & main_visitor. ret_val. unwrap( ) . ty) ) ;
112
+ assert ! ( !main_visitor. calls. is_empty( ) ) ;
113
+
114
+ let exit_fn = main_visitor. calls . last ( ) . unwrap ( ) ;
115
+ assert ! ( exit_fn. mangled_name( ) . contains( "exit_fn" ) , "Unexpected last function: {exit_fn:?}" ) ;
116
+
117
+ let mut exit_body = exit_fn. body ( ) . unwrap ( ) ;
118
+ let locals = exit_body. locals ( ) . to_vec ( ) ;
119
+ let mut exit_visitor = TestMutVisitor :: collect ( locals) ;
120
+ exit_visitor. visit_body ( & mut exit_body) ;
121
+ assert ! ( exit_visitor. ret_val. is_some( ) ) ;
122
+ assert_eq ! ( exit_visitor. args. len( ) , 1 ) ;
123
+ assert ! ( exit_visitor. tys. contains( & exit_visitor. ret_val. unwrap( ) . ty) ) ;
124
+ assert ! ( exit_visitor. tys. contains( & exit_visitor. args[ 0 ] . ty) ) ;
125
+ ControlFlow :: Continue ( ( ) )
126
+ }
127
+
128
+ struct TestMutVisitor {
129
+ locals : Vec < mir:: LocalDecl > ,
130
+ pub tys : HashSet < ty:: Ty > ,
131
+ pub ret_val : Option < mir:: LocalDecl > ,
132
+ pub args : Vec < mir:: LocalDecl > ,
133
+ pub calls : Vec < mir:: mono:: Instance > ,
134
+ }
135
+
136
+ impl TestMutVisitor {
137
+ fn collect ( locals : Vec < mir:: LocalDecl > ) -> TestMutVisitor {
138
+ let visitor = TestMutVisitor {
139
+ locals : locals,
140
+ tys : Default :: default ( ) ,
141
+ ret_val : None ,
142
+ args : vec ! [ ] ,
143
+ calls : vec ! [ ] ,
144
+ } ;
145
+ visitor
146
+ }
147
+ }
148
+
149
+ impl mir:: MutMirVisitor for TestMutVisitor {
150
+ fn visit_ty ( & mut self , ty : & mut ty:: Ty , _location : mir:: visit:: Location ) {
151
+ self . tys . insert ( * ty) ;
152
+ self . super_ty ( ty)
153
+ }
154
+
155
+ fn visit_ret_decl ( & mut self , local : mir:: Local , decl : & mut mir:: LocalDecl ) {
156
+ assert ! ( local == mir:: RETURN_LOCAL ) ;
157
+ assert ! ( self . ret_val. is_none( ) ) ;
158
+ self . ret_val = Some ( decl. clone ( ) ) ;
159
+ self . super_ret_decl ( local, decl) ;
160
+ }
161
+
162
+ fn visit_arg_decl ( & mut self , local : mir:: Local , decl : & mut mir:: LocalDecl ) {
163
+ self . args . push ( decl. clone ( ) ) ;
164
+ assert_eq ! ( local, self . args. len( ) ) ;
165
+ self . super_arg_decl ( local, decl) ;
166
+ }
167
+
168
+ fn visit_terminator ( & mut self , term : & mut mir:: Terminator , location : mir:: visit:: Location ) {
169
+ if let mir:: TerminatorKind :: Call { func, .. } = & mut term. kind {
170
+ let ty:: TyKind :: RigidTy ( ty) = func. ty ( & self . locals ) . unwrap ( ) . kind ( ) else {
171
+ unreachable ! ( )
172
+ } ;
173
+ let ty:: RigidTy :: FnDef ( def, args) = ty else { unreachable ! ( ) } ;
174
+ self . calls . push ( mir:: mono:: Instance :: resolve ( def, & args) . unwrap ( ) ) ;
175
+ }
176
+ self . super_terminator ( term, location) ;
177
+ }
178
+ }
179
+
102
180
/// This test will generate and analyze a dummy crate using the stable mir.
103
181
/// For that, it will first write the dummy crate into a file.
104
182
/// Then it will create a `StableMir` using custom arguments and then
@@ -113,7 +191,8 @@ fn main() {
113
191
CRATE_NAME . to_string( ) ,
114
192
path. to_string( ) ,
115
193
] ;
116
- run ! ( args, test_visitor) . unwrap ( ) ;
194
+ run ! ( args. clone( ) , test_visitor) . unwrap ( ) ;
195
+ run ! ( args, test_mut_visitor) . unwrap ( ) ;
117
196
}
118
197
119
198
fn generate_input ( path : & str ) -> std:: io:: Result < ( ) > {
0 commit comments