diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index a7e5404c20ce..6fe8f36020bf 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -982,10 +982,25 @@ class StructInfoLCAFinder StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) return ObjectStructInfo(lhs->span); - if (lhs->dtype == rhs->dtype) return GetRef(lhs); - // PrimType will be treated as their boxed(object) values - // as a result we can unify to object. - return ObjectStructInfo(lhs->span); + if (lhs->dtype != rhs->dtype) { + // PrimType will be treated as their boxed(object) values + // as a result we can unify to object. + return ObjectStructInfo(lhs->span); + } + if (!lhs->value.defined() || !rhs->value.defined() || + !analyzer_->CanProveEqual(lhs->value.value(), rhs->value.value())) { + // The two values are known to contain the same dtype, but may + // contain different values. + if (!lhs->value.defined()) { + // If the mismatch was due to extra information in the RHS, + // prefer to avoid constructing a new object. + return GetRef(lhs); + } else { + return PrimStructInfo(lhs->dtype, lhs->span); + } + } + + return GetRef(lhs); } StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 235059ece2aa..7688c4a64291 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -429,6 +429,18 @@ class WellFormedChecker : public relax::ExprVisitor, } this->VisitVarDef(binding->var); + + if (check_struct_info_ && binding->var->struct_info_.defined() && + binding->value->struct_info_.defined()) { + auto expr_sinfo = GetStructInfo(binding->value); + auto var_sinfo = GetStructInfo(binding->var); + if (!IsBaseOf(var_sinfo, expr_sinfo)) { + Malformed(Diagnostic::Error(binding->var) + << "Expression of type " << expr_sinfo + << " cannot be assigned to a variable of type " << var_sinfo); + } + } + if (is_lambda) { recur_vars_.erase(binding->var); } diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 89080ebc3eb1..5493b44f822b 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -65,7 +65,11 @@ class NormalizeMutator : public ExprMutatorBase { Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { builder_->BeginBindingBlock(); - builder_->BeginScope(params); + if (params.defined()) { + builder_->BeginScope(params); + } else { + builder_->BeginInnerScope(); + } Expr ret = this->VisitExpr(expr); BindingBlock prologue = builder_->EndBlock(); if (!prologue->bindings.empty()) { diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index 83b1ddd4fc9e..b2931549e92b 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -24,7 +24,7 @@ from tvm import TVMError from tvm import relax as rx from tvm import tir, ir -from tvm.script import relax as R +from tvm.script import relax as R, tir as T def test_get_static_type_basic(): @@ -620,6 +620,98 @@ def fn_info_erased(): _check_lca(fopaque2(), fn_info_shape(1), fopaque2()) +def _generate_prim_test_cases(): + dtypes = [ + "bool", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", + "float16", + "float32", + "float64", + ] + + for dtype in dtypes: + # LCA of a PrimStructInfo with itself yields itself + yield (R.Prim(dtype), R.Prim(dtype), R.Prim(dtype)) + + # The LCA of two values, each statically known to be the same + # value, is known to have that value. + yield ( + R.Prim(value=tir.const(0, dtype)), + R.Prim(value=tir.const(0, dtype)), + R.Prim(value=tir.const(0, dtype)), + ) + + # The LCA of two values, each of which is statically known to + # have a different value, no longer knows the contained value. + yield ( + R.Prim(value=tir.const(0, dtype)), + R.Prim(value=tir.const(1, dtype)), + R.Prim(dtype=dtype), + ) + + # LCA of a known variable with itself yields itself + var_N = tir.Var("N", dtype) + yield (R.Prim(value=var_N), R.Prim(value=var_N), R.Prim(value=var_N)) + + # LCA of a known variable with a known static value is no + # longer known to have a specific value. + yield (R.Prim(value=var_N), R.Prim(value=tir.const(0, dtype)), R.Prim(dtype=dtype)) + yield (R.Prim(value=tir.const(0, dtype)), R.Prim(value=var_N), R.Prim(dtype=dtype)) + + var_M = tir.Var("M", dtype) + yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Prim(dtype=dtype)) + + for dtype_a in dtypes: + for dtype_b in dtypes: + if dtype_a != dtype_b: + # Unlike R.Tensor, R.Prim does not currently support a + # value with an unknown datatype. If the dtype + # differs between the two annotations, the next wider + # category is R.Object. + yield (R.Prim(dtype_a), R.Prim(dtype_b), R.Object) + + # Because the dtypes are different, even `R.Prim` containing + # the same value in different representations (e.g. + # `T.float32(0)` vs `T.float16(0)`) fall back to `R.Object`. + yield ( + R.Prim(value=tir.const(0, dtype_a)), + R.Prim(value=tir.const(0, dtype_b)), + R.Object, + ) + + # And the same is true for known variable values + var_N = tir.Var("N", dtype_a) + var_M = tir.Var("M", dtype_b) + yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Object) + + +@pytest.mark.parametrize("test_case", list(_generate_prim_test_cases())) +def test_prim_struct_info_lca(test_case): + def _normalize_sinfo(sinfo): + if isinstance(sinfo, tvm.relax.StructInfo): + return sinfo + elif isinstance(sinfo, tvm.script.parser.relax.entry.StructInfoProxy): + return sinfo.as_struct_info() + elif callable(sinfo): + return sinfo() + else: + raise TypeError(f"Cannot normalize {type(sinfo)} to StructInfo") + + lhs, rhs, expected = map(_normalize_sinfo, test_case) + + lca = rx.analysis.struct_info_lca(lhs, rhs) + assert tvm.ir.structural_equal( + lca, expected + ), f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead found {lca}" + + def _generate_tir_var_test_cases(): n, m = tir.Var("n", "int64"), tir.Var("m", "int64") shape0 = rx.ShapeStructInfo([1, n, 3]) diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index c0b962c3f3a0..3db3efee1afc 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -1208,5 +1208,92 @@ def add_one( assert rx.analysis.well_formed(Module) +def test_var_binding_must_have_compatible_struct_info(): + """Variables must accurately describe their contents + + To be well-formed, the inferred struct info must not conflict with + the StructInfo annotations. + + """ + + # The function is equivalent to the TVMScript below. However, + # TVMScript applies additional checks that would catch this error + # while parsing. In order to validate the well-formed checker + # itself, this test directly constructs the function withoutusing + # TVMScript, skipping the TVMScript-specific checks. + # + # @R.function + # def main( + # A: R.Tensor(shape=[128, 32], dtype="float32"), + # ): + # B: R.Tensor(shape=[128, 32], dtype="int32") = A + # return B + + param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32")) + var = tvm.relax.Var("B", R.Tensor(shape=[128, 32], dtype="int32")) + binding = tvm.relax.VarBinding(var, param) + body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var) + tvm.relax.expr._update_struct_info(body, var.struct_info) + main = tvm.relax.Function([param], body) + + assert not rx.analysis.well_formed(main) + + +def test_var_binding_may_have_less_constrained_struct_info(): + """StructInfo of variable may be less specific than expression + + The StructInfo annotation of a variable is not required to be an + exact match to the expression's StructInfo, and may provide less + specific information than the inference would provide. + + """ + + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + ): + B: R.Object = R.add(A, A) + return B + + assert isinstance( + Module["main"].body.blocks[0].bindings[0].var.struct_info, tvm.relax.ObjectStructInfo + ), "Validity of this test requires a variable with R.Object struct info" + + assert rx.analysis.well_formed(Module) + + +def test_var_binding_with_incomplete_struct_info_must_be_consistent(): + """StructInfo of variable must be accurate + + Even though StructInfo annotation may be less specific, the + information that they do contain must be correct. + + """ + + # The function is equivalent to the TVMScript below. However, + # TVMScript applies additional checks that would catch this error + # while parsing. In order to validate the well-formed checker + # itself, this test directly constructs the function withoutusing + # TVMScript, skipping the TVMScript-specific checks. + # + # @R.function + # def main( + # A: R.Tensor(shape=[128, 32], dtype="float32"), + # ): + # B: R.Tensor(ndim=3) = A + # return B + + param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32")) + var = tvm.relax.Var("B", R.Tensor(ndim=3, dtype="int32")) + binding = tvm.relax.VarBinding(var, param) + body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var) + tvm.relax.expr._update_struct_info(body, var.struct_info) + main = tvm.relax.Function([param], body) + + assert not rx.analysis.well_formed(main) + + if __name__ == "__main__": tvm.testing.main()