This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
io.py
1022 lines (893 loc) · 35.7 KB
/
io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=unnecessary-pass
"""Data iterators for common data formats."""
from collections import namedtuple
import sys
import ctypes
import logging
import threading
import numpy as np
from ..base import _LIB
from ..base import c_str_array, mx_uint, py_str
from ..base import DataIterHandle, NDArrayHandle
from ..base import mx_real_t
from ..base import check_call, build_param_doc as _build_param_doc
from ..ndarray import NDArray
from ..ndarray.sparse import CSRNDArray
from ..util import is_np_array
from ..ndarray import array
from ..ndarray import concat, tile
from .utils import _init_data, _has_instance, _getdata_by_idx
class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
"""DataDesc is used to store name, shape, type and layout
information of the data or the label.
The `layout` describes how the axes in `shape` should be interpreted,
for example for image data setting `layout=NCHW` indicates
that the first axis is number of examples in the batch(N),
C is number of channels, H is the height and W is the width of the image.
For sequential data, by default `layout` is set to ``NTC``, where
N is number of examples in the batch, T the temporal axis representing time
and C is the number of channels.
Parameters
----------
cls : DataDesc
The class.
name : str
Data name.
shape : tuple of int
Data shape.
dtype : np.dtype, optional
Data type.
layout : str, optional
Data layout.
"""
def __new__(cls, name, shape, dtype=mx_real_t, layout='NCHW'): # pylint: disable=super-on-old-class
ret = super(cls, DataDesc).__new__(cls, name, shape)
ret.dtype = dtype
ret.layout = layout
return ret
def __repr__(self):
return f"DataDesc[{self.name},{self.shape},{self.dtype},{self.layout}]"
@staticmethod
def get_batch_axis(layout):
"""Get the dimension that corresponds to the batch size.
When data parallelism is used, the data will be automatically split and
concatenated along the batch-size dimension. Axis can be -1, which means
the whole array will be copied for each data-parallelism device.
Parameters
----------
layout : str
layout string. For example, "NCHW".
Returns
-------
int
An axis indicating the batch_size dimension.
"""
if layout is None:
return 0
return layout.find('N')
@staticmethod
def get_list(shapes, types):
"""Get DataDesc list from attribute lists.
Parameters
----------
shapes : a tuple of (name, shape)
types : a tuple of (name, np.dtype)
"""
if types is not None:
type_dict = dict(types)
return [DataDesc(x[0], x[1], type_dict[x[0]]) for x in shapes]
else:
return [DataDesc(x[0], x[1]) for x in shapes]
class DataBatch(object):
"""A data batch.
MXNet's data iterator returns a batch of data for each `next` call.
This data contains `batch_size` number of examples.
If the input data consists of images, then shape of these images depend on
the `layout` attribute of `DataDesc` object in `provide_data` parameter.
If `layout` is set to 'NCHW' then, images should be stored in a 4-D matrix
of shape ``(batch_size, num_channel, height, width)``.
If `layout` is set to 'NHWC' then, images should be stored in a 4-D matrix
of shape ``(batch_size, height, width, num_channel)``.
The channels are often in RGB order.
Parameters
----------
data : list of `NDArray`, each array containing `batch_size` examples.
A list of input data.
label : list of `NDArray`, each array often containing a 1-dimensional array. optional
A list of input labels.
pad : int, optional
The number of examples padded at the end of a batch. It is used when the
total number of examples read is not divisible by the `batch_size`.
These extra padded examples are ignored in prediction.
index : numpy.array, optional
The example indices in this batch.
bucket_key : int, optional
The bucket key, used for bucketing module.
provide_data : list of `DataDesc`, optional
A list of `DataDesc` objects. `DataDesc` is used to store
name, shape, type and layout information of the data.
The *i*-th element describes the name and shape of ``data[i]``.
provide_label : list of `DataDesc`, optional
A list of `DataDesc` objects. `DataDesc` is used to store
name, shape, type and layout information of the label.
The *i*-th element describes the name and shape of ``label[i]``.
"""
def __init__(self, data, label=None, pad=None, index=None,
bucket_key=None, provide_data=None, provide_label=None):
if data is not None:
assert isinstance(data, (list, tuple)), "Data must be list of NDArrays"
if label is not None:
assert isinstance(label, (list, tuple)), "Label must be list of NDArrays"
self.data = data
self.label = label
self.pad = pad
self.index = index
self.bucket_key = bucket_key
self.provide_data = provide_data
self.provide_label = provide_label
def __str__(self):
data_shapes = [d.shape for d in self.data]
if self.label:
label_shapes = [l.shape for l in self.label]
else:
label_shapes = None
return "{}: data shapes: {} label shapes: {}".format(
self.__class__.__name__,
data_shapes,
label_shapes)
class DataIter(object):
"""The base class for an MXNet data iterator.
All I/O in MXNet is handled by specializations of this class. Data iterators
in MXNet are similar to standard-iterators in Python. On each call to `next`
they return a `DataBatch` which represents the next batch of data. When
there is no more data to return, it raises a `StopIteration` exception.
Parameters
----------
batch_size : int, optional
The batch size, namely the number of items in the batch.
See Also
--------
NDArrayIter : Data-iterator for MXNet NDArray or numpy-ndarray objects.
CSVIter : Data-iterator for csv data.
LibSVMIter : Data-iterator for libsvm data.
ImageIter : Data-iterator for images.
"""
def __init__(self, batch_size=0):
self.batch_size = batch_size
def __iter__(self):
return self
def reset(self):
"""Reset the iterator to the begin of the data."""
pass
def next(self):
"""Get next data batch from iterator.
Returns
-------
DataBatch
The data of next batch.
Raises
------
StopIteration
If the end of the data is reached.
"""
if self.iter_next():
return DataBatch(data=self.getdata(), label=self.getlabel(), \
pad=self.getpad(), index=self.getindex())
else:
raise StopIteration
def __next__(self):
return self.next()
def iter_next(self):
"""Move to the next batch.
Returns
-------
boolean
Whether the move is successful.
"""
pass
def getdata(self):
"""Get data of current batch.
Returns
-------
list of NDArray
The data of the current batch.
"""
pass
def getlabel(self):
"""Get label of the current batch.
Returns
-------
list of NDArray
The label of the current batch.
"""
pass
def getindex(self):
"""Get index of the current batch.
Returns
-------
index : numpy.array
The indices of examples in the current batch.
"""
return None
def getpad(self):
"""Get the number of padding examples in the current batch.
Returns
-------
int
Number of padding examples in the current batch.
"""
pass
class ResizeIter(DataIter):
"""Resize a data iterator to a given number of batches.
Parameters
----------
data_iter : DataIter
The data iterator to be resized.
size : int
The number of batches per epoch to resize to.
reset_internal : bool
Whether to reset internal iterator on ResizeIter.reset.
Examples
--------
>>> nd_iter = mx.io.NDArrayIter(mx.nd.ones((100,10)), batch_size=25)
>>> resize_iter = mx.io.ResizeIter(nd_iter, 2)
>>> for batch in resize_iter:
... print(batch.data)
[<NDArray 25x10 @cpu(0)>]
[<NDArray 25x10 @cpu(0)>]
"""
def __init__(self, data_iter, size, reset_internal=True):
super(ResizeIter, self).__init__()
self.data_iter = data_iter
self.size = size
self.reset_internal = reset_internal
self.cur = 0
self.current_batch = None
self.provide_data = data_iter.provide_data
self.provide_label = data_iter.provide_label
self.batch_size = data_iter.batch_size
if hasattr(data_iter, 'default_bucket_key'):
self.default_bucket_key = data_iter.default_bucket_key
def reset(self):
self.cur = 0
if self.reset_internal:
self.data_iter.reset()
def iter_next(self):
if self.cur == self.size:
return False
try:
self.current_batch = self.data_iter.next()
except StopIteration:
self.data_iter.reset()
self.current_batch = self.data_iter.next()
self.cur += 1
return True
def getdata(self):
return self.current_batch.data
def getlabel(self):
return self.current_batch.label
def getindex(self):
return self.current_batch.index
def getpad(self):
return self.current_batch.pad
class PrefetchingIter(DataIter):
"""Performs pre-fetch for other data iterators.
This iterator will create another thread to perform ``iter_next`` and then
store the data in memory. It potentially accelerates the data read, at the
cost of more memory usage.
Parameters
----------
iters : DataIter or list of DataIter
The data iterators to be pre-fetched.
rename_data : None or list of dict
The *i*-th element is a renaming map for the *i*-th iter, in the form of
{'original_name' : 'new_name'}. Should have one entry for each entry
in iter[i].provide_data.
rename_label : None or list of dict
Similar to ``rename_data``.
Examples
--------
>>> iter1 = mx.io.NDArrayIter({'data':mx.nd.ones((100,10))}, batch_size=25)
>>> iter2 = mx.io.NDArrayIter({'data':mx.nd.ones((100,10))}, batch_size=25)
>>> piter = mx.io.PrefetchingIter([iter1, iter2],
... rename_data=[{'data': 'data_1'}, {'data': 'data_2'}])
>>> print(piter.provide_data)
[DataDesc[data_1,(25, 10L),<type 'numpy.float32'>,NCHW],
DataDesc[data_2,(25, 10L),<type 'numpy.float32'>,NCHW]]
"""
def __init__(self, iters, rename_data=None, rename_label=None):
super(PrefetchingIter, self).__init__()
if not isinstance(iters, list):
iters = [iters]
self.n_iter = len(iters)
assert self.n_iter > 0
self.iters = iters
self.rename_data = rename_data
self.rename_label = rename_label
self.batch_size = self.provide_data[0][1][0]
self.data_ready = [threading.Event() for i in range(self.n_iter)]
self.data_taken = [threading.Event() for i in range(self.n_iter)]
for i in self.data_taken:
i.set()
self.started = True
self.current_batch = [None for i in range(self.n_iter)]
self.next_batch = [None for i in range(self.n_iter)]
def prefetch_func(self, i):
"""Thread entry"""
while True:
self.data_taken[i].wait()
if not self.started:
break
try:
self.next_batch[i] = self.iters[i].next()
except StopIteration:
self.next_batch[i] = None
self.data_taken[i].clear()
self.data_ready[i].set()
self.prefetch_threads = [threading.Thread(target=prefetch_func, args=[self, i]) \
for i in range(self.n_iter)]
for thread in self.prefetch_threads:
thread.setDaemon(True)
thread.start()
def __del__(self):
self.started = False
for i in self.data_taken:
i.set()
for thread in self.prefetch_threads:
thread.join()
@property
def provide_data(self):
if self.rename_data is None:
return sum([i.provide_data for i in self.iters], [])
else:
return sum([[
DataDesc(r[x.name], x.shape, x.dtype)
if isinstance(x, DataDesc) else DataDesc(*x)
for x in i.provide_data
] for r, i in zip(self.rename_data, self.iters)], [])
@property
def provide_label(self):
if self.rename_label is None:
return sum([i.provide_label for i in self.iters], [])
else:
return sum([[
DataDesc(r[x.name], x.shape, x.dtype)
if isinstance(x, DataDesc) else DataDesc(*x)
for x in i.provide_label
] for r, i in zip(self.rename_label, self.iters)], [])
def reset(self):
for i in self.data_ready:
i.wait()
for i in self.iters:
i.reset()
for i in self.data_ready:
i.clear()
for i in self.data_taken:
i.set()
def iter_next(self):
for i in self.data_ready:
i.wait()
if self.next_batch[0] is None:
for i in self.next_batch:
assert i is None, "Number of entry mismatches between iterators"
return False
else:
for batch in self.next_batch:
assert batch.pad == self.next_batch[0].pad, \
"Number of entry mismatches between iterators"
self.current_batch = DataBatch(sum([batch.data for batch in self.next_batch], []),
sum([batch.label for batch in self.next_batch], []),
self.next_batch[0].pad,
self.next_batch[0].index,
provide_data=self.provide_data,
provide_label=self.provide_label)
for i in self.data_ready:
i.clear()
for i in self.data_taken:
i.set()
return True
def next(self):
if self.iter_next():
return self.current_batch
else:
raise StopIteration
def getdata(self):
return self.current_batch.data
def getlabel(self):
return self.current_batch.label
def getindex(self):
return self.current_batch.index
def getpad(self):
return self.current_batch.pad
class NDArrayIter(DataIter):
"""Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset``
``mx.nd.sparse.CSRNDArray`` or ``scipy.sparse.csr_matrix``.
Examples
--------
>>> data = np.arange(40).reshape((10,2,2))
>>> labels = np.ones([10, 1])
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard')
>>> for batch in dataiter:
... print batch.data[0].asnumpy()
... batch.data[0].shape
...
[[[ 36. 37.]
[ 38. 39.]]
[[ 16. 17.]
[ 18. 19.]]
[[ 12. 13.]
[ 14. 15.]]]
(3L, 2L, 2L)
[[[ 32. 33.]
[ 34. 35.]]
[[ 4. 5.]
[ 6. 7.]]
[[ 24. 25.]
[ 26. 27.]]]
(3L, 2L, 2L)
[[[ 8. 9.]
[ 10. 11.]]
[[ 20. 21.]
[ 22. 23.]]
[[ 28. 29.]
[ 30. 31.]]]
(3L, 2L, 2L)
>>> dataiter.provide_data # Returns a list of `DataDesc`
[DataDesc[data,(3, 2L, 2L),<type 'numpy.float32'>,NCHW]]
>>> dataiter.provide_label # Returns a list of `DataDesc`
[DataDesc[softmax_label,(3, 1L),<type 'numpy.float32'>,NCHW]]
In the above example, data is shuffled as `shuffle` parameter is set to `True`
and remaining examples are discarded as `last_batch_handle` parameter is set to `discard`.
Usage of `last_batch_handle` parameter:
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='pad')
>>> batchidx = 0
>>> for batch in dataiter:
... batchidx += 1
...
>>> batchidx # Padding added after the examples read are over. So, 10/3+1 batches are created.
4
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard')
>>> batchidx = 0
>>> for batch in dataiter:
... batchidx += 1
...
>>> batchidx # Remaining examples are discarded. So, 10/3 batches are created.
3
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over')
>>> batchidx = 0
>>> for batch in dataiter:
... batchidx += 1
...
>>> batchidx # Remaining examples are rolled over to the next iteration.
3
>>> dataiter.reset()
>>> dataiter.next().data[0].asnumpy()
[[[ 36. 37.]
[ 38. 39.]]
[[ 0. 1.]
[ 2. 3.]]
[[ 4. 5.]
[ 6. 7.]]]
(3L, 2L, 2L)
`NDArrayIter` also supports multiple input and labels.
>>> data = {'data1':np.zeros(shape=(10,2,2)), 'data2':np.zeros(shape=(20,2,2))}
>>> label = {'label1':np.zeros(shape=(10,1)), 'label2':np.zeros(shape=(20,1))}
>>> dataiter = mx.io.NDArrayIter(data, label, 3, True, last_batch_handle='discard')
`NDArrayIter` also supports ``mx.nd.sparse.CSRNDArray``
with `last_batch_handle` set to `discard`.
>>> csr_data = mx.nd.array(np.arange(40).reshape((10,4))).tostype('csr')
>>> labels = np.ones([10, 1])
>>> dataiter = mx.io.NDArrayIter(csr_data, labels, 3, last_batch_handle='discard')
>>> [batch.data[0] for batch in dataiter]
[
<CSRNDArray 3x4 @cpu(0)>,
<CSRNDArray 3x4 @cpu(0)>,
<CSRNDArray 3x4 @cpu(0)>]
Parameters
----------
data: array or list of array or dict of string to array
The input data.
label: array or list of array or dict of string to array, optional
The input label.
batch_size: int
Batch size of data.
shuffle: bool, optional
Whether to shuffle the data.
Only supported if no h5py.Dataset inputs are used.
last_batch_handle : str, optional
How to handle the last batch. This parameter can be 'pad', 'discard' or
'roll_over'.
If 'pad', the last batch will be padded with data starting from the begining
If 'discard', the last batch will be discarded
If 'roll_over', the remaining elements will be rolled over to the next iteration and
note that it is intended for training and can cause problems if used for prediction.
data_name : str, optional
The data name.
label_name : str, optional
The label name.
"""
def __init__(self, data, label=None, batch_size=1, shuffle=False,
last_batch_handle='pad', data_name='data',
label_name='softmax_label'):
super(NDArrayIter, self).__init__(batch_size)
self.data = _init_data(data, allow_empty=False, default_name=data_name)
self.label = _init_data(label, allow_empty=True, default_name=label_name)
if ((_has_instance(self.data, CSRNDArray) or
_has_instance(self.label, CSRNDArray)) and
(last_batch_handle != 'discard')):
raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
" with `last_batch_handle` set to `discard`.")
self.idx = np.arange(self.data[0][1].shape[0])
self.shuffle = shuffle
self.last_batch_handle = last_batch_handle
self.batch_size = batch_size
self.cursor = -self.batch_size
self.num_data = self.idx.shape[0]
# shuffle
self.reset()
self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
self.num_source = len(self.data_list)
# used for 'roll_over'
self._cache_data = None
self._cache_label = None
@property
def provide_data(self):
"""The name and shape of data provided by this iterator."""
return [
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
for k, v in self.data
]
@property
def provide_label(self):
"""The name and shape of label provided by this iterator."""
batch_axis = self.layout.find('N')
return [
DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
[self.batch_size] + list(v.shape[batch_axis + 1:])),
v.dtype, layout=self.layout)
for k, v in self.label
]
def hard_reset(self):
"""Ignore roll over data and set to start."""
if self.shuffle:
self._shuffle_data()
self.cursor = -self.batch_size
self._cache_data = None
self._cache_label = None
def reset(self):
"""Resets the iterator to the beginning of the data."""
if self.shuffle:
self._shuffle_data()
# the range below indicate the last batch
if self.last_batch_handle == 'roll_over' and \
self.num_data - self.batch_size < self.cursor < self.num_data:
# (self.cursor - self.num_data) represents the data we have for the last batch
self.cursor = self.cursor - self.num_data - self.batch_size
else:
self.cursor = -self.batch_size
def iter_next(self):
"""Increments the coursor by batch_size for next batch
and check current cursor if it exceed the number of data points."""
self.cursor += self.batch_size
return self.cursor < self.num_data
def next(self):
"""Returns the next batch of data."""
if not self.iter_next():
raise StopIteration
data = self.getdata()
label = self.getlabel()
# iter should stop when last batch is not complete
if data[0].shape[0] != self.batch_size:
# in this case, cache it for next epoch
self._cache_data = data
self._cache_label = label
raise StopIteration
return DataBatch(data=data, label=label, \
pad=self.getpad(), index=None)
def _getdata(self, data_source, start=None, end=None):
"""Load data from underlying arrays."""
assert start is not None or end is not None, 'should at least specify start or end'
start = start if start is not None else 0
if end is None:
end = data_source[0][1].shape[0] if data_source else 0
s = slice(start, end)
return [
x[1][s]
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[s])][[
list(self.idx[s]).index(i)
for i in sorted(self.idx[s])
]]) for x in data_source
]
def _concat(self, first_data, second_data):
"""Helper function to concat two NDArrays."""
if (not first_data) or (not second_data):
return first_data if first_data else second_data
assert len(first_data) == len(
second_data), 'data source should contain the same size'
return [
concat(
first_data[i],
second_data[i],
dim=0
) for i in range(len(first_data))
]
def _tile(self, data, repeats):
if not data:
return []
res = []
for datum in data:
reps = [1] * len(datum.shape)
reps[0] = repeats
res.append(tile(datum, reps))
return res
def _batchify(self, data_source):
"""Load data from underlying arrays, internal use only."""
assert self.cursor < self.num_data, 'DataIter needs reset.'
# first batch of next epoch with 'roll_over'
if self.last_batch_handle == 'roll_over' and \
-self.batch_size < self.cursor < 0:
assert self._cache_data is not None or self._cache_label is not None, \
'next epoch should have cached data'
cache_data = self._cache_data if self._cache_data is not None else self._cache_label
second_data = self._getdata(
data_source, end=self.cursor + self.batch_size)
if self._cache_data is not None:
self._cache_data = None
else:
self._cache_label = None
return self._concat(cache_data, second_data)
# last batch with 'pad'
elif self.last_batch_handle == 'pad' and \
self.cursor + self.batch_size > self.num_data:
pad = self.batch_size - self.num_data + self.cursor
first_data = self._getdata(data_source, start=self.cursor)
if pad > self.num_data:
repeats = pad // self.num_data
second_data = self._tile(self._getdata(data_source, end=self.num_data), repeats)
if pad % self.num_data != 0:
second_data = self._concat(second_data, self._getdata(data_source, end=pad % self.num_data))
else:
second_data = self._getdata(data_source, end=pad)
return self._concat(first_data, second_data)
# normal case
else:
if self.cursor + self.batch_size < self.num_data:
end_idx = self.cursor + self.batch_size
# get incomplete last batch
else:
end_idx = self.num_data
return self._getdata(data_source, self.cursor, end_idx)
def getdata(self):
"""Get data."""
return self._batchify(self.data)
def getlabel(self):
"""Get label."""
return self._batchify(self.label)
def getpad(self):
"""Get pad value of DataBatch."""
if self.last_batch_handle == 'pad' and \
self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
# check the first batch
elif self.last_batch_handle == 'roll_over' and \
-self.batch_size < self.cursor < 0:
return -self.cursor
else:
return 0
def _shuffle_data(self):
"""Shuffle the data."""
# shuffle index
np.random.shuffle(self.idx)
# get the data by corresponding index
self.data = _getdata_by_idx(self.data, self.idx)
self.label = _getdata_by_idx(self.label, self.idx)
class MXDataIter(DataIter):
"""A python wrapper a C++ data iterator.
This iterator is the Python wrapper to all native C++ data iterators, such
as `CSVIter`, `ImageRecordIter`, `MNISTIter`, etc. When initializing
`CSVIter` for example, you will get an `MXDataIter` instance to use in your
Python code. Calls to `next`, `reset`, etc will be delegated to the
underlying C++ data iterators.
Usually you don't need to interact with `MXDataIter` directly unless you are
implementing your own data iterators in C++. To do that, please refer to
examples under the `src/io` folder.
Parameters
----------
handle : DataIterHandle, required
The handle to the underlying C++ Data Iterator.
data_name : str, optional
Data name. Default to "data".
label_name : str, optional
Label name. Default to "softmax_label".
See Also
--------
src/io : The underlying C++ data iterator implementation, e.g., `CSVIter`.
"""
def __init__(self, handle, data_name='data', label_name='softmax_label', **kwargs):
super(MXDataIter, self).__init__()
from ..ndarray import _ndarray_cls
from ..numpy.multiarray import _np_ndarray_cls
self._create_ndarray_fn = _np_ndarray_cls if is_np_array() else _ndarray_cls
self.handle = handle
self._kwargs = kwargs
# debug option, used to test the speed with io effect eliminated
self._debug_skip_load = False
# load the first batch to get shape information
self.first_batch = None
self.first_batch = self.next()
data = self.first_batch.data[0]
label = self.first_batch.label[0]
# properties
self.provide_data = [DataDesc(data_name, data.shape, data.dtype)]
self.provide_label = [DataDesc(label_name, label.shape, label.dtype)]
self.batch_size = data.shape[0]
def __del__(self):
check_call(_LIB.MXDataIterFree(self.handle))
def debug_skip_load(self):
# Set the iterator to simply return always first batch. This can be used
# to test the speed of network without taking the loading delay into
# account.
self._debug_skip_load = True
logging.info('Set debug_skip_load to be true, will simply return first batch')
def reset(self):
self._debug_at_begin = True
self.first_batch = None
check_call(_LIB.MXDataIterBeforeFirst(self.handle))
def next(self):
if self._debug_skip_load and not self._debug_at_begin:
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(),
index=self.getindex())
if self.first_batch is not None:
batch = self.first_batch
self.first_batch = None
return batch
self._debug_at_begin = False
next_res = ctypes.c_int(0)
check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res)))
if next_res.value:
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(),
index=self.getindex())
else:
raise StopIteration
def iter_next(self):
if self.first_batch is not None:
return True
next_res = ctypes.c_int(0)
check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res)))
return next_res.value
def getdata(self):
hdl = NDArrayHandle()
check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl)))
return self._create_ndarray_fn(hdl, False)
def getlabel(self):
hdl = NDArrayHandle()
check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl)))
return self._create_ndarray_fn(hdl, False)
def getindex(self):
index_size = ctypes.c_uint64(0)
index_data = ctypes.POINTER(ctypes.c_uint64)()
check_call(_LIB.MXDataIterGetIndex(self.handle,
ctypes.byref(index_data),
ctypes.byref(index_size)))
if index_size.value:
address = ctypes.addressof(index_data.contents)
dbuffer = (ctypes.c_uint64* index_size.value).from_address(address)
np_index = np.frombuffer(dbuffer, dtype=np.uint64)
return np_index.copy()
else:
return None
def getpad(self):
pad = ctypes.c_int(0)
check_call(_LIB.MXDataIterGetPadNum(self.handle, ctypes.byref(pad)))
return pad.value
def getitems(self):
output_vars = ctypes.POINTER(NDArrayHandle)()
num_output = ctypes.c_int(0)
check_call(_LIB.MXDataIterGetItems(self.handle,
ctypes.byref(num_output),
ctypes.byref(output_vars)))
out = [self._create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
False) for i in range(num_output.value)]
return tuple(out)
def __len__(self):
length = ctypes.c_int64(-1)
check_call(_LIB.MXDataIterGetLenHint(self.handle, ctypes.byref(length)))
if length.value < 0:
return 0
return length.value
def _make_io_iterator(handle):
"""Create an io iterator by handle."""
name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = mx_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
arg_types = ctypes.POINTER(ctypes.c_char_p)()
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.MXDataIterGetIterInfo( \
handle, ctypes.byref(name), ctypes.byref(desc), \
ctypes.byref(num_args), \
ctypes.byref(arg_names), \
ctypes.byref(arg_types), \
ctypes.byref(arg_descs)))
iter_name = py_str(name.value)
narg = int(num_args.value)
param_str = _build_param_doc(
[py_str(arg_names[i]) for i in range(narg)],
[py_str(arg_types[i]) for i in range(narg)],
[py_str(arg_descs[i]) for i in range(narg)])
doc_str = (f'{desc.value}\n\n' +
f'{param_str}\n' +
'Returns\n' +
'-------\n' +
'MXDataIter\n'+
' The result iterator.')
def creator(*args, **kwargs):
"""Create an iterator.
The parameters listed below can be passed in as keyword arguments.
Parameters
----------
name : string, required.
Name of the resulting data iterator.
Returns
-------
dataiter: Dataiter
The resulting data iterator.
"""
param_keys = []
param_vals = []
for k, val in kwargs.items():
if iter_name == 'ThreadedDataLoader':
# convert ndarray to handle
if hasattr(val, 'handle'):
val = val.handle.value
elif isinstance(val, (tuple, list)):
val = [vv.handle.value if hasattr(vv, 'handle') else vv for vv in val]
elif isinstance(getattr(val, '_iter', None), MXDataIter):
val = val._iter.handle.value
param_keys.append(k)
param_vals.append(str(val))
# create atomic symbol
param_keys = c_str_array(param_keys)
param_vals = c_str_array(param_vals)
iter_handle = DataIterHandle()
check_call(_LIB.MXDataIterCreateIter(
handle,
mx_uint(len(param_keys)),
param_keys, param_vals,
ctypes.byref(iter_handle)))