-
Notifications
You must be signed in to change notification settings - Fork 433
/
Copy pathnn.py
2057 lines (1782 loc) · 107 KB
/
nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-License-Identifier: Apache-2.0
"""
nn
"""
import logging
import numpy as np
from onnx import onnx_pb, helper
from onnx.onnx_pb import TensorProto
from tf2onnx import constants, utils
from tf2onnx.graph_builder import GraphBuilder
from tf2onnx.handler import tf_op
from tf2onnx.onnx_opset import common, controlflow, tensor
logger = logging.getLogger(__name__)
# pylint: disable=unused-argument,missing-docstring,unused-variable
def spatial_map(shape, perm):
new_shape = shape[:]
for i in perm:
new_shape[i] = shape[perm[i]]
return new_shape
def is_channels_last(node):
"""Returns whether node is channels last, so (N, ..., C)."""
return not node.data_format.startswith("NC")
def make_shape_channels_first(shape):
"""Makes a (N, ..., C) shape into (N, C, ...)."""
return shape[:1] + shape[-1:] + shape[1:-1]
def make_shape_channels_last(shape):
"""Makes a (N, C, ...) shape into (N, ..., C)."""
return shape[:1] + shape[1:-1] + shape[1:2]
def get_channels_first_permutation(spatial):
"""Returns a permutation to make a (N, ..., C) array into (N, C, ...)."""
return [0, spatial + 1] + list(range(1, spatial + 1))
def get_channels_last_permutation(spatial):
"""Returns a permutation to make a (N, C, ...) array into (N, ..., C)."""
return [0] + list(range(2, spatial + 2)) + [1]
def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
input_indices=None, output_indices=None, spatial=2):
"""Convert input and kernel from tensorflow to onnx. This may be required to
insert transpose ops for input, kernel, and output unless they are constants
and we can transpose the constant.
We transpose inputs if they are in NHWC. We always transpose the kernel from
HWNC to NCHW. Outputs are transposed if the format is NHWC.
Some convolutions like depthwise_conv2d require a reshape of the kernel.
Args:
ctx: The parent graph.
node: Node of the convolution op.
with_kernel: Transpose the kernel.
new_kernel_shape: Pass to reshape the kernel.
input_indices: Indices that define the inputs.
output_indices: Indices that define the outputs.
"""
if input_indices is None:
input_indices = [0]
if output_indices is None:
output_indices = [0]
# Transpose inputs if needed.
if is_channels_last(node):
# Ge channels first permutation.
permutation = get_channels_first_permutation(spatial)
# Transpose input if needed, no need to record shapes on input
for idx in input_indices:
# If input is a constant, transpose that one if we are the only consumer.
input_node = node.inputs[idx]
input_name = node.input[idx]
if input_node.is_const() and len(ctx.find_output_consumers(input_name)) == 1:
# Transpose constant to make it channels first.
val = input_node.get_tensor_value(as_list=False)
val = np.transpose(val, permutation)
input_node.set_tensor_value(val)
else:
# Insert transpose op.
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
transpose.set_attr("perm", permutation)
transpose.skip_conversion = True
shape = ctx.get_shape(input_name)
if shape is not None:
new_shape = make_shape_channels_first(shape)
ctx.set_shape(transpose.output[0], new_shape)
# Transpose kernel if needed.
if with_kernel:
# Some ONNX convolution ops require to reshape the kernel (ie. depthwise_conv2d).
if new_kernel_shape:
kernel_name = node.input[1]
if ctx.opset < 5:
# Old reshape takes new shape as attribute.
reshape = ctx.insert_new_node_on_input(node, "Reshape", kernel_name)
reshape.set_attr("shape", new_kernel_shape)
reshape.skip_conversion = True
else:
# New reshape takes new shape as input[1].
shape_name = utils.make_name(node.name)
ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64))
reshape = ctx.make_node("Reshape", [kernel_name, shape_name])
ctx.replace_input(node, kernel_name, reshape.output[0], 1)
reshape.skip_conversion = True
ctx.set_shape(reshape.output[0], new_kernel_shape)
# Get kernel (may have be changed to a reshape above).
kernel_node = node.inputs[1]
kernel_name = node.input[1]
# Transpose kernel from (..., C_in, C_out) to (C_out, C_in, ...)
permutation = [spatial + 1, spatial] + list(range(spatial))
# If kernel is a constant, transpose that one if we are the only consumer.
need_transpose = True
if (kernel_node.is_const() or kernel_node.op.op_type == "DequantizeLinear") \
and len(ctx.find_output_consumers(kernel_name)) == 1:
if kernel_node.op.op_type == 'DequantizeLinear':
# Assuming the model was trained in NHWC in TF,
# the weights would be in [fH, fW, C_in, C_out].
# orig_conv_weights -> Q -> DQ -> new_conv_weights -> conv
weights_node = kernel_node.inputs[0].inputs[0]
val = weights_node.get_tensor_value(as_list=False)
val = np.transpose(val, permutation)
weights_node.set_tensor_value(val)
need_transpose = False
# Change the quantization axis for Q and DQ node accordingly
kernel_node.set_attr("axis", 0) # DQ node
kernel_node.inputs[0].set_attr("axis", 0) # Q node
else:
val = kernel_node.get_tensor_value(as_list=False)
val = np.transpose(val, permutation)
kernel_node.set_tensor_value(val)
need_transpose = False
if need_transpose:
transpose = ctx.insert_new_node_on_input(node, "Transpose", kernel_name)
transpose.set_attr("perm", permutation)
transpose.skip_conversion = True
new_shape = spatial_map(ctx.get_shape(kernel_name), permutation)
ctx.set_shape(transpose.output[0], new_shape)
# Transpose outputs back if needed.
if is_channels_last(node):
for idx in output_indices:
# Make output channels last again by transposing.
output_name = node.output[idx]
output_shape = ctx.get_shape(node.output[idx])
permutation = get_channels_last_permutation(spatial)
op_name = utils.make_name(node.name)
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
transpose.set_attr("perm", permutation)
transpose.skip_conversion = True
# Set tensorflow channels last shape as the transpose node shape.
ctx.set_shape(transpose.output[0], output_shape)
# Make the current ONNX convolution output shape channels first.
ctx.set_shape(output_name, make_shape_channels_first(output_shape))
# NOTE: Not strictly correct as it can also be NCW or NCDHW for example.
# NOTE: Generally speaking it's channels first.
node.data_format = "NCHW"
def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
padding = node.get_attr("padding")
if not padding:
return
if dilations is None:
dilations = [1] * spatial
padding = padding.s.decode("utf-8")
if padding == "SAME":
# Initialize with all zeros.
# Paddings are in (x_begin, y_begin, ..., x_end, y_end, ...) order.
pads = [0] * (spatial * 2)
# Get shapes and check whether valid.
input_shape = ctx.get_shape(node.input[0])
output_shape = ctx.get_shape(node.output[0])
if len(input_shape) != spatial + 2:
raise ValueError(
"node {} output needs to be rank {}, is {}".format(
node.name, spatial + 2, len(input_shape)
)
)
if len(output_shape) != spatial + 2:
raise ValueError(
"node {} output needs to be rank {}, is {}".format(
node.name, spatial + 2, len(output_shape)
)
)
# Transpose to channels first if not so.
if is_channels_last(node):
input_shape = make_shape_channels_first(input_shape)
output_shape = make_shape_channels_first(output_shape)
# Check for unknown input/output dimensions. Fall back to auto padding if so.
if any(input_shape[i + 2] == -1 or output_shape[i + 2] == -1 for i in range(spatial)):
logger.debug(
"node %s has unknown dim for pads calculation, fallback to auto_pad: "
"input_shape=%s, output_shape=%s",
node.name,
input_shape,
output_shape,
)
node.set_attr("auto_pad", "SAME_UPPER")
return
# Calculate paddings.
for i in range(spatial):
pad = (
(output_shape[i + 2] - 1) * strides[i]
+ dilations[i] * (kernel_shape[i] - 1) + 1
- input_shape[i + 2]
)
pad = max(pad, 0)
pads[i] = pad // 2
pads[i + spatial] = pad - pad // 2
node.set_attr("pads", pads)
elif padding == "VALID":
pass
elif padding == "EXPLICIT":
pads = node.get_attr_value("explicit_paddings")
start_pads = []
end_pads = []
d = 1 if is_channels_last(node) else 2
for i in range(spatial):
start_pads.append(pads[(d + i) * 2])
end_pads.append(pads[(d + i) * 2 + 1])
node.set_attr("pads", start_pads + end_pads)
else:
raise ValueError("invalid padding value: {}".format(padding))
def parse_dims_attr(node, dims, spatial):
if is_channels_last(node):
# We have (N, ..., C) or (...).
if len(dims) != spatial:
dims = dims[1:-1]
else:
# We have (N, C, ...) or (...).
if len(dims) != spatial:
dims = dims[2:]
return dims
def conv_dims_attr(node, name, new_name=None, spatial=2):
# Fetch attribute.
if new_name is None:
new_name = name
dims = node.get_attr(name)
if not dims:
return None
# Get spatial part.
dims = dims.ints
dims = parse_dims_attr(node, dims, spatial)
# Set new value and return it.
node.set_attr(new_name, dims)
return dims
def conv_kernel_shape(ctx, node, input_idx, spatial=2):
# Kernel shape is (..., C_in, C_out).
kernel_shape = ctx.get_shape(node.input[input_idx])
if len(kernel_shape) != spatial + 2:
raise ValueError("kernel rank must be spatial+2")
# Get spatial part.
kernel_shape = kernel_shape[:spatial]
# Set new value and return it.
node.set_attr("kernel_shape", kernel_shape)
return kernel_shape
def build_dynamic_target_size(ctx, transposed_intput, target_hw):
"""
Build the target tensor shape for the Resize op.
Args:
- ctx: the graph context
- transposed_intput: A tensor of rank 4 of shape [n c h w]
- target_hw: tensor of rank 2 containing the target size for a resize: [nh nw]
Returns:
A tensor of rank 2 containing [n c nh nw]
"""
# We get the first half [n c] of the target shape
shape_of_transposed_input = ctx.make_node("Shape", [transposed_intput])
first_half_of_shape = GraphBuilder(ctx).make_slice(
{"data": shape_of_transposed_input.output[0], "ends": [2], "starts": [0]})
if ctx.get_dtype(target_hw) != TensorProto.INT64:
target_hw = ctx.make_node("Cast", [target_hw], attr={'to': TensorProto.INT64}).output[0]
# We build a tensor containing [n c nh nw]
final_target_size = ctx.make_node("Concat", [first_half_of_shape, target_hw], {'axis': 0})
return final_target_size
@tf_op(["Conv1D", "Conv2D", "Conv3D"])
class ConvOp:
@classmethod
def any_version(cls, opset, ctx, node, **kwargs):
# ONNX specification:
#
# T output = Conv2D(T input, T filter, @list(int) strides, @bool use_cudnn_on_gpu,
# @string padding, @string data_format)
#
# T Y = Conv(T X, T W, T B, @AttrType.STRING auto_pad, @AttrType.INTS dilations, @AttrType.INT group,
# @AttrType.INTS kernel_shape, @AttrType.INTS pads, @AttrType.INTS strides)
#
# Determine number of spatial dimensions.
spatial = int(node.type[-2])
# Make it a convolution node.
node.type = "Conv"
# Determine kernel spatial shape, strides and dilations.
kernel_shape = conv_kernel_shape(ctx, node, 1, spatial=spatial)
strides = conv_dims_attr(node, "strides", spatial=spatial)
dilations = conv_dims_attr(node, "dilations", spatial=spatial)
# prefix with batch dim of [1] to satisfy rank requirements
input_shape = ctx.get_shape(node.input[0])
if input_shape is not None and len(input_shape) == spatial + 1:
gb = GraphBuilder(ctx)
usq_node = gb.make_unsqueeze({"axes": [0], 'data': node.input[0]}, return_node=True)
ctx.replace_inputs(node, [usq_node.output[0]] + node.input[1:])
# Set padding.
add_padding(
ctx, node, kernel_shape, strides, dilations=dilations, spatial=spatial
)
groups = int(1)
data_format = str(node.attr["data_format"].s, encoding="utf8")
shape_dim = -1
if data_format == "NHWC":
shape_dim = ctx.get_shape(node.input[0])[3]
elif data_format == "NCHW":
shape_dim = ctx.get_shape(node.input[0])[1]
if shape_dim != -1:
groups = int(shape_dim / ctx.get_shape(node.input[1])[2])
node.set_attr("group", groups)
# Convert input and filters.
conv_convert_inputs(ctx, node, with_kernel=True, spatial=spatial)
@classmethod
def version_1(cls, ctx, node, **kwargs):
cls.any_version(1, ctx, node, **kwargs)
@classmethod
def version_11(cls, ctx, node, **kwargs):
# No change.
cls.any_version(11, ctx, node, **kwargs)
@classmethod
def version_13(cls, ctx, node, **kwargs):
# Signature change for operator Unsqueeze.
cls.any_version(13, ctx, node, **kwargs)
def get_shape_from_const_or_concat(ctx, node):
if node.is_const():
return node.get_tensor_value()
if node.type == 'Concat':
# Sometimes the shape is formed by concating a bunch of consts together
res = []
if any(ctx.get_shape(inp) != [1] for inp in node.input):
return None
for i, inp in enumerate(node.inputs):
# The concat is converted from a Pack. Conversion adds an unsqueeze to the inputs.
if node.inputs[i].type == 'Unsqueeze' and node.inputs[i].inputs[0].is_scalar():
res.append(node.inputs[i].inputs[0].get_tensor_value())
else:
if i == 0:
# For the batch dimension we don't care if it is unknown
res.append(-1)
else:
return None
return res
return None
@tf_op(["Conv2DBackpropInput", "Conv3DBackpropInputV2"])
class ConvTranspose:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = Conv2DBackpropInput(int32 input_sizes, T filter, T out_backprop,
# @list(int) strides, @bool use_cudnn_on_gpu, @string padding, @string data_format, @list(int) dilations)
# T Y = ConvTranspose(T X, T W, T B, @STRING auto_pad, @INTS dilations,
# @INT group, @INTS kernel_shape, @INTS output_shape, @INTS pads, @INTS strides)
if node.type == "Conv3DBackpropInputV2":
spatial = 3
else:
spatial = 2
node.type = "ConvTranspose"
# Note: inputs are reversed from what one would expect.
conv_kernel_shape(ctx, node, 1, spatial=spatial)
input_shape = ctx.get_shape(node.input[2])
input_batch_dim = input_shape[0]
output_c_dim = ctx.get_shape(node.input[1])[-2]
if is_channels_last(node):
input_dims = input_shape[1:1+spatial]
else:
input_dims = input_shape[2:2+spatial]
output_shape_orig = node.output_shapes
# output_shape is explicitly specified here and then converted to explicit pads.
output_shape = get_shape_from_const_or_concat(ctx, node.inputs[0])
if output_shape is not None:
if is_channels_last(node):
new_output_shape = [output_shape[1], output_shape[2]]
if spatial == 3:
new_output_shape.append(output_shape[3])
else:
new_output_shape = [output_shape[2], output_shape[3]]
if spatial == 3:
new_output_shape.append(output_shape[4])
utils.make_sure(new_output_shape.count(-1) <= 0, "output dims need to be known")
utils.make_sure(all(new_output_shape[i] >= input_dims[i] for i in range(spatial)),
"output dims cannot be smaller than input dims.")
if -1 in input_dims:
node.set_attr("output_shape", new_output_shape)
else:
if "strides" in node.attr:
strides = parse_dims_attr(node, node.get_attr("strides").ints, spatial)
else:
strides = [1] * spatial
if "dilations" in node.attr:
dilations = parse_dims_attr(node, node.get_attr("dilations").ints, spatial)
else:
dilations = [1] * spatial
kernel_shape = parse_dims_attr(node, node.get_attr("kernel_shape").ints, spatial)
total_padding = [-1] * spatial
pads = [1] * (spatial * 2)
for i in range(spatial):
total_padding[i] = (strides[i] * (input_dims[i] - 1)
+ ((kernel_shape[i] - 1) * dilations[i] + 1)
- new_output_shape[i])
start_i = i
end_i = i + spatial
pads[start_i] = int(total_padding[i] / 2)
pads[end_i] = total_padding[i] - pads[start_i]
node.set_attr("pads", pads)
node.set_attr("auto_pad", "NOTSET")
else:
utils.make_sure(ctx.opset >= 10, "Opset 10 needed for Conv Backprop Input with non-constant shape")
strides = parse_dims_attr(node, node.get_attr('strides').ints, spatial)
if 'dilations' in node.attr:
dilations = parse_dims_attr(node, node.get_attr('dilations').ints, spatial)
else:
dilations = [1] * spatial
kernel_shape = parse_dims_attr(node, node.get_attr('kernel_shape').ints, spatial)
new_dims = [-1] * spatial
for i in range(spatial):
new_dims[i] = strides[i] * (input_dims[i] - 1) + ((kernel_shape[i] - 1) * dilations[i] + 1)
if is_channels_last(node):
new_shape = [input_batch_dim] + new_dims + [output_c_dim]
else:
new_shape = [input_batch_dim, output_c_dim] + new_dims
ctx.set_shape(node.output[0], new_shape)
use_strides_workaround = any(d > 1 for d in strides)
if use_strides_workaround and ctx.opset < 12:
# When strides > 1, ONNX and TF have an implementation difference in ConvTranspose. ONNX outputs a
# slightly smaller tensor which must be padded with a row of 0s. Pad with dynamic shape requires
# opset >= 11 and Max of int64 needs opset >= 12. Depending on the output_shape, this row of 0s might
# be shaved off, in which case TF and ONNX agree. When output_shape is dynamic it is impossible to
# know at conversion time whether this is the case and the workaround is needed.
logger.warning("Conv Backprop Input with strides > 1 and non-constant shape has known bug. "
"Workaround requires opset 12.")
use_strides_workaround = False
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
output_shape = ctx.make_node("Shape", [node.output[0]])
output_h = GraphBuilder(ctx).make_slice(
{"data": output_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
output_w = GraphBuilder(ctx).make_slice(
{"data": output_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
expect_h = GraphBuilder(ctx).make_slice(
{"data": input_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
expect_w = GraphBuilder(ctx).make_slice(
{"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
diff_h = ctx.make_node("Sub", [output_h, expect_h])
diff_w = ctx.make_node("Sub", [output_w, expect_w])
nonneg_diff_h = diff_h
nonneg_diff_w = diff_w
if use_strides_workaround:
const_zero = ctx.make_const(utils.make_name(node.name + "_const_zero"), np.array([0], dtype=np.int64))
nonneg_diff_h = ctx.make_node("Max", [diff_h.output[0], const_zero.output[0]])
nonneg_diff_w = ctx.make_node("Max", [diff_w.output[0], const_zero.output[0]])
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
start_h = ctx.make_node("Div", [nonneg_diff_h.output[0], const_two.output[0]])
start_w = ctx.make_node("Div", [nonneg_diff_w.output[0], const_two.output[0]])
end_h = ctx.make_node("Add", [start_h.output[0], expect_h])
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
if spatial == 3:
output_d = GraphBuilder(ctx).make_slice(
{"data": output_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
expect_d = GraphBuilder(ctx).make_slice(
{"data": input_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
diff_d = ctx.make_node("Sub", [output_d, expect_d])
nonneg_diff_d = diff_d
if use_strides_workaround:
nonneg_diff_d = ctx.make_node("Max", [diff_d.output[0], const_zero.output[0]])
start_d = ctx.make_node("Div", [nonneg_diff_d.output[0], const_two.output[0]])
end_d = ctx.make_node("Add", [start_d.output[0], expect_d])
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0], start_d.output[0]],
attr={"axis": 0})
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0], end_d.output[0]], attr={"axis": 0})
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
np.array([1, 2, 3], dtype=np.int64))
else:
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0})
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
np.array([1, 2], dtype=np.int64))
slice_node = ctx.make_node("Slice",
[node.output[0], starts.output[0], ends.output[0], slice_axes.output[0]],
shapes=output_shape_orig)
final_node = slice_node
if use_strides_workaround:
cz = const_zero.output[0]
neg_diff_h = ctx.make_node("Neg", [diff_h.output[0]])
shrink_h_by = ctx.make_node("Max", [neg_diff_h.output[0], const_zero.output[0]])
shb = shrink_h_by.output[0]
neg_diff_w = ctx.make_node("Neg", [diff_w.output[0]])
shrink_w_by = ctx.make_node("Max", [neg_diff_w.output[0], const_zero.output[0]])
swb = shrink_w_by.output[0]
if spatial == 3:
neg_diff_d = ctx.make_node("Neg", [diff_d.output[0]])
shrink_d_by = ctx.make_node("Max", [neg_diff_d.output[0], const_zero.output[0]])
sdb = shrink_d_by.output[0]
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0})
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])
else:
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0})
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])
final_node = padded_node
downstream_nodes = ctx.find_output_consumers(node.output[0])
downstream_nodes.remove(output_shape)
downstream_nodes.remove(slice_node)
ctx.replace_all_inputs(node.output[0], final_node.output[0], ops=downstream_nodes)
conv_dims_attr(node, "strides", spatial=spatial)
conv_dims_attr(node, "dilations", spatial=spatial)
# remove output_shapes input
ctx.remove_input(node, node.input[0], 0)
# swap data and kernel
t = node.input[0]
ctx.replace_input(node, node.input[0], node.input[1], 0)
ctx.replace_input(node, node.input[1], t, 1)
conv_convert_inputs(ctx, node, with_kernel=True, spatial=spatial)
@classmethod
def version_11(cls, ctx, node, **kwargs):
cls.version_1(ctx, node, **kwargs)
@tf_op(["DepthwiseConv2d", "DepthwiseConv2dNative"])
class DepthwiseConv2d:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = DepthwiseConv2dNative(T input, T filter, @list(int) strides, @string padding, @string data_format)
# T Y = ConvTranspose(T X, T W, T B, @AttrType.STRING auto_pad, @AttrType.INTS dilations, @AttrType.INT group,
# @AttrType.INTS kernel_shape, @AttrType.INTS output_shape, @AttrType.INTS pads, @AttrType.INTS strides)
#
# this is not documented well in onnx, the hint comes from pytorch documentation:
# http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d
# The configuration when groups == in_channels and out_channels = K * in_channels
# where K is a positive integer is termed in literature as depthwise convolution.
# In other words, for an input of size (N,Cin,Hin,Win),
# if you want a depthwise convolution with a depthwise multiplier K,
# then you use the constructor arguments (in_channels=Cin,out_channels=Cin*K,...,groups=Cin)
#
node.type = "Conv"
input_shape = ctx.get_shape(node.input[0])
if len(input_shape) != 4:
raise ValueError("only Conv2D is supported")
kernel_shape = ctx.get_shape(node.input[1])
if len(kernel_shape) != 4:
raise ValueError("only Conv2D is supported")
k_h, k_w, k_input_channels, k_channel_multiplier = kernel_shape
if "depth_multiplier" in node.attr:
depth_multiplier = node.get_attr_int("depth_multiplier")
k_input_channels //= depth_multiplier
k_channel_multiplier *= depth_multiplier
if k_input_channels < 1:
raise ValueError("input channel must be positive")
k_output_channels = k_input_channels * k_channel_multiplier
node.set_attr("kernel_shape", [k_h, k_w])
strides = conv_dims_attr(node, "strides")
dilations = conv_dims_attr(node, "dilations")
node.set_attr("group", k_input_channels)
add_padding(ctx, node, kernel_shape, strides, dilations)
new_kernel_shape = [k_h, k_w, 1, k_output_channels]
conv_convert_inputs(ctx, node, with_kernel=True, new_kernel_shape=new_kernel_shape)
@tf_op(["AvgPool", "AvgPool3D"], onnx_op="AveragePool")
@tf_op(["MaxPool", "MaxPoolV2", "MaxPool3D"], onnx_op="MaxPool")
class PoolOp:
@classmethod
def version_1(cls, ctx, node, **kwargs):
cls._convert(ctx, node, **kwargs)
@classmethod
def version_10(cls, ctx, node, **kwargs):
cls._convert(ctx, node, **kwargs)
@classmethod
def version_11(cls, ctx, node, **kwargs):
# no change
cls._convert(ctx, node, **kwargs)
@classmethod
def _convert(cls, ctx, node, **kwargs):
# T output = MaxPool(T input, @list(int) ksize, @list(int) strides, @string padding, @string data_format)
# T Y = MaxPool(T X, @AttrType.STRING auto_pad, @AttrType.INTS kernel_shape, @AttrType.INTS pads,
# @AttrType.INTS strides)
# above seems wrong - input[1] is ksize, input[2] is strides
# stride and ksize in tf is not always NHWC, so watch out when converting into onnx's NCHW
if kwargs["tf_op"] in ["AvgPool3D", "MaxPool3D"]:
spatial = 3
else:
spatial = 2
origin_dtype = ctx.get_dtype(node.output[0])
if origin_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
# the onnx spec doesn't allow int types for pool ops
input_shapes = [ctx.get_shape(node.input[0])]
output_shapes = [ctx.get_shape(node.output[0])]
cast_node = ctx.make_node("Cast", [node.input[0]], dtypes=[onnx_pb.TensorProto.FLOAT], shapes=input_shapes,
name=node.name + "_cast", attr={"to": onnx_pb.TensorProto.FLOAT})
_ = ctx.insert_node_on_output(cast_node, node.inputs[0].output[0])
cast_back_node = ctx.make_node("Cast", [node.output[0]], dtypes=[origin_dtype], shapes=output_shapes,
name=node.name + "_castback", attr={"to": origin_dtype})
_ = ctx.insert_node_on_output(cast_back_node, node.output[0])
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.FLOAT)
if len(node.input) < 3:
kernel_shape_tf = node.get_attr("ksize").ints
strides_tf = node.get_attr("strides").ints
else:
kernel_shape_tf = node.inputs[1].get_tensor_value()
strides_tf = node.inputs[2].get_tensor_value()
ctx.remove_input(node, node.input[2], 2)
ctx.remove_input(node, node.input[1], 1)
kernel_shape_hw = parse_dims_attr(node, kernel_shape_tf, spatial)
strides_hw = parse_dims_attr(node, strides_tf, spatial)
node.set_attr("kernel_shape", kernel_shape_hw)
node.set_attr("strides", strides_hw)
dilations = conv_dims_attr(node, "dilations", spatial=spatial)
add_padding(ctx, node, kernel_shape_hw, strides_hw, dilations=dilations, spatial=spatial)
conv_convert_inputs(ctx, node, with_kernel=False, spatial=spatial)
@tf_op(["MaxPoolWithArgmax"], onnx_op="MaxPool")
class MaxPoolWithArgmaxOp:
@classmethod
def version_8(cls, ctx, node, **kwargs):
# T output = MaxPool(T input, @list(int) ksize, @list(int) strides, @string padding, @string data_format)
# Set kernel_shape attribute
kernel_shape = node.get_attr("ksize").ints
kernel_shape = [kernel_shape[1], kernel_shape[2]]
node.set_attr("kernel_shape", kernel_shape)
# Set strides attribute
strides = node.get_attr("strides").ints
strides = [strides[1], strides[2]]
node.set_attr("strides", strides)
# The input data_format is NHWC for TF MaxPoolWithArgmax
node.set_attr("data_format", "NHWC")
# Convert indices from NCHW to NHWC format
input_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
input_shape_guess = ctx.get_shape(node.input[0])
n, h, w, c = ctx.make_node("Split", [input_shape], attr={'axis': 0}, output_count=4).output
hw = ctx.make_node("Mul", [h, w]).output[0]
chw = ctx.make_node("Mul", [hw, c]).output[0]
consumers = ctx.find_output_consumers(node.output[1])
if ctx.opset >= 10:
xy = ctx.make_node("Mod", [node.output[1], hw]).output[0]
else:
xy_div = ctx.make_node("Div", [node.output[1], hw]).output[0]
xy_mul = ctx.make_node("Mul", [xy_div, hw]).output[0]
xy = ctx.make_node("Sub", [node.output[1], xy_mul]).output[0]
xy_scale_c = ctx.make_node("Mul", [xy, c]).output[0]
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
const_one = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0]
if input_shape_guess is not None and input_shape_guess[3] > 0:
c_range_np = np.arange(input_shape_guess[3], dtype=np.int64)
c_range = ctx.make_const(utils.make_name("c_range"), c_range_np).output[0]
else:
utils.make_sure(ctx.opset >= 11, "opset 11 required for MaxPoolWithArgmax with non-const num channels")
c_sq = GraphBuilder(ctx).make_squeeze({'data': c, 'axes': [0]})
c_range = ctx.make_node("Range", [const_zero, c_sq, const_one]).output[0]
xyc = ctx.make_node("Add", [xy_scale_c, c_range]).output[0]
single_batch = input_shape_guess is not None and input_shape_guess[0] == 1
# Documentation says include_batch_in_index has default False, but tf 1.13 excludes it and assumes True
if node.get_attr_value('include_batch_in_index', True) and not single_batch:
utils.make_sure(ctx.opset >= 11, "opset 11 required for MaxPoolWithArgmax with include_batch_in_index")
n_sq = GraphBuilder(ctx).make_squeeze({'data': n, 'axes': [0]})
n_range = ctx.make_node("Range", [const_zero, n_sq, const_one]).output[0]
n_range_unsq = GraphBuilder(ctx).make_unsqueeze({'data': n_range, 'axes': [1, 2, 3]})
n_range_scale = ctx.make_node("Mul", [n_range_unsq, chw]).output[0]
result = ctx.make_node("Add", [xyc, n_range_scale]).output[0]
else:
result = xyc
ctx.replace_all_inputs(node.output[1], result, ops=consumers)
add_padding(ctx, node, kernel_shape, strides)
conv_convert_inputs(ctx, node, with_kernel=False, input_indices=[0], output_indices=[0, 1])
@tf_op(["BiasAdd", "BiasAddV1"])
class BiasAdd:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = BiasAdd(T value, T bias, @string data_format)
# T output = BiasAddV1(T value, T bias)
# TODO: for now use add. We may need to convert to NCHW.
node.type = "Add"
common.BroadcastOp.version_1(ctx, node, **kwargs)
@classmethod
def version_7(cls, ctx, node, **kwargs):
# T output = BiasAdd(T value, T bias, @string data_format)
# T output = BiasAddV1(T value, T bias)
# According TF bias_add definition, the input dim is always only 1.
node.type = "Add"
common.BroadcastOp.version_6(ctx, node, **kwargs)
# on NHWC, bias will broadcast from largest dim, which is default onnx Add op broadcast behavior.
if not node.is_nhwc():
# however, in NCHW, bias should be at 2nd dim, which by default onnx Add op has no way to know,
# so it needs being reshaped into 3-dim tensor before add
shape0 = ctx.get_shape(node.input[0])
shape1 = ctx.get_shape(node.input[1])
if node.inputs[1].type == 'Const' and len(shape1) == 1:
new_broadcast_shape = [shape1[0]] + [1] * (len(shape0) - 2)
shape_name = utils.make_name(node.name)
ctx.make_const(shape_name, np.array(new_broadcast_shape, dtype=np.int64))
op_name = node.input[1]
reshape_node = ctx.make_node("Reshape", [op_name, shape_name])
ctx.replace_input(node, op_name, reshape_node.output[0], 1)
ctx.set_shape(reshape_node.output[0], new_broadcast_shape)
@tf_op(["Pad", "PadV2", "MirrorPad"], onnx_op="Pad")
class Pad:
@classmethod
def convert_symmetric_pads(cls, ctx, node):
"""Currently there isn't a symmetric padding mode in ONNX so we add a dummy row then use the reflect mode
and remove the dummy row with compress. Ex: 1234 -> 012340 -> 2101234043 -> 21123443. Only do this to
dims with non-zero pads (if pads are constant)"""
rank = ctx.get_rank(node.input[0])
utils.make_sure(rank is not None, "Cannot convert pad with symmetric mode and unknown rank")
utils.make_sure(ctx.opset >= 9, "opset 9 required for symmetric padding mode")
node.set_attr("mode", "reflect")
const_pads = None
consumers = ctx.find_output_consumers(node.output[0])
output_shape = ctx.get_shape(node.output[0])
if ctx.opset < 11:
const_pads = node.get_attr_value("pads")
elif node.inputs[1].is_const():
const_pads = node.inputs[1].get_tensor_value()
non_zero_axes = list(range(rank))
if const_pads is not None:
non_zero_axes = []
for i in range(rank):
if const_pads[i] != 0 or const_pads[i + rank] != 0:
non_zero_axes.append(i)
inc_pads = [0] * (rank * 2)
for a in non_zero_axes:
inc_pads[a] = 1
inc_pads[a + rank] = 1
if ctx.opset < 11:
padded_inp = ctx.make_node("Pad", [node.input[0]], attr={'mode': 'constant', 'pads': inc_pads}).output[0]
else:
pad1_pads_const = ctx.make_const(utils.make_name("pad1_pads"), np.array(inc_pads, np.int64)).output[0]
padded_inp = ctx.make_node("Pad", [node.input[0], pad1_pads_const], attr={'mode': 'constant'}).output[0]
ctx.replace_input(node, node.input[0], padded_inp, 0)
ctx.update_node_shape_dtype(node, override=True)
output = node.output[0]
shape = ctx.make_node("Shape", [output]).output[0]
dims = ctx.make_node("Split", [shape], output_count=rank).output
two_false = ctx.make_const(utils.make_name("two_false"), np.array([False, False], bool)).output[0]
inv_second = ctx.make_const(utils.make_name("inv_second"), np.array([1, -1], np.int64)).output[0]
dec_second = ctx.make_const(utils.make_name("dec_second"), np.array([0, 1], np.int64)).output[0]
for a in non_zero_axes:
one_tensor = helper.make_tensor("value", onnx_pb.TensorProto.BOOL, dims=[1], vals=[1])
ones_of_shape = ctx.make_node("ConstantOfShape", [dims[a]], attr={'value': one_tensor}).output[0]
if const_pads is not None:
to_remove_val = [const_pads[a], -1 - const_pads[a + rank]]
to_remove = ctx.make_const(utils.make_name("to_remove"), np.array(to_remove_val, np.int64)).output[0]
else:
pads_idx = ctx.make_const(utils.make_name("pads_idx"), np.array([a, a + rank], np.int64)).output[0]
pads_vals = ctx.make_node("Gather", [node.input[1], pads_idx]).output[0]
pads_inv_second = ctx.make_node("Mul", [pads_vals, inv_second]).output[0]
to_remove = ctx.make_node("Sub", [pads_inv_second, dec_second]).output[0]
scatter_op = "ScatterElements" if ctx.opset >= 11 else "Scatter"
dims_to_keep = ctx.make_node(scatter_op, [ones_of_shape, to_remove, two_false]).output[0]
compress = ctx.make_node("Compress", [output, dims_to_keep], attr={'axis': a})
output = compress.output[0]
ctx.replace_all_inputs(node.output[0], output, consumers)
ctx.set_shape(output, output_shape)
@classmethod
def version_1(cls, ctx, node, **kwargs):
node.type = "Pad"
# T output = Pad(T input, int32 paddings, @type Tpaddings), CONST model using default value
# or PadV2(T input, int32 paddings, T constant_value, @type Tpaddings), CONST mode - default value specified
# or MirrorPad(T input, int32 paddings, @type Tpaddings, @STRING mode), other mode.
# T output = Pad(T data, @STRING mode, @INTS pads, @FLOAT value)
paddings = np.array(node.inputs[1].get_tensor_value()).transpose().flatten()
mode = node.get_attr("mode")
if mode:
mode = mode.s.decode("utf-8").lower()
node.set_attr("mode", mode)
if mode not in [None, "symmetric", "constant", "reflect"]:
raise ValueError(mode + " pad mode is not supported")
if mode in [None, "constant"] and len(node.input) == 3:
const_val = node.inputs[2].get_tensor_value()
node.set_attr("value", const_val)
ctx.remove_input(node, node.input[2], 2)
ctx.remove_input(node, node.input[1], 1)
node.set_attr("pads", paddings)
origin_dtype = ctx.get_dtype(node.output[0])
if origin_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT,
onnx_pb.TensorProto.DOUBLE]:
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
ctx.copy_shape(node.name, cast_node.output[0])
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
name=utils.make_name(node.name) + "_castback",
to=origin_dtype)
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
ctx.copy_shape(node.name, cast_back_node.output[0])
if mode == "symmetric":
cls.convert_symmetric_pads(ctx, node)
@classmethod
def version_11(cls, ctx, node, **kwargs):
mode = node.get_attr("mode")
if mode:
mode = mode.s.decode("utf-8").lower()
node.set_attr("mode", mode)
if mode not in [None, "symmetric", "constant", "reflect"]:
raise ValueError(mode + " pad mode is not supported")
if not node.inputs[1].is_const():
# pads must be int64.
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
ctx.insert_new_node_on_input(node, "Transpose", node.input[1])
shape_const = ctx.make_const(utils.make_name(node.name), np.array([-1]).astype(np.int64))
ctx.insert_new_node_on_input(node, "Reshape", [node.input[1], shape_const.name])
else:
paddings = node.inputs[1].get_tensor_value(as_list=False).astype(np.int64).transpose().flatten()
pad_const = ctx.make_const(utils.make_name("pad_const"), paddings)
ctx.replace_input(node, node.input[1], pad_const.output[0], 1)
origin_dtype = ctx.get_dtype(node.output[0])
if origin_dtype not in [TensorProto.FLOAT, TensorProto.DOUBLE,
TensorProto.INT32, TensorProto.INT64]:
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.FLOAT)
ctx.set_dtype(cast_node.output[0], TensorProto.FLOAT)
ctx.copy_shape(node.name, cast_node.output[0])
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
name=utils.make_name(node.name) + "_castback",
to=origin_dtype)
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
ctx.copy_shape(node.name, cast_back_node.output[0])
if mode == "symmetric":
cls.convert_symmetric_pads(ctx, node)
@tf_op(["FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"])
class BatchNorm:
@classmethod
def version_6(cls, ctx, node, **kwargs):
tf_type = node.type
input_rank = len(ctx.get_shape(node.input[0]))
if input_rank == 4:
spatial = 2
elif input_rank == 5:
spatial = 3
else:
raise ValueError("node input must be 4 or 5-dimensional, is {} now".format(input_rank))
node.type = "BatchNormalization"
# tf inputs: x, scale, bias, mean, variance
# tf outputs: y, batch_mean, batch_var
# a: data_format, epsilon, is_training
# onnx inputs: X, scale, B, mean, variance, attributes: epsilon, momentum=0.9, spatial : 1
# output: y, mean, var, savedmean, savedvar,
# detach unused outputs. While we could let the unused outputs dangle,
# some runtimes like pytorch/caffe2 do complain about it.
# onnx batchnorm requires same T for all inputs
mean_type = ctx.get_dtype(node.input[3])
x_dtype = ctx.get_dtype(node.input[0])
if x_dtype != mean_type:
# TODO: this works but more efficient would be to flip the other inputs. We'd need to check
# TODO: first if this works with the onnx implementation so its a later for now
ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=mean_type)
# casting the input[0] will change the output dtype of bn so we need to cast back
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
name=utils.make_name(node.name) + "_castback",
to=x_dtype)
ctx.set_dtype(cast_back_node.output[0], x_dtype)
ctx.copy_shape(node.name, cast_back_node.output[0])
ctx.set_dtype(node.output[0], mean_type)
consumers = [ctx.find_output_consumers(output_name) for output_name in node.output[1:]]
if not any(consumers):
new_output = [node.output[0]]
# the setter makes a copy of new_output
node.output = new_output
conv_convert_inputs(ctx, node, with_kernel=False, spatial=spatial)
inp_shape = ctx.get_shape(node.input[0])
inp_rank = len(inp_shape) if inp_shape is not None else None