@@ -285,26 +285,30 @@ def test_backwards_compatible_imports(self):
285
285
def test_list_extend_multi_sample_trait (self ):
286
286
center_crop = mt .CenterSpatialCrop ([128 , 128 ])
287
287
multi_sample_transform = mt .RandSpatialCropSamples ([64 , 64 ], 1 )
288
+ flatten_sequence_transform = mt .FlattenSequence ()
288
289
289
290
img = torch .zeros ([1 , 512 , 512 ])
290
291
291
292
self .assertEqual (execute_compose (img , [center_crop ]).shape , torch .Size ([1 , 128 , 128 ]))
292
- single_multi_sample_trait_result = execute_compose (img , [multi_sample_transform , center_crop ])
293
+ single_multi_sample_trait_result = execute_compose (img , [multi_sample_transform , center_crop , flatten_sequence_transform ])
293
294
self .assertIsInstance (single_multi_sample_trait_result , list )
294
295
self .assertEqual (len (single_multi_sample_trait_result ), 1 )
295
296
self .assertEqual (single_multi_sample_trait_result [0 ].shape , torch .Size ([1 , 64 , 64 ]))
296
297
297
- double_multi_sample_trait_result = execute_compose (img , [multi_sample_transform , multi_sample_transform , center_crop ])
298
+ double_multi_sample_trait_result = execute_compose (img , [
299
+ multi_sample_transform , multi_sample_transform , flatten_sequence_transform , center_crop
300
+ ])
298
301
self .assertIsInstance (double_multi_sample_trait_result , list )
299
302
self .assertEqual (len (double_multi_sample_trait_result ), 1 )
300
303
self .assertEqual (double_multi_sample_trait_result [0 ].shape , torch .Size ([1 , 64 , 64 ]))
301
304
302
305
def test_multi_sample_trait_cardinality (self ):
303
306
img = torch .zeros ([1 , 128 , 128 ])
304
307
t2 = mt .RandSpatialCropSamples ([32 , 32 ], num_samples = 2 )
308
+ flatten_sequence_transform = mt .FlattenSequence ()
305
309
306
310
# chaining should multiply counts: 2 x 2 = 4, flattened
307
- res = execute_compose (img , [t2 , t2 ])
311
+ res = execute_compose (img , [t2 , t2 , flatten_sequence_transform ])
308
312
self .assertIsInstance (res , list )
309
313
self .assertEqual (len (res ), 4 )
310
314
for r in res :
0 commit comments