Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Remove segfault in R.call_tir_inplace validation #17242

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 39 additions & 41 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,19 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) {
// may result in an error if performed before normalization.
call = Downcast<Call>(NormalizeCallTIR(ctx, std::move(call)));

Array<StructInfo> sinfo_outputs = [&]() -> Array<StructInfo> {
auto out_sinfo = call->sinfo_args[0];
if (auto* tuple_output = out_sinfo.as<TupleStructInfoNode>()) {
return tuple_output->fields;
} else {
return {out_sinfo};
}
}();

// there must be an inplace index for each output
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
size_t num_outputs = 1U;
if (auto* tup_info = call->sinfo_args[0].as<TupleStructInfoNode>()) {
num_outputs = tup_info->fields.size();
}
if (attrs->inplace_indices.size() != num_outputs) {
ICHECK(attrs);
if (attrs->inplace_indices.size() != sinfo_outputs.size()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "There must be an in-place index specified for each output");
}
Expand Down Expand Up @@ -459,45 +465,37 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) {
// input shape
// TODO(@slyubomirsky): eventually we will want to handle cases where that is not true
Tuple call_args = Downcast<Tuple>(call->args[1]);
if (attrs->inplace_indices.size() == 1) {
auto* out_sinfo = call->sinfo_args[0].as<TensorStructInfoNode>();
if (!out_sinfo) {
ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor");

for (size_t i_output = 0; i_output < attrs->inplace_indices.size(); i_output++) {
auto i_input = attrs->inplace_indices[i_output].IntValue();
if (i_input == -1) {
continue;
}
auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(
call_args->fields[attrs->inplace_indices[0].IntValue()]);
if (!input_sinfo || !input_sinfo->shape.defined() ||
!CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(),
ctx->GetAnalyzer())) {

auto sinfo_output = sinfo_outputs[i_output];
auto tinfo_output = sinfo_output.as<TensorStructInfoNode>();

if (!tinfo_output || !tinfo_output->shape.defined() || tinfo_output->IsUnknownDtype()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The shape of output 0 must match input "
<< attrs->inplace_indices[0].IntValue() << ", whereas we have "
<< out_sinfo->shape.value() << " in output 0 versus "
<< input_sinfo->shape.value() << " in input "
<< attrs->inplace_indices[0].IntValue());
<< "The output struct info for an in-place mutation must be a tensor "
<< "with a defined shape and dtype, "
<< "but output " << i_output << " has struct info " << sinfo_output);
}
} else {
auto out_sinfos = call->sinfo_args[0].as<TupleStructInfoNode>()->fields;
for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
if (attrs->inplace_indices[i].IntValue() == -1) {
continue;
}
auto* out_sinfo = out_sinfos[i].as<TensorStructInfoNode>();
if (!out_sinfo) {
ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor");
}
auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(
call_args->fields[attrs->inplace_indices[i].IntValue()]);
if (!input_sinfo || !input_sinfo->shape.defined() ||
!CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(),
ctx->GetAnalyzer())) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The shape of output " << i << " must match that of input "
<< attrs->inplace_indices[i].IntValue() << ", whereas we have "
<< out_sinfo->shape.value() << " in output " << i << " versus "
<< input_sinfo->shape.value() << " in input "
<< attrs->inplace_indices[i].IntValue());
}

auto sinfo_input = GetStructInfo(call_args->fields[i_input]);
auto tinfo_input = sinfo_input.as<TensorStructInfoNode>();

if (!tinfo_input ||
(tinfo_output->IsUnknownDtype() || tinfo_output->dtype != tinfo_input->dtype) ||
(!tinfo_input->shape.defined() ||
!CanProveShapeEqual(tinfo_input->shape.value(), tinfo_output->shape.value(),
ctx->GetAnalyzer()))) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The input used for an in-place mutation must be "
<< "a tensor with identical shape and dtype as the output. "
<< "However, output " << i_output << " with struct info " << sinfo_output
<< " is specified as an in-place mutation of input " << i_input
<< " with struct info " << sinfo_input);
}
}

Expand Down
197 changes: 163 additions & 34 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm import relax

import tvm.script
from tvm.script import tir as T, relax as R
from tvm.script import ir as I, tir as T, relax as R


