11import sys
2+ import copy
23import unittest
34import numpy as np
45import warnings
6+ import operator
7+ from collections import defaultdict
58
69from nibabel .testing import assert_arrays_equal
710from nibabel .testing import clear_and_catch_warnings
1619DATA = {}
1720
1821
22+ def make_fake_streamline (nb_points , data_per_point_shapes = {},
23+ data_for_streamline_shapes = {}, rng = None ):
24+ """ Make a single streamline according to provided requirements. """
25+ if rng is None :
26+ rng = np .random .RandomState ()
27+
28+ streamline = rng .randn (nb_points , 3 ).astype ("f4" )
29+
30+ data_per_point = {}
31+ for k , shape in data_per_point_shapes .items ():
32+ data_per_point [k ] = rng .randn (* ((nb_points ,) + shape )).astype ("f4" )
33+
34+ data_for_streamline = {}
35+ for k , shape in data_for_streamline .items ():
36+ data_for_streamline [k ] = rng .randn (* shape ).astype ("f4" )
37+
38+ return streamline , data_per_point , data_for_streamline
39+
40+
41+ def make_fake_tractogram (list_nb_points , data_per_point_shapes = {},
42+ data_for_streamline_shapes = {}, rng = None ):
43+ """ Make multiple streamlines according to provided requirements. """
44+ all_streamlines = []
45+ all_data_per_point = defaultdict (lambda : [])
46+ all_data_per_streamline = defaultdict (lambda : [])
47+ for nb_points in list_nb_points :
48+ data = make_fake_streamline (nb_points , data_per_point_shapes ,
49+ data_for_streamline_shapes , rng )
50+ streamline , data_per_point , data_for_streamline = data
51+
52+ all_streamlines .append (streamline )
53+ for k , v in data_per_point .items ():
54+ all_data_per_point [k ].append (v )
55+
56+ for k , v in data_for_streamline .items ():
57+ all_data_per_streamline [k ].append (v )
58+
59+ return all_streamlines , all_data_per_point , all_data_per_streamline
60+
61+
62+ def make_dummy_streamline (nb_points ):
63+ """ Make the streamlines that have been used to create test data files."""
64+ if nb_points == 1 :
65+ streamline = np .arange (1 * 3 , dtype = "f4" ).reshape ((1 , 3 ))
66+ data_per_point = {"fa" : np .array ([[0.2 ]], dtype = "f4" ),
67+ "colors" : np .array ([(1 , 0 , 0 )]* 1 , dtype = "f4" )}
68+ data_for_streamline = {"mean_curvature" : np .array ([1.11 ], dtype = "f4" ),
69+ "mean_torsion" : np .array ([1.22 ], dtype = "f4" ),
70+ "mean_colors" : np .array ([1 , 0 , 0 ], dtype = "f4" )}
71+
72+ elif nb_points == 2 :
73+ streamline = np .arange (2 * 3 , dtype = "f4" ).reshape ((2 , 3 ))
74+ data_per_point = {"fa" : np .array ([[0.3 ],
75+ [0.4 ]], dtype = "f4" ),
76+ "colors" : np .array ([(0 , 1 , 0 )]* 2 , dtype = "f4" )}
77+ data_for_streamline = {"mean_curvature" : np .array ([2.11 ], dtype = "f4" ),
78+ "mean_torsion" : np .array ([2.22 ], dtype = "f4" ),
79+ "mean_colors" : np .array ([0 , 1 , 0 ], dtype = "f4" )}
80+
81+ elif nb_points == 5 :
82+ streamline = np .arange (5 * 3 , dtype = "f4" ).reshape ((5 , 3 ))
83+ data_per_point = {"fa" : np .array ([[0.5 ],
84+ [0.6 ],
85+ [0.6 ],
86+ [0.7 ],
87+ [0.8 ]], dtype = "f4" ),
88+ "colors" : np .array ([(0 , 0 , 1 )]* 5 , dtype = "f4" )}
89+ data_for_streamline = {"mean_curvature" : np .array ([3.11 ], dtype = "f4" ),
90+ "mean_torsion" : np .array ([3.22 ], dtype = "f4" ),
91+ "mean_colors" : np .array ([0 , 0 , 1 ], dtype = "f4" )}
92+
93+ return streamline , data_per_point , data_for_streamline
94+
95+
1996def setup ():
2097 global DATA
2198 DATA ['rng' ] = np .random .RandomState (1234 )
22- DATA ['streamlines' ] = [np .arange (1 * 3 , dtype = "f4" ).reshape ((1 , 3 )),
23- np .arange (2 * 3 , dtype = "f4" ).reshape ((2 , 3 )),
24- np .arange (5 * 3 , dtype = "f4" ).reshape ((5 , 3 ))]
25-
26- DATA ['fa' ] = [np .array ([[0.2 ]], dtype = "f4" ),
27- np .array ([[0.3 ],
28- [0.4 ]], dtype = "f4" ),
29- np .array ([[0.5 ],
30- [0.6 ],
31- [0.6 ],
32- [0.7 ],
33- [0.8 ]], dtype = "f4" )]
34-
35- DATA ['colors' ] = [np .array ([(1 , 0 , 0 )]* 1 , dtype = "f4" ),
36- np .array ([(0 , 1 , 0 )]* 2 , dtype = "f4" ),
37- np .array ([(0 , 0 , 1 )]* 5 , dtype = "f4" )]
38-
39- DATA ['mean_curvature' ] = [np .array ([1.11 ], dtype = "f4" ),
40- np .array ([2.11 ], dtype = "f4" ),
41- np .array ([3.11 ], dtype = "f4" )]
42-
43- DATA ['mean_torsion' ] = [np .array ([1.22 ], dtype = "f4" ),
44- np .array ([2.22 ], dtype = "f4" ),
45- np .array ([3.22 ], dtype = "f4" )]
46-
47- DATA ['mean_colors' ] = [np .array ([1 , 0 , 0 ], dtype = "f4" ),
48- np .array ([0 , 1 , 0 ], dtype = "f4" ),
49- np .array ([0 , 0 , 1 ], dtype = "f4" )]
99+
100+ DATA ['streamlines' ] = []
101+ DATA ['fa' ] = []
102+ DATA ['colors' ] = []
103+ DATA ['mean_curvature' ] = []
104+ DATA ['mean_torsion' ] = []
105+ DATA ['mean_colors' ] = []
106+ for nb_points in [1 , 2 , 5 ]:
107+ data = make_dummy_streamline (nb_points )
108+ streamline , data_per_point , data_for_streamline = data
109+ DATA ['streamlines' ].append (streamline )
110+ DATA ['fa' ].append (data_per_point ['fa' ])
111+ DATA ['colors' ].append (data_per_point ['colors' ])
112+ DATA ['mean_curvature' ].append (data_for_streamline ['mean_curvature' ])
113+ DATA ['mean_torsion' ].append (data_for_streamline ['mean_torsion' ])
114+ DATA ['mean_colors' ].append (data_for_streamline ['mean_colors' ])
50115
51116 DATA ['data_per_point' ] = {'colors' : DATA ['colors' ],
52117 'fa' : DATA ['fa' ]}
@@ -63,17 +128,13 @@ def setup():
63128 affine_to_rasmm = np .eye (4 ))
64129
65130 DATA ['streamlines_func' ] = lambda : (e for e in DATA ['streamlines' ])
66- fa_func = lambda : (e for e in DATA ['fa' ])
67- colors_func = lambda : (e for e in DATA ['colors' ])
68- mean_curvature_func = lambda : (e for e in DATA ['mean_curvature' ])
69- mean_torsion_func = lambda : (e for e in DATA ['mean_torsion' ])
70- mean_colors_func = lambda : (e for e in DATA ['mean_colors' ])
71-
72- DATA ['data_per_point_func' ] = {'colors' : colors_func ,
73- 'fa' : fa_func }
74- DATA ['data_per_streamline_func' ] = {'mean_curvature' : mean_curvature_func ,
75- 'mean_torsion' : mean_torsion_func ,
76- 'mean_colors' : mean_colors_func }
131+ DATA ['data_per_point_func' ] = {
132+ 'colors' : lambda : (e for e in DATA ['colors' ]),
133+ 'fa' : lambda : (e for e in DATA ['fa' ])}
134+ DATA ['data_per_streamline_func' ] = {
135+ 'mean_curvature' : lambda : (e for e in DATA ['mean_curvature' ]),
136+ 'mean_torsion' : lambda : (e for e in DATA ['mean_torsion' ]),
137+ 'mean_colors' : lambda : (e for e in DATA ['mean_colors' ])}
77138
78139 DATA ['lazy_tractogram' ] = LazyTractogram (DATA ['streamlines_func' ],
79140 DATA ['data_per_streamline_func' ],
@@ -130,6 +191,11 @@ def assert_tractogram_equal(t1, t2):
130191 t2 .data_per_streamline , t2 .data_per_point )
131192
132193
194+ def extender (a , b ):
195+ a .extend (b )
196+ return a
197+
198+
133199class TestPerArrayDict (unittest .TestCase ):
134200
135201 def test_per_array_dict_creation (self ):
@@ -181,6 +247,53 @@ def test_getitem(self):
181247 assert_arrays_equal (sdict [- 1 ][k ], v [- 1 ])
182248 assert_arrays_equal (sdict [[0 , - 1 ]][k ], v [[0 , - 1 ]])
183249
250+ def test_extend (self ):
251+ sdict = PerArrayDict (len (DATA ['tractogram' ]),
252+ DATA ['data_per_streamline' ])
253+
254+ new_data = {'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
255+ 'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
256+ 'mean_colors' : 4 * np .array (DATA ['mean_colors' ])}
257+ sdict2 = PerArrayDict (len (DATA ['tractogram' ]),
258+ new_data )
259+
260+ sdict .extend (sdict2 )
261+ assert_equal (len (sdict ), len (sdict2 ))
262+ for k in DATA ['tractogram' ].data_per_streamline :
263+ assert_arrays_equal (sdict [k ][:len (DATA ['tractogram' ])],
264+ DATA ['tractogram' ].data_per_streamline [k ])
265+ assert_arrays_equal (sdict [k ][len (DATA ['tractogram' ]):],
266+ new_data [k ])
267+
268+ # Extending with an empty PerArrayDict should change nothing.
269+ sdict_orig = copy .deepcopy (sdict )
270+ sdict .extend (PerArrayDict ())
271+ for k in sdict_orig .keys ():
272+ assert_arrays_equal (sdict [k ], sdict_orig [k ])
273+
274+ # Test incompatible PerArrayDicts.
275+ # Other dict has more entries.
276+ new_data = {'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
277+ 'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
278+ 'mean_colors' : 4 * np .array (DATA ['mean_colors' ]),
279+ 'other' : 5 * np .array (DATA ['mean_colors' ])}
280+ sdict2 = PerArrayDict (len (DATA ['tractogram' ]), new_data )
281+ assert_raises (ValueError , sdict .extend , sdict2 )
282+
283+ # Other dict has not the same entries (key mistmached).
284+ new_data = {'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
285+ 'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
286+ 'other' : 4 * np .array (DATA ['mean_colors' ])}
287+ sdict2 = PerArrayDict (len (DATA ['tractogram' ]), new_data )
288+ assert_raises (ValueError , sdict .extend , sdict2 )
289+
290+ # Other dict has the right number of entries but wrong shape.
291+ new_data = {'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
292+ 'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
293+ 'mean_colors' : 4 * np .array (DATA ['mean_torsion' ])}
294+ sdict2 = PerArrayDict (len (DATA ['tractogram' ]), new_data )
295+ assert_raises (ValueError , sdict .extend , sdict2 )
296+
184297
185298class TestPerArraySequenceDict (unittest .TestCase ):
186299
@@ -233,6 +346,62 @@ def test_getitem(self):
233346 assert_arrays_equal (sdict [- 1 ][k ], v [- 1 ])
234347 assert_arrays_equal (sdict [[0 , - 1 ]][k ], v [[0 , - 1 ]])
235348
349+ def test_extend (self ):
350+ total_nb_rows = DATA ['tractogram' ].streamlines .total_nb_rows
351+ sdict = PerArraySequenceDict (total_nb_rows , DATA ['data_per_point' ])
352+
353+ # Test compatible PerArraySequenceDicts.
354+ list_nb_points = [2 , 7 , 4 ]
355+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
356+ "fa" : DATA ['fa' ][0 ].shape [1 :]}
357+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
358+ data_per_point_shapes ,
359+ rng = DATA ['rng' ])
360+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
361+
362+ sdict .extend (sdict2 )
363+ assert_equal (len (sdict ), len (sdict2 ))
364+ for k in DATA ['tractogram' ].data_per_point :
365+ assert_arrays_equal (sdict [k ][:len (DATA ['tractogram' ])],
366+ DATA ['tractogram' ].data_per_point [k ])
367+ assert_arrays_equal (sdict [k ][len (DATA ['tractogram' ]):],
368+ new_data [k ])
369+
370+ # Extending with an empty PerArraySequenceDicts should change nothing.
371+ sdict_orig = copy .deepcopy (sdict )
372+ sdict .extend (PerArraySequenceDict ())
373+ for k in sdict_orig .keys ():
374+ assert_arrays_equal (sdict [k ], sdict_orig [k ])
375+
376+ # Test incompatible PerArraySequenceDicts.
377+ # Other dict has more entries.
378+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
379+ "fa" : DATA ['fa' ][0 ].shape [1 :],
380+ "other" : (7 ,)}
381+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
382+ data_per_point_shapes ,
383+ rng = DATA ['rng' ])
384+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
385+ assert_raises (ValueError , sdict .extend , sdict2 )
386+
387+ # Other dict has not the same entries (key mistmached).
388+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
389+ "other" : DATA ['fa' ][0 ].shape [1 :]}
390+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
391+ data_per_point_shapes ,
392+ rng = DATA ['rng' ])
393+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
394+ assert_raises (ValueError , sdict .extend , sdict2 )
395+
396+ # Other dict has the right number of entries but wrong shape.
397+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
398+ "fa" : DATA ['fa' ][0 ].shape [1 :] + (3 ,)}
399+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
400+ data_per_point_shapes ,
401+ rng = DATA ['rng' ])
402+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
403+ assert_raises (ValueError , sdict .extend , sdict2 )
404+
236405
237406class TestLazyDict (unittest .TestCase ):
238407
@@ -570,6 +739,28 @@ def test_tractogram_to_world(self):
570739 tractogram .affine_to_rasmm = None
571740 assert_raises (ValueError , tractogram .to_world )
572741
742+ def test_tractogram_extend (self ):
743+ # Load tractogram that contains some metadata.
744+ t = DATA ['tractogram' ].copy ()
745+
746+ for op , in_place in ((operator .add , False ), (operator .iadd , True ),
747+ (extender , True )):
748+ first_arg = t .copy ()
749+ new_t = op (first_arg , t )
750+ assert_equal (new_t is first_arg , in_place )
751+ assert_tractogram_equal (new_t [:len (t )], DATA ['tractogram' ])
752+ assert_tractogram_equal (new_t [len (t ):], DATA ['tractogram' ])
753+
754+ # Test extending an empty Tractogram.
755+ t = Tractogram ()
756+ t += DATA ['tractogram' ]
757+ assert_tractogram_equal (t , DATA ['tractogram' ])
758+
759+ # and the other way around.
760+ t = DATA ['tractogram' ].copy ()
761+ t += Tractogram ()
762+ assert_tractogram_equal (t , DATA ['tractogram' ])
763+
573764
574765class TestLazyTractogram (unittest .TestCase ):
575766
@@ -580,11 +771,12 @@ def test_lazy_tractogram_creation(self):
580771 # Streamlines and other data as generators
581772 streamlines = (x for x in DATA ['streamlines' ])
582773 data_per_point = {"colors" : (x for x in DATA ['colors' ])}
583- data_per_streamline = {'mean_torsion ' : (x for x in DATA ['mean_torsion' ]),
584- 'mean_colors ' : (x for x in DATA ['mean_colors' ])}
774+ data_per_streamline = {'torsion ' : (x for x in DATA ['mean_torsion' ]),
775+ 'colors ' : (x for x in DATA ['mean_colors' ])}
585776
586777 # Creating LazyTractogram with generators is not allowed as
587- # generators get exhausted and are not reusable unlike generator function.
778+ # generators get exhausted and are not reusable unlike generator
779+ # function.
588780 assert_raises (TypeError , LazyTractogram , streamlines )
589781 assert_raises (TypeError , LazyTractogram ,
590782 data_per_streamline = data_per_streamline )
@@ -610,12 +802,11 @@ def test_lazy_tractogram_creation(self):
610802
611803 def test_lazy_tractogram_from_data_func (self ):
612804 # Create an empty `LazyTractogram` yielding nothing.
613- _empty_data_gen = lambda : iter ([])
614-
615- tractogram = LazyTractogram .from_data_func (_empty_data_gen )
805+ tractogram = LazyTractogram .from_data_func (lambda : iter ([]))
616806 check_tractogram (tractogram )
617807
618- # Create `LazyTractogram` from a generator function yielding TractogramItem.
808+ # Create `LazyTractogram` from a generator function yielding
809+ # TractogramItem.
619810 data = [DATA ['streamlines' ], DATA ['fa' ], DATA ['colors' ],
620811 DATA ['mean_curvature' ], DATA ['mean_torsion' ],
621812 DATA ['mean_colors' ]]
@@ -641,6 +832,13 @@ def test_lazy_tractogram_getitem(self):
641832 assert_raises (NotImplementedError ,
642833 DATA ['lazy_tractogram' ].__getitem__ , 0 )
643834
835+ def test_lazy_tractogram_extend (self ):
836+ t = DATA ['lazy_tractogram' ].copy ()
837+ new_t = DATA ['lazy_tractogram' ].copy ()
838+
839+ for op in (operator .add , operator .iadd , extender ):
840+ assert_raises (NotImplementedError , op , new_t , t )
841+
644842 def test_lazy_tractogram_len (self ):
645843 modules = [module_tractogram ] # Modules for which to catch warnings.
646844 with clear_and_catch_warnings (record = True , modules = modules ) as w :
@@ -746,8 +944,8 @@ def test_lazy_tractogram_copy(self):
746944 # Check we copied the data and not simply created new references.
747945 assert_true (tractogram is not DATA ['lazy_tractogram' ])
748946
749- # When copying LazyTractogram, the generator function yielding streamlines
750- # should stay the same.
947+ # When copying LazyTractogram, the generator function yielding
948+ # streamlines should stay the same.
751949 assert_true (tractogram ._streamlines
752950 is DATA ['lazy_tractogram' ]._streamlines )
753951
@@ -759,12 +957,14 @@ def test_lazy_tractogram_copy(self):
759957 is not DATA ['lazy_tractogram' ]._data_per_point )
760958
761959 for key in tractogram .data_per_streamline :
762- assert_true (tractogram .data_per_streamline .store [key ]
763- is DATA ['lazy_tractogram' ].data_per_streamline .store [key ])
960+ data = tractogram .data_per_streamline .store [key ]
961+ expected = DATA ['lazy_tractogram' ].data_per_streamline .store [key ]
962+ assert_true (data is expected )
764963
765964 for key in tractogram .data_per_point :
766- assert_true (tractogram .data_per_point .store [key ]
767- is DATA ['lazy_tractogram' ].data_per_point .store [key ])
965+ data = tractogram .data_per_point .store [key ]
966+ expected = DATA ['lazy_tractogram' ].data_per_point .store [key ]
967+ assert_true (data is expected )
768968
769969 # The affine should be a copy.
770970 assert_true (tractogram ._affine_to_apply
0 commit comments