Skip to content

Commit

Permalink
[REFACTOR][TIR] Introduce ExprDeepEqual, Remove IRDeepCompare (#5206)
Browse files Browse the repository at this point in the history
* [REFACTOR][TIR] Introduce ExprDeepEqual, Remove IRDeepCompare

This PR introduces ExprDeepEqual which reuses the StructuralEqual infra.
We migrated the usecases of ir_pass::Equal to ExprDeepEqual and StructuralEqual.

* Address comments
  • Loading branch information
tqchen authored Apr 2, 2020
1 parent 0449966 commit e60003c
Show file tree
Hide file tree
Showing 45 changed files with 419 additions and 641 deletions.
9 changes: 8 additions & 1 deletion docs/api/python/tir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@ tvm.tir
:autosummary:



tvm.tir.transform
-----------------
.. automodule:: tvm.tir.transform
:members:
:imported-members:
:autosummary:


tvm.tir.analysis
----------------
.. automodule:: tvm.tir.analysis
:members:
:imported-members:
:autosummary:
54 changes: 54 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/tir/analysis.h
* \brief Analysis utilitie and passes for TIR.
*/
#ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_

#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>

namespace tvm {
namespace tir {

/*!
* \brief Compare two expressions recursively and check if they are equal
* to each other without var remapping.
*
* This function does not remap variable bindings, it will not
* return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
*
* Use StructuralEqual for such cases.
*
* Due to the restriction of not remapping variables, this function can run
* faster than StructuralEqual and can be used as a utility function during arithmetic
* simplifications.
*
* \sa StructuralEqual
*/
struct ExprDeepEqual {
public:
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
11 changes: 11 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,17 @@ class FunctionBaseNode : public Object {
virtual const std::string& func_name() const = 0;
/*! \return the number of outputs of this function */
virtual int num_outputs() const = 0;

// fall back to pointer equality now before refactor.
bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const {
return this == other;
}

void SHashReduce(SHashReducer hash_reduce) const {
}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
};

/*! \brief reference to a function */
Expand Down
29 changes: 0 additions & 29 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,35 +76,6 @@ Stmt CanonicalSimplify(Stmt stmt,
TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
Map<Var, Range> vrange = Map<Var, Range>());

/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
TVM_DLL bool Equal(const PrimExpr& lhs, const PrimExpr& rhs);

/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
bool Equal(const Stmt& lhs, const Stmt& rhs);

/*!
* \brief Deep compare lhs and rhs.
*
* If you only want equality comparison, use Equal
* which will also tie definitions. The compare mode
* will give order of expression in total order.
*
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
int Compare(const PrimExpr& lhs, const PrimExpr& rhs);

/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
from tvm.tir import ir_pass
from tvm.tir import call_pure_intrin
from tvm.tir.stmt import For

Expand All @@ -47,7 +46,7 @@ def _range(annotation, args):
else:
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
low, ext = args[0], args[1]
if not ir_pass.Equal(low, const(0, dtype='int32')):
if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype='int32')):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
iter_var = None
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def concat_list_to_block(lst):
def visit_list_to_block(visit, lst):
"""Visit and concatenate a list of Python IR nodes to HalideIR Block"""
lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
lst = [stmt for stmt in lst if not tvm.ir.structural_equal(stmt, util.make_nop())]
if not lst:
return util.make_nop()
return concat_list_to_block(lst)
Expand Down Expand Up @@ -178,7 +178,7 @@ def add_symbol(self, key, ty, val): #pylint: disable=invalid-name
self.binds[val.var.name] = val
return
val_ = self.binds[val.var.name]
_internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent),
_internal_assert(tvm.tir.analysis.expr_deep_equal(val_.dom.extent, val.dom.extent),
"Thread extents should be uniform!")
self.symbols[key] = ty, val_

Expand Down Expand Up @@ -525,7 +525,7 @@ def visit_For(self, node):
if iter_var is None:
_internal_assert(for_type is not None, "The loop iterating function parse error!")
offset = iter_var = tvm.te.var(_name)
if not _ir_pass.Equal(low, tvm.runtime.const(0, 'int32')):
if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, 'int32')):
offset = iter_var + low
self.add_symbol(_name, Symbol.LoopVar, offset)
_body = visit_list_to_block(self.visit, node.body)
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def structural_equal(lhs, rhs, map_free_vars=False):
structural_hash
assert_strucural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
return bool(tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, False, map_free_vars))

Expand Down Expand Up @@ -225,6 +227,8 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
--------
structural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, True, map_free_vars)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@
from . import ir_builder
from . import ir_pass
from . import transform
from . import analysis
20 changes: 20 additions & 0 deletions python/tvm/tir/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace of all TIR analysis utils."""
# pylint: disable=wildcard-import, invalid-name

from .analysis import *
21 changes: 21 additions & 0 deletions python/tvm/tir/analysis/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tir.analysis"""
import tvm._ffi


tvm._ffi._init_api("tir.analysis", __name__)
57 changes: 57 additions & 0 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Wrapping existing analysis utils."""
# pylint: disable=invalid-name

from . import _ffi_api


def expr_deep_equal(lhs, rhs):
"""Deeply compare two nested expressions.
Parameters
----------
lhs : PrimExpr
The left operand.
rhs : PrimExpr
The right operand.
Returns
-------
result : bool
The comparison result
Note
----
This function does not remap variable bindings, it will not
return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
Use py:func:`tvm.ir.structural_equal` to handle structural variable remapping.
Due to the restriction of not remapping variables, this function can run
faster than StructuralEqual and can be used as a utility function during arithmetic
simplifications.
Always consider py:func:`tvm.ir.structural_equal` first, which handles
the structural remapping.
See Also
--------
tvm.ir.structural_equal
"""
return _ffi_api.expr_deep_equal(lhs, rhs)
4 changes: 3 additions & 1 deletion src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/analysis.h>

#include "const_fold.h"
#include "pattern_match.h"
#include "rewrite_simplify.h"
Expand Down Expand Up @@ -157,7 +159,7 @@ class SplitExpr : public PrimExpr {

inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
if (index.same_as(other->index)) return true;
return tir::Equal(index, other->index);
return tir::ExprDeepEqual()(index, other->index);
}

inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const {
Expand Down
3 changes: 2 additions & 1 deletion src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ class ConstIntBoundAnalyzer::Impl :

Entry VisitExpr(const PrimExpr& expr) final {
Entry res = ExprFunctor::VisitExpr(expr);
tir::ExprDeepEqual equal;
// a linear search over additional info
// assume we won't have a lot of conditions
for (const BoundInfo& info : additional_info_) {
if (tir::Equal(expr, info.expr)) {
if (equal(expr, info.expr)) {
res = Intersect(res, info.bound);
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/arith/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
#define TVM_ARITH_PATTERN_MATCH_H_

#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tuple>
#include "const_fold.h"

Expand Down Expand Up @@ -135,7 +136,7 @@ class PEqualChecker<PrimExpr> {
public:
bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
if (lhs.same_as(rhs)) return true;
return tir::Equal(lhs, rhs);
return tir::ExprDeepEqual()(lhs, rhs);
}
};

Expand Down
9 changes: 5 additions & 4 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ TryCompare(const PrimExpr& x, int64_t val) {
}

void RewriteSimplifier::Impl::
Update(const Var& var, const PrimExpr& info, bool override) {
if (!override) {
Update(const Var& var, const PrimExpr& info, bool can_override) {
if (!can_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(Equal(it->second, info))
CHECK(ExprDeepEqual()(it->second, info))
<< "Trying to update var \'" << var << "\'"
<< " with a different value: "
<< "original=" << it->second
Expand Down Expand Up @@ -1716,10 +1716,11 @@ VisitExpr_(const CallNode* op) {
return op->args[0] & op->args[1];
}
}
ExprDeepEqual expr_equal;
if (op->is_intrinsic(CallNode::likely)) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (Equal(constraint, op->args[0])) {
if (expr_equal(constraint, op->args[0])) {
return make_const(op->dtype, true);
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/arith/stmt_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/arith/analyzer.h>

#include <tvm/tir/op.h>
#include <tvm/arith/analyzer.h>
#include "ir_mutator_with_analyzer.h"
Expand Down Expand Up @@ -83,7 +85,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
op = stmt.as<StoreNode>();
if (const LoadNode* load = op->value.as<LoadNode>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
Equal(load->index, op->index)) {
tir::ExprDeepEqual()(load->index, op->index)) {
return EvaluateNode::make(0);
}
}
Expand Down
1 change: 0 additions & 1 deletion src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ class RemapVarSEqualHandler :
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_rhs_;
};


TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs,
const ObjectRef& rhs,
Expand Down
Loading

0 comments on commit e60003c

Please sign in to comment.