-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathdali.py
319 lines (292 loc) · 11 KB
/
dali.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import copy
import os
import numpy as np
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import paddle
from nvidia.dali import fn
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
from nvidia.dali.plugin.paddle import DALIGenericIterator
class HybridTrainPipe(Pipeline):
def __init__(self,
file_root,
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
shard_id=0,
num_shards=1,
random_shuffle=True,
num_threads=4,
seed=42,
pad_output=False,
output_dtype=types.FLOAT,
dataset='Train'):
super(HybridTrainPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.readers.File(
file_root=file_root,
file_list=file_list,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=random_shuffle)
# set internal nvJPEG buffers size to handle full-sized ImageNet images
# without additional reallocations
device_memory_padding = 211025920
host_memory_padding = 140544512
self.decode = ops.decoders.ImageRandomCrop(
device='mixed',
output_type=types.DALIImageType.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[lower, upper],
random_area=[min_area, 1.0],
num_attempts=100)
self.res = ops.Resize(
device='gpu', resize_x=crop, resize_y=crop, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
dtype=output_dtype,
output_layout='CHW',
crop=(crop, crop),
mean=mean,
std=std,
pad_output=pad_output)
self.coin = ops.random.CoinFlip(probability=0.5)
self.to_int64 = ops.Cast(dtype=types.DALIDataType.INT64, device="gpu")
def define_graph(self):
rng = self.coin()
jpegs, labels = self.input(name="Reader")
images = self.decode(jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror=rng)
return [output, self.to_int64(labels.gpu())]
def __len__(self):
return self.epoch_size("Reader")
class HybridValPipe(Pipeline):
def __init__(self,
file_root,
file_list,
batch_size,
resize_shorter,
crop,
interp,
mean,
std,
device_id,
shard_id=0,
num_shards=1,
random_shuffle=False,
num_threads=4,
seed=42,
pad_output=False,
output_dtype=types.FLOAT):
super(HybridValPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.readers.File(
file_root=file_root,
file_list=file_list,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=random_shuffle)
self.decode = ops.decoders.Image(device="mixed")
self.res = ops.Resize(
device="gpu", resize_shorter=resize_shorter, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
dtype=output_dtype,
output_layout='CHW',
crop=(crop, crop),
mean=mean,
std=std,
pad_output=pad_output)
self.to_int64 = ops.Cast(dtype=types.DALIDataType.INT64, device="gpu")
def define_graph(self):
jpegs, labels = self.input(name="Reader")
images = self.decode(jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.to_int64(labels.gpu())]
def __len__(self):
return self.epoch_size("Reader")
def dali_dataloader(config, mode, device, seed=None):
assert "gpu" in device, "gpu training is required for DALI"
device_id = int(device.split(':')[1])
config_dataloader = config[mode]
seed = 42 if seed is None else seed
ops = [
list(x.keys())[0]
for x in config_dataloader["dataset"]["transform_ops"]
]
support_ops_train = [
"DecodeImage", "NormalizeImage", "RandFlipImage", "RandCropImage"
]
support_ops_eval = [
"DecodeImage", "ResizeImage", "CropImage", "NormalizeImage"
]
if mode.lower() == 'train':
assert set(ops) == set(
support_ops_train
), "The supported trasform_ops for train_dataset in dali is : {}".format(
",".join(support_ops_train))
else:
assert set(ops) == set(
support_ops_eval
), "The supported trasform_ops for eval_dataset in dali is : {}".format(
",".join(support_ops_eval))
normalize_ops = [
op for op in config_dataloader["dataset"]["transform_ops"]
if "NormalizeImage" in op
][0]["NormalizeImage"]
channel_num = normalize_ops.get("channel_num", 3)
output_dtype = types.FLOAT16 if normalize_ops.get("output_fp16",
False) else types.FLOAT
env = os.environ
# assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \
# "Please leave enough GPU memory for DALI workspace, e.g., by setting" \
# " `export FLAGS_fraction_of_gpu_memory_to_use=0.8`"
gpu_num = paddle.distributed.get_world_size()
batch_size = config_dataloader["sampler"]["batch_size"]
file_root = config_dataloader["dataset"]["image_root"]
file_list = config_dataloader["dataset"]["cls_label_path"]
interp = 1 # settings.interpolation or 1 # default to linear
interp_map = {
0: types.DALIInterpType.INTERP_NN, # cv2.INTER_NEAREST
1: types.DALIInterpType.INTERP_LINEAR, # cv2.INTER_LINEAR
2: types.DALIInterpType.INTERP_CUBIC, # cv2.INTER_CUBIC
3: types.DALIInterpType.
INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
}
assert interp in interp_map, "interpolation method not supported by DALI"
interp = interp_map[interp]
pad_output = channel_num == 4
transforms = {
k: v
for d in config_dataloader["dataset"]["transform_ops"]
for k, v in d.items()
}
scale = transforms["NormalizeImage"].get("scale", 1.0 / 255)
scale = eval(scale) if isinstance(scale, str) else scale
mean = transforms["NormalizeImage"].get("mean", [0.485, 0.456, 0.406])
std = transforms["NormalizeImage"].get("std", [0.229, 0.224, 0.225])
mean = [v / scale for v in mean]
std = [v / scale for v in std]
sampler_name = config_dataloader["sampler"].get("name",
"DistributedBatchSampler")
assert sampler_name in ["DistributedBatchSampler", "BatchSampler"]
if mode.lower() == "train":
resize_shorter = 256
crop = transforms["RandCropImage"]["size"]
scale = transforms["RandCropImage"].get("scale", [0.08, 1.])
ratio = transforms["RandCropImage"].get("ratio", [3.0 / 4, 4.0 / 3])
min_area = scale[0]
lower = ratio[0]
upper = ratio[1]
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
shard_id = int(env['PADDLE_TRAINER_ID'])
num_shards = int(env['PADDLE_TRAINERS_NUM'])
device_id = int(env['FLAGS_selected_gpus'])
pipe = HybridTrainPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
shard_id,
num_shards,
seed=seed + shard_id,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
pipelines = [pipe]
# sample_per_shard = len(pipe) // num_shards
else:
pipe = HybridTrainPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id=device_id,
shard_id=0,
num_shards=1,
seed=seed,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
pipelines = [pipe]
# sample_per_shard = len(pipelines[0])
return DALIGenericIterator(
pipelines, ['data', 'label'], reader_name='Reader')
else:
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
crop = transforms["CropImage"]["size"]
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and sampler_name == "DistributedBatchSampler":
shard_id = int(env['PADDLE_TRAINER_ID'])
num_shards = int(env['PADDLE_TRAINERS_NUM'])
device_id = int(env['FLAGS_selected_gpus'])
pipe = HybridValPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
interp,
mean,
std,
device_id=device_id,
shard_id=shard_id,
num_shards=num_shards,
pad_output=pad_output,
output_dtype=output_dtype)
else:
pipe = HybridValPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
interp,
mean,
std,
device_id=device_id,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
return DALIGenericIterator(
[pipe], ['data', 'label'], reader_name="Reader")