diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index cc64294c92ca..a9a0bed6712a 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -52,6 +52,13 @@ class Analyzer; using tir::Var; +enum DivMode { + /*! \brief Truncated division. */ + kTruncDiv, + /*! \brief Floor division. */ + kFloorDiv +}; + /*! * \brief Constant integer up and lower bound(inclusive). * Useful for value bound analysis. diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index ad044b288941..0ef74ce0d5ce 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -41,6 +41,11 @@ using tir::IterVar; using tir::Var; using tir::VarNode; +// According to experiments two best simplifications orders were can->rw and rw->can->rw, +// but rw->can->rw is better for a couple of cases. +// Also we should end with rw because it factors multipliers out. +constexpr int kSimplifyRewriteCanonicalRewrite = 3; + /*! * \brief Represent integer grouped bounds which are classified into * lower bounds (inclusive), upper bounds (inclusive) and equalities. @@ -251,6 +256,15 @@ class IntConstraintsTransform : public ObjectRef { TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst, Map src_to_dst, Map dst_to_src); + /*! + * \brief Chain-compose two IntConstraintsTransform together. + * this->dst must be the same as other->src. + * @param other another IntConstraintsTransform whose src is same as this->dst. + * @return composed IntConstraintsTransform(this->src, other->dst) + * with its variables and ranges are properly modified. + */ + IntConstraintsTransform operator+(const IntConstraintsTransform& other) const; + TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; @@ -306,6 +320,16 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol */ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve); +/*! + * \brief Combine the information into an array of (in)equalities. + * \param variables The variables in \p bounds. + * It is used to determine the iteration order to avoid indeterministic results. + * \param bounds grouped boundary of the variables. + * \param relations other relations. + */ +Array AsConditions(const Array& variables, const Map& bounds, + const Array& relations); + /*! * \brief Solve linear inequalities and infer the range of each variable. * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 365eb60400d3..c2a198def720 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -1427,6 +1427,22 @@ class Map : public ObjectRef { MapNode* GetMapNode() const { return static_cast(data_.get()); } }; +/*! + * \brief Merge two Maps. + * \param lhs the first Map to merge. + * \param rhs the second Map to merge. + * @return The merged Array. Original Maps are kept unchanged. + */ +template ::value>::type, + typename = typename std::enable_if::value>::type> +inline Map Merge(Map lhs, const Map& rhs) { + for (const auto& p : rhs) { + lhs.Set(p.first, p.second); + } + return std::move(lhs); +} + } // namespace tvm namespace tvm { diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 423ea896aae8..997278589aef 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -996,6 +996,21 @@ class Array : public ObjectRef { } }; +/*! + * \brief Concat two Arrays. + * \param lhs first Array to be concatenated. + * \param rhs second Array to be concatenated. + * \return The concatenated Array. Original Arrays are kept unchanged. + */ +template ::value>::type> +inline Array Concat(Array lhs, const Array& rhs) { + for (const auto& x : rhs) { + lhs.push_back(x); + } + return std::move(lhs); +} + // Specialize make_object to make sure it is correct. template <> inline ObjectPtr make_object() { diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index cbc7a51d5b25..e5b2c2b6957c 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -57,13 +57,20 @@ struct ExprDeepEqual { }; /*! - * \brief Find undefined vars in the statment. + * \brief Find undefined vars in the statement. * \param stmt The function to be checked. * \param defs The vars that is defined. * \return Array of undefined vars. */ TVM_DLL Array UndefinedVars(const Stmt& stmt, const Array& defs); +/*! + * \brief Find undefined vars in the expression. + * \param expr The expression to be checked. + * \return Array of undefined vars. + */ +TVM_DLL Array UndefinedVars(const PrimExpr& expr); + /*! * \brief Analyze the side effect * \param expr The expression to be checked. diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 726289cebf09..a8ef6a1d162c 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -59,13 +59,6 @@ class CanonicalExprNode : public PrimExprNode { TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode); }; -enum DivMode { - /*! \brief Truncated division. */ - kTruncDiv, - /*! \brief Floor division. */ - kFloorDiv -}; - inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { if (mode == kTruncDiv) { return truncmod(a, b); diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index c95f7f855ceb..189869bd64e7 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -38,6 +38,32 @@ namespace tvm { namespace arith { +Array AsConditions(const Array& variables, const Map& bounds, + const Array& relations) { + Array res; + // use variables to keep the order of iteration + // so as to get rid of any non-determinism. + CHECK_EQ(variables.size(), bounds.size()); + for (const auto v : variables) { + CHECK(bounds.count(v)); + const auto& bnds = bounds[v]; + PrimExpr lhs = bnds->coef * v; + for (const PrimExpr& rhs : bnds->equal) { + res.push_back(tir::EQ(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->lower) { + res.push_back(tir::GE(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->upper) { + res.push_back(tir::LE(lhs, rhs)); + } + } + for (const PrimExpr& e : relations) { + res.push_back(e); + } + return res; +} + IntGroupBounds::IntGroupBounds(PrimExpr coef, Array lower, Array equal, Array upper) { CHECK(coef.dtype().is_int() || coef.dtype().is_uint()) @@ -231,6 +257,26 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstrai data_ = std::move(node); } +IntConstraintsTransform IntConstraintsTransform::operator+( + const IntConstraintsTransform& other) const { + CHECK(other->src.same_as(operator->()->dst)); + Map dst_to_src; + Map src_to_dst; + + Analyzer ana_first; + ana_first.Bind(operator->()->src->ranges); + for (auto p : other->dst_to_src) { + dst_to_src.Set(p.first, ana_first.Simplify(Substitute(p.second, operator->()->dst_to_src))); + } + + Analyzer ana_second; + ana_second.Bind(other->dst->ranges); + for (auto p : operator->()->src_to_dst) { + src_to_dst.Set(p.first, ana_second.Simplify(Substitute(p.second, other->src_to_dst))); + } + return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src); +} + TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform") diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index f489d046835d..5744cfdac78f 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -94,35 +94,6 @@ struct ExprLess { } }; -/*! - * \brief Combine the information into an array of (in)equalities. - */ -Array as_conditions(const Array& variables, const Map& bounds, - const Array& relations) { - Array res; - // use variables to keep the order of iteration - // so as to get rid of any non-determinism. - CHECK_EQ(variables.size(), bounds.size()); - for (const auto v : variables) { - CHECK(bounds.count(v)); - const auto& bnds = bounds[v]; - PrimExpr lhs = bnds->coef * v; - for (const PrimExpr& rhs : bnds->equal) { - res.push_back(tir::EQ(lhs, rhs)); - } - for (const PrimExpr& rhs : bnds->lower) { - res.push_back(tir::GE(lhs, rhs)); - } - for (const PrimExpr& rhs : bnds->upper) { - res.push_back(tir::LE(lhs, rhs)); - } - } - for (const PrimExpr& e : relations) { - res.push_back(e); - } - return res; -} - void DebugPrint( const std::unordered_set& current_ineq_set, const std::unordered_set& next_ineq_set, @@ -290,7 +261,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Simplify each inequality into the form `expr <= 0` and add to current formulas for (const PrimExpr& ineq : system_to_solve->relations) { - AddInequality(¤t_ineq_set_to_solve, NormalizeComparisons()(analyzer.Simplify(ineq, 3)), + AddInequality(¤t_ineq_set_to_solve, + NormalizeComparisons()(analyzer.Simplify(ineq, kSimplifyRewriteCanonicalRewrite)), &analyzer); } @@ -307,8 +279,9 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Add bounds from vranges if (system_to_solve->ranges.count(v)) { const Range& range = system_to_solve->ranges[v]; - PrimExpr range_lbound = analyzer.Simplify(range->min, 3); - PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1, 3); + PrimExpr range_lbound = analyzer.Simplify(range->min, kSimplifyRewriteCanonicalRewrite); + PrimExpr range_ubound = + analyzer.Simplify(range->min + range->extent - 1, kSimplifyRewriteCanonicalRewrite); coef_neg.push_back({-1, range_lbound}); coef_pos.push_back({1, -range_ubound}); } @@ -329,7 +302,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 // with steps = 2 it's (y*2) - 10 <= 0 - new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 3)); + new_ineq = + NormalizeComparisons()(analyzer.Simplify(new_ineq, kSimplifyRewriteCanonicalRewrite)); AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer); } } @@ -354,7 +328,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t for (const auto& pos : coef_pos) { PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second; - bound = analyzer.Simplify(bound, 3); + bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(upper_bounds.begin(), upper_bounds.end(), [&bound, &analyzer](const PrimExpr& o) { @@ -375,7 +349,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t } for (const auto& neg : coef_neg) { PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second; - bound = analyzer.Simplify(bound, 3); + bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(lower_bounds.begin(), lower_bounds.end(), [&bound, &analyzer](const PrimExpr& o) { @@ -414,7 +388,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Everything that is left goes to res.relations Array other_conditions; for (const PrimExpr& e : current_ineq_set_to_solve) { - PrimExpr e_simp = analyzer.Simplify(e, 3); + PrimExpr e_simp = analyzer.Simplify(e, kSimplifyRewriteCanonicalRewrite); if (is_const_int(e_simp, 0)) { // contradiction detected other_conditions = {const_false()}; @@ -465,7 +439,8 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // There is an equation of the form `v == expr`, so this variable can be completely removed. // Note that we use the 0-th expression because they are ordered by complexity, // so it must be the simplest one. - Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1, 3)); + Range best_range(bnd->equal[0], + analyzer.Simplify(bnd->equal[0] + 1, kSimplifyRewriteCanonicalRewrite)); res_ranges.Set(var, best_range); vranges.Set(var, best_range); } else { @@ -491,7 +466,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { arith::Analyzer analyzer; analyzer.Bind(vranges); for (const PrimExpr& old_cond : - as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) { + AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) { if (!analyzer.CanProve(old_cond)) { // those not represented in vranges (res_ranges) res_relations.push_back(old_cond); @@ -584,7 +559,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ // Add the original conditions (with variables substituted) to the resulting conditions for (const PrimExpr& old_cond : - as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) { + AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) { PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst)); if (!is_const_int(new_cond, 1)) { // those not represented in vranges (res_ranges) @@ -615,7 +590,7 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " << args.size(); } - *ret = as_conditions(problem->variables, ret_ineq.first, ret_ineq.second); + *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second); }); TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange").set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 79a8da4fea75..b769b65d4b4a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -886,7 +886,6 @@ RELAY_REGISTER_OP("scatter_add") .set_attr("TOpPattern", kOpaque) .set_support_level(10); -//// // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc new file mode 100644 index 000000000000..3860c0038568 --- /dev/null +++ b/src/te/autodiff/ad_simplify.cc @@ -0,0 +1,1231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ad_simplify.cc + * \brief Simplify tensor compute generated by tensor-level autodiff. + * + * The major simplification we do in this file is to eliminate + * the Jacobian tensor created by autodiff. + * + * Jacobian tensor is sparse because one output element usually relates + * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping + * between input tensor and output tensor, thus the Jacobian is diagonal. + * + * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix, + * \alpha and \beta are vectors represent the indices of In and Out respectively. + * i.e., the non-zero Jacobian indices is a linear combination of the input indices. + * Thereby we solve linear equations of \beta = A \alpha, + * as well as linear inequalities of their domain ranges. + * + * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J]. + * arXiv preprint arXiv:1711.01348, 2017. for more details. + * + * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition, + * replace the compute expression with solved new axes, and create a selection node + * (non-zero-condition ? new_compute_expression : 0). + * + * Due to TVM's restriction, we also lift the reduction to the top of the compute stage. + * + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "ad_util.h" + +namespace tvm { +namespace te { + +using arith::DivMode; +using arith::kFloorDiv; +using arith::kSimplifyRewriteCanonicalRewrite; +using arith::kTruncDiv; + +// Combine all expressions from the container using &&. +template +PrimExpr All(const container& c) { + PrimExpr res; + for (const auto& e : c) { + if (res.get()) { + res = res && e; + } else { + res = e; + } + } + if (res.get()) { + return res; + } else { + return const_true(); + } +} + +Map IterVarsToMap(const Array& itervars) { + Map res; + for (const IterVar& v : itervars) { + res.Set(v->var, v->dom); + } + return res; +} + +// Given a map from vars to ranges create an array of itervars +Array IterVarsFromMap(const Array& vars, const Map& vranges, + IterVarType iter_type = kDataPar, std::string thread_tag = "") { + Array res; + for (const Var& v : vars) { + CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map " + << vranges; + res.push_back(IterVar(vranges[v], v, iter_type, thread_tag)); + } + return res; +} + +Array IterVarsToVars(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(v->var); + } + return res; +} + +template +bool is_const_value(const PrimExpr& e, ValueType value) { + static_assert(std::is_integral::value, + "Comparison to non-integer values is forbidden."); + if (const tir::IntImmNode* i = e.as()) { + return i->value == value; + } else if (const tir::FloatImmNode* i = e.as()) { + return i->value == value; + } else if (const tir::CastNode* c = e.as()) { + return is_const_value(c->value, value); + } else if (const tir::BroadcastNode* b = e.as()) { + return is_const_value(b->value, value); + } else { + return false; + } +} + +// Return true if this combiner is just a sum. +bool IsSumCombiner(const CommReducer& combiner, const Map& vranges) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + if (combiner->result.size() != 1) { + return false; + } + + if (!is_const_value( + analyzer.Simplify(combiner->identity_element[0], kSimplifyRewriteCanonicalRewrite), 0)) { + return false; + } + + PrimExpr combiner_result = + analyzer.Simplify(combiner->result[0], kSimplifyRewriteCanonicalRewrite); + + return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) || + tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]); +} + +bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, + const Map& vranges) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index], + kSimplifyRewriteCanonicalRewrite), + 0)) { + return false; + } + + PrimExpr zero = make_zero(combiner->result[value_index].dtype()); + PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero}, + {combiner->rhs[value_index], zero}}); + in = analyzer.Simplify(in, kSimplifyRewriteCanonicalRewrite); + + return is_const_value(in, 0); +} + +struct NonzeroConditionResult { + PrimExpr cond; + PrimExpr value; + + PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); } + + friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) { + return os << r.to_expr(); + } +}; + +// The implementation of NonzeroCondition +// transform expression to cond ? value : 0 +class NonzeroConditionFunctor : public ExprFunctor { + public: + NonzeroConditionResult NonzeroCondition(const PrimExpr& e) { + if (e.dtype().is_bool()) { + // Boolean expressions are non-zero whenever they are true themselves + return {e, const_true()}; + } else { + return VisitExpr(e); + } + } + + // Most of the cases are implemented using helpers below + result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef(op)); } + result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef(op)); } + result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef(op)); } + result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef(op)); } + result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef(op)); } + result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef(op)); } + result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef(op)); } + result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef
(op)); } + result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef(op)); } + result_type VisitExpr_(const FloorDivNode* op) final { + return BinOpDivLike_(GetRef(op)); + } + result_type VisitExpr_(const FloorModNode* op) final { + return BinOpDivLike_(GetRef(op)); + } + result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef(op)); } + result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef(op)); } + + result_type VisitExpr_(const CastNode* op) final { + auto nz_a = NonzeroCondition(op->value); + return {nz_a.cond, Cast(op->dtype, nz_a.value)}; + } + + result_type VisitExpr_(const SelectNode* op) final { + PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value; + auto nz_a = NonzeroCondition(true_val); + auto nz_b = NonzeroCondition(false_val); + + // If the false part is zero, we can get rid of the select + if (is_const_value(nz_b.value, 0)) { + PrimExpr new_cond = analyzer_.Simplify(nz_a.cond && cond, kSimplifyRewriteCanonicalRewrite); + return {new_cond, nz_a.value}; + } + + // If the true part is zero, we can also get rid of the select + if (is_const_value(nz_a.value, 0)) { + PrimExpr new_cond = analyzer_.Simplify(nz_b.cond && !cond, kSimplifyRewriteCanonicalRewrite); + return {new_cond, nz_b.value}; + } + + // Otherwise we retain the select and combine the conditions into this + PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), + kSimplifyRewriteCanonicalRewrite); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, GetRef(op)}; + } else { + return {new_cond, Select(cond, nz_a.value, nz_b.value)}; + } + } + + result_type VisitExpr_(const CallNode* op) final { + if (op->op.same_as(op_if_then_else_)) { + PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2]; + auto nz_a = NonzeroCondition(true_val); + auto nz_b = NonzeroCondition(false_val); + + // We don't have as much freedom here as in the select case + // since the `if` must be preserved in any case + PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), + kSimplifyRewriteCanonicalRewrite); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, GetRef(op)}; + } else { + return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)}; + } + } else { + return Default_(GetRef(op)); + } + } + + result_type VisitExpr_(const ProducerLoadNode* op) final { + return Default_(GetRef(op)); + } + + NonzeroConditionResult Default_(const PrimExpr& e) { + // This is always correct, so it's the default + return {const_true(), e}; + } + + template + NonzeroConditionResult Const_(const T& op) { + if (op->value == 0) { + return {const_false(), op}; + } else { + return {const_true(), op}; + } + } + + template + NonzeroConditionResult BinOpAddLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + auto nz_b = NonzeroCondition(op->b); + + // For addition and similar ops the result may be nonzero if either of the arguments is + // nonzero, so we combine the conditions with Or. + if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) { + // If the conditions are the same, we don't need Or + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {nz_a.cond, op}; + } else { + return {nz_a.cond, T(nz_a.value, nz_b.value)}; + } + } else { + // Otherwise use Or + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond || nz_b.cond, kSimplifyRewriteCanonicalRewrite); + // A little optimization: if the combined condition is the same as one of the inner + // conditions, we don't need to guard the inner value with a select, otherwise + // we create a select in the `to_expr` call. + PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); + PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); + PrimExpr new_expr = T(new_a, new_b); + return {new_cond, new_expr}; + } + } + + template + NonzeroConditionResult BinOpMulLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + auto nz_b = NonzeroCondition(op->b); + + // For multiplication and similar ops the result may be nonzero if + // both the arguments are nonzero, so we combine with And. + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond && nz_b.cond, kSimplifyRewriteCanonicalRewrite); + + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {new_cond, op}; + } else { + return {new_cond, T(nz_a.value, nz_b.value)}; + } + } + + template + NonzeroConditionResult BinOpDivLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + + // For Div we simply use the condition of the numerator. + + if (nz_a.value.same_as(op->a)) { + return {nz_a.cond, op}; + } else { + return {nz_a.cond, T(nz_a.value, op->b)}; + } + } + + private: + arith::Analyzer analyzer_; + const Op& op_if_then_else_ = Op::Get("tir.if_then_else"); +}; + +inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) { + return NonzeroConditionFunctor().NonzeroCondition(expr); +} + +struct FactorOutAtomicFormulasResult { + std::vector atomic_formulas; + PrimExpr rest; + + PrimExpr to_expr() const { + PrimExpr res = rest; + for (const PrimExpr& e : atomic_formulas) { + res = And(e, res); + } + return res; + } + + Array to_array() const { + Array res = atomic_formulas; + res.push_back(rest); + return res; + } +}; + +// The implementation of FactorOutAtomicFormulas +class FactorOutAtomicFormulasFunctor + : public ExprFunctor { + public: + result_type Atomic_(const PrimExpr& e) { + // For atomic expressions the result is the expr itself with True as the residual + return {{e}, make_const(e.dtype(), 1)}; + } + + // This is basically the list of expression kinds that are considered atomic + result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef(op)); } + result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef(op)); } + result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef(op)); } + result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef(op)); } + result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef(op)); } + result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef(op)); } + result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef(op)); } + result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef(op)); } + result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef(op)); } + + result_type VisitExpr_(const SelectNode* op) final { + // Select can be rewritten through other logical ops + PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value); + return VisitExpr(expr); + } + + result_type VisitExpr_(const NotNode* op) final { + // Not should be moved down + if (const OrNode* or_expr = op->a.as()) { + PrimExpr expr = !or_expr->a && !or_expr->b; + return VisitExpr(expr); + } else if (const AndNode* and_expr = op->a.as()) { + PrimExpr expr = !and_expr->a || !and_expr->b; + return VisitExpr(expr); + } else if (const SelectNode* sel_expr = op->a.as()) { + PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) && + (sel_expr->condition || !sel_expr->false_value)); + return VisitExpr(expr); + } + return Atomic_(GetRef(op)); + } + + result_type VisitExpr_(const AndNode* op) final { + auto res_a = VisitExpr(op->a); + auto res_b = VisitExpr(op->b); + + // For the And case we return the union of the sets of atomic formulas + std::unordered_set res_set; + res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + std::inserter(res_set, res_set.end())); + std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::inserter(res_set, res_set.end())); + + std::vector res{res_set.begin(), res_set.end()}; + + // And the residuals are combined with && + return {res, res_a.rest && res_b.rest}; + } + + result_type VisitExpr_(const MulNode* op) final { + // Since we work with bools, for multiplication we do the same thing as for And + PrimExpr e_and = op->a && op->b; + return VisitExpr(e_and); + } + + result_type VisitExpr_(const OrNode* op) final { + auto res_a = VisitExpr(op->a); + auto res_b = VisitExpr(op->b); + + std::unordered_set res_a_set{ + res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()}; + std::unordered_set res_b_set{ + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()}; + + // For the Or case we intersect the sets of atomic formulas + std::unordered_set res_set; + res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); + for (const auto& res_b_formula : res_b_set) { + if (res_a_set.count(res_b_formula)) { + res_set.insert(res_b_formula); + } + } + + // Computing the residual is more complex: we have to compute the sets of atomic formulas + // which are left behind, and then combine them with the residuals into the new residual. + std::vector new_cond_a; + new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size()); + for (const auto& formula : res_a_set) { + if (!res_set.count(formula)) new_cond_a.emplace_back(formula); + } + + std::vector new_cond_b; + new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size()); + for (const auto& formula : res_b_set) { + if (!res_set.count(formula)) new_cond_b.emplace_back(formula); + } + + res_a.atomic_formulas = std::move(new_cond_a); + res_b.atomic_formulas = std::move(new_cond_b); + + PrimExpr new_rest = res_a.to_expr() || res_b.to_expr(); + std::vector res{res_set.begin(), res_set.end()}; + + return {res, new_rest}; + } +}; + +// Transform the given formula into a conjunction of atomic formulas (represented as an array) +// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b, +// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level. +FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) { + CHECK(e.dtype().is_bool()); + return FactorOutAtomicFormulasFunctor().VisitExpr(e); +} + +struct EliminateDivModResult { + PrimExpr expr; + Map substitution; + Array new_variables; + Array conditions; + Map ranges; +}; + +inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { + if (mode == kTruncDiv) { + return truncmod(a, b); + } else { + CHECK_EQ(mode, kFloorDiv); + return floormod(a, b); + } +} + +inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { + if (mode == kTruncDiv) { + return truncdiv(a, b); + } else { + CHECK_EQ(mode, kFloorDiv); + return floordiv(a, b); + } +} + +class EliminateDivModMutator : public ExprMutator { + public: + Map substitution; + Array new_variables; + Array conditions; + Map ranges; + + explicit EliminateDivModMutator(Map ranges) : ranges(std::move(ranges)) {} + + virtual PrimExpr VisitExpr_(const DivNode* op) { + const IntImmNode* imm = op->b.as(); + if (imm && imm->value != 0) { + if (imm->value < 0) { + // x / -c == -(x/c) for truncated division + return make_zero(op->dtype) - + VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value))); + } + + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value)); + if (it != expr_to_vars_.end()) { + return it->second.first; + } + + // Otherwise recursively mutate the left hand side, and create new variables + PrimExpr mutated_a = VisitExpr(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) { + return var_pair_opt.value().first; + } else { + return truncdiv(mutated_a, op->b); + } + } + + return truncdiv(VisitExpr(op->a), VisitExpr(op->b)); + } + + virtual PrimExpr VisitExpr_(const ModNode* op) { + const IntImmNode* imm = op->b.as(); + if (imm && imm->value != 0) { + if (imm->value < 0) { + // x % -c == x % c for truncated division + return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value))); + } + + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value)); + if (it != expr_to_vars_.end()) { + return it->second.second; + } + + // Otherwise recursively mutate the left hand side, and create new variables + PrimExpr mutated_a = VisitExpr(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) { + return var_pair_opt.value().second; + } else { + return truncmod(mutated_a, op->b); + } + } + + return truncmod(VisitExpr(op->a), VisitExpr(op->b)); + } + + virtual PrimExpr VisitExpr_(const FloorDivNode* op) { + const IntImmNode* imm = op->b.as(); + if (imm && imm->value != 0) { + if (imm->value < 0) { + // x / -c == (-x) / c for flooring division + return VisitExpr( + floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value))); + } + + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value)); + if (it != expr_to_vars_.end()) { + return it->second.first; + } + + // Otherwise recursively mutate the left hand side, and create new variables + PrimExpr mutated_a = VisitExpr(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) { + return var_pair_opt.value().first; + } else { + return floordiv(mutated_a, op->b); + } + } + + return floordiv(VisitExpr(op->a), VisitExpr(op->b)); + } + + virtual PrimExpr VisitExpr_(const FloorModNode* op) { + const IntImmNode* imm = op->b.as(); + if (imm && imm->value != 0) { + if (imm->value < 0) { + // x % -c == -(-x % c) for flooring division + return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a, + make_const(op->dtype, -imm->value))); + } + + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value)); + if (it != expr_to_vars_.end()) { + return it->second.second; + } + + // Otherwise recursively mutate the left hand side, and create new variables + PrimExpr mutated_a = VisitExpr(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) { + return var_pair_opt.value().second; + } else { + return floormod(mutated_a, op->b); + } + } + + return floormod(VisitExpr(op->a), VisitExpr(op->b)); + } + + private: + dmlc::optional> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut, + int64_t val, DivMode mode) { + using tresult = dmlc::optional>; + + // Try to find the variables using the mutated expressions + if (!e.same_as(mut)) { + auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val)); + if (it != expr_to_vars_.end()) { + return tresult(it->second); + } + } + + PrimExpr val_e = make_const(e.dtype(), val); + idx_ += 1; + + // Convert `ranges` to IntSets + std::unordered_map var_intsets; + for (const auto& p : ranges) { + var_intsets[p.first.get()] = IntSet::FromRange(p.second); + } + + // Infer ranges for the expressions we want to replace with variables + Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range()); + Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range()); + + // We don't want to add unbounded variables + if (!div_range.get() || !mod_range.get()) { + LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode) + << " because its bounds cannot be inferred"; + return tresult(); + } + if (!mod_range.get()) { + LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode) + << " because its bounds cannot be inferred"; + return tresult(); + } + + // Create new variables for the expressions + auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype()); + auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype()); + + new_variables.push_back(div); + new_variables.push_back(mod); + + // Note that we have to perform substitution to mut because mut may contain new variables + substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode)); + substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode)); + + ranges.Set(div, div_range); + ranges.Set(mod, mod_range); + + // This additional condition works as a definition for the new variables + conditions.push_back(mut == div * val_e + mod); + + if (!analyzer_.CanProve(mod_range->extent <= val_e)) { + // If we use the C/C++ definition of mod, there may be multiple values of `mod` + // satisfying the added condition if the expr `e` may change its sign, so we + // have to add another condition. + LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because " + << ModImpl(e, val_e, mode) << " probably may change its sign"; + conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0)); + } + + auto p = std::make_pair(div, mod); + expr_to_vars_[std::make_tuple(mode, e, val)] = p; + if (!e.same_as(mut)) { + expr_to_vars_[std::make_tuple(mode, mut, val)] = p; + } + return tresult(p); + } + + class TupleEqual_ { + public: + bool operator()(const std::tuple& lhs, + const std::tuple& rhs) const { + return std::get<0>(lhs) == std::get<0>(rhs) && + tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) && + std::get<2>(lhs) == std::get<2>(rhs); + } + }; + + class TupleHasher_ { + public: + size_t operator()(const std::tuple& key) const { + return ((std::hash()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >> + 1) ^ + (std::hash()(std::get<2>(key)) << 1); + } + }; + + // A counter for naming new variables + int idx_{0}; + // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod) + // such that `div = e / n` and `mod = e % n` + std::unordered_map, std::pair, TupleHasher_, + TupleEqual_> + expr_to_vars_; + arith::Analyzer analyzer_; +}; + +// Replace every subexpr of the form e/const and e % const with a new variable. +// Syntactically equal expressions will be mapped to the same variable. +EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map ranges) { + EliminateDivModResult res; + EliminateDivModMutator mutator(ranges); + res.expr = mutator(expr); + res.conditions = std::move(mutator.conditions); + res.new_variables = std::move(mutator.new_variables); + res.substitution = std::move(mutator.substitution); + res.ranges = std::move(mutator.ranges); + return res; +} + +arith::IntConstraintsTransform EliminateDivModFromDomainConditions( + const arith::IntConstraints& domain) { + auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges); + + Map new_vranges = elim_res.ranges; + Array new_axis = Concat(domain->variables, elim_res.new_variables); + PrimExpr new_cond = elim_res.expr && All(elim_res.conditions); + + arith::IntConstraints new_domain(new_axis, new_vranges, + FactorOutAtomicFormulas(new_cond).to_array()); + + Map src_to_dst; + Map dst_to_src = elim_res.substitution; + for (const Var& v : domain->variables) { + src_to_dst.Set(v, v); + dst_to_src.Set(v, v); + } + + return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src); +} + +inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) { + Map identity_map; + for (const Var& v : domain->variables) { + identity_map.Set(v, v); + } + return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map); +} + +// Simplify an iteration domain. +arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains, + bool eliminate_div_mod) { + arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains); + + if (eliminate_div_mod) { + transf = transf + EliminateDivModFromDomainConditions(transf->dst); + } + + // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably + // should find a better terminating criterion (like stop when the domain volume stops decreasing) + // Also 2 steps seems to be slightly better than 3 + for (size_t i = 0; i < 2; ++i) { + transf = transf + arith::SolveLinearEquations(transf->dst); + transf = transf + arith::SolveInequalitiesDeskewRange(transf->dst); + } + + return transf; +} + +// Use the condition of a reduction op to simplify its domain (axis) +PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map& outer_vranges) { + if (const ReduceNode* red = expr.as()) { + Array vars = IterVarsToVars(red->axis); + Map vranges = Merge(outer_vranges, IterVarsToMap(red->axis)); + Array relations = FactorOutAtomicFormulas(red->condition).to_array(); + + arith::IntConstraints domain(vars, vranges, relations); + auto res = SimplifyDomain(domain); + + Array new_source; + for (const PrimExpr& src : red->source) { + new_source.push_back(Substitute(src, res->src_to_dst)); + } + + Array new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce); + + // Perform simplification mainly to remove a possibly empty reduction. + arith::Analyzer analyzer; + return analyzer.Simplify( + Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index), + kSimplifyRewriteCanonicalRewrite); + } else { + return expr; + } +} + +// Extract from cond an implication of cond not containing vars +std::pair ImplicationNotContainingVars( + const PrimExpr& cond, const std::unordered_set& vars) { + CHECK(cond.dtype().is_bool()) << "The type of cond must be bool"; + // TODO(sgrechanik-h): NOTs could be pushed down using De Morgan laws + // before running this function but this case didn't seem to be important enough. + if (const AndNode* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {pair_a.first && pair_b.first, pair_a.second && pair_b.second}; + } else if (const OrNode* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) && + (pair_b.first || pair_a.second) && + (pair_a.second || pair_b.second)}; + } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { + return {cond, const_true()}; + } else { + return {const_true(), cond}; + } +} + +// Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out +// (in)equalities which do not depend on the reduction variables. +std::pair LiftConditionsThroughReduction(const PrimExpr& cond, + const Array& red_axis, + const Array& outer_axis) { + // Factor out atomics so that we can consider this as a system of inequalities + auto factor_atomic_res = FactorOutAtomicFormulas(cond); + Array atomics = factor_atomic_res.atomic_formulas; + const PrimExpr& rest = factor_atomic_res.rest; + + Array allvars; + for (const IterVar& v : red_axis) { + allvars.push_back(v->var); + } + for (const IterVar& v : outer_axis) { + allvars.push_back(v->var); + } + + auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis)); + // start from reduction vars, so that input vars don't depend on them + arith::IntConstraints ineq_to_solve(allvars, vranges, atomics); + auto res_ineq = arith::SolveLinearInequalities(ineq_to_solve); + atomics = arith::AsConditions(allvars, res_ineq.first, res_ineq.second); + + // Append the rest part + PrimExpr rewritten_cond = All(atomics) && rest; + + std::unordered_set vset; + for (const IterVar& v : red_axis) { + vset.insert(v->var.get()); + } + + // The outer (first) condition does not contain reduction vars, + // the inner (second) condition is everything else + auto res = ImplicationNotContainingVars(rewritten_cond, vset); + return res; +} + +// Convert an array of itervars to an array of inequalities +Array IterVarsToInequalities(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(GE(v->var, v->dom->min)); + res.push_back(LT(v->var, v->dom->min + v->dom->extent)); + } + return res; +} + +class RemoveRedundantInequalitiesMutator : public ExprMutator { + public: + explicit RemoveRedundantInequalitiesMutator(Array known) { + for (const PrimExpr& cond : known) { + known_.push_back(analyzer_.Simplify(cond, kSimplifyRewriteCanonicalRewrite)); + } + } + + virtual PrimExpr VisitExpr_(const SelectNode* op) { + bool has_side_effect = (SideEffect(GetRef(op)) > CallEffectKind::kReadState); + PrimExpr new_cond = + analyzer_.Simplify(VisitExpr(op->condition), kSimplifyRewriteCanonicalRewrite); + if (is_one(new_cond) && !has_side_effect) { + return VisitExpr(op->true_value); + } else if (is_zero(new_cond) && !has_side_effect) { + return VisitExpr(op->false_value); + } else { + Array new_known = known_; + for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return Select(new_cond, new_mutator(op->true_value), VisitExpr(op->false_value)); + } + } + + virtual PrimExpr VisitExpr_(const CallNode* op) { + if (op->op.same_as(op_if_then_else_)) { + PrimExpr new_cond = + analyzer_.Simplify(VisitExpr(op->args[0]), kSimplifyRewriteCanonicalRewrite); + if (is_one(new_cond)) { + return VisitExpr(op->args[1]); + } else if (is_zero(new_cond)) { + return VisitExpr(op->args[2]); + } else { + Array new_known = known_; + for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return if_then_else(new_cond, new_mutator(op->args[1]), VisitExpr(op->args[2])); + } + } else { + return ExprMutator::VisitExpr_(op); + } + } + + virtual PrimExpr VisitExpr_(const ReduceNode* op) { + Array known_with_axes = known_; + for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) { + known_with_axes.push_back(axis_cond); + } + RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes); + + PrimExpr new_cond = mutator_with_axes(op->condition); + + Array new_known = known_with_axes; + for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + + Array new_source; + for (const PrimExpr& src : op->source) { + new_source.push_back(new_mutator(src)); + } + + return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index); + } + + virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef(op)); } + virtual PrimExpr VisitExpr_(const NENode* op) { return MutateAtomic_(GetRef(op)); } + virtual PrimExpr VisitExpr_(const LTNode* op) { return MutateAtomic_(GetRef(op)); } + virtual PrimExpr VisitExpr_(const LENode* op) { return MutateAtomic_(GetRef(op)); } + virtual PrimExpr VisitExpr_(const GTNode* op) { return MutateAtomic_(GetRef(op)); } + virtual PrimExpr VisitExpr_(const GENode* op) { return MutateAtomic_(GetRef(op)); } + + virtual PrimExpr VisitExpr_(const AndNode* op) { return VisitExpr(op->a) && VisitExpr(op->b); } + + private: + PrimExpr MutateAtomic_(const PrimExpr& e) { + PrimExpr simplified = analyzer_.Simplify(e, kSimplifyRewriteCanonicalRewrite); + for (const PrimExpr& other : known_) { + if (ExprDeepEqual()(simplified, other)) { + return const_true(); + } + } + return simplified; + } + + Array known_; + arith::Analyzer analyzer_; + const Op& op_if_then_else_ = Op::Get("tir.if_then_else"); +}; + +// Propagate information from conditions and remove redundant inequalities +inline PrimExpr RemoveRedundantInequalities(const PrimExpr& expr, const Array& known) { + return RemoveRedundantInequalitiesMutator(known)(expr); +} + +// Extract the given expr under the given condition as a separate tensor if the volume of the +// extracted tensor will be less than the volume of the outer_axis +PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond, + const Array& outer_axis, const Map& vranges) { + // solve cond, e.g., (jac_i0 == i) && (jac_i1 == j) + arith::IntConstraints domain_to_solve(outer_axis, vranges, + FactorOutAtomicFormulas(cond).to_array()); + auto res = SimplifyDomain(domain_to_solve); + + arith::Analyzer analyzer; + analyzer.Bind(res->dst->ranges); + PrimExpr new_expr = + analyzer.Simplify(Substitute(expr, res->src_to_dst), kSimplifyRewriteCanonicalRewrite); + // TODO(yzhliu): This is mostly done to simplify if_then_else + // which is not realized by the canonical simplifier + new_expr = RemoveRedundantInequalities(new_expr, res->dst->relations); + + // Keep only those variables of the new vars which are used in the new_expr + Array used_res_variables; + for (const Var& var : res->dst->variables) { + if (ExprUseVar(new_expr, var)) { + CHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred."; + used_res_variables.push_back(var); + } + } + + // If the expression does not use vars then it is probably better to keep it inlined + if (used_res_variables.empty()) { + // We can return the new_expr here instead of the old expr because it doesn't use variables + // otherwise we would need to replace the new vars or create a let-expression + return new_expr; + } + + // If it's already tensor[...] then it will probably be useless to further simplify it. + if (new_expr.as()) { + return expr; + } + + // Compute volumes before and after + PrimExpr old_volume = make_const(DataType::Int(64), 1); + for (const Var& var : outer_axis) { + CHECK(vranges.count(var)) << "Range of " << var << " was not provided."; + old_volume = old_volume * vranges[var]->extent; + } + + PrimExpr new_volume = make_const(DataType::Int(64), 1); + for (const Var& var : used_res_variables) { + new_volume = new_volume * res->dst->ranges[var]->extent; + } + + // if we can prove that the old volume is not greater than the new volume then + // prefer the old expression. + arith::Analyzer ana_vranges; + ana_vranges.Bind(vranges); + if (ana_vranges.CanProve(old_volume <= new_volume)) { + return expr; + } + + Tensor tensor = TensorFromExpr(new_expr, IterVarsFromMap(used_res_variables, res->dst->ranges), + "extracted_tensor"); + + Array args; + for (const Var& var : used_res_variables) { + args.push_back(res->dst_to_src[var]); + } + + return ProducerLoad(tensor, args); +} + +class ReductionAsTensorAccessMutator : public ExprMutator { + public: + explicit ReductionAsTensorAccessMutator(const Array& outer_axis, Map vranges, + std::string name = "extracted_reduction") + : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {} + + PrimExpr VisitExpr_(const ReduceNode* op) final { + ReductionAsTensorAccessMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_), + Merge(vranges_, IterVarsToMap(op->axis)), name_); + + Array new_source; + for (const PrimExpr& src : op->source) { + new_source.push_back(new_mutator(src)); + } + + PrimExpr new_reduce = + Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index); + + Array undefined_vars = UndefinedVars(new_reduce); + std::unordered_set undefined_var_set; + for (const Var& var : undefined_vars) { + undefined_var_set.insert(var.get()); + } + + // Vars of the tensor we are going to create for this reduction + Array vars; + for (const Var& v : outer_axis_) { + // We take variables from the outer_axis_ which are also present in the new reduction + if (undefined_var_set.count(v.get())) { + vars.push_back(v); + } + } + + auto new_axis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_)); + Array new_axis = new_axis_vmap_pair.first; + arith::Analyzer analyzer; + analyzer.Bind(IterVarsToMap(new_axis)); + new_reduce = analyzer.Simplify(Substitute(new_reduce, new_axis_vmap_pair.second), + kSimplifyRewriteCanonicalRewrite); + + Tensor tensor = TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_); + + Array args; + for (const Var& v : vars) { + args.push_back(v); + } + + return ProducerLoad(tensor, args); + } + + private: + Array outer_axis_; + Map vranges_; + std::string name_; + std::string tag_; + Map attrs_; +}; + +// Extract reductions as separate tensors. +inline PrimExpr ReductionAsTensorAccess(const PrimExpr& expr, const Array& outer_axis, + const Map& vranges) { + return ReductionAsTensorAccessMutator(outer_axis, vranges)(expr); +} + +PrimExpr LiftReductions(const PrimExpr& expr, const Array& outer_axis, + const Map& vranges) { + if (const ReduceNode* red = expr.as()) { + Array new_outer_axis = Concat(IterVarsToVars(red->axis), outer_axis); + Map new_vranges = Merge(vranges, IterVarsToMap(red->axis)); + Array new_source; + for (const PrimExpr& src : red->source) { + new_source.push_back(ReductionAsTensorAccess(src, new_outer_axis, new_vranges)); + } + PrimExpr new_condition = ReductionAsTensorAccess(red->condition, new_outer_axis, new_vranges); + + return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index); + } else { + return ReductionAsTensorAccess(expr, outer_axis, vranges); + } +} + +PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const Array& axis, + const Map& vranges) { + PrimExpr result; + Map combined_vranges = Merge(vranges, IterVarsToMap(axis)); + arith::Analyzer analyzer; + analyzer.Bind(combined_vranges); + + // Simplify the original expression first, mostly to simplify combiners + PrimExpr expr = analyzer.Simplify(expr_orig, kSimplifyRewriteCanonicalRewrite); + + if (const ReduceNode* red = expr.as()) { + // TODO(sgrechanik-h): There are some other operations which behave like sum + bool is_sum = IsSumCombiner(red->combiner, vranges); + if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) { + PrimExpr new_red = expr; + + // Here we simplify the reduction + PrimExpr cond = red->condition; + Array source = red->source; + + // If it is a summation then we can lift nonzeroness conditions from the source + // and add them to the reduction conditions + if (is_sum) { + auto nz = NonzeronessCondition(red->source[red->value_index]); + cond = nz.cond && cond; + source.Set(0, nz.value); + } + + new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index); + new_red = SimplifyReductionDomain(new_red, combined_vranges); + // If the reduction disappears completely then transform the result as a non-reduction + if (!new_red.as()) { + return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges); + } + + PrimExpr new_outer_cond, new_reduce_cond; + Array new_source = red->source; + + // Partially lift conditions from the reduce condition + std::tie(new_outer_cond, new_reduce_cond) = + LiftConditionsThroughReduction(red->condition, red->axis, axis); + + // If it's not sum then we haven't yet lifted nonzeroness cond from the source + if (!is_sum) { + PrimExpr outer_nz_cond, nz_cond, nz_source; + auto nz = NonzeronessCondition(red->source[red->value_index]); + // Append conditions from the reduction + nz_cond = new_reduce_cond && nz.cond; + nz_source = nz.value; + std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis); + new_outer_cond = new_outer_cond && outer_nz_cond; + new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype()))); + } + + PrimExpr new_reduce = + Reduce(red->combiner, new_source, red->axis, new_reduce_cond, red->value_index); + new_reduce = + TrySimplifyCompute(new_reduce, new_outer_cond, IterVarsToVars(axis), combined_vranges); + result = Select(new_outer_cond, new_reduce, make_zero(new_reduce.dtype())); + } else { + return SimplifyReductionDomain(expr, combined_vranges); + } + } else { + auto nz = NonzeronessCondition(expr); + PrimExpr new_expr = + TrySimplifyCompute(nz.value, nz.cond, IterVarsToVars(axis), combined_vranges); + result = Select(nz.cond, new_expr, make_zero(new_expr.dtype())); + } + + // Note that RemoveRedundantInequalities can sometimes propagate equalities which + // other simplifiers cannot, like (i % 3) == 0. + Array axis_conds = IterVarsToInequalities(axis); + result = RemoveRedundantInequalities(result, axis_conds); + + // Currently in TVM reductions are only allowed at the top level of compute, + // we need to extract intermediate inlined reduction as a separate stage (tensor). + // Sometimes TrySimplifyCompute doesn't perform lift / extraction, + // so there may be some non-top reductions left, take care of them. + result = LiftReductions(result, IterVarsToVars(axis), combined_vranges); + return analyzer.Simplify(result, kSimplifyRewriteCanonicalRewrite); +} + +Tensor RemoveJacobianAndLiftNonzeroCond(const Tensor& tensor, const Map& vranges) { + auto transform_func = [&vranges](const PrimExpr& expr, const Array& axis) { + return RemoveJacobianAndLiftNonzeroCondImpl(expr, axis, vranges); + }; + return TransformTensorBody(tensor, transform_func); +} + +} // namespace te +} // namespace tvm diff --git a/src/te/autodiff/ad_util.cc b/src/te/autodiff/ad_util.cc index 89ff96d4724b..995c8e0c3170 100644 --- a/src/te/autodiff/ad_util.cc +++ b/src/te/autodiff/ad_util.cc @@ -26,8 +26,11 @@ #include #include +#include #include +#include "../schedule/operation_inline.h" + namespace tvm { namespace te { @@ -60,5 +63,134 @@ PrimExpr CloneReduction(const PrimExpr& expr) { } } +Operation ComputeOpFromExprs(const Array& exprs, const Array& axis, + const std::string& name, const std::string& tag, + const Map& attrs, bool clone_axis) { + if (clone_axis) { + Array new_axis = axis; + Map vmap; + std::tie(new_axis, vmap) = CloneIterVars(axis); + Array new_exprs; + for (const PrimExpr& e : exprs) { + new_exprs.push_back(Substitute(CloneReduction(e), vmap)); + } + return ComputeOpFromExprs(new_exprs, new_axis, name, tag, attrs, false); + } + + Array new_exprs; + + // If this is a reduction then we have to replicate it + if (const ReduceNode* red = exprs[0].as()) { + for (size_t i = 0; i < red->source.size(); ++i) { + PrimExpr ith_red = Reduce(red->combiner, red->source, red->axis, red->condition, i); + new_exprs.push_back(ith_red); + } + } else { + new_exprs = exprs; + } + + return ComputeOp(name, tag, attrs, axis, new_exprs); +} + +Tensor TensorFromExpr(const PrimExpr& expr, const Array& axis, const std::string& name, + const std::string& tag, const Map& attrs, + bool clone_axis) { + int new_value_index = 0; + if (const ReduceNode* red = expr.as()) { + new_value_index = red->value_index; + } + return ComputeOpFromExprs({expr}, axis, name, tag, attrs, clone_axis).output(new_value_index); +} + +Tensor TransformTensorBody( + const Tensor& tensor, + const std::function&)>& func) { + if (const ComputeOpNode* op = tensor->op.as()) { + // Transform only one body + PrimExpr new_body = func(op->body[tensor->value_index], op->axis); + + // If the body didn't change then we can return the same tensor + if (new_body.same_as(op->body[tensor->value_index])) { + return tensor; + } + + return TensorFromExpr(new_body, op->axis, op->name, op->tag, op->attrs); + } else { + return tensor; + } +} + +Tensor TransformTensorBody(const Tensor& tensor, + const std::function& func) { + return TransformTensorBody(tensor, + [func](const PrimExpr& e, const Array&) { return func(e); }); +} + +// If expr is a Tensor Access node, perform inlining, otherwise do nothing +PrimExpr InlineImmediateTensorAccess(const PrimExpr& expr) { + if (const ProducerLoadNode* op = expr.as()) { + auto tensor = Downcast(op->producer); + if (const ComputeOpNode* op_comp = tensor->op.as()) { + Array tensor_axes; + for (const auto& var : op_comp->axis) { + tensor_axes.push_back(var->var); + } + + Stmt inlined = + Inline(Evaluate(expr), tensor->op, tensor_axes, op_comp->body[tensor->value_index]); + if (const EvaluateNode* ev = inlined.as()) { + // If it is a reduction, clone it + return CloneReduction(ev->value); + } + } + } + return expr; +} + +// Implements InlineTensors by trying to inline every Call of the given Expr +class InlineTensorsMutator : public ExprMutator { + public: + explicit InlineTensorsMutator(const Array& inlineable, bool inline_reductions = false) + : inline_reductions_(inline_reductions) { + for (const Tensor& tensor : inlineable) { + inlineable_.emplace(tensor->op.operator->(), tensor->value_index); + } + } + + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + auto tensor = Downcast(op->producer); + if (const ComputeOpNode* op_comp = tensor->op.as()) { + // Inline only if the array of inlineable tensors is empty or contains this tensor + if (inlineable_.empty() || inlineable_.count({op_comp, tensor->value_index})) { + // Inline only compute nodes that are not reductions (unless inline reductions is allowed) + if (inline_reductions_ || !op_comp->body[0].as()) { + PrimExpr expr = GetRef(op); + // Inline this tensor access and then try to perform further inlining + return VisitExpr(InlineImmediateTensorAccess(expr)); + } + } + } + // If we cannot inline this call, we should try to do inlining in its arguments + return ExprMutator::VisitExpr_(op); + } + + private: + // Tensors which are allowed to be inlined, represented as pairs (op_node, value_index) + std::set> inlineable_; + bool inline_reductions_; +}; + +Tensor InlineTensorAccess(const Tensor& tensor, const Array& inlineable, + bool inline_reductions) { + auto transformation = [inlineable, inline_reductions](const PrimExpr& e) { + return InlineTensorsMutator(inlineable, inline_reductions)(e); + }; + return TransformTensorBody(tensor, transformation); +} + +Tensor InlineTailTensorAccess(const Tensor& tensor) { + return TransformTensorBody(tensor, InlineImmediateTensorAccess); +} + } // namespace te } // namespace tvm diff --git a/src/te/autodiff/ad_util.h b/src/te/autodiff/ad_util.h index 56ab6c18b929..21de61cc46c2 100644 --- a/src/te/autodiff/ad_util.h +++ b/src/te/autodiff/ad_util.h @@ -24,9 +24,11 @@ #ifndef TVM_TE_AUTODIFF_AD_UTIL_H_ #define TVM_TE_AUTODIFF_AD_UTIL_H_ +#include #include #include +#include #include #include #include @@ -48,6 +50,86 @@ std::pair, Map> CloneIterVars(const Array */ PrimExpr CloneReduction(const PrimExpr& expr); +/*! + * \brief Create a tensor from an expression. The expression may be a reduction, in which + * case its body will be correctly duplicated if it is a multi-valued reduction. + * + * \param expr The expr which will be the tensor's body. + * \param axis The input variables with ranges. + * \param name The tensor's name. + * \param tag The tensor's tag. + * \param attrs The tensor's attrs. + * \param clone_axis Whether to clone the given axis and perform substitution. + * \return A tensor. + */ +Tensor TensorFromExpr(const PrimExpr& expr, const Array& axis, + const std::string& name = "tensor", const std::string& tag = "", + const Map& attrs = {}, bool clone_axis = true); + +Tensor TransformTensorBody( + const Tensor& tensor, + const std::function&)>& func); + +Tensor TransformTensorBody(const Tensor& tensor, + const std::function& func); + +/*! + * \brief Inline tensors access recursively. + * + * This function will inline tensors recursively until it reaches a tensor which is impossible to + * inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is + * not from \p inlineable). It won't descend into non-inlinable tensors' bodies. + * + * \param tensor The tensor whose body to transform. + * \param inlineable A list of tensors which are allowed to be inlined. If empty, try + * to inline all tensors. + * \param inline_reductions Whether to inline reductions (this may result in top-level reduction + * nodes). + * + * \return An inlined tensor + */ +TVM_DLL Tensor InlineTensorAccess(const Tensor& tensor, + const Array& inlineable = Array(), + bool inline_reductions = false); + +/*! + * \brief Inline tensors access at the tail. + * \param tensor The tensor whose body to transform. + * \return An inlined tensor + */ +TVM_DLL Tensor InlineTailTensorAccess(const Tensor& tensor); + +/*! + * \brief Simplify an iteration domain. + * + * An iteration domain is basically an array of variables and a condition. The function will do the + * following: + * - Replace div and mod operations with new variables (optional). + * - Extract (in)equalities from the condition. + * - Perform Fourier-Motzkin elimination. + * - Shear the domain of iteration (e.g. if `y <= x <= y + 2` then x will be replaced with `y + d` + * where `d` is a new variable such that `0 <= d <= 2`). + * - Remove redundant variables. + * - Infer new variable ranges (hopefully more precise). + * + * \param iter_domains The original domain. + * \param eliminate_div_mod Whether to eliminate div and mod by introducing new variables. + */ +TVM_DLL arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains, + bool eliminate_div_mod = true); + +/*! + * \brief Perform lifting of conditions of being possible to be non-zero together with + * applying some transformations like simplifying the reduction domain. Works only with + * this particular tensor's body, i.e. doesn't perform inlining. + * + * \param tensor The original tensor; + * \param vranges Optional map from free variables to their value ranges. + * \return An optimized tensor. + */ +TVM_DLL Tensor RemoveJacobianAndLiftNonzeroCond(const Tensor& tensor, + const Map& vranges = Map()); + } // namespace te } // namespace tvm #endif // TVM_TE_AUTODIFF_AD_UTIL_H_ diff --git a/src/te/autodiff/adjoint.cc b/src/te/autodiff/adjoint.cc index 8b7c428ac8a4..d027b3913d39 100644 --- a/src/te/autodiff/adjoint.cc +++ b/src/te/autodiff/adjoint.cc @@ -39,6 +39,8 @@ #include #include +#include "ad_util.h" + namespace tvm { namespace te { @@ -63,6 +65,10 @@ Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Te Tensor jac = Jacobian(output, input); Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(), output->op->name + "." + input->op->name + ".grad"); + result = InlineTensorAccess(result, {jac}, false); + result = RemoveJacobianAndLiftNonzeroCond(result); + // inline tail call + result = InlineTailTensorAccess(result); return result; } diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index e2479d8f133e..e769e54e96be 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -356,7 +356,9 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { new_shape.push_back(e); } - return Tensor(new_shape, output->dtype, new_op, value_index); + Tensor ret = Tensor(new_shape, output->dtype, new_op, value_index); + ret = RemoveJacobianAndLiftNonzeroCond(ret); + return ret; } } // namespace te diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index d5b51cbf2236..5298added83b 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -134,6 +134,13 @@ class VarUseDefAnalysis : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } + PrimExpr VisitExpr_(const ReduceNode* op) final { + for (const auto& iv : op->axis) { + this->HandleDef(iv->var.get()); + } + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr VisitExpr_(const LoadNode* op) final { this->HandleUse(op->buffer_var); return StmtExprMutator::VisitExpr_(op); @@ -187,6 +194,13 @@ Array UndefinedVars(const Stmt& stmt, const Array& args) { return m.undefined_; } +Array UndefinedVars(const PrimExpr& expr) { + VarUseDefAnalysis m; + m.simplify_let_ = false; + m(expr); + return m.undefined_; +} + class HostDeviceSplitter : public StmtMutator { public: explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) diff --git a/tests/python/unittest/test_te_autodiff.py b/tests/python/unittest/test_te_autodiff.py index 54158745e1bd..25accde760ff 100644 --- a/tests/python/unittest/test_te_autodiff.py +++ b/tests/python/unittest/test_te_autodiff.py @@ -24,7 +24,7 @@ import numpy as np -def check_grad(out, inputs, data_range=(-10, 10), desired_grads=None): +def check_grad(out, inputs, args=[], data_range=(-10, 10), desired_grads=None, assert_no_jacobian=True): inputs = inputs if isinstance(inputs, list) else [inputs] def check_device(device, host="llvm"): @@ -36,26 +36,32 @@ def check_device(device, host="llvm"): return sout = te.create_schedule(out.op) - mout = tvm.build(sout, [out] + inputs) + mout = tvm.build(sout, [out] + inputs + args) out_shape = get_const_tuple(out.shape) l, h = data_range input_data = [tvm.nd.array( np.random.uniform(l, h, size=get_const_tuple(input.shape)).astype(input.dtype)) for input in inputs] + arg_vals = [tvm.nd.array( + np.random.uniform(l, h, size=get_const_tuple(arg.shape)).astype(arg.dtype)) + for arg in args] ones = topi.full_like(out, 1.0) # we provide head to sum and reduce the output dimension, # which equals to grad(out.sum(), inputs) grads = te.gradient(out, inputs, head=ones) grad_sched = te.create_schedule([grad.op for grad in grads]) - mgrad = tvm.build(grad_sched, list(grads) + inputs) - # print(tvm.lower(grad_sched, list(grads) + inputs, simple_mode=True)) + mgrad = tvm.build(grad_sched, list(grads) + inputs + args) + if assert_no_jacobian: + # TODO(yzhliu): it is better to visit the expression and do assertion + lowered_ir = str(tvm.lower(grad_sched, list(grads) + inputs + args, simple_mode=True)) + assert "jacobian" not in lowered_ir, lowered_ir grad_data = [tvm.nd.empty(get_const_tuple(i.shape), g.dtype) for i, g in zip(inputs, grads)] - mgrad(*grad_data, *input_data) + mgrad(*grad_data, *input_data, *arg_vals) g_res = [g.asnumpy() for g in grad_data] if desired_grads: @@ -67,7 +73,7 @@ def forward(*in_data): out_data = tvm.nd.empty(out_shape, out.dtype) mout(out_data, *[tvm.nd.array(d) for d in list(in_data)]) return out_data.asnumpy().sum() - check_numerical_grads(forward, [d.asnumpy() for d in input_data], g_res) + check_numerical_grads(forward, [d.asnumpy() for d in input_data + arg_vals], g_res) check_device("cpu") @@ -158,15 +164,168 @@ def fidentity(t0): check_grad(Y, X) -def test_conv2d(): - np.random.seed(0) +def test_topi(): X = te.placeholder((1, 2, 4, 4), name='X') W = te.placeholder((5, 2, 3, 3), name='W') + W1 = te.placeholder((2, 5, 3, 3), name='W1') + W2 = te.placeholder((1,), name='W2') R = topi.nn.conv2d(X, W, 1, 1, 1) check_grad(R, [X, W]) + R1 = topi.nn.conv2d(topi.nn.relu(R), W1, 1, 0, 1) + check_grad(R1, [X, W, W1]) + + R = topi.broadcast_to(W2, (5, 2, 3, 3)) + check_grad(R, [W2]) + + R = topi.nn.conv2d(X, topi.broadcast_to(W2, (5, 2, 3, 3)), 1, 1, 1) + check_grad(R, [X, W2]) + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'avg') + check_grad(R, X) + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') + check_grad(R, X) + + X = te.placeholder((1, 2, 5, 5), name='X') + R = topi.reshape(X, (1, 32)) + check_grad(R, [X]) + + X = te.placeholder((1, 2, 5, 5), name='X') + W = te.placeholder((2, 2, 3, 3), name='W') + + S = topi.reshape(X, (1, 50)) + check_grad(S, [X]) + + R = X + topi.nn.conv2d(X + topi.nn.conv2d(X, W, 1, 1, 1), W, 1, 1, 1) + check_grad(R, [X, W]) + + S = topi.nn.softmax(topi.reshape(R, (1, 50))) + check_grad(S, [X, W]) + + S = topi.sigmoid(topi.reshape(R, (1, 50))) + check_grad(S, [X, W]) + + S = topi.tanh(topi.reshape(R, (1, 50))) + check_grad(S, [X, W]) + + S = topi.nn.log_softmax(topi.reshape(R, (1, 50))) + check_grad(S, [X, W]) + check_grad(S, [W], [X]) + + X = te.placeholder((1, 2, 3, 5), name='X') + Y = te.placeholder((1, 2, 7, 5), name='Y') + S = topi.concatenate((X, Y), 2) + check_grad(S, [X, Y]) + + X = te.placeholder((1, 2, 6, 5), name='X') + (S, R) = topi.split(X, 2, 2) + check_grad(S, [X]) + check_grad(R, [X]) + R1 = topi.concatenate((S, R), 2) + check_grad(R1, [X]) + R2 = topi.concatenate((R, S), 2) + check_grad(R2, [X]) + + X = te.placeholder((4, 5), name='X') + I = te.placeholder((100,), name='I', dtype='int32') + R = topi.take(X, topi.abs(I)) + check_grad(R, [X], [I]) + + W = te.placeholder((5, 5), name='W') + exps = topi.exp(topi.nn.dense(X, W)) + sumexps = topi.sum(exps, axis=-1, keepdims=True) + R = exps/sumexps + check_grad(R, [X, W], data_range=(-1, 1)) + + +def test_stride_dilation(): + X = te.placeholder((1, 2, 10, 10), name='X') + W = te.placeholder((2, 2, 1, 1), name='W') + + Y = topi.nn.conv2d(X, W, 1, 0, 1) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 2, 0, 1) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 3, 0, 1) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 1, 0, 2) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 2, 0, 2) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 3, 0, 2) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 1, 0, 3) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 2, 0, 3) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 3, 0, 3) + check_grad(Y, [X, W]) + + W = te.placeholder((2, 2, 2, 2), name='W') + + Y = topi.nn.conv2d(X, W, 1, 0, 1) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 2, 0, 1) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 3, 0, 1) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 1, 0, 2) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 2, 0, 2) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 3, 0, 2) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 1, 0, 3) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 2, 0, 3) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 3, 0, 3) + check_grad(Y, [X, W]) + + W = te.placeholder((2, 2, 3, 3), name='W') + + Y = topi.nn.conv2d(X, W, 1, 0, 1) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 2, 0, 1) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 3, 0, 1) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 1, 0, 2) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 2, 0, 2) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 3, 0, 2) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 1, 0, 3) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 2, 0, 3) + check_grad(Y, [X, W]) + Y = topi.nn.conv2d(X, W, 3, 0, 3) + check_grad(Y, [X, W]) + + Y = topi.nn.pool(X, [1, 1], [1, 1], [0, 0, 0, 0], 'max') + check_grad(Y, [X]) + Y = topi.nn.pool(X, [1, 1], [2, 2], [0, 0, 0, 0], 'max') + check_grad(Y, [X]) + Y = topi.nn.pool(X, [1, 1], [3, 3], [0, 0, 0, 0], 'max') + check_grad(Y, [X]) + Y = topi.nn.pool(X, [2, 2], [1, 1], [0, 0, 0, 0], 'max') + check_grad(Y, [X]) + Y = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') + check_grad(Y, [X]) + Y = topi.nn.pool(X, [2, 2], [3, 3], [0, 0, 0, 0], 'max') + check_grad(Y, [X]) + Y = topi.nn.pool(X, [3, 3], [1, 1], [0, 0, 0, 0], 'max') + check_grad(Y, [X]) + Y = topi.nn.pool(X, [3, 3], [2, 2], [0, 0, 0, 0], 'max') + check_grad(Y, [X]) + Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], 'max') + check_grad(Y, [X]) + if __name__ == "__main__": test_basic_operation() - test_conv2d() + test_topi() + test_stride_dilation()