Skip to content

Commit

Permalink
[Relax][Bugfix] LCA of PrimStructInfo must check known values
Browse files Browse the repository at this point in the history
The `StructInfoLCA` determines the lowest common ancestor between two
`StructInfo` annotations.  This is primarily used in Relax to
determine the appropriate `StructInfo` annotation for a `relax::If`
node, given the `StructInfo` of each branch.  Prior to this commit,
when determining the LCA of two `PrimStructInfo` annotations, the
`StructInfoLCA` function only inspected the datatype of
`PrimStructInfo` annotations, and did not check for known values.  For
example, the LCA of `R.Prim(value=T.int64(128))` and
`R.Prim(value=T.int64(64))` is `R.Prim("int64")`, but was incorrectly
determined as `R.Prim(value=T.int64(128))` by the `StructInfoLCA`
function.

This commit updates `StructInfoLCA` to inspect the known values of a
`PrimStructInfo`, as well as the datatype.
  • Loading branch information
Lunderberg committed Sep 4, 2024
1 parent 3b04ddf commit 8b7d373
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 5 deletions.
23 changes: 19 additions & 4 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -982,10 +982,25 @@ class StructInfoLCAFinder
StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final {
auto* rhs = other.as<PrimStructInfoNode>();
if (rhs == nullptr) return ObjectStructInfo(lhs->span);
if (lhs->dtype == rhs->dtype) return GetRef<StructInfo>(lhs);
// PrimType will be treated as their boxed(object) values
// as a result we can unify to object.
return ObjectStructInfo(lhs->span);
if (lhs->dtype != rhs->dtype) {
// PrimType will be treated as their boxed(object) values
// as a result we can unify to object.
return ObjectStructInfo(lhs->span);
}
if (!lhs->value.defined() || !rhs->value.defined() ||
!analyzer_->CanProveEqual(lhs->value.value(), rhs->value.value())) {
// The two values are known to contain the same dtype, but may
// contain different values.
if (!lhs->value.defined()) {
// If the mismatch was due to extra information in the RHS,
// prefer to avoid constructing a new object.
return GetRef<StructInfo>(lhs);
} else {
return PrimStructInfo(lhs->dtype, lhs->span);
}
}

return GetRef<StructInfo>(lhs);
}

StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final {
Expand Down
94 changes: 93 additions & 1 deletion tests/python/relax/test_analysis_struct_info_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm import TVMError
from tvm import relax as rx
from tvm import tir, ir
from tvm.script import relax as R
from tvm.script import relax as R, tir as T


def test_get_static_type_basic():
Expand Down Expand Up @@ -620,6 +620,98 @@ def fn_info_erased():
_check_lca(fopaque2(), fn_info_shape(1), fopaque2())


def _generate_prim_test_cases():
dtypes = [
"bool",
"int8",
"uint8",
"int16",
"uint16",
"int32",
"uint32",
"int64",
"uint64",
"float16",
"float32",
"float64",
]

for dtype in dtypes:
# LCA of a PrimStructInfo with itself yields itself
yield (R.Prim(dtype), R.Prim(dtype), R.Prim(dtype))

# The LCA of two values, each statically known to be the same
# value, is known to have that value.
yield (
R.Prim(value=tir.const(0, dtype)),
R.Prim(value=tir.const(0, dtype)),
R.Prim(value=tir.const(0, dtype)),
)

# The LCA of two values, each of which is statically known to
# have a different value, no longer knows the contained value.
yield (
R.Prim(value=tir.const(0, dtype)),
R.Prim(value=tir.const(1, dtype)),
R.Prim(dtype=dtype),
)

# LCA of a known variable with itself yields itself
var_N = tir.Var("N", dtype)
yield (R.Prim(value=var_N), R.Prim(value=var_N), R.Prim(value=var_N))

# LCA of a known variable with a known static value is no
# longer known to have a specific value.
yield (R.Prim(value=var_N), R.Prim(value=tir.const(0, dtype)), R.Prim(dtype=dtype))
yield (R.Prim(value=tir.const(0, dtype)), R.Prim(value=var_N), R.Prim(dtype=dtype))

var_M = tir.Var("M", dtype)
yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Prim(dtype=dtype))

for dtype_a in dtypes:
for dtype_b in dtypes:
if dtype_a != dtype_b:
# Unlike R.Tensor, R.Prim does not currently support a
# value with an unknown datatype. If the dtype
# differs between the two annotations, the next wider
# category is R.Object.
yield (R.Prim(dtype_a), R.Prim(dtype_b), R.Object)

# Because the dtypes are different, even `R.Prim` containing
# the same value in different representations (e.g.
# `T.float32(0)` vs `T.float16(0)`) fall back to `R.Object`.
yield (
R.Prim(value=tir.const(0, dtype_a)),
R.Prim(value=tir.const(0, dtype_b)),
R.Object,
)

# And the same is true for known variable values
var_N = tir.Var("N", dtype_a)
var_M = tir.Var("M", dtype_b)
yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Object)


@pytest.mark.parametrize("test_case", list(_generate_prim_test_cases()))
def test_prim_struct_info_lca(test_case):
def _normalize_sinfo(sinfo):
if isinstance(sinfo, tvm.relax.StructInfo):
return sinfo
elif isinstance(sinfo, tvm.script.parser.relax.entry.StructInfoProxy):
return sinfo.as_struct_info()
elif callable(sinfo):
return sinfo()
else:
raise TypeError(f"Cannot normalize {type(sinfo)} to StructInfo")

lhs, rhs, expected = map(_normalize_sinfo, test_case)

lca = rx.analysis.struct_info_lca(lhs, rhs)
assert tvm.ir.structural_equal(
lca, expected
), f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead found {lca}"


def _generate_tir_var_test_cases():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
shape0 = rx.ShapeStructInfo([1, n, 3])
Expand Down

0 comments on commit 8b7d373

Please sign in to comment.