From 9cab4bc9c1d3716d20a1fc9aad9631bd0bf410c9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 5 Aug 2024 10:41:31 -0500 Subject: [PATCH] [Relax] Remove segfault in R.call_tir_inplace validation Prior to this commit, the error message produced when validating `R.call_tir_inplace` included the shape of the argument that will be mutated in-place. This correctly caught and raised an error when the argument is a tensor with known shape that is incompatible with the output tensor's shape. However, this same error message could be also be reached if the input does not have `TensorStructInfo` at all, which would trigger a segfault. This commit updates the validation to print the argument's `StructInfo` directly, rather than a field from the struct info. This correctly raises an error for the cases where the argument is not a tensor, or is a tensor with unknown dimensionality, while still printing the explicit shape of the mismatched tensor when avalable. --- src/relax/op/op.cc | 80 ++++++----- tests/python/relax/test_transform.py | 197 ++++++++++++++++++++++----- 2 files changed, 202 insertions(+), 75 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 77cf4a2c6fd0..0a840248ffe8 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -419,13 +419,19 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { // may result in an error if performed before normalization. call = Downcast(NormalizeCallTIR(ctx, std::move(call))); + Array sinfo_outputs = [&]() -> Array { + auto out_sinfo = call->sinfo_args[0]; + if (auto* tuple_output = out_sinfo.as()) { + return tuple_output->fields; + } else { + return {out_sinfo}; + } + }(); + // there must be an inplace index for each output const auto* attrs = call->attrs.as(); - size_t num_outputs = 1U; - if (auto* tup_info = call->sinfo_args[0].as()) { - num_outputs = tup_info->fields.size(); - } - if (attrs->inplace_indices.size() != num_outputs) { + ICHECK(attrs); + if (attrs->inplace_indices.size() != sinfo_outputs.size()) { ctx->ReportFatal(Diagnostic::Error(call) << "There must be an in-place index specified for each output"); } @@ -459,45 +465,37 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { // input shape // TODO(@slyubomirsky): eventually we will want to handle cases where that is not true Tuple call_args = Downcast(call->args[1]); - if (attrs->inplace_indices.size() == 1) { - auto* out_sinfo = call->sinfo_args[0].as(); - if (!out_sinfo) { - ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); + + for (size_t i_output = 0; i_output < attrs->inplace_indices.size(); i_output++) { + auto i_input = attrs->inplace_indices[i_output].IntValue(); + if (i_input == -1) { + continue; } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[0].IntValue()]); - if (!input_sinfo || !input_sinfo->shape.defined() || - !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), - ctx->GetAnalyzer())) { + + auto sinfo_output = sinfo_outputs[i_output]; + auto tinfo_output = sinfo_output.as(); + + if (!tinfo_output || !tinfo_output->shape.defined() || tinfo_output->IsUnknownDtype()) { ctx->ReportFatal(Diagnostic::Error(call) - << "The shape of output 0 must match input " - << attrs->inplace_indices[0].IntValue() << ", whereas we have " - << out_sinfo->shape.value() << " in output 0 versus " - << input_sinfo->shape.value() << " in input " - << attrs->inplace_indices[0].IntValue()); + << "The output struct info for an in-place mutation must be a tensor " + << "with a defined shape and dtype, " + << "but output " << i_output << " has struct info " << sinfo_output); } - } else { - auto out_sinfos = call->sinfo_args[0].as()->fields; - for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { - if (attrs->inplace_indices[i].IntValue() == -1) { - continue; - } - auto* out_sinfo = out_sinfos[i].as(); - if (!out_sinfo) { - ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); - } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[i].IntValue()]); - if (!input_sinfo || !input_sinfo->shape.defined() || - !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), - ctx->GetAnalyzer())) { - ctx->ReportFatal(Diagnostic::Error(call) - << "The shape of output " << i << " must match that of input " - << attrs->inplace_indices[i].IntValue() << ", whereas we have " - << out_sinfo->shape.value() << " in output " << i << " versus " - << input_sinfo->shape.value() << " in input " - << attrs->inplace_indices[i].IntValue()); - } + + auto sinfo_input = GetStructInfo(call_args->fields[i_input]); + auto tinfo_input = sinfo_input.as(); + + if (!tinfo_input || + (tinfo_output->IsUnknownDtype() || tinfo_output->dtype != tinfo_input->dtype) || + (!tinfo_input->shape.defined() || + !CanProveShapeEqual(tinfo_input->shape.value(), tinfo_output->shape.value(), + ctx->GetAnalyzer()))) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The input used for an in-place mutation must be " + << "a tensor with identical shape and dtype as the output. " + << "However, output " << i_output << " with struct info " << sinfo_output + << " is specified as an in-place mutation of input " << i_input + << " with struct info " << sinfo_input); } } diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index e7e8f94fc2ac..ee2df866fb35 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -20,7 +20,7 @@ from tvm import relax import tvm.script -from tvm.script import tir as T, relax as R +from tvm.script import ir as I, tir as T, relax as R def test_to_non_dataflow(): @@ -446,45 +446,174 @@ def foo( tvm.ir.assert_structural_equal(Expected["foo"], new_mod["foo"], map_free_vars=True) -@pytest.mark.xfail() def test_call_tir_inplace_repeated_input(): - @tvm.script.ir_module - class Input: - @T.prim_func - def func( - A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32") - ): - T.evaluate(0) + with pytest.raises(tvm.error.DiagnosticError): + + @tvm.script.ir_module + class Input: + @T.prim_func + def func( + A: T.Buffer((2, 3), "int32"), + B: T.Buffer((2, 3), "int32"), + C: T.Buffer((2, 3), "int32"), + ): + T.evaluate(0) - @R.function - def foo( - x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32"), z: R.Tensor((2, 3), "int32") - ) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")): - R.func_attr({"relax.force_pure": True}) - gv0 = R.call_tir_inplace( - Input.func, - (x, y, z), - # repeated 0 -> that's an error - [0, 0], - [R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")], - ) - return gv0 + @R.function + def foo( + x: R.Tensor((2, 3), "int32"), + y: R.Tensor((2, 3), "int32"), + z: R.Tensor((2, 3), "int32"), + ) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")): + R.func_attr({"relax.force_pure": True}) + gv0 = R.call_tir_inplace( + Input.func, + (x, y, z), + # repeated 0 -> that's an error + [0, 0], + [R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")], + ) + return gv0 -@pytest.mark.xfail() def test_call_tir_inplace_all_new(): - @tvm.script.ir_module - class Input: - @T.prim_func - def func(A: T.Buffer((2, 3), "int32")): - T.evaluate(0) + with pytest.raises(tvm.error.DiagnosticError): - @R.function - def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): - R.func_attr({"relax.force_pure": True}) - # cannot make the only output a fresh one - gv0 = R.call_tir_inplace(Input.func, x, -1, R.Tensor((2, 3), dtype="int32")) - return gv0 + @tvm.script.ir_module + class Input: + @T.prim_func + def func(A: T.Buffer((2, 3), "int32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + R.func_attr({"relax.force_pure": True}) + # cannot make the only output a fresh one + gv0 = R.call_tir_inplace(Input.func, x, -1, R.Tensor((2, 3), dtype="int32")) + return gv0 + + +def test_inplace_mutation_with_tuple_argument_raises_error(): + """TIR PrimFuncs do not support Tuple arguments + + The `R.call_tir_inplace` operator must receive an in-line tuple of + arguments, where each argument in the tuple may be expressed in + TIR. Here, `[[A]]` specifies a tuple of arguments, where the + first argument is itself a tuple. Since PrimFuncs do not support + Tuple arguments, this is invalid. + + This is a regression test. In previous implementations, this + triggered a segfault rather than raising an exception. + + """ + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): + cls = Module + gv1 = R.call_tir_inplace( + cls.multiply_by_two, + [[A]], + out_sinfo=R.Tensor((16,), dtype="float32"), + inplace_indices=[0], + ) + return gv1 + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer((16,), "float32")): + for i in range(16): + A[i] = A[i] * T.float32(2) + + +def test_inplace_mutation_with_non_tensor_argument_raises_error(): + """In-place argument must be a tensor + + The `R.call_tir_inplace` operator must receive an in-line tuple of + arguments, where each argument in the tuple may be expressed in + TIR. Here, the argument `A` is not a tensor. + + This is a regression test. In previous implementations, this + triggered a segfault rather than raising an exception. + + """ + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class Module: + @R.function + def main(A: R.Object): + gv1 = R.call_tir_inplace( + Module.multiply_by_two, + [A], + out_sinfo=R.Tensor((16,), dtype="float32"), + inplace_indices=[0], + ) + return gv1 + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer((16,), "float32")): + for i in range(16): + A[i] = A[i] * T.float32(2) + + +def test_inplace_mutation_with_incompatible_tensor_shape_raises_error(): + """In-place argument must have compatible shape + + The `R.call_tir_inplace` operator must receive an in-line tuple of + arguments, where the shape of each in-place argument is compatible + with the corresponding output. Here, the shape of argument `A` is + different than the output's shape (`[32]` as opposed to `[16]`). + + """ + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([32], dtype="float32")): + gv1 = R.call_tir_inplace( + Module.multiply_by_two, + [A], + out_sinfo=R.Tensor((16,), dtype="float32"), + inplace_indices=[0], + ) + return gv1 + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer((16,), "float32")): + for i in range(16): + A[i] = A[i] * T.float32(2) + + +def test_inplace_mutation_with_incompatible_tensor_dtype_raises_error(): + """In-place argument must have compatible dtype + + The `R.call_tir_inplace` operator must receive an in-line tuple of + arguments, where the shape of each in-place argument is compatible + with the corresponding output. Here, the dtype of argument `A` is + different than the output's dtype (`int32` as opposed to `float32`). + + """ + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], dtype="int32")): + gv1 = R.call_tir_inplace( + Module.multiply_by_two, + [A], + out_sinfo=R.Tensor((16,), dtype="float32"), + inplace_indices=[0], + ) + return gv1 + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer((16,), "float32")): + for i in range(16): + A[i] = A[i] * T.float32(2) if __name__ == "__main__":