def test_to_non_dataflow():
Expand Down Expand Up @@ -446,45 +446,174 @@ def foo(
tvm.ir.assert_structural_equal(Expected["foo"], new_mod["foo"], map_free_vars=True)


@pytest.mark.xfail()
def test_call_tir_inplace_repeated_input():
@tvm.script.ir_module
class Input:
@T.prim_func
def func(
A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32")
):
T.evaluate(0)
with pytest.raises(tvm.error.DiagnosticError):

@tvm.script.ir_module
class Input:
@T.prim_func
def func(
A: T.Buffer((2, 3), "int32"),
B: T.Buffer((2, 3), "int32"),
C: T.Buffer((2, 3), "int32"),
):
T.evaluate(0)

@R.function
def foo(
x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32"), z: R.Tensor((2, 3), "int32")
) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")):
R.func_attr({"relax.force_pure": True})
gv0 = R.call_tir_inplace(
Input.func,
(x, y, z),
# repeated 0 -> that's an error
[0, 0],
[R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")],
)
return gv0
@R.function
def foo(
x: R.Tensor((2, 3), "int32"),
y: R.Tensor((2, 3), "int32"),
z: R.Tensor((2, 3), "int32"),
) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")):
R.func_attr({"relax.force_pure": True})
gv0 = R.call_tir_inplace(
Input.func,
(x, y, z),
# repeated 0 -> that's an error
[0, 0],
[R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")],
)
return gv0


@pytest.mark.xfail()
def test_call_tir_inplace_all_new():
@tvm.script.ir_module
class Input:
@T.prim_func
def func(A: T.Buffer((2, 3), "int32")):
T.evaluate(0)
with pytest.raises(tvm.error.DiagnosticError):

@R.function
def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
R.func_attr({"relax.force_pure": True})
# cannot make the only output a fresh one
gv0 = R.call_tir_inplace(Input.func, x, -1, R.Tensor((2, 3), dtype="int32"))
return gv0
@tvm.script.ir_module
class Input:
@T.prim_func
def func(A: T.Buffer((2, 3), "int32")):
T.evaluate(0)

@R.function
def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
R.func_attr({"relax.force_pure": True})
# cannot make the only output a fresh one
gv0 = R.call_tir_inplace(Input.func, x, -1, R.Tensor((2, 3), dtype="int32"))
return gv0


def test_inplace_mutation_with_tuple_argument_raises_error():
"""TIR PrimFuncs do not support Tuple arguments

The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where each argument in the tuple may be expressed in
TIR. Here, `[[A]]` specifies a tuple of arguments, where the
first argument is itself a tuple. Since PrimFuncs do not support
Tuple arguments, this is invalid.

This is a regression test. In previous implementations, this
triggered a segfault rather than raising an exception.

"""
with pytest.raises(tvm.error.DiagnosticError):

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
cls = Module
gv1 = R.call_tir_inplace(
cls.multiply_by_two,
[[A]],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1

@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)


def test_inplace_mutation_with_non_tensor_argument_raises_error():
"""In-place argument must be a tensor

The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where each argument in the tuple may be expressed in
TIR. Here, the argument `A` is not a tensor.

This is a regression test. In previous implementations, this
triggered a segfault rather than raising an exception.

"""
with pytest.raises(tvm.error.DiagnosticError):

@I.ir_module
class Module:
@R.function
def main(A: R.Object):
gv1 = R.call_tir_inplace(
Module.multiply_by_two,
[A],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1

@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)


def test_inplace_mutation_with_incompatible_tensor_shape_raises_error():
"""In-place argument must have compatible shape

The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where the shape of each in-place argument is compatible
with the corresponding output. Here, the shape of argument `A` is
different than the output's shape (`[32]` as opposed to `[16]`).

"""
with pytest.raises(tvm.error.DiagnosticError):

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([32], dtype="float32")):
gv1 = R.call_tir_inplace(
Module.multiply_by_two,
[A],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1

@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)


def test_inplace_mutation_with_incompatible_tensor_dtype_raises_error():
"""In-place argument must have compatible dtype

The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where the shape of each in-place argument is compatible
with the corresponding output. Here, the dtype of argument `A` is
different than the output's dtype (`int32` as opposed to `float32`).

"""
with pytest.raises(tvm.error.DiagnosticError):

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], dtype="int32")):
gv1 = R.call_tir_inplace(
Module.multiply_by_two,
[A],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1

@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)


if __name__ == "__main__":
Expand Down
Loading