diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index ec0c96fb9..a2a6a42a8 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -1772,9 +1772,12 @@ class MatrixDiagPart: def version_11(cls, ctx, node, **kwargs): # MatrixDiagPart by slice and gather const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64)) + const_zero_ = ctx.make_const(utils.make_name(node.name) + 'const_zero_', np.array(0).astype(np.int64)) + const_zero_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero_zero', np.array([0, 0]).astype(np.int64)) const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64)) + const_one_ = ctx.make_const(utils.make_name(node.name) + 'const_one_', np.array(1).astype(np.int64)) const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64)) const_negative_one = ctx.make_const(utils.make_name(node.name) + 'const_negative_one', np.array([-1]).astype(np.int64)) @@ -1802,7 +1805,9 @@ def version_11(cls, ctx, node, **kwargs): const_negative_one.output[0]]) sliced_input_shape_new = ctx.make_node('Concat', [sliced_input_shape_half.output[0], const_one.output[0]], attr={'axis': -1}) - matrice_range = ctx.make_node('Range', [const_zero.output[0], min_matrice_dim.output[0], const_one.output[0]]) + min_matrice_dim_ = ctx.make_node('Squeeze', [min_matrice_dim.output[0]], {'axes': [0]}) + matrice_range = ctx.make_node('Range', [const_zero_.output[0], min_matrice_dim_.output[0], + const_one_.output[0]]) unsqueezed_matrice_range = ctx.make_node('Unsqueeze', [matrice_range.output[0]], attr={"axes": [-1]}) expanded_range = ctx.make_node('Expand', [unsqueezed_matrice_range.output[0], sliced_input_shape_new.output[0]]) gathered_result = ctx.make_node('GatherElements', [sliced_input.output[0], expanded_range.output[0]], @@ -1893,6 +1898,8 @@ def version_11(cls, ctx, node, **kwargs): new_width = body_graph.make_node('Slice', [processed_shape.output[0], const_neg_one.output[0], shape_processed_shape.output[0]]) abs_k = body_graph.make_node('Abs', [current_k.output[0]]) + + range_k = body_graph.make_node('Range', [abs_k.output[0], new_width.output[0], const_one.output[0]], domain="com.microsoft") sliced_range = body_graph.make_node('Slice', [range_k.output[0], const_zero.output[0], new_depth.output[0]])