diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 6c9225070d3f..611f4348d55e 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1199,14 +1199,29 @@ class Squeeze(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): + data = inputs[0] axis = get_constant(inputs[1], params) if isinstance(axis, relax.Constant): - axis = [int(x) for x in axis.data.numpy()] + axis = tuple([int(x) for x in axis.data.numpy()]) + # If data is constant, perform computation directly. - if isinstance(inputs[0], relax.Constant): - out_data = _np.squeeze(inputs[0].data.numpy(), axis) - return relax.const(out_data, inputs[0].struct_info.dtype) - return relax.op.squeeze(inputs[0], axis) + if isinstance(data, relax.Constant): + if isinstance(axis, (tuple, type(None))): + out_data = _np.squeeze(data.data.numpy(), axis) + else: + raise NotImplementedError("Squeeze with symbolic axes not supported") + + return relax.const(out_data, data.struct_info.dtype) + + if isinstance(data, relax.ShapeExpr): + if axis == (0,): + return relax.PrimValue(data[0]) + else: + raise NotImplementedError( + "Squeeze with symbolic axes and non-zero axes is not supported." + ) + + return relax.op.squeeze(data, axis) class Constant(OnnxOpConverter): @@ -1559,12 +1574,12 @@ def _impl_v13(cls, bb, inputs, attr, params): splits_rank = splits.checked_type.ndim if splits is not None and splits_rank > 0: if isinstance(splits, relax.Constant): - splits = splits.data.asnumpy() + splits = splits.data.numpy() indices = [] index = 0 for i in splits[:-1]: index += i - indices.append(index) + indices.append(index.item()) else: raise ValueError("Dynamic Split not yet supported") # When splits isnt specified divide evenly over axis. @@ -1611,11 +1626,16 @@ def _impl_v13(cls, bb, inputs, attr, params): steps = [1] * len(axes) # If input is a shape tensor, we can directly extract it. if isinstance(data, relax.ShapeExpr): - shape_data = [dim.value for dim in data] + shape_data = list(data) # Starts, ends, and steps must be 1-d for shape operation. assert all(len(i) == 1 for i in [starts, ends, steps]) sliced_values = shape_data[starts[0] : ends[0] : steps[0]] - return relax.const(sliced_values, "int64") + + if all([isinstance(val, (tir.IntImm, int)) for val in sliced_values]): + return relax.const([x.value for x in sliced_values], "int64") + else: + return relax.ShapeExpr(sliced_values) + # If all `starts`, `ends`, and `steps` are constant, use strict mode # Otherwise, we assume the slice is inbound. assume_inbound = not all( @@ -2237,8 +2257,24 @@ class Flatten(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): axis = attr.get("axis", 1) - data_shape = [i.value for i in inputs[0].struct_info.shape] - new_shape = (1, -1) if axis == 0 else (_np.prod(data_shape[0:axis]).astype("int64"), -1) + data_shape = list(inputs[0].struct_info.shape) + + if axis == 0: + new_shape = (1, -1) + else: + shape_flags = [isinstance(x, tvm.script.tir.IntImm) for x in data_shape[0:axis]] + + if all(shape_flags): + data_shape = [x.value for x in data_shape[0:axis]] + new_shape = (_np.prod(data_shape).astype("int64"), -1) + else: + batch_size = 1 + + for el in data_shape[0:axis]: + batch_size = batch_size * el + + new_shape = (batch_size, -1) + return relax.op.reshape(inputs[0], new_shape) @@ -3220,6 +3256,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): "Equal", "Where", "Cast", + "Squeeze", ] return_tuple_ops = [ "SequenceConstruct", diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 46373510b101..9faa441138fc 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -53,27 +53,33 @@ def generate_random_inputs( for dim in i.type.tensor_type.shape.dim: shape.append(dim.dim_value) - # Extract datatype for the input. - if i.type.tensor_type.elem_type: - dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[i.type.tensor_type.elem_type]) - else: - dtype = "float32" - - # Generate random inputs for each input. - if dtype == "bool": - # random_value = np.random.choice(a=[False, True], size=shape) - random_value = rg.choice(a=[False, True], size=shape) - elif dtype.startswith("int"): - # Keep non-zero values - random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype) - random_value[random_value <= 0] -= 1 - else: - random_value = rg.standard_normal(size=shape).astype(dtype) - input_values[i.name] = random_value + input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type) return input_values +def generate_random_value(shape, elem_type) -> np.ndarray: + + # Extract datatype for the input. + if elem_type: + dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type]) + else: + dtype = "float32" + + # Generate random inputs for each input. + if dtype == "bool": + # random_value = np.random.choice(a=[False, True], size=shape) + random_value = rg.choice(a=[False, True], size=shape) + elif dtype.startswith("int"): + # Keep non-zero values + random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype) + random_value[random_value <= 0] -= 1 + else: + random_value = rg.standard_normal(size=shape).astype(dtype) + + return random_value + + def check_correctness( model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None, @@ -156,12 +162,14 @@ def _check_output(tvm_out, ort_out): elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray): shape_out = tvm.nd.array([int(i) for i in tvm_out]) tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol) + elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray): + tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol) else: raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}") # Check that number of outputs match. assert len(tvm_output) == len(ort_output), "Unequal number of outputs" - for (tvm_out, ort_out) in zip(tvm_output, ort_output): + for tvm_out, ort_out in zip(tvm_output, ort_output): # TODO Allow configurable tolerance. if ort_out is not None: _check_output(tvm_out, ort_out) @@ -219,6 +227,31 @@ def verify_unary( check_correctness(model, opset=opset) +def verify_unary_dynamic_shape( + op_name, + shape, + shape_instance, + attrs={}, + domain=None, + input_dtype=TensorProto.FLOAT, + output_dtype=TensorProto.FLOAT, + opset=14, +): + test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain) + graph = helper.make_graph( + [test_node], + "elemwise_test", + inputs=[ + helper.make_tensor_value_info("x", input_dtype, shape), + ], + outputs=[helper.make_tensor_value_info("y", output_dtype, shape)], + ) + + model = helper.make_model(graph, producer_name="elemwise_test") + inputs = {"x": generate_random_value(shape_instance, input_dtype)} + check_correctness(model, inputs, opset=opset) + + def verify_binary( op_name, shape_a, shape_b, shape_c, attrs={}, domain=None, dtype=TensorProto.FLOAT, opset=14 ): @@ -1013,6 +1046,87 @@ def test_squeeze(axis): check_correctness(model, opset=13) +@pytest.mark.parametrize("axis", [[0, 2], None]) +def test_squeeze_constant(axis): + shape = [1, 32, 1, 32] + constant = make_constant_node( + "x", onnx.TensorProto.FLOAT, shape, rg.standard_normal(size=shape).astype("float32") + ) + if axis: + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + else: + squeeze_node = helper.make_node("Squeeze", ["x"], ["y"]) + + initializer = ( + [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if axis else None + ) + + graph = helper.make_graph( + [constant, squeeze_node], + "squeeze_test", + inputs=[], + initializer=initializer, + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_test") + check_correctness(model, opset=13) + + +@pytest.mark.parametrize("axis", [[0]]) +@pytest.mark.parametrize("A", [8, 16, 32]) +@pytest.mark.parametrize("B", [8, 16, 32]) +def test_dynamic_squeeze(axis, A, B): + + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, "A", "B"] + + initializer = ( + [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if axis else None + ) + + graph = helper.make_graph( + [squeeze_node], + "squeeze_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + initializer=initializer, + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, ["A", "B"])], + ) + + model = helper.make_model(graph, producer_name="squeeze_test") + inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")} + check_correctness(model, inputs, opset=13) + + +@pytest.mark.parametrize("axis", [[0]]) +@pytest.mark.parametrize("A", [8, 16, 32]) +def test_dynamic_shape_squeeze(axis, A): + + shape_node = helper.make_node("Shape", ["x"], ["y"]) + squeeze_node = helper.make_node("Squeeze", ["y", "axes"], ["z"]) + shape = ["A"] + + initializer = ( + [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if axis else None + ) + + graph = helper.make_graph( + [shape_node, squeeze_node], + "squeeze_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + initializer=initializer, + outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, [])], + ) + + model = helper.make_model(graph, producer_name="squeeze_test") + inputs = {"x": rg.standard_normal(size=[A]).astype("float32")} + check_correctness(model, inputs, opset=13) + + def test_const(): shape = [32, 32] const_node = helper.make_node( @@ -1548,6 +1662,68 @@ def verify_slice(data_shape, output_shape, starts, ends, axes=None, steps=None): # ) +def test_slice_dynamic_shape(): + def verify_slice( + data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None + ): + if isinstance(starts, list): + starts = np.array(starts, "int64") + if isinstance(ends, list): + ends = np.array(ends, "int64") + if isinstance(axes, list): + axes = np.array(axes, "int64") + if isinstance(steps, list): + steps = np.array(steps, "int64") + + slice_inputs = ["y", "starts", "ends"] + initializer = [ + helper.make_tensor("starts", TensorProto.INT64, starts.shape, starts), + helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends), + ] + + if axes is not None: + initializer.append(helper.make_tensor("axes", TensorProto.INT64, axes.shape, axes)) + slice_inputs.append("axes") + if steps is not None: + initializer.append(helper.make_tensor("steps", TensorProto.INT64, steps.shape, steps)) + slice_inputs.append("steps") + + shape_node = helper.make_node("Shape", inputs=["x"], outputs=["y"]) + slice_node = helper.make_node("Slice", inputs=slice_inputs, outputs=["z"]) + + graph = helper.make_graph( + [shape_node, slice_node], + "slice_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, data_shape), + ], + outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, output_shape)], + initializer=initializer, + ) + + model = helper.make_model(graph, producer_name="slice_test") + inputs = {"x": rg.standard_normal(size=data_instance_shape).astype("float32")} + check_correctness(model, inputs) + + verify_slice([20, 10, 5], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + verify_slice(["A", 10, 5], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + verify_slice(["A", "B", 5], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + verify_slice([20, 10, "C"], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + verify_slice(["A", "B", "C"], [20, 10, 5], [2], starts=[0], ends=[2], axes=[0]) + + verify_slice([20, 10, 5], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + verify_slice(["A", 10, 5], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + verify_slice(["A", "B", 5], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + verify_slice([20, 10, "C"], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + verify_slice(["A", "B", "C"], [20, 10, 5], [1], starts=[1], ends=[2], axes=[0]) + + verify_slice([20, 10, 5], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + verify_slice(["A", 10, 5], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + verify_slice(["A", "B", 5], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + verify_slice([20, 10, "C"], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + verify_slice(["A", "B", "C"], [20, 10, 5], [2], starts=[1], ends=[3], axes=[0]) + + # TODO Enable dynamism @pytest.mark.parametrize("dynamic", [False]) def test_attention(dynamic): @@ -1795,12 +1971,15 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o ) ] + split_constant = None if pass_split: if opset >= 13: np_split = np.array(split).astype(np.int64) - initializer.append( - helper.make_tensor("split", TensorProto.INT64, list(np_split.shape), np_split) + split_constant = make_constant_node( + "split", onnx.TensorProto.INT64, list(np_split.shape), np_split ) + input_names.append("split") + node = helper.make_node( "Split", inputs=input_names, @@ -1812,8 +1991,10 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o split_attr = helper.make_attribute("split", split) node.attribute.append(split_attr) + nodes = [split_constant, node] if split_constant else [node] + graph = helper.make_graph( - [node], + nodes, "split_test", inputs=inputs, initializer=initializer, @@ -2226,6 +2407,12 @@ def test_flatten(): verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2}) +def test_flatten_dynamic(): + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 0}) + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": -1}) + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 2}) + + def test_onehot(): one_hot_node = helper.make_node("OneHot", ["indices", "depth", "values"], ["y"], axis=1) graph = helper.make_graph(