@@ -12,17 +12,20 @@ macro_rules! include_models {
12
12
// ATTENTION: Modify this macro to include all models in the `model` directory.
13
13
include_models ! (
14
14
add,
15
- sub,
16
- mul,
17
- div,
15
+ avg_pool2d,
18
16
concat,
19
17
conv1d,
20
18
conv2d,
21
- dropout,
19
+ div,
20
+ dropout_opset16,
21
+ dropout_opset7,
22
22
global_avr_pool,
23
- softmax,
24
23
log_softmax,
25
- maxpool2d
24
+ maxpool2d,
25
+ mul,
26
+ reshape,
27
+ softmax,
28
+ sub
26
29
) ;
27
30
28
31
#[ cfg( test) ]
@@ -151,8 +154,27 @@ mod tests {
151
154
}
152
155
153
156
#[ test]
154
- fn dropout ( ) {
155
- let model: dropout:: Model < Backend > = dropout:: Model :: default ( ) ;
157
+ fn dropout_opset16 ( ) {
158
+ let model: dropout_opset16:: Model < Backend > = dropout_opset16:: Model :: default ( ) ;
159
+
160
+ // Run the model with ones as input for easier testing
161
+ let input = Tensor :: < Backend , 4 > :: ones ( [ 2 , 4 , 10 , 15 ] ) ;
162
+
163
+ let output = model. forward ( input) ;
164
+
165
+ let expected_shape = Shape :: from ( [ 2 , 4 , 10 , 15 ] ) ;
166
+ assert_eq ! ( output. shape( ) , expected_shape) ;
167
+
168
+ let output_sum = output. sum ( ) . into_scalar ( ) ;
169
+
170
+ let expected_sum = 1200.0 ; // from pytorch
171
+
172
+ assert ! ( expected_sum. approx_eq( output_sum, ( 1.0e-4 , 2 ) ) ) ;
173
+ }
174
+
175
+ #[ test]
176
+ fn dropout_opset7 ( ) {
177
+ let model: dropout_opset7:: Model < Backend > = dropout_opset7:: Model :: default ( ) ;
156
178
157
179
// Run the model with ones as input for easier testing
158
180
let input = Tensor :: < Backend , 4 > :: ones ( [ 2 , 4 , 10 , 15 ] ) ;
@@ -255,4 +277,36 @@ mod tests {
255
277
256
278
assert_eq ! ( output. to_data( ) , expected) ;
257
279
}
280
+
281
+ #[ test]
282
+ fn avg_pool2d ( ) {
283
+ // Initialize the model without weights (because the exported file does not contain them)
284
+ let model: avg_pool2d:: Model < Backend > = avg_pool2d:: Model :: new ( ) ;
285
+
286
+ // Run the model
287
+ let input = Tensor :: < Backend , 4 > :: from_floats ( [ [ [
288
+ [ -0.077 , 0.360 , -0.782 , 0.072 , 0.665 ] ,
289
+ [ -0.287 , 1.621 , -1.597 , -0.052 , 0.611 ] ,
290
+ [ 0.760 , -0.034 , -0.345 , 0.494 , -0.078 ] ,
291
+ [ -1.805 , -0.476 , 0.205 , 0.338 , 1.353 ] ,
292
+ [ 0.374 , 0.013 , 0.774 , -0.109 , -0.271 ] ,
293
+ ] ] ] ) ;
294
+ let output = model. forward ( input) ;
295
+ let expected = Data :: from ( [ [ [ [ 0.008 , -0.131 , -0.208 , 0.425 ] ] ] ] ) ;
296
+
297
+ output. to_data ( ) . assert_approx_eq ( & expected, 3 ) ;
298
+ }
299
+
300
+ #[ test]
301
+ fn reshape ( ) {
302
+ // Initialize the model without weights (because the exported file does not contain them)
303
+ let model: reshape:: Model < Backend > = reshape:: Model :: new ( ) ;
304
+
305
+ // Run the model
306
+ let input = Tensor :: < Backend , 1 > :: from_floats ( [ 0. , 1. , 2. , 3. ] ) ;
307
+ let output = model. forward ( input) ;
308
+ let expected = Data :: from ( [ [ 0. , 1. , 2. , 3. ] ] ) ;
309
+
310
+ assert_eq ! ( output. to_data( ) , expected) ;
311
+ }
258
312
}
0 commit comments