forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conv2d.py
790 lines (662 loc) · 30.7 KB
/
conv2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
# pylint: disable=invalid-name,unused-variable,no-else-return
"""Conv2D schedule for ARM CPU"""
from __future__ import absolute_import as _abs
import warnings
import numpy as np
import tvm
from tvm import autotvm
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
from ..util import traverse_inline, get_const_tuple, const_matrix
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
conv2d_winograd_without_weight_transform, depthwise_conv2d_nchw
from ..nn.util import get_const_int, get_pad_tuple
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
"""TOPI compute callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
kernel : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height,
filter_width, num_filter_block]
strides : list of two ints
[stride_height, stride_width]
padding : list of two ints
[pad_height, pad_width]
dilation : list of two ints
[dilation_height, dilation_width]
layout : str
layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
if layout == "NCHW":
return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
num_tile=2)
else:
return _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
num_tile=2)
@autotvm.register_topi_schedule(schedule_conv2d_nchw, 'arm_cpu', ['direct', 'winograd'])
def schedule_conv2d_nchw_arm_cpu(cfg, outs):
"""TOPI schedule callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
# schedule conv2d
if 'spatial_conv2d_output' in op.tag:
output = op.output(0)
conv = op.input_tensors[0]
data_vec = conv.op.input_tensors[0]
data_pad = data_vec.op.input_tensors[0]
s[data_pad].compute_inline()
kernel_vec = conv.op.input_tensors[1]
if kernel_vec.op.name == 'kernel_vec':
kernel = kernel_vec.op.input_tensors[0]
else:
kernel = kernel_vec
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
if 'nhwc' in op.tag:
_schedule_spatial_pack_nhwc(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
else:
_schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
if 'winograd_conv2d_output' in op.tag:
output = op.output(0)
_schedule_winograd(cfg, s, output, outs[0])
traverse_inline(s, outs[0].op, _callback)
return s
def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, num_tile):
assert layout == "NCHW", "Only support NCHW"
# create workload according to raw arguments
out_dtype = out_dtype or data.dtype
N, CI, IH, IW = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
if len(kernel.shape) == 4:
pre_packed = False
CO, _, KH, KW = get_const_tuple(kernel.shape)
else: # kernel tensor is pre packed
pre_packed = True
CO, _, KH, KW, VC = get_const_tuple(kernel.shape)
CO = CO * VC
dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])
# ==================== define configuration space ====================
n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
if num_tile == 2: # for arm cpu
co, vc = cfg.define_split('tile_co', co, num_outputs=2)
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
elif num_tile == 3: # for mali gpu
co, _, vc = cfg.define_split('tile_co', co, num_outputs=3)
oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3)
ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3)
else:
raise RuntimeError("Invalid num_tile")
cfg.define_reorder("reorder_0",
[n, co, oh, ow, ci, kh, kw, vh, vw, vc],
policy='candidate', candidate=[
[n, co, oh, ow, ci, kh, kw, vh, vw, vc],
[n, co, oh, ow, ci, kh, kw, vc, vh, vw]])
cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
# fallback support
if cfg.is_fallback:
if num_tile == 2: # arm cpu
ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct')
cfg.fallback_with_reference_log(ref_log)
elif num_tile == 3: # mali gpu
ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct')
cfg.fallback_with_reference_log(ref_log)
# ====================================================================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
kvshape = (CO // VC, CI, KH, KW, VC)
ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC)
oshape = (N, CO, OH, OW)
if dilation_h != 1 or dilation_w != 1:
# undilate input data
dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw:
data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h]
[(w*VW+vw)*WSTR+kw*dilation_w],
name='data_vec_undilated')
else:
dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw:
data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
name='data_vec')
if pre_packed:
kernel_vec = kernel
else:
kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc:
kernel[co*VC+vc][ci][kh][kw],
name='kernel_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw')
if dilation_h != 1 or dilation_w != 1:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
else:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_conv2d_output')
return output
def _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, num_tile):
assert layout == "NHWC", "Only support NHWC"
# create workload according to raw arguments
out_dtype = out_dtype or data.dtype
N, IH, IW, CI = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
if len(kernel.shape) == 4:
KH, KW, _, CO = get_const_tuple(kernel.shape)
pad_top, pad_left, pad_bottom, pad_right = get_const_tuple(padding)
# get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
OH = (IH + pad_top + pad_bottom - KH) // HSTR + 1
OW = (IW + pad_left + pad_right - KW) // WSTR + 1
data_pad = pad(data, [0, pad_top, pad_left, 0], [0, pad_bottom, pad_right, 0])
# ==================== define configuration space ====================
n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
if num_tile == 2: # for arm cpu
co, vc = cfg.define_split('tile_co', co, num_outputs=2)
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
elif num_tile == 3: # for mali gpu
co, _, vc = cfg.define_split('tile_co', co, num_outputs=3)
oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3)
ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3)
else:
raise RuntimeError("Invalid num_tile")
cfg.define_reorder("reorder_0",
[n, oh, ow, co, ci, kh, kw, vh, vw, vc],
policy='candidate', candidate=[
[n, oh, ow, co, ci, kh, kw, vh, vw, vc],
[n, oh, ow, co, ci, kh, kw, vc, vh, vw]])
cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
# fallback support
# if cfg.is_fallback:
# if num_tile == 2: # arm cpu
# ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct')
# cfg.fallback_with_reference_log(ref_log)
# elif num_tile == 3: # mali gpu
# ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct')
# cfg.fallback_with_reference_log(ref_log)
# ====================================================================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
kvshape = (CO // VC, CI, KH, KW, VC)
ovshape = (N, OH // VH, OW // VW, CO // VC, VH, VW, VC)
oshape = (N, OH, OW, CO)
# undilate input data
dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw:
data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
name='data_vec')
kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc:
kernel[co*VC+vc][ci][kh][kw],
name='kernel_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw')
conv = tvm.compute(ovshape, lambda n, h, w, co, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
output = tvm.compute(oshape, lambda n, h, w, co:
conv[n][h//VH][w//VW][CO // VC][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_conv2d_output_nhwc')
return output
def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
conv, output, last):
"""schedule implementation"""
n, co, oh, ow, vh, vw, vc = s[conv].op.axis
ci, kh, kw = s[conv].op.reduce_axis
# schedule conv
cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc])
cfg["ann_reduce"].apply(s, conv, [kh, kw],
axis_lens=[get_const_int(kh.dom.extent),
get_const_int(kw.dom.extent)],
max_unroll=16,
cfg=cfg)
cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
axis_lens=[cfg['tile_oh'].size[-1],
cfg['tile_ow'].size[-1],
cfg['tile_co'].size[-1]],
max_unroll=16,
cfg=cfg)
# schedule fusion
n, co, h, w = s[last].op.axis
co, vc = cfg['tile_co'].apply(s, last, co)
oh, vh = cfg['tile_oh'].apply(s, last, h)
ow, vw = cfg['tile_ow'].apply(s, last, w)
s[last].reorder(n, co, oh, ow, vh, vw, vc)
if last != output:
s[output].compute_inline()
cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
axis_lens=[cfg['tile_oh'].size[-1],
cfg['tile_ow'].size[-1],
cfg['tile_co'].size[-1]],
max_unroll=16,
cfg=cfg)
s[conv].compute_at(s[last], ow)
# mark parallel
s[last].parallel(co)
if data_vec.op.name == 'data_vec_undilated':
_, h, _, _, _, _, _, _ = s[data_vec].op.axis
else:
_, h, _, _, _, _ = s[data_vec].op.axis
s[data_vec].parallel(h)
if kernel_vec.op.name == 'kernel_vec':
co, _, _, _, _ = s[kernel_vec].op.axis
if autotvm.GLOBAL_SCOPE.in_tuning:
# kernel packing will be pre-computed during compilation, so we skip
# this part to make tuning records correct
s[kernel_vec].pragma(co, 'debug_skip_region')
else:
s[kernel_vec].parallel(co)
elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose': # for conv2d transpose
co, _, _, _, _ = s[kernel_vec].op.axis
s[kernel_vec].parallel(co)
return s
def _schedule_spatial_pack_nhwc(cfg, s, data_vec, kernel_vec,
conv, output, last):
"""schedule implementation"""
n, oh, ow, co, vh, vw, vc = s[conv].op.axis
ci, kh, kw = s[conv].op.reduce_axis
# schedule conv
cfg["reorder_0"].apply(s, conv, [n, oh, ow, co, ci, kh, kw, vh, vw, vc])
cfg["ann_reduce"].apply(s, conv, [kh, kw],
axis_lens=[get_const_int(kh.dom.extent),
get_const_int(kw.dom.extent)],
max_unroll=16,
cfg=cfg)
cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
axis_lens=[cfg['tile_oh'].size[-1],
cfg['tile_ow'].size[-1],
cfg['tile_co'].size[-1]],
max_unroll=16,
cfg=cfg)
# schedule fusion
n, h, w, co = s[last].op.axis
co, vc = cfg['tile_co'].apply(s, last, co)
oh, vh = cfg['tile_oh'].apply(s, last, h)
ow, vw = cfg['tile_ow'].apply(s, last, w)
s[last].reorder(n, oh, ow, co, vh, vw, vc)
if last != output:
s[output].compute_inline()
cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
axis_lens=[cfg['tile_oh'].size[-1],
cfg['tile_ow'].size[-1],
cfg['tile_co'].size[-1]],
max_unroll=16,
cfg=cfg)
s[conv].compute_at(s[last], co)
# mark parallel
s[last].parallel(oh)
_, h, _, _, _, _ = s[data_vec].op.axis
s[data_vec].parallel(h)
co, _, _, _, _ = s[kernel_vec].op.axis
s[kernel_vec].parallel(co)
return s
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
""" TOPI compute callback. Use winograd template """
tile_size = 4
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout,
out_dtype, tile_size)
def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
N, CI, IH, IW = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
if len(kernel.shape) == 4:
if dilation_h != 1 or dilation_w != 1:
kernel = dilate(kernel, (1, 1, dilation_h, dilation_w))
pre_computed = False
CO, _, KH, KW = get_const_tuple(kernel.shape)
else:
assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation"
pre_computed = True
H_CAT, W_CAT, CO, CI, VC = get_const_tuple(kernel.shape)
CO *= VC
KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
if tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], dtype=np.float32)
B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)
A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], np.float32)
B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)
A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))
m = A_data.shape[1]
r = 3
alpha = m + r - 1
K = CO
C = CI
H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW
cfg.define_split('tile_p', cfg.axis(P), num_outputs=2, filter=lambda x: x.size[-1] <= 16)
cfg.define_split('tile_k', cfg.axis(K), num_outputs=2, filter=lambda x: x.size[-1] <= 16)
VP = cfg['tile_p'].size[-1]
VK = cfg['tile_k'].size[-1]
# pack input tile
input_tile = tvm.compute((C, P // VP, alpha, alpha, VP),
lambda c, b, eps, nu, bb:
data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps]
[(b*VP+bb) % nW * m + nu],
name='d')
# transform kernel
if pre_computed:
U = kernel
else:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')
# transform image
B = const_matrix(B_data, 'B')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb:
tvm.sum(input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) *
B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V')
# batch gemm
c = tvm.reduce_axis((0, C), name='c')
M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b:
tvm.sum(U[eps][nu][k // VK][c][k % VK] *
V[eps][nu][b // VP][c][b % VP], axis=c), name='M')
# inverse transform
A = const_matrix(A_data, 'A')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw:
tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw],
axis=[r_eps, r_nu]), name='Y')
# unpack output
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
name='output', tag='winograd_conv2d_output')
# we have to manually assign effective GFLOP for winograd
cfg.add_flop(2 * N * K * H * W * KH * KW * C)
return output
def _schedule_winograd(cfg, s, output, last):
Y = output.op.input_tensors[0]
M, A = Y.op.input_tensors
U, V = M.op.input_tensors
d, B = V.op.input_tensors
data_pad = d.op.input_tensors[0]
# padding
s[data_pad].compute_inline()
# pack input tiles
s[d].compute_inline()
# transform kernel
if isinstance(U.op, tvm.tensor.ComputeOp):
kernel, G = U.op.input_tensors
s[G].compute_inline()
eps, nu, k, c, kk, = s[U].op.axis
if autotvm.GLOBAL_SCOPE.in_tuning:
# kernel transformation will be pre-computed during compilation, so we skip
# this part to make tuning records correct
s[U].pragma(eps, 'debug_skip_region')
else:
r_kh, r_kw = s[U].op.reduce_axis
s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk)
for axis in [eps, nu, r_kh, r_kw]:
s[U].unroll(axis)
s[U].vectorize(kk)
s[U].parallel(k)
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
# transform image
DD = s.cache_read(d, 'global', [V])
s[B].compute_inline()
eps, nu, b, c, bb = s[V].op.axis
r_eps, r_nu = s[V].op.reduce_axis
s[V].reorder(b, c, eps, nu, r_eps, r_nu, bb)
for axis in [eps, nu, r_eps, r_nu]:
s[V].unroll(axis)
s[DD].compute_at(s[V], c)
s[V].vectorize(bb)
s[V].parallel(b)
# batch gemm
eps, nu, k, b = s[M].op.axis
c = s[M].op.reduce_axis[0]
cfg.define_split('tile_c', c, num_outputs=2, filter=lambda x: x.size[-1] <= 16)
co, ci = cfg['tile_c'].apply(s, M, c)
xo, xi = cfg['tile_p'].apply(s, M, b)
s[M].reorder(eps, nu, xo, co, k, ci, xi)
cfg.define_annotate('ann_reduce', [ci], policy='try_unroll')
cfg.define_annotate('ann_spatial', [k, xi], policy='try_unroll_vec')
cfg['ann_reduce'].apply(s, M, [ci],
axis_lens=[cfg['tile_c'].size[-1]],
max_unroll=16,
cfg=cfg)
cfg['ann_spatial'].apply(s, M, [k, xi])
# inverse transform
s[A].compute_inline()
k, b, vh, vw = s[Y].op.axis
r_eps, r_nu = s[Y].op.reduce_axis
for axis in [vh, vw, r_eps, r_nu]:
s[Y].unroll(axis)
# output
n, co, h, w = s[last].op.axis
co, coi = cfg['tile_k'].apply(s, last, co)
s[M].compute_at(s[last], co)
s[last].parallel(co)
MM = s.cache_read(M, 'global', [Y])
m = get_const_int(V.shape[0]) + 1 - 3
ho, wo, hi, wi = s[last].tile(h, w, m, m)
s[Y].compute_at(s[last], wo)
s[MM].compute_at(s[last], wo)
if output != last:
s[output].compute_inline()
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'arm_cpu', ['winograd'])
def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
"""TOPI compute callback"""
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,\
tile_size)
@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
'arm_cpu', ['winograd'])
def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
"""TOPI schedule callback"""
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if 'winograd_conv2d_output' in op.tag:
output = op.output(0)
_schedule_winograd(cfg, s, output, outs[0])
traverse_inline(s, outs[0].op, _callback)
return s
##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["arm_cpu"])
def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
"""Alter op layout for pre-computing kernel transformation
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution
inputs : nnvm.symbol or tvm.relay.Expr
Grouped input symbols
tinfos : list
Input shape and dtype
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
"""
copy_inputs = [s for s in inputs]
new_attrs = {k: attrs[k] for k in attrs.keys()}
dilation = attrs.get_int_tuple("dilation")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
groups = attrs.get_int('groups')
data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
layout = attrs[data_layout_key]
out_dtype = attrs["out_dtype"]
if out_dtype == "" or out_dtype == "same":
out_dtype = tinfos[0].dtype
if layout != 'NCHW':
return None
if dilation != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
return None
data, kernel = tinfos[0:2]
N, CI, H, W = get_const_tuple(data.shape)
CO, _, KH, KW = get_const_tuple(kernel.shape)
if groups == 1:
# query config of this workload
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None
if cfg.template_key == 'direct': # pack weight tensor
VC = cfg['tile_co'].size[-1]
new_attrs['kernel_layout'] = 'OIHW%do' % VC
# Store the same config for the altered operator (workload)
new_data = data
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg)
return F.nn.conv2d(*copy_inputs, **new_attrs)
else: # pre-compute weight transformation in winograd
if "-device=arm_cpu" in target.options:
tile_size = 4
VC = cfg['tile_k'].size[-1]
else:
from ..mali.conv2d import _pick_tile_size
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
VC = cfg['tile_bna'].val
weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
weight = F.reshape(weight,
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size
# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation,
new_attrs[data_layout_key], out_dtype, tile_size],
conv2d_winograd_without_weight_transform)
dispatch_ctx.update(target, new_workload, cfg)
return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
else:
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
return None
if cfg.template_key == 'contrib_spatial_pack':
VC = cfg['tile_co'].size[-1]
new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
# Store the same config for the altered operator (workload)
new_data = data
CO, M, KH, KW = get_const_tuple(kernel.shape)
new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype],
depthwise_conv2d_nchw)
dispatch_ctx.update(target, new_workload, cfg)
return F.nn.conv2d(*copy_inputs, **new_attrs)
else:
# currently we only have contrib_spatial_pack and direct template
# add more schedule templates.
return None