Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff #6078

Merged
merged 25 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7f6c5fc
[Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff
yzhliu Jul 16, 2020
5972033
fix lint
yzhliu Jul 17, 2020
018f67b
Merge remote-tracking branch 'upstream/master' into opt-autodiff
yzhliu Jul 17, 2020
2bd109c
fix clang-format
yzhliu Jul 17, 2020
004ae0b
add comments and magic number
yzhliu Jul 27, 2020
737cdf9
clang-lint
yzhliu Jul 27, 2020
a3e6b7e
Merge remote-tracking branch 'upstream/master' into opt-autodiff
yzhliu Jul 27, 2020
d158851
Merge remote-tracking branch 'upstream/master' into opt-autodiff
yzhliu Jul 27, 2020
f9f4c18
Merge remote-tracking branch 'upstream/master' into opt-autodiff
yzhliu Jul 28, 2020
5951a94
address some comments
yzhliu Aug 1, 2020
914de04
Merge remote-tracking branch 'upstream/master' into opt-autodiff
yzhliu Aug 1, 2020
69ea436
remove FreeVarsVisitor
yzhliu Aug 6, 2020
e866720
fix constexpr lint
yzhliu Aug 6, 2020
d7178b3
Merge remote-tracking branch 'upstream/master' into opt-autodiff
yzhliu Aug 6, 2020
6596d7d
fix lint
yzhliu Aug 6, 2020
e6c0b9b
fix lint
yzhliu Aug 6, 2020
79172d8
add Map.Merge
yzhliu Aug 7, 2020
bb19a32
Merge remote-tracking branch 'upstream/master' into opt-autodiff
yzhliu Aug 7, 2020
9731d0f
lint
yzhliu Aug 7, 2020
47a5852
change Array::Concat & Map::Merge to global functions
yzhliu Aug 12, 2020
fd3a4a5
Merge remote-tracking branch 'upstream/master' into opt-autodiff
yzhliu Aug 12, 2020
4ede7fb
fix lint
yzhliu Aug 12, 2020
357510b
move functions to global
yzhliu Aug 13, 2020
c80ef74
static -> inline
yzhliu Aug 13, 2020
e5746ed
Merge remote-tracking branch 'upstream/master' into opt-autodiff
yzhliu Aug 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/arith/int_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ using tir::IterVar;
using tir::Var;
using tir::VarNode;

// According to experiments two best simplifications orders were can->rw and rw->can->rw,
// but rw->can->rw is better for a couple of cases.
// Also we should end with rw because it factors multipliers out.
#define ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE 3
yzhliu marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Represent integer grouped bounds which are classified into
* lower bounds (inclusive), upper bounds (inclusive) and equalities.
Expand Down
24 changes: 15 additions & 9 deletions src/arith/solve_linear_inequality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,10 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t

// Simplify each inequality into the form `expr <= 0` and add to current formulas
for (const PrimExpr& ineq : system_to_solve->relations) {
AddInequality(&current_ineq_set_to_solve, NormalizeComparisons()(analyzer.Simplify(ineq, 3)),
&analyzer);
AddInequality(
&current_ineq_set_to_solve,
NormalizeComparisons()(analyzer.Simplify(ineq, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE)),
&analyzer);
}

Map<Var, IntGroupBounds> res_bounds;
Expand All @@ -278,8 +280,10 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
// Add bounds from vranges
if (system_to_solve->ranges.count(v)) {
const Range& range = system_to_solve->ranges[v];
PrimExpr range_lbound = analyzer.Simplify(range->min, 3);
PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1, 3);
PrimExpr range_lbound =
analyzer.Simplify(range->min, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1,
ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
coef_neg.push_back({-1, range_lbound});
coef_pos.push_back({1, -range_ubound});
}
Expand All @@ -300,7 +304,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
// we need rewrite_simplify -> canonical_simplify -> rewrite_simplify
// to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0
// with steps = 2 it's (y*2) - 10 <= 0
new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 3));
new_ineq = NormalizeComparisons()(
analyzer.Simplify(new_ineq, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE));
AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer);
}
}
Expand All @@ -325,7 +330,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t

for (const auto& pos : coef_pos) {
PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second;
bound = analyzer.Simplify(bound, 3);
bound = analyzer.Simplify(bound, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
// Don't add if any of the existing bounds is better
if (std::any_of(upper_bounds.begin(), upper_bounds.end(),
[&bound, &analyzer](const PrimExpr& o) {
Expand All @@ -346,7 +351,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
}
for (const auto& neg : coef_neg) {
PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second;
bound = analyzer.Simplify(bound, 3);
bound = analyzer.Simplify(bound, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
// Don't add if any of the existing bounds is better
if (std::any_of(lower_bounds.begin(), lower_bounds.end(),
[&bound, &analyzer](const PrimExpr& o) {
Expand Down Expand Up @@ -385,7 +390,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
// Everything that is left goes to res.relations
Array<PrimExpr> other_conditions;
for (const PrimExpr& e : current_ineq_set_to_solve) {
PrimExpr e_simp = analyzer.Simplify(e, 3);
PrimExpr e_simp = analyzer.Simplify(e, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
if (is_const_int(e_simp, 0)) {
// contradiction detected
other_conditions = {const_false()};
Expand Down Expand Up @@ -436,7 +441,8 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) {
// There is an equation of the form `v == expr`, so this variable can be completely removed.
// Note that we use the 0-th expression because they are ordered by complexity,
// so it must be the simplest one.
Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1, 3));
Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1,
ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE));
res_ranges.Set(var, best_range);
vranges.Set(var, best_range);
} else {
Expand Down
79 changes: 59 additions & 20 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,30 @@

/*!
* \file ad_simplify.cc
* \brief Simplify tensor compute generated by autodiff.
* \brief Simplify tensor compute generated by tensor-level autodiff.
*
* The major simplification we do in this file is to eliminate
* the Jacobian tensor created by autodiff.
*
* Jacobian tensor is sparse because one output element usually relates
* to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
* between input tensor and output tensor, thus the Jacobian is diagonal.
*
* Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
* \alpha and \beta are vectors represent the indices of In and Out respectively.
* i.e., the non-zero Jacobian indices is a linear combination of the input indices.
* Thereby we solve linear equations of \beta = A \alpha,
* as well as linear inequalities of their domain ranges.
*
* Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
* arXiv preprint arXiv:1711.01348, 2017. for more details.
*
* Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
* replace the compute expression with solved new axes, and create a selection node
* (non-zero-condition ? new_compute_expression : 0).
*
* Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
*
*/
#include <dmlc/optional.h>
#include <tvm/arith/analyzer.h>
Expand Down Expand Up @@ -129,11 +152,14 @@ bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges)
return false;
}

if (!is_const_value(analyzer.Simplify(combiner->identity_element[0], 3), 0)) {
if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
0)) {
return false;
}

PrimExpr combiner_result = analyzer.Simplify(combiner->result[0], 3);
PrimExpr combiner_result =
analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);

return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
Expand All @@ -143,14 +169,16 @@ bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
const Map<Var, Range>& vranges) {
arith::Analyzer analyzer;
analyzer.Bind(vranges);
if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index], 3), 0)) {
if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
0)) {
return false;
}

PrimExpr zero = make_zero(combiner->result[value_index].dtype());
PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
{combiner->rhs[value_index], zero}});
in = analyzer.Simplify(in, 3);
in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);

return is_const_value(in, 0);
}
Expand Down Expand Up @@ -215,18 +243,21 @@ class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const

// If the false part is zero, we can get rid of the select
if (is_const_value(nz_b.value, 0)) {
PrimExpr new_cond = analyzer_.Simplify(nz_a.cond && cond, 3);
PrimExpr new_cond =
analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
return {new_cond, nz_a.value};
}

// If the true part is zero, we can also get rid of the select
if (is_const_value(nz_a.value, 0)) {
PrimExpr new_cond = analyzer_.Simplify(nz_b.cond && !cond, 3);
PrimExpr new_cond =
analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
return {new_cond, nz_b.value};
}

// Otherwise we retain the select and combine the conditions into this
PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), 3);
PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
return {new_cond, GetRef<PrimExpr>(op)};
} else {
Expand All @@ -242,7 +273,8 @@ class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const

// We don't have as much freedom here as in the select case
// since the `if` must be preserved in any case
PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), 3);
PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
return {new_cond, GetRef<PrimExpr>(op)};
} else {
Expand Down Expand Up @@ -287,7 +319,8 @@ class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const
}
} else {
// Otherwise use Or
PrimExpr new_cond = analyzer_.Simplify(nz_a.cond || nz_b.cond, 3);
PrimExpr new_cond =
analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
// A little optimization: if the combined condition is the same as one of the inner
// conditions, we don't need to guard the inner value with a select, otherwise
// we create a select in the `to_expr` call.
Expand All @@ -305,7 +338,8 @@ class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const

// For multiplication and similar ops the result may be nonzero if
// both the arguments are nonzero, so we combine with And.
PrimExpr new_cond = analyzer_.Simplify(nz_a.cond && nz_b.cond, 3);
PrimExpr new_cond =
analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);

if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
return {new_cond, op};
Expand Down Expand Up @@ -800,7 +834,8 @@ PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& ou
// Perform simplification mainly to remove a possibly empty reduction.
arith::Analyzer analyzer;
return analyzer.Simplify(
Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index), 3);
Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
} else {
return expr;
}
Expand Down Expand Up @@ -880,13 +915,14 @@ class RemoveRedundantInequalitiesMutator : public ExprMutator {
public:
explicit RemoveRedundantInequalitiesMutator(Array<PrimExpr> known) {
for (const PrimExpr& cond : known) {
known_.push_back(analyzer_.Simplify(cond, 3));
known_.push_back(analyzer_.Simplify(cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE));
}
}

virtual PrimExpr VisitExpr_(const SelectNode* op) {
bool has_side_effect = (SideEffect(GetRef<PrimExpr>(op)) > CallEffectKind::kReadState);
PrimExpr new_cond = analyzer_.Simplify(VisitExpr(op->condition), 3);
PrimExpr new_cond =
analyzer_.Simplify(VisitExpr(op->condition), ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
if (is_one(new_cond) && !has_side_effect) {
return VisitExpr(op->true_value);
} else if (is_zero(new_cond) && !has_side_effect) {
Expand All @@ -905,7 +941,8 @@ class RemoveRedundantInequalitiesMutator : public ExprMutator {

virtual PrimExpr VisitExpr_(const CallNode* op) {
if (op->op.same_as(Op::Get("tir.if_then_else"))) {
PrimExpr new_cond = analyzer_.Simplify(VisitExpr(op->args[0]), 3);
PrimExpr new_cond =
analyzer_.Simplify(VisitExpr(op->args[0]), ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
if (is_one(new_cond)) {
return VisitExpr(op->args[1]);
} else if (is_zero(new_cond)) {
Expand Down Expand Up @@ -959,7 +996,7 @@ class RemoveRedundantInequalitiesMutator : public ExprMutator {

private:
PrimExpr MutateAtomic_(const PrimExpr& e) {
PrimExpr simplified = analyzer_.Simplify(e, 3);
PrimExpr simplified = analyzer_.Simplify(e, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
for (const PrimExpr& other : known_) {
if (ExprDeepEqual()(simplified, other)) {
return const_true();
Expand Down Expand Up @@ -988,7 +1025,8 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,

arith::Analyzer analyzer;
analyzer.Bind(res->dst->ranges);
PrimExpr new_expr = analyzer.Simplify(Substitute(expr, res->src_to_dst), 3);
PrimExpr new_expr = analyzer.Simplify(Substitute(expr, res->src_to_dst),
ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
// TODO(yzhliu): This is mostly done to simplify if_then_else
// which is not realized by the canonical simplifier
new_expr = RemoveRedundantInequalities(new_expr, res->dst->relations);
Expand Down Expand Up @@ -1130,7 +1168,8 @@ class ReductionAsTensorAccessMutator : public ExprMutator {
Array<IterVar> new_axis = new_axis_vmap_pair.first;
arith::Analyzer analyzer;
analyzer.Bind(IterVarsToMap(new_axis));
new_reduce = analyzer.Simplify(Substitute(new_reduce, new_axis_vmap_pair.second), 3);
new_reduce = analyzer.Simplify(Substitute(new_reduce, new_axis_vmap_pair.second),
ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);

Tensor tensor = TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_);

Expand Down Expand Up @@ -1181,7 +1220,7 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
analyzer.Bind(combined_vranges);

// Simplify the original expression first, mostly to simplify combiners
PrimExpr expr = analyzer.Simplify(expr_orig, 3);
PrimExpr expr = analyzer.Simplify(expr_orig, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);

if (const ReduceNode* red = expr.as<ReduceNode>()) {
// TODO(sgrechanik-h): There are some other operations which behave like sum
Expand Down Expand Up @@ -1252,7 +1291,7 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
// Sometimes TrySimplifyCompute doesn't perform lift / extraction,
// so there may be some non-top reductions left, take care of them.
result = LiftReductions(result, IterVarsToVars(axis), combined_vranges);
return analyzer.Simplify(result, 3);
return analyzer.Simplify(result, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
}

Tensor RemoveJacobianAndLiftNonzeroCond(const Tensor& tensor, const Map<Var, Range>& vranges) {
Expand Down