Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Validate StructInfo of variable bindings #17332

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -982,10 +982,25 @@ class StructInfoLCAFinder
StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final {
auto* rhs = other.as<PrimStructInfoNode>();
if (rhs == nullptr) return ObjectStructInfo(lhs->span);
if (lhs->dtype == rhs->dtype) return GetRef<StructInfo>(lhs);
// PrimType will be treated as their boxed(object) values
// as a result we can unify to object.
return ObjectStructInfo(lhs->span);
if (lhs->dtype != rhs->dtype) {
// PrimType will be treated as their boxed(object) values
// as a result we can unify to object.
return ObjectStructInfo(lhs->span);
}
if (!lhs->value.defined() || !rhs->value.defined() ||
!analyzer_->CanProveEqual(lhs->value.value(), rhs->value.value())) {
// The two values are known to contain the same dtype, but may
// contain different values.
if (!lhs->value.defined()) {
// If the mismatch was due to extra information in the RHS,
// prefer to avoid constructing a new object.
return GetRef<StructInfo>(lhs);
} else {
return PrimStructInfo(lhs->dtype, lhs->span);
}
}

return GetRef<StructInfo>(lhs);
}

StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final {
Expand Down
12 changes: 12 additions & 0 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,18 @@ class WellFormedChecker : public relax::ExprVisitor,
}

this->VisitVarDef(binding->var);

if (check_struct_info_ && binding->var->struct_info_.defined() &&
binding->value->struct_info_.defined()) {
auto expr_sinfo = GetStructInfo(binding->value);
auto var_sinfo = GetStructInfo(binding->var);
if (!IsBaseOf(var_sinfo, expr_sinfo)) {
Malformed(Diagnostic::Error(binding->var)
<< "Expression of type " << expr_sinfo
<< " cannot be assigned to a variable of type " << var_sinfo);
}
}

if (is_lambda) {
recur_vars_.erase(binding->var);
}
Expand Down
6 changes: 5 additions & 1 deletion src/relax/transform/normalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ class NormalizeMutator : public ExprMutatorBase {

Expr VisitWithNewScope(const Expr& expr, Optional<Array<Var>> params = NullOpt) {
builder_->BeginBindingBlock();
builder_->BeginScope(params);
if (params.defined()) {
builder_->BeginScope(params);
} else {
builder_->BeginInnerScope();
}
Expr ret = this->VisitExpr(expr);
BindingBlock prologue = builder_->EndBlock();
if (!prologue->bindings.empty()) {
Expand Down
94 changes: 93 additions & 1 deletion tests/python/relax/test_analysis_struct_info_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm import TVMError
from tvm import relax as rx
from tvm import tir, ir
from tvm.script import relax as R
from tvm.script import relax as R, tir as T


def test_get_static_type_basic():
Expand Down Expand Up @@ -620,6 +620,98 @@ def fn_info_erased():
_check_lca(fopaque2(), fn_info_shape(1), fopaque2())


def _generate_prim_test_cases():
dtypes = [
"bool",
"int8",
"uint8",
"int16",
"uint16",
"int32",
"uint32",
"int64",
"uint64",
"float16",
"float32",
"float64",
]

for dtype in dtypes:
# LCA of a PrimStructInfo with itself yields itself
yield (R.Prim(dtype), R.Prim(dtype), R.Prim(dtype))

# The LCA of two values, each statically known to be the same
# value, is known to have that value.
yield (
R.Prim(value=tir.const(0, dtype)),
R.Prim(value=tir.const(0, dtype)),
R.Prim(value=tir.const(0, dtype)),
)

# The LCA of two values, each of which is statically known to
# have a different value, no longer knows the contained value.
yield (
R.Prim(value=tir.const(0, dtype)),
R.Prim(value=tir.const(1, dtype)),
R.Prim(dtype=dtype),
)

# LCA of a known variable with itself yields itself
var_N = tir.Var("N", dtype)
yield (R.Prim(value=var_N), R.Prim(value=var_N), R.Prim(value=var_N))

# LCA of a known variable with a known static value is no
# longer known to have a specific value.
yield (R.Prim(value=var_N), R.Prim(value=tir.const(0, dtype)), R.Prim(dtype=dtype))
yield (R.Prim(value=tir.const(0, dtype)), R.Prim(value=var_N), R.Prim(dtype=dtype))

var_M = tir.Var("M", dtype)
yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Prim(dtype=dtype))

for dtype_a in dtypes:
for dtype_b in dtypes:
if dtype_a != dtype_b:
# Unlike R.Tensor, R.Prim does not currently support a
# value with an unknown datatype. If the dtype
# differs between the two annotations, the next wider
# category is R.Object.
yield (R.Prim(dtype_a), R.Prim(dtype_b), R.Object)

# Because the dtypes are different, even `R.Prim` containing
# the same value in different representations (e.g.
# `T.float32(0)` vs `T.float16(0)`) fall back to `R.Object`.
yield (
R.Prim(value=tir.const(0, dtype_a)),
R.Prim(value=tir.const(0, dtype_b)),
R.Object,
)

# And the same is true for known variable values
var_N = tir.Var("N", dtype_a)
var_M = tir.Var("M", dtype_b)
yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Object)


@pytest.mark.parametrize("test_case", list(_generate_prim_test_cases()))
def test_prim_struct_info_lca(test_case):
def _normalize_sinfo(sinfo):
if isinstance(sinfo, tvm.relax.StructInfo):
return sinfo
elif isinstance(sinfo, tvm.script.parser.relax.entry.StructInfoProxy):
return sinfo.as_struct_info()
elif callable(sinfo):
return sinfo()
else:
raise TypeError(f"Cannot normalize {type(sinfo)} to StructInfo")

lhs, rhs, expected = map(_normalize_sinfo, test_case)

lca = rx.analysis.struct_info_lca(lhs, rhs)
assert tvm.ir.structural_equal(
lca, expected
), f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead found {lca}"


def _generate_tir_var_test_cases():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
shape0 = rx.ShapeStructInfo([1, n, 3])
Expand Down
87 changes: 87 additions & 0 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,5 +1208,92 @@ def add_one(
assert rx.analysis.well_formed(Module)


def test_var_binding_must_have_compatible_struct_info():
"""Variables must accurately describe their contents

To be well-formed, the inferred struct info must not conflict with
the StructInfo annotations.

"""

# The function is equivalent to the TVMScript below. However,
# TVMScript applies additional checks that would catch this error
# while parsing. In order to validate the well-formed checker
# itself, this test directly constructs the function withoutusing
# TVMScript, skipping the TVMScript-specific checks.
#
# @R.function
# def main(
# A: R.Tensor(shape=[128, 32], dtype="float32"),
# ):
# B: R.Tensor(shape=[128, 32], dtype="int32") = A
# return B

param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32"))
var = tvm.relax.Var("B", R.Tensor(shape=[128, 32], dtype="int32"))
binding = tvm.relax.VarBinding(var, param)
body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var)
tvm.relax.expr._update_struct_info(body, var.struct_info)
main = tvm.relax.Function([param], body)

assert not rx.analysis.well_formed(main)


def test_var_binding_may_have_less_constrained_struct_info():
"""StructInfo of variable may be less specific than expression

The StructInfo annotation of a variable is not required to be an
exact match to the expression's StructInfo, and may provide less
specific information than the inference would provide.

"""

@I.ir_module
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
):
B: R.Object = R.add(A, A)
return B

assert isinstance(
Module["main"].body.blocks[0].bindings[0].var.struct_info, tvm.relax.ObjectStructInfo
), "Validity of this test requires a variable with R.Object struct info"

assert rx.analysis.well_formed(Module)


def test_var_binding_with_incomplete_struct_info_must_be_consistent():
"""StructInfo of variable must be accurate

Even though StructInfo annotation may be less specific, the
information that they do contain must be correct.

"""

# The function is equivalent to the TVMScript below. However,
# TVMScript applies additional checks that would catch this error
# while parsing. In order to validate the well-formed checker
# itself, this test directly constructs the function withoutusing
# TVMScript, skipping the TVMScript-specific checks.
#
# @R.function
# def main(
# A: R.Tensor(shape=[128, 32], dtype="float32"),
# ):
# B: R.Tensor(ndim=3) = A
# return B

param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32"))
var = tvm.relax.Var("B", R.Tensor(ndim=3, dtype="int32"))
binding = tvm.relax.VarBinding(var, param)
body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var)
tvm.relax.expr._update_struct_info(body, var.struct_info)
main = tvm.relax.Function([param], body)

assert not rx.analysis.well_formed(main)


if __name__ == "__main__":
tvm.testing.main()
Loading