Skip to content

Commit

Permalink
[FIX] Fix structure equal hash for MatchShape (#112)
Browse files Browse the repository at this point in the history
The pattern field of the match shape can define variables,
as a result, we need to add DefEqual and Hash here.

Added a regression testcase.

Lesson: we would benefit from more testcases
with check_save_roundtrip checks(like this one) for more relax example.

Additional change:
- Redirected TVMScript printer to be able to print relax fragements useful for debugging.
  • Loading branch information
tqchen authored and junrushao committed Feb 5, 2023
1 parent 1a70cbe commit 2118966
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 2 deletions.
6 changes: 4 additions & 2 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,15 @@ class MatchShapeNode : public BindingNode {
}

bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const {
return equal(value, other->value) && equal(pattern, other->pattern) &&
// NOTE: pattern can contain ShapeExpr which defines the vars
return equal(value, other->value) && equal.DefEqual(pattern, other->pattern) &&
equal.DefEqual(var, other->var);
}

void SHashReduce(SHashReducer hash_reduce) const {
// NOTE: pattern can contain ShapeExpr which defines the vars
hash_reduce(value);
hash_reduce(pattern);
hash_reduce.DefHash(pattern);
hash_reduce.DefHash(var);
}

Expand Down
3 changes: 3 additions & 0 deletions src/relay/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ class RelaxScriptPrinter : public relax::IRFunctor<Doc(const ObjectRef&)>,
RelaxScriptPrinter* parent_;
};
};

String AsRelaxScript(const ObjectRef& mod, bool show_meta_data);

} // namespace relax
} // namespace tvm

Expand Down
8 changes: 8 additions & 0 deletions src/relay/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2013,6 +2013,14 @@ Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) {
}

String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) {
// Temporary redirect possibly relax related printing to relax script
// TODO(tvm-team): make relax script printer handle all possible cases and
// make that as a default of TVMScript printer
if (mod->IsInstance<IRModuleNode>() || mod->IsInstance<relax::FunctionNode>()) {
// TODO(tvm-team) support tir_prefix in relax printer
return relax::AsRelaxScript(mod, show_meta);
}

ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
Doc doc;
doc << TVMScriptPrinter::PrintHeader(tir_prefix)
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relax/test_structual_equal_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations # must import to defer parsing of annotations
import pytest
import sys
import tvm
from tvm import relax as rx, tir
from tvm.script import tir as T, relax as R


def _check_equal(x, y):
Expand All @@ -31,6 +33,11 @@ def _check_equal(x, y):
assert xhash == yhash


def _check_save_roundtrip(x):
y = tvm.ir.load_json(tvm.ir.save_json(x))
_check_equal(x, y)


def test_var_binding():
dtype = rx.DynTensorType(1)
x = rx.Var("x", [10], dtype)
Expand Down Expand Up @@ -106,5 +113,16 @@ def generator():
_check_equal(mod0, mod1)


def test_match_shape_symbolic():
@tvm.script.ir_module
class InputModule:
@R.function
def f(x: Tensor[(_, _), "float32"]):
x0 = R.match_shape(x, (n, m))
return (x0, (n + 1, m))

_check_save_roundtrip(InputModule)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 2118966

Please sign in to comment.