diff --git a/docs/api/python/dev.rst b/docs/api/python/dev.rst index 3a1804f37d1e..f2ee86faaaa4 100644 --- a/docs/api/python/dev.rst +++ b/docs/api/python/dev.rst @@ -70,3 +70,8 @@ tvm.make ~~~~~~~~ .. automodule:: tvm.make :members: + +tvm.testing +~~~~~~~~~~~ +.. automodule:: tvm.testing + :members: diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index ddad9d10f8f9..b8361f2aaed9 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -15,6 +15,7 @@ Python API container function autotvm + autodiff graph_runtime rpc bridge diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 856bad198e88..b6e6df43f498 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -67,6 +67,7 @@ List of operators topi.not_equal topi.greater_equal topi.less_equal + topi.tensordot topi.image.resize @@ -123,6 +124,7 @@ topi .. autofunction:: topi.power .. autofunction:: topi.greater .. autofunction:: topi.less +.. autofunction:: topi.tensordot topi.nn ~~~~~~~ diff --git a/include/tvm/autodiff.h b/include/tvm/autodiff.h new file mode 100644 index 000000000000..d9f47009af2a --- /dev/null +++ b/include/tvm/autodiff.h @@ -0,0 +1,150 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file autodiff.h + * \brief Automatic differentiation of tensor expressions. + */ +#ifndef TVM_AUTODIFF_H_ +#define TVM_AUTODIFF_H_ + +#include +#include + +namespace tvm { +namespace ir { + +class DifferentiationResult; + +/*! \brief Node to represent a differentiation result */ +class DifferentiationResultNode : public Node { + public: + /*! \brief The requested adjoints, i.e. Jacobians or gradients wrt to the given inputs */ + Array result; + /*! \brief A map from tensors to the corresponding adjoints (including internal nodes) */ + Map adjoints; + /*! \brief Single summands of the adjoints*/ + Map> adjoint_summands; + /*! \brief constructor */ + DifferentiationResultNode() {} + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("result", &result); + v->Visit("adjoints", &adjoints); + v->Visit("adjoint_summands", &adjoint_summands); + } + TVM_DLL static DifferentiationResult make(Array result, + Map adjoints, + Map> adjoint_summands); + + static constexpr const char* _type_key = "DifferentiationResult"; + TVM_DECLARE_NODE_TYPE_INFO(DifferentiationResultNode, Node); +}; + +/*! + * \brief A result of differentiation. + */ +class DifferentiationResult : public NodeRef { + public: + /*! \brief default constructor, used internally */ + DifferentiationResult() {} + explicit DifferentiationResult(NodePtr n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const DifferentiationResultNode* operator->() const { + return static_cast(node_.get()); + } + /*! \brief specify container node */ + using ContainerType = DifferentiationResultNode; +}; + + +/*! \brief A type of a "local" differentiation function for reverse mode AD + * + * A function of this type is a building block for reverse-mode automatic differentiation. It + * should take three tensors: `output`, `input` and `head`, `head` being the adjoint corresponding + * to the `output`, and return (a summand of) the adjoint corresponding to the input. In other + * words, it should differentiate `output` wrt `input` and multiply the result by `head` with + * tensor dot product (`head` should be on the left of the multiplication). `input` should be an + * immediate dependency of `output` (should be called from within the body of `output`). + * + * See also ::DiffBuildingBlock, which might be considered the reference implementation. + */ +using FDiffBuildingBlock = std::function; + +/*! + * \brief Take the derivative of the expression with respect to the given variable. + * \param expr The expression to differentiate. + * \param var The variable to differentiate with respect to. + * \return The expression for the derivative. + */ +EXPORT Expr Derivative(const Expr& expr, const VarExpr& var); + +/*! + * \brief Get the tensor representing the Jacobian of the output with respect to the input. + * + * Note that if \p output depends on \p input indirectly (by using some other tensor + * depending on \p input), this dependency won't contribute to the resulting Jacobian. + * For such cases use the function ::Differentiate. + * + * \param output The tensor to differentiate. + * \param input The input tensor, which \p output should directly use. + * \param optimize Whether to perform optimizations like lifting of nonzeroness conditions. + * \return The tensor representing the Jacobian of shape `output.shape + input.shape`. + */ +EXPORT Tensor Jacobian(const Tensor& output, const Tensor& input, bool optimize = true); + +/*! + * \brief The building block for reverse-mode AD. + * + * Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor + * dot product. \p input must be an immediate dependency of \p output (must be called from within + * the body of \p output). That is, the function will compute a summand of the adjoint for \p input + * given the adjoint for \p output (which is called \p head here). + * + * \param output The tensor to differentiate. + * \param input The input tensor, which \p output should directly use. + * \param head The adjoint of \p output. Must be of shape `prefix + output.shape` + * \return The tensor representing the adjoint of \p input of shape `prefix + input.shape`. + */ +EXPORT Tensor DiffBuildingBlock(const Tensor& output, const Tensor& input, const Tensor& head); + +/*! + * \brief Perform reverse mode automatic differentiation. + * + * Each item of the `result` field of the result is an adjoint for the corresponding item of + * \p inputs, i.e. \p head multiplied by the Jacobian of \p output with respect to the + * corresponding item of \p inputs. + * + * \param output The tensor to differentiate. + * \param inputs The array of input tensors. When the array is empty, will perform differentiation + * wrt all tensors the output depends on. + * \param head The adjoint of the output, in other words, some tensor, by which the Jacobians + * will be multiplied. Its shape must be of the form `prefix + output.shape`. If the + * null pointer is provided, the identity tensor of shape + * `output.shape + output.shape` will be used. + * \param fdiff The function performing differentiation and multiplication, see + * ::FDiffBuildingBlock. + * \param override_deps A map from tensors to their dependencies (`InputTensors()` are used by + * default). Overriding dependencies may be useful to treat a group of tensors + * as a single supertensor. In this case the fdiff functions should also be + * modified accordingly. + * \return An object of type DifferentiationResult which contains three fields: + * - `result` An array of adjoints corresponding to \p inputs. + * - `adjoints` A map from tensors to the corresponding adjoints (includes intermediate + * tensors). + * - `adjoint_summands` A map from tensors to maps from parent tensors to individual + * summands of the adjoint. + */ +EXPORT DifferentiationResult Differentiate( + const Tensor& output, + const Array& inputs = Array(), + const Tensor& head = Tensor(), + const FDiffBuildingBlock& fdiff = DiffBuildingBlock, + const Map>& override_deps = Map>()); + +} // namespace ir +} // namespace tvm +#endif // TVM_AUTODIFF_H_ diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index f2c9c3d517a5..ded99fb52206 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -85,6 +85,16 @@ inline const uint64_t* as_const_uint(const Expr& x) { */ inline bool is_const_int(const Expr& x, int64_t value); +/*! + * \brief Check if the given expr is a const of any type equal to the given integer value. + * \param e The expression. + * \param value The value to compare to. + * \return Whether the expression is a const equal to the value. + * \tparam ValueType The value type + */ +template +inline bool is_const_value(const Expr& e, ValueType value); + /*! * \brief Check whether stmt is nop. * \param stmt The input statement @@ -515,6 +525,26 @@ inline bool is_const_int(const Expr& x, int64_t value) { return false; } +template +inline bool is_const_value(const Expr& e, ValueType value) { + static_assert(std::is_integral::value, + "Comparison to non-integer values is forbidden."); + // This implementation was copy-pasted from HalideIR + if (const ir::IntImm* i = e.as()) { + return i->value == value; + } else if (const ir::UIntImm* i = e.as()) { + return (value >= 0) && (i->value == (uint64_t)value); + } else if (const ir::FloatImm* i = e.as()) { + return i->value == value; + } else if (const ir::Cast* c = e.as()) { + return is_const_value(c->value, value); + } else if (const ir::Broadcast* b = e.as()) { + return is_const_value(b->value, value); + } else { + return false; + } +} + inline bool is_no_op(const Stmt& stmt) { if (!stmt.defined()) return true; if (const auto* op = stmt.as()) { diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 9c09dc5a4ac3..749b6b3dbb07 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -19,6 +19,7 @@ from . import generic from . import hybrid from . import testing +from . import autodiff from . import ndarray as nd from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl @@ -36,6 +37,7 @@ from .schedule import create_schedule from .build_module import build, lower, build_config from .tag import tag_scope +from .autodiff import differentiate # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel diff --git a/python/tvm/autodiff.py b/python/tvm/autodiff.py new file mode 100644 index 000000000000..1d496185eb76 --- /dev/null +++ b/python/tvm/autodiff.py @@ -0,0 +1,162 @@ +"""Automatic differentiation of tensor expressions.""" +from ._ffi.function import _init_api +from ._ffi.node import NodeBase, register_node + +_init_api("tvm.autodiff") + +@register_node +class DifferentiationResult(NodeBase): + """Result of differentiation. + + Parameters + ---------- + result : List[Tensor] + The requested adjoints, i.e. the jacobians or gradients of the given output + wrt to the given inputs. + + adjoints : Dict[Tensor, Tensor] + A map from tensors to the corresponding adjoints (including internal nodes). + + adjoint_summands : Dict[Tensor, Dict[Tensor, Tensor]] + Single summands of the adjoints. + """ + + # Here we convert tvm Maps to dicts because Map compares keys by reference which is + # wrong for Tensors. Hopefully, in the future Map gets fixed somehow, and these properties + # may be removed then. + + @property + def adjoints(self): + res = NodeBase.__getattr__(self, 'adjoints') + return dict(res.items()) + + @property + def adjoint_summands(self): + res = NodeBase.__getattr__(self, 'adjoint_summands') + return {k: dict(v.items()) for k, v in res.items()} + + def _check_not_empty(self): + if not self.result: + raise ValueError("The result of differentiation does not contain any explicitly " + "requested results, so using it as an iterable is probably a mistake. " + "Please explicitly use res.adjoints to get adjoints or res.result to " + "get the empty list.") + + def __getitem__(self, i): + self._check_not_empty() + return self.result[i] + + def __len__(self): + self._check_not_empty() + return len(self.result) + + +def differentiate(output, inputs=None, head=None, override=None, fdiff=None): + """Perform reverse-mode automatic differentiation. + + Parameters + ---------- + output : Tensor + The tensor to differentiate. + + inputs : List[Tensor] + The list of input tensors. When the list is empty or None, will perform + differentiation wrt all tensors the output depends on (i.e. will compute all + adjoints and populate the corresponding dict, but the list of results + will be empty). + + head : Tensor + The adjoint of the output, in other words, some tensor, by which the Jacobians + will be multiplied. Its shape must be of the form `prefix + output.shape`. + If `None` is passed, the identity tensor of shape `output.shape + output.shape` + will be used. + + override : Dict[Tensor, (List[Tensor], Callable[[Tensor, List[Tensor], Tensor], List[Tensor]])] + Override differentiation for certain tensors. This dict maps tensors `t` to pairs + `(dependencies, custom_diff)` where `dependencies` is a list of tensors which are considered + to be inputs of `t` (which may differ from the immediate inputs), and `custom_diff` is a + custom differentiation function which will be called as `custom_diff(t, dependencies, + adjoint)` and should return a list of adjoints corresponding to dependencies. Note that this + function differs from the one required for `fdiff` in that it takes a list of inputs instead + of a single input and returns a list of adjoints instead of a single adjoint. + + fdiff : Callable[[Tensor, Tensor, Tensor], Tensor] + The default function performing differentiation and multiplication, by default + `tvm.autodiff.DiffBuildingBlock` is used. The function must accept three + parameters: + - `output` - an output tensor + - `input` - an input tensor + - `head` - the adjoint of the output tensor + The result should be `head` multiplied by the jacobian of `output` wrt `input` + + Returns + ------- + differentiation_result : DifferentiationResult + + Example + ------- + .. code-block:: python + + x = tvm.placeholder((32, 3, 28, 28), name='x') + w1 = tvm.placeholder((10, 3, 3, 3), name='w1') + w2 = tvm.placeholder((10, 10, 3, 3), name='w2') + z1 = topi.nn.conv2d(x, w1, 1, 1, 1) + z2 = topi.nn.conv2d(z1, w2, 1, 1, 1) + y = topi.sum(z2) + + # produce gradients + [dw1, dw2] = tvm.differentiate(y, [w1, w2]) + + # produce Jacobians + [jw1, jw2] = tvm.differentiate(z2, [w1, w2]) + + # produce gradients, the head adjoint for z2 is provided manually + [dw1, dw2] = tvm.differentiate(z2, [w1, w2], topi.full_like(z2, 1.0)) + + # produce gradients wrt all inputs + res = tvm.differentiate(y) + dw1 = res.adjoints[w1] + dw2 = res.adjoints[w2] + + # a custom differentiation function + def my_fdiff(out, inp, head): + # this is the naive version, without any optimizations + return topi.tensordot(head, tvm.autodiff.Jacobian(out, inp, False), len(out.shape)) + + # using a custom differentiation function for everything + [dw1, dw2] = tvm.differentiate(y, [w1, w2], fdiff=my_fdiff) + + # accessing individual summands of the adjoint + y = z1 + z2 + res = tvm.differentiate(y, [w1, w2]) + [s1, s2] = res.adjoint_summands[z1].values() + + # a generalization of my_fdiff which works for non-immediate dependencies + # this is necessary because z1 is not an immediate dep of z2 because of padding + def my_diff(out, inputs, head): + return tvm.differentiate(out, inputs, head, fdiff=my_fdiff) + + # using a custom differentiation function only for z2 + res = tvm.differentiate(y, [w1, w2], override={z2: ([z1, w2], my_diff)}) + """ + if inputs is None: + inputs = [] + + if fdiff is None: + fdiff = DiffBuildingBlock + + if override is not None: + # pylint: disable=dangerous-default-value + def _modified_fdiff(out, inp, head, override=override, old_fdiff=fdiff, cache={}): + if out in override: + if (out, head) not in cache: + cache[(out, head)] = override[out][1](out, override[out][0], head) + idx = override[out][0].index(inp) + return cache[(out, head)][idx] + return old_fdiff(out, inp, head) + + fdiff = _modified_fdiff + + override_deps = {t: deps for t, (deps, _) in override.items()} + return Differentiate(output, inputs, head, fdiff, override_deps) + return Differentiate(output, inputs, head, fdiff) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 1a6666bdee2a..bca939881ea9 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -1,6 +1,7 @@ """ TVM testing utilities """ import logging import numpy as np +import tvm def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): """ Version of np.testing.assert_allclose with `atol` and `rtol` fields set @@ -14,7 +15,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): def check_numerical_grads(function, input_values, grad_values, function_value=None, - delta=1e-3, atol=1e-2, rtol=0.1): + delta=1e-3, atol=1e-2, rtol=0.1, acceptable_fail_fraction=None): """A helper function that checks that numerical gradients of a function are equal to gradients computed in some different way (analytical gradients). @@ -50,6 +51,10 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No rtol : float, optional Relative tolerance. + + acceptable_fail_fraction : float, optional + If not None, raise an error only when the fraction of wrong elements for a gradient is + higher than this value. """ # If input_values is a list then function accepts positional arguments # In this case transform it to a function taking kwargs of the form {"0": ..., "1": ...} @@ -93,7 +98,7 @@ def compare_derivative(j, n_der, grad): wrong_positions = [] # compute partial derivatives for each position in this variable - for j in range(np.prod(grad.shape)): + for j in range(int(np.prod(grad.shape))): # forward difference approximation nder = derivative(x_name, j, delta) @@ -116,7 +121,7 @@ def compare_derivative(j, n_der, grad): ngrad.reshape(-1)[j] = nder - wrong_percentage = int(100*len(wrong_positions)/np.prod(grad.shape)) + wrong_fraction = len(wrong_positions)/np.prod(grad.shape) dist = np.sqrt(np.sum((ngrad - grad)**2)) grad_norm = np.sqrt(np.sum(ngrad**2)) @@ -131,17 +136,208 @@ def compare_derivative(j, n_der, grad): sqrt_n = np.sqrt(float(np.prod(grad.shape))) if dist > atol*sqrt_n + rtol*grad_norm: - raise AssertionError( - "Analytical and numerical grads wrt '{}' differ too much\n" - "analytical grad = {}\n numerical grad = {}\n" - "{}% of elements differ, first 10 of wrong positions: {}\n" - "distance > atol*sqrt(n) + rtol*grad_norm\n" - "distance {} > {}*{} + {}*{}" - .format(x_name, grad, ngrad, wrong_percentage, wrong_positions[:10], - dist, atol, sqrt_n, rtol, grad_norm)) + enough_failures = (acceptable_fail_fraction is None or + wrong_fraction > acceptable_fail_fraction) + if enough_failures: + raise AssertionError( + "Analytical and numerical grads wrt '{}' differ too much\n" + "analytical grad = {}\n numerical grad = {}\n" + "{}% of elements differ, first 10 of wrong positions: {}\n" + "distance > atol*sqrt(n) + rtol*grad_norm\n" + "distance {} > {}*{} + {}*{}" + .format(x_name, grad, ngrad, int(100*wrong_fraction), wrong_positions[:10], + dist, atol, sqrt_n, rtol, grad_norm)) + else: + logging.warning("Analytical and numerical grads wrt '%s' differ, however " + "there were not enough wrong elements to raise an error " + "(only %d%%)", + x_name, int(100*wrong_fraction)) max_diff = np.max(np.abs(ngrad - grad)) avg_diff = np.mean(np.abs(ngrad - grad)) logging.info("Numerical grad test wrt '%s' of shape %s passes, " "dist = %f, max_diff = %f, avg_diff = %f", x_name, grad.shape, dist, max_diff, avg_diff) + + +class PerformanceEstimate: + """A result of static performance estimation. + + Parameters + ---------- + iterations : int + The total number of iterations of all the loops. + + multiplications : int + The total number of expensive operations like multiplications. + + memory : int + The amount of memory to allocate. + """ + def __init__(self, iterations=0, multiplications=0, memory=0): + self.iterations = iterations + self.multiplications = multiplications + self.memory = memory + + def as_tuple(self): + return (self.iterations, self.multiplications, self.memory) + + def __add__(self, other): + return PerformanceEstimate(iterations=self.iterations + other.iterations, + multiplications=self.multiplications + other.multiplications, + memory=self.memory + other.memory) + + def max(self, other): + return PerformanceEstimate( + iterations=max(self.iterations, other.iterations), + multiplications=max(self.multiplications, other.multiplications), + memory=max(self.memory, other.memory)) + + def times(self, iters): + return PerformanceEstimate(iterations=self.iterations*iters, + multiplications=self.multiplications*iters, + memory=self.memory) + + def __repr__(self): + return "PerformanceEstimate(iterations={}, multiplications={}, memory={})".format( + self.iterations, self.multiplications, self.memory) + + def __le__(self, other): + return \ + self.iterations <= other.iterations and \ + self.multiplications <= other.multiplications and \ + self.memory <= other.memory + + +def estimate_performance(s, param_values=None, processed_ops=None): + """Statically estimate performance of statements, expressions and tensors. Note that the + estimate is very rough, it mustn't be used to predict future performance, its only purpose is + to detect possible performance regressions. + + Parameters + ---------- + s + A statement, an expression, a tensor, an operation, or a list + of any of the above. + + param_values : Dict[tvm.expr.Var, int], optional + Values for parameters (free variables). + + Returns + ------- + estimate : PerformanceEstimate + """ + from tvm import stmt + from tvm import expr + + if param_values is None: + param_values = {} + + if processed_ops is None: + processed_ops = {} + res = estimate_performance(s, param_values=param_values, processed_ops=processed_ops) + for op_est in processed_ops.values(): + res += op_est + return res + + def est(expression, param_values=param_values, processed_ops=processed_ops): + return estimate_performance(expression, + param_values=param_values, + processed_ops=processed_ops) + + def _eval(expression, param_values=param_values): + return tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expression, param_values)).value + + def _prod(elems): + res = 1 + for x in elems: + res *= x + return res + + if s is None or isinstance(s, (stmt.AssertStmt, stmt.Free, stmt.Prefetch, + expr.ConstExpr, expr.Var, tvm.tensor.PlaceholderOp)): + return PerformanceEstimate() + elif isinstance(s, list): + res = PerformanceEstimate() + for item in s: + res += est(item) + return res + elif s in processed_ops: + return PerformanceEstimate() + elif isinstance(s, stmt.Allocate): + mem = _prod([_eval(e) for e in s.extents]) + return est(s.condition) + est(s.body) + PerformanceEstimate(memory=mem) + elif isinstance(s, stmt.Block): + return est(s.first) + est(s.rest) + elif isinstance(s, stmt.Evaluate): + return est(s.value) + elif isinstance(s, stmt.For): + body_est = est(s.body) + body_est.iterations = max(1, body_est.iterations) + return body_est.times(_eval(s.extent)) + elif isinstance(s, stmt.IfThenElse): + return est(s.condition) + est(s.then_case) + est(s.else_case) + elif isinstance(s, stmt.LetStmt): + return est(s.value) + est(s.body) + elif isinstance(s, (stmt.ProducerConsumer, stmt.AttrStmt)): + return est(s.body) + elif isinstance(s, stmt.Provide): + return est(s.value) + elif isinstance(s, stmt.Realize): + return est(s.condition) + est(s.body) + elif isinstance(s, stmt.Store): + return est(s.value) + est(s.index) + est(s.predicate) + elif isinstance(s, (expr.Mul, expr.Div, expr.Mod)): + return est(s.a) + est(s.b) + PerformanceEstimate(multiplications=1) + elif isinstance(s, (expr.BinaryOpExpr, expr.CmpExpr, expr.LogicalExpr)): + if not hasattr(s, 'b'): + return est(s.a) + return est(s.a) + est(s.b) + elif isinstance(s, expr.Call): + res = PerformanceEstimate() + for a in s.args: + res += est(a) + if s.call_type == expr.Call.Halide: + # The estimate is added to processed_ops, we don't need the result here + est(s.func) + elif s.name == "tvm_if_then_else": + pass + else: + # expr.If it is a non-halide call (e.g. exp or log), consider it a mul + res += PerformanceEstimate(multiplications=1) + return res + elif isinstance(s, expr.Cast): + return est(s.value) + elif isinstance(s, expr.Load): + return est(s.index) + est(s.predicate) + elif isinstance(s, expr.Select): + return est(s.condition) + est(s.true_value) + est(s.false_value) + elif isinstance(s, expr.Reduce): + iterations = _prod([_eval(iv.dom.extent) for iv in s.axis]) + res = PerformanceEstimate() + for id_elem in s.combiner.identity_element: + res += est(id_elem) + on_each_iter = est(s.condition) + for src in s.source: + on_each_iter += est(src) + for comb_res in s.combiner.result: + on_each_iter += est(comb_res) + on_each_iter.iterations = max(1, on_each_iter.iterations) + return res + on_each_iter.times(iterations) + elif isinstance(s, tvm.tensor.Tensor): + return est(s.op) + elif isinstance(s, tvm.tensor.ComputeOp): + iterations = _prod([_eval(iv.dom.extent) for iv in s.axis]) + if s.reduce_axis: + res = est(s.body[0]) + else: + res = PerformanceEstimate() + for b in s.body: + res += est(b) + res.iterations = max(1, res.iterations) + res = res.times(iterations) + PerformanceEstimate(memory=iterations*len(s.body)) + processed_ops[s] = res + return PerformanceEstimate() + + raise ValueError("Don't know how to estimate performance of {} of type {}" + .format(s, type(s))) diff --git a/src/op/op_util.cc b/src/op/op_util.cc index b18552d5c562..f80f5f1eaabb 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -206,6 +206,34 @@ Expr ReplaceTensor(Expr expr, } +void ReplaceTensorRecursivelyImpl(Tensor tensor, + std::unordered_map* replace) { + if (!replace->count(tensor)) { + for (const Tensor& subtensor : tensor->op->InputTensors()) { + ReplaceTensorRecursivelyImpl(subtensor, replace); + } + Operation new_op = tensor->op->ReplaceInputs(tensor->op, *replace); + if (new_op.same_as(tensor->op)) { + (*replace)[tensor] = tensor; + } else { + (*replace)[tensor] = + TensorNode::make(tensor->shape, tensor->dtype, new_op, tensor->value_index); + } + } +} + +Array ReplaceTensorRecursively(Array tensors, + const std::unordered_map& replace) { + auto new_replace = replace; + Array res; + for (const Tensor& t : tensors) { + ReplaceTensorRecursivelyImpl(t, &new_replace); + res.push_back(new_replace[t]); + } + return res; +} + + Stmt Substitute(Stmt s, const std::unordered_map& value_map) { std::unordered_map init; @@ -245,5 +273,48 @@ ir::ForType IterVarTypeToForType(IterVarType iter_type) { } } +Tensor TensorFromExpr(const Expr& expr, const Array& axis, + const std::string& name, const std::string& tag, + const Map& attrs) { + Array new_bodies; + int new_value_index = 0; + + // If this is a reduction then we have to clone its body + if (const Reduce* red = expr.as()) { + new_value_index = red->value_index; + + for (size_t i = 0; i < red->source.size(); ++i) { + Expr ith_red = Reduce::make(red->combiner, red->source, red->axis, red->condition, i); + new_bodies.push_back(ith_red); + } + } else { + new_value_index = 0; + new_bodies.push_back(expr); + } + + return ComputeOpNode::make(name, tag, attrs, axis, new_bodies).output(new_value_index); +} + +Tensor TransformBody(const Tensor& tensor, + std::function&)> func) { + if (const ComputeOpNode* op = tensor->op.as()) { + // Transform only one body + Expr new_body = func(op->body[tensor->value_index], op->axis); + + // If the body didn't change then we can return the same tensor + if (new_body.same_as(op->body[tensor->value_index])) { + return tensor; + } + + return TensorFromExpr(new_body, op->axis, op->name, op->tag, op->attrs); + } else { + return tensor; + } +} + +Tensor TransformBody(const Tensor& tensor, std::function func) { + return TransformBody(tensor, [func](const Expr& e, const Array&) { return func(e); }); +} + } // namespace op } // namespace tvm diff --git a/src/op/op_util.h b/src/op/op_util.h index de2e44c2ed59..f8cebe229112 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "../pass/ir_util.h" #include "../pass/arg_binder.h" #include "../schedule/message_passing.h" @@ -63,6 +64,15 @@ Stmt ReplaceTensor(Stmt stmt, Expr ReplaceTensor(Expr expr, const std::unordered_map& replace); +/*! + * \brief Replace tensor references in the given tensors recursively (not only in their bodies + * but also in the bodies of its dependencies). + * \param tensors The tensors to be processed. + * \param replace The replacement rule. + */ +Array ReplaceTensorRecursively(Array tensors, + const std::unordered_map& replace); + /*! * \brief Substitute the variables of stmt by value map. * \param stmt the statment @@ -84,6 +94,45 @@ IterVarType ForTypeToIterVarType(ir::ForType for_type); */ ir::ForType IterVarTypeToForType(IterVarType iter_type); +/*! + * \brief Create a tensor from an expression. The expression may be a reduction, in which + * case its body will be correctly duplicated if it is a multi-valued reduction. + * + * \param expr The expr which will be the tensor's body. + * \param axis The input variables with ranges. + * \param name The tensor's name. + * \param tag The tensor's tag. + * \param attrs The tensor's attrs. + * \return A tensor. + */ +Tensor TensorFromExpr(const Expr& expr, const Array& axis, + const std::string& name = "tensor", const std::string& tag = "", + const Map& attrs = {}); + +/*! + * \brief Transform the body of a tensor if it is a compute tensor, otherwise return it + * unchanged. Note that if the compute returns a tuple, it transforms only one element, + * other elements are discarded. + * + * \param tensor The tensor to transform. + * \param func The transformation function working on expressions and additionally taking + * the array of the tensor's itervars. + * \return The transformed tensor. + */ +Tensor TransformBody(const Tensor& tensor, + std::function&)> func); + +/*! + * \brief Transform the body of a tensor if it is a compute tensor, otherwise return it + * unchanged. Note that if the compute returns a tuple, it transforms only one element, + * other elements are discarded. + * + * \param tensor The tensor to transform. + * \param func The transformation function (working on expressions). + * \return The transformed tensor. + */ +Tensor TransformBody(const Tensor& tensor, std::function func); + } // namespace op } // namespace tvm #endif // TVM_OP_OP_UTIL_H_ diff --git a/src/pass/autodiff.cc b/src/pass/autodiff.cc new file mode 100644 index 000000000000..2ee806dfb7db --- /dev/null +++ b/src/pass/autodiff.cc @@ -0,0 +1,539 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file autodiff.cc + * \brief Automatic differentiation of tensor expressions + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/op_util.h" +#include "zero_elimination.h" + +namespace tvm { +namespace ir { + + +DifferentiationResult DifferentiationResultNode::make(Array result, + Map adjoints, + Map> summands) { + auto n = make_node(); + n->result = std::move(result); + n->adjoints = adjoints; + n->adjoint_summands = summands; + return DifferentiationResult(n); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const DifferentiationResultNode* r, IRPrinter* p) { + p->stream << "DifferentiationResult(result=" << r->result + << ", adjoints=" << r->adjoints + << ", adjoint_summands=" << r->adjoint_summands << ')'; + }); + +TVM_REGISTER_NODE_TYPE(DifferentiationResultNode); + + +#define NOT_IMPLEMENTED \ + { LOG(FATAL) << "Derivative of this expr is not implemented: " << e; throw; } + +/*! \brief Differentiate an expression wrt a variable or a tensor element */ +class JacobianMutator : public IRMutator { + public: + /*! + * \brief Differentiate wrt `input(indices)`. + * \param input The input tensor. + * \param indices The indices of the element with respect to which to differentiate. + */ + explicit JacobianMutator(Tensor input, Array indices) + : input_(input), indices_(indices) {} + /*! + * \brief Differentiate wrt the input variable. + * \param input The input variable. + */ + explicit JacobianMutator(VarExpr input) + : input_var_(input) {} + + virtual Expr Mutate(Expr e) { + if (e.type().is_int() || e.type().is_uint()) { + // Assume that the derivative of any integer expression is always 0 + return make_zero(e.type()); + } else { + return IRMutator::Mutate(e); + } + } + + Expr Mutate_(const Variable* op, const Expr& e) { + if (input_var_.operator->() && input_var_.get() == op && op->type.is_float()) { + return FloatImm::make(op->type, 1.0); + } else { + return make_zero(op->type); + } + } + + Expr Mutate_(const Load* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const Let* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const Call* op, const Expr& e) { + if (op->call_type == Call::CallType::Halide) { + if (input_.operator->() && op->func.same_as(input_->op) && + op->value_index == input_->value_index) { + Expr condition = const_true(); + for (size_t i = 0; i < input_.ndim(); ++i) { + condition = And::make(condition, EQ::make(indices_[i], op->args[i])); + } + return Cast::make(op->type, condition); + } else { + return make_zero(op->type); + } + } else if (op->call_type == Call::CallType::PureIntrinsic) { + static std::unordered_set piecewise_const = {"floor", "ceil", "trunc", "round"}; + if (op->name == "exp") { + return Mul::make(Mutate(op->args[0]), e); + } else if (op->name == "log") { + return Div::make(Mutate(op->args[0]), op->args[0]); + } else if (op->name == "sigmoid") { + return Mul::make(Mutate(op->args[0]), + Mul::make(e, Sub::make(FloatImm::make(e.type(), 1.0), e))); + } else if (op->name == "sqrt") { + return Div::make(Mutate(op->args[0]), Mul::make(e, FloatImm::make(e.type(), 2.0))); + } else if (op->name == "tanh") { + return Mul::make(Mutate(op->args[0]), + Sub::make(FloatImm::make(e.type(), 1.0), Mul::make(e, e))); + } else if (op->name == "pow") { + auto x = op->args[0], y = op->args[1]; + return e * (Mutate(y)*log(x) + Mutate(x)*y/x); + } else if (op->name == "fabs") { + auto type = op->args[0].type(); + return Mul::make(Mutate(op->args[0]), + Select::make(GE::make(op->args[0], make_zero(type)), + FloatImm::make(type, 1.0), FloatImm::make(type, -1.0))); + } else if (op->name == intrinsic::tvm_if_then_else) { + Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; + return Call::make(op->type, op->name, new_args, op->call_type, op->func, op->value_index); + } else if (piecewise_const.count(op->name)) { + return FloatImm::make(e.type(), 0.0); + } else { + throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); + } + } + NOT_IMPLEMENTED + } + + Expr Mutate_(const Add* op, const Expr& e) { + return op->make(Mutate(op->a), Mutate(op->b)); + } + + Expr Mutate_(const Sub* op, const Expr& e) { + return op->make(Mutate(op->a), Mutate(op->b)); + } + + Expr Mutate_(const Mul* op, const Expr& e) { + return Add::make(Mul::make(Mutate(op->a), op->b), Mul::make(op->a, Mutate(op->b))); + } + + Expr Mutate_(const Div* op, const Expr& e) { + return Div::make( + Sub::make(Mul::make(Mutate(op->a), op->b), Mul::make(op->a, Mutate(op->b))), + Mul::make(op->b, op->b)); + } + + Expr Mutate_(const Mod* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const Min* op, const Expr& e) { + return Select::make(LE::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); + } + + Expr Mutate_(const Max* op, const Expr& e) { + return Select::make(GE::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); + } + + Expr Mutate_(const EQ* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const NE* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const LT* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const LE* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const GT* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const GE* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const And* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const Or* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const Reduce*, const Expr& e) { + // This case is relatively difficult because a reduction expression + // may use an arbitrary combiner. + // The resulting reduction expression will return a tuple containing + // both derivatives and the original results (in exactly this order). + + // We have to clone the reduction axes because otherwise the original expression + // cannot be used together with the derivative (it will lead to errors during lowering) + Expr expr_with_new_axes = CloneReduction(e); + const Reduce* op = expr_with_new_axes.as(); + + // New lhs and rhs variables of the new combiner consist of variables + // representing derivatives followed by the original variables. + Array new_lhs; + for (const auto& var : op->combiner->lhs) { + new_lhs.push_back(var.copy_with_suffix(".der")); + } + for (const auto& var : op->combiner->lhs) { + new_lhs.push_back(var); + } + + Array new_rhs; + for (const auto& var : op->combiner->rhs) { + new_rhs.push_back(var.copy_with_suffix(".der")); + } + for (const auto& var : op->combiner->rhs) { + new_rhs.push_back(var); + } + + // The new combiner result also consists of the resulting derivatives + // followed by the original results. + Array new_result; + for (const auto& res : op->combiner->result) { + // Each resulting derivative is computed as a sum of derivatives + // wrt lhs and rhs multiplied by the derivatives of lhs and rhs + Expr new_res = make_zero(res.type()); + for (size_t i = 0; i < op->combiner->lhs.size(); ++i) { + Expr res_di = Derivative(res, op->combiner->lhs[i]); + // new_lhs[i] is the derivative of lhs[i] (wrt our input tensor) + new_res = Add::make(new_res, Mul::make(new_lhs[i], res_di)); + } + for (size_t i = 0; i < op->combiner->rhs.size(); ++i) { + Expr res_di = Derivative(res, op->combiner->rhs[i]); + new_res = Add::make(new_res, Mul::make(new_rhs[i], res_di)); + } + new_result.push_back(new_res); + } + for (const auto& res : op->combiner->result) { + new_result.push_back(res); + } + + // The identity is transformed in a similar way + Array new_identity; + for (const auto& id : op->combiner->identity_element) { + new_identity.push_back(Mutate(id)); + } + for (const auto& id : op->combiner->identity_element) { + new_identity.push_back(id); + } + + Array new_source; + for (const auto& src : op->source) { + new_source.push_back(Mutate(src)); + } + for (const auto& src : op->source) { + new_source.push_back(src); + } + + CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); + // Also simplify the resulting combiner (mostly to get rid of unused components) + return Simplify( + Reduce::make(new_combiner, new_source, op->axis, op->condition, op->value_index)); + } + + Expr Mutate_(const Cast* op, const Expr& e) { + if (op->type.is_float()) { + return Cast::make(op->type, Mutate(op->value)); + } else { + return make_zero(op->type); + } + } + + Expr Mutate_(const Not* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const Select* op, const Expr& e) { + return Select::make(op->condition, Mutate(op->true_value), Mutate(op->false_value)); + } + + Expr Mutate_(const Ramp* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const Broadcast* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const IntImm* op, const Expr& e) { return op->make(op->type, 0); } + Expr Mutate_(const UIntImm* op, const Expr& e) { return op->make(op->type, 0); } + Expr Mutate_(const FloatImm* op, const Expr& e) { return op->make(op->type, 0); } + + Expr Mutate_(const StringImm* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const Shuffle* op, const Expr& e) NOT_IMPLEMENTED + + private: + Tensor input_; + Array indices_; + VarExpr input_var_; +}; + +Expr Jacobian(const Expr& expr, const Tensor& input, const Array& indices) { + return JacobianMutator(input, indices).Mutate(expr); +} + +Expr Derivative(const Expr& expr, const VarExpr& var) { + return JacobianMutator(var).Mutate(expr); +} + +Tensor Jacobian(const Tensor& output, const Tensor& input, bool optimize) { + if (const ComputeOpNode* op = output->op.as()) { + bool is_input_tensor = false; + for (const Tensor& child : op->InputTensors()) { + if (input == child) { + is_input_tensor = true; + break; + } + } + CHECK(is_input_tensor) << "Jacobian is called on a pair of tensors such that the output " + << "does not depend on the input. This is probably a mistake."; + + // We have to clone the iteration axes because otherwise the original expression + // cannot be used together with the derivative (it will lead to errors during lowering) + Array new_axis; + std::unordered_map vmap; + for (IterVar iv : op->axis) { + IterVar new_v = + IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), + iv->iter_type, iv->thread_tag); + new_axis.push_back(new_v); + vmap[iv->var.operator->()] = new_v; + } + + // Generate new itervars for the input + Array input_itervars; + size_t i = 0; + for (Expr ext : input->shape) { + IterVar new_v = + IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i)), + IterVarType::kDataPar); + // Append them to new_axis + new_axis.push_back(new_v); + // We also need a separate array of these itervars + input_itervars.push_back(new_v); + ++i; + } + + // The differentiation itself happens here + Expr new_body = + Jacobian(Substitute(op->body[output->value_index], vmap), input, input_itervars); + new_body = Simplify(new_body); + + int value_index = 0; + Array new_bodies; + + // If this is a reduction then it may return a tuple and we have + // to repeat the body several times + if (const Reduce* red = new_body.as()) { + value_index = red->value_index; + for (size_t i = 0; i < red->source.size(); ++i) { + new_bodies.push_back( + Reduce::make(red->combiner, red->source, red->axis, red->condition, i)); + } + } else { + new_bodies.push_back(new_body); + } + + auto new_op = + ComputeOpNode::make(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); + + // new_shape = output.shape + input.shape + Array new_shape = output->shape; + for (const auto& e : input->shape) { + new_shape.push_back(e); + } + + Tensor tensor = TensorNode::make(new_shape, output->dtype, new_op, value_index); + + if (optimize) { + tensor = OptimizeAndLiftNonzeronessConditions(tensor); + } + + return tensor; + } else { + LOG(FATAL) << "Derivative of this op is not implemented: " << output->op; + throw; + } +} + +Tensor DiffBuildingBlock(const Tensor& output, const Tensor& input, const Tensor& head) { + Tensor jac_output_input = Jacobian(output, input); + Tensor result = topi::tensordot(head, jac_output_input, output->shape.size(), + output->op->name + "." + input->op->name + ".grad"); + // TODO(sgrechanik-h): Here we inline only jac_output_input because otherwise there will be + // performance problems. A better solution would be to inline smartly. + result = InlineTensors(result, {jac_output_input}); + result = OptimizeAndLiftNonzeronessConditions(result); + result = InlineTailCall(result); + return result; +} + +DifferentiationResult Differentiate(const Tensor& output, + const Array& inputs, + const Tensor& head_or_null, + const FDiffBuildingBlock& fdiff, + const Map>& override_deps) { + Tensor head = head_or_null; + + // If the head is a null pointer, create an identity tensor + if (!head.get()) { + Array shape = output->shape; + for (auto e : output->shape) { + shape.push_back(e); + } + auto func = + [&output](const Array& input_indices) { + Expr res = const_true(); + for (size_t i = 0; i < output->shape.size(); ++i) { + res = res && Expr(input_indices[i]) == Expr(input_indices[output->shape.size() + i]); + } + return Cast::make(output->dtype, res); + }; + head = tvm::compute(shape, func, "identity"); + } + + // This map maps a tensor to the list of tensors immediately depending on it (using it in their + // bodies) + std::unordered_map> reverse_dependencies; + + // Map doesn't work correctly for Tensors, so convert it to std::unordered_map + std::unordered_map> override_deps_map; + for (auto pair : override_deps) { + override_deps_map.insert(pair); + } + + // Collect reverse dependencies + std::vector stack({output}); + while (!stack.empty()) { + Tensor tensor = stack.back(); + stack.pop_back(); + + auto it = override_deps_map.find(tensor); + Array deps = it != override_deps_map.end() ? it->second : tensor->op->InputTensors(); + + for (const Tensor& child : deps) { + if (!reverse_dependencies.count(child)) { + stack.push_back(child); + } + reverse_dependencies[child].push_back(tensor); + } + } + + // Individual summands of the adjoints + std::unordered_map> summands; + + // This map maps tensors to the corresponding adjoints (dLoss/dTensor) + std::unordered_map adjoints; + // head is the adjoint of output by definition + adjoints[output] = head; + + // This is a recursive function that does all the work. It computes the adjoint for a given + // tensor, adds it to the map, and returns it + std::function compute_adjoint; + compute_adjoint = + [&compute_adjoint, &adjoints, &summands, &reverse_dependencies, &fdiff, &head, &output] + (const Tensor& tensor) { + if (!adjoints.count(tensor)) { + // Here the adjoint hasn't been computed yet + Tensor res_adjoint; + std::vector deps = reverse_dependencies[tensor]; + if (deps.empty()) { + // No reverse dependencies means that the output does not depend on this tensor, + // return a zero tensor of the appropriate shape + Array result_shape(head->shape.begin(), + head->shape.end() + (-output->shape.size())); + for (auto e : tensor->shape) { + result_shape.push_back(e); + } + res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype)); + } else { + // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied + // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian + // and the multiplication is done in the function fdiff (DiffBuildingBlock by default). + for (const Tensor& dep : deps) { + Tensor part = fdiff(dep, tensor, compute_adjoint(dep)); + res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; + + // Add this part to summands + auto& summands_of_adjoint = summands[tensor]; + if (summands_of_adjoint.get()) { + summands_of_adjoint.Set(dep, part); + } else { + summands_of_adjoint = Map({{dep, part}}); + } + } + } + + adjoints[tensor] = res_adjoint; + return res_adjoint; + } else { + return adjoints[tensor]; + } + }; + + // Adjoints corresponding to inputs + Array result; + + // If inputs is empty, compute adjoints for all tensors, on which output depends + if (inputs.empty()) { + for (const auto& dep : reverse_dependencies) { + compute_adjoint(dep.first); + } + } + + // Compute an adjoint for each input + for (const Tensor& input : inputs) { + result.push_back(compute_adjoint(input)); + } + + return DifferentiationResultNode::make(result, adjoints, summands); +} + + +TVM_REGISTER_API("autodiff.Jacobian") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() > 2) { + *ret = Jacobian(args[0], args[1], args[2].operator bool()); + } else { + *ret = Jacobian(args[0], args[1]); + } + }); + +TVM_REGISTER_API("autodiff.Derivative") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = Derivative(args[0], args[1]); + }); + +TVM_REGISTER_API("autodiff.DiffBuildingBlock") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = DiffBuildingBlock(args[0], args[1], args[2]); + }); + +TVM_REGISTER_API("autodiff.Differentiate") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() <= 1) { + *ret = Differentiate(args[0]); + } else if (args.size() == 2) { + *ret = Differentiate(args[0], args[1]); + } else if (args.size() == 3) { + *ret = Differentiate(args[0], args[1], args[2]); + } else if (args.size() >= 4) { + auto pfunc = args[3].operator PackedFunc(); + auto fdiff = + [pfunc](const Tensor& o, const Tensor& i, const Tensor& h) { + return pfunc(o, i, h); + }; + if (args.size() >= 5) { + *ret = Differentiate(args[0], args[1], args[2], fdiff, args[4]); + } else { + *ret = Differentiate(args[0], args[1], args[2], fdiff); + } + } + }); + +} // namespace ir +} // namespace tvm diff --git a/src/pass/dump_tensor.cc b/src/pass/dump_tensor.cc new file mode 100644 index 000000000000..18927e07d1de --- /dev/null +++ b/src/pass/dump_tensor.cc @@ -0,0 +1,100 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file dump_tensor.cc + * \brief Print out tensors recursively. + */ +#include +#include +#include +#include + +namespace tvm { +namespace ir { + +std::string PrintTensorName(const Tensor& tensor) { + if (!tensor.get()) { + return "NULL_TENSOR"; + } + + std::ostringstream oss; + oss << tensor->op->name << "[" << tensor->value_index << "]"; + return oss.str(); +} + +std::string PrintIterVars(const Array& itervars) { + std::ostringstream oss; + oss << "("; + bool first = true; + for (const IterVar& iv : itervars) { + if (!first) oss << ", "; + first = false; + oss << iv->var << " : " << "[" << iv->dom->min + << ", " << (iv->dom->min + iv->dom->extent - 1) << "]"; + } + oss << ")"; + return oss.str(); +} + +std::string PrintTensorsRecursively(const Array& tensors) { + std::vector unprocessed; + std::unordered_set processed; + std::ostringstream oss; + + for (const Tensor& t : tensors) { + unprocessed.push_back(t); + } + + while (!unprocessed.empty()) { + Tensor cur = unprocessed.back(); + unprocessed.pop_back(); + processed.insert(cur); + + oss << "tensor " << PrintTensorName(cur) << " : " << cur->dtype << " " << cur->shape << "\n"; + if (const ComputeOpNode* comp = cur->op.as()) { + oss << "axes " << PrintIterVars(comp->axis) << "\n"; + Expr body = comp->body[cur->value_index]; + + for (const Tensor& t : comp->InputTensors()) { + if (processed.count(t) == 0) { + unprocessed.push_back(t); + } + } + + if (const Reduce* red = body.as()) { + oss << "Reduction\n"; + oss << " identity " << red->combiner->identity_element << "\n"; + oss << " lhs " << red->combiner->lhs << " rhs " << red->combiner->rhs << "\n"; + oss << " combiner " << red->combiner->result << "\n"; + oss << " axis " << PrintIterVars(red->axis) << "\n"; + oss << " condition " << red->condition << "\n"; + for (size_t i = 0; i < red->source.size(); ++i) { + oss << " source[" << i << "] = " << red->source[i] << "\n"; + } + } else { + oss << " " << body << "\n"; + } + } else { + oss << " " << cur->op << "\n"; + } + oss << "\n"; + } + + return oss.str(); +} + +std::string PrintTensorRecursively(const Tensor& tensor) { + return PrintTensorsRecursively({tensor}); +} + +TVM_REGISTER_API("PrintTensorRecursively") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = PrintTensorRecursively(args[0]); + }); + +TVM_REGISTER_API("PrintTensorsRecursively") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = PrintTensorsRecursively(args[0]); + }); + +} // namespace ir +} // namespace tvm diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc new file mode 100644 index 000000000000..56e4006c824a --- /dev/null +++ b/src/pass/zero_elimination.cc @@ -0,0 +1,1719 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file zero_elimination.cc + * \brief Transform tensors in such a way as to eliminate summation over zeros. + */ +#include "zero_elimination.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "arithmetic/ModulusRemainder.h" +#include "../op/op_util.h" + +namespace tvm { +namespace ir { + +using HalideIR::Internal::gcd; +using HalideIR::Internal::lcm; + +struct ExprLess { + bool operator()(const Expr& l, const Expr& r) const { + return Compare(l, r) < 0; + } +}; + +struct ExprEq { + bool operator()(const Expr& l, const Expr& r) const { + return Compare(l, r) == 0; + } +}; + +// Merge two maps, prefer the right one on conflict +template +Map Merge(Map original, const Map& update) { + for (const auto& p : update) { + original.Set(p.first, p.second); + } + return std::move(original); +} + +// Concatenate two arrays +template +Array Concat(Array a, const Array& b) { + for (const auto& x : b) { + a.push_back(x); + } + return std::move(a); +} + +// Combine all expressions from the container using &&. +template +Expr All(const container& c) { + Expr res; + for (const auto& e : c) { + if (res.get()) { + res = res && e; + } else { + res = e; + } + } + if (res.get()) { + return res; + } else { + return const_true(); + } +} + +// Create a select statement of the form cond ? on_true : 0 +Expr SelectElseZero(const Expr& cond, const Expr& on_true) { + return Select::make(cond, on_true, make_zero(on_true.type())); +} + +// Simplify the expression as thoroughly as possible by using all available simplifiers. +Expr SuperSimplify(Expr e, const Map& vranges = Map()) { + // For some reason no simplifier can detect that there is only one value of the variable + std::unordered_map vmap; + for (const auto& var_range : vranges) { + if (is_const_int(var_range.second->extent, 1)) { + vmap[var_range.first.get()] = var_range.second->min; + } + } + if (!vmap.empty()) { + e = Substitute(e, vmap); + } + + return CanonicalSimplify(Simplify(CanonicalSimplify(e, vranges), vranges), vranges); +} + +// Provability check that uses SuperSimplify +bool CanProve(Expr e, const Map& vranges = Map()) { + return is_one(SuperSimplify(e, vranges)); +} + +class ExprFreeVarsVisitor : public IRVisitor { + public: + std::vector free_array; + std::unordered_set bound; + std::unordered_set free; + + virtual void Visit(const NodeRef& node) { + if (const Variable* v = node.as()) { + if (!bound.count(v) && !free.count(v)) { + free.insert(v); + free_array.push_back(Var(node.node_)); + } + } else { + IRVisitor::Visit(node); + } + } + + void Visit_(const Variable* op) { + CHECK(false) << "This case shouldn't happen"; + } + + void Visit_(const LetStmt* op) { + bound.insert(op->var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const For* op) { + bound.insert(op->loop_var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const Let* op) { + bound.insert(op->var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const Reduce* op) { + for (const auto& iv : op->axis) { + bound.insert(iv->var.get()); + } + IRVisitor::Visit_(op); + } + + void Visit_(const Store* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Allocate* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Free* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Load* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } +}; + +// Get free variables of an expression +Array ExprFreeVars(const Expr& expr) { + ExprFreeVarsVisitor visitor; + visitor.Visit(expr); + return visitor.free_array; +} + +// Clone iter vars and return both the new vars and the substitution from old to new. +std::pair, std::unordered_map> CloneIterVars( + const Array& vars) { + Array new_vars; + std::unordered_map vmap; + for (const IterVar& iv : vars) { + IterVar new_v = + IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), + iv->iter_type, iv->thread_tag); + new_vars.push_back(new_v); + vmap[iv->var.get()] = new_v; + } + return std::make_pair(std::move(new_vars), std::move(vmap)); +} + +// Clone reduction by cloning the axis variables. +Expr CloneReduction(const Expr& expr) { + if (const Reduce* red = expr.as()) { + Array new_axis; + std::unordered_map vmap; + std::tie(new_axis, vmap) = CloneIterVars(red->axis); + + Array src_with_newaxis; + for (const auto& src : red->source) { + src_with_newaxis.push_back(Substitute(src, vmap)); + } + + return Reduce::make(red->combiner, src_with_newaxis, + new_axis, Substitute(red->condition, vmap), red->value_index); + } else { + return expr; + } +} + +// Convert an array of itervars to an array of inequalities +Array IterVarsToInequalities(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(GE::make(v->var, v->dom->min)); + res.push_back(LT::make(v->var, v->dom->min + v->dom->extent)); + } + return res; +} + +// Convert an array of itervars to a map from vars to ranges +Map IterVarsToMap(const Array& itervars) { + Map res; + for (const IterVar& v : itervars) { + res.Set(v->var, v->dom); + } + return res; +} + +// Convert an array of itervars to an array of vars +Array IterVarsToVars(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(v->var); + } + return res; +} + +// Given a map from vars to ranges create an array of itervars +Array IterVarsFromMap(const Array& vars, const Map& vranges, + IterVarType iter_type = kDataPar, std::string thread_tag = "") { + Array res; + for (const Var& v : vars) { + CHECK(vranges.count(v)) << "A range for the variable " << v + << " was not provided in map " << vranges; + res.push_back(IterVarNode::make(vranges[v], v, iter_type, thread_tag)); + } + return res; +} + +// Return true if this combiner is just a sum. +bool IsSumCombiner(const CommReducer& combiner) { + if (combiner->result.size() != 1) { + return false; + } + + if (!is_const_value(SuperSimplify(combiner->identity_element[0]), 0)) { + return false; + } + + return is_const_value(SuperSimplify(combiner->result[0] - + (combiner->lhs[0] + combiner->rhs[0])), + 0); +} + +// Return true if zero may be factored out of a reduction with this combiner. +bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index) { + if (!is_const_value(combiner->identity_element[value_index], 0)) { + return false; + } + + Expr zero = make_zero(combiner->result[value_index].type()); + Expr in = Substitute(combiner->result[value_index], + {{combiner->lhs[value_index], zero}, + {combiner->rhs[value_index], zero}}); + in = SuperSimplify(in); + + return is_const_value(in, 0); +} + +Expr InlineThisCall(const Expr& expr) { + if (const Call* op = expr.as()) { + if (op->call_type == Call::CallType::Halide) { + if (const ComputeOpNode* op_comp = op->func.as()) { + Array tensor_axes; + for (const auto& var : op_comp->axis) { + tensor_axes.push_back(var->var); + } + + Stmt inlined = Inline(Evaluate::make(expr), op->func, tensor_axes, + op_comp->body[op->value_index]); + if (const ir::Evaluate* ev = inlined.as()) { + // If it is a reduction, clone it + return CloneReduction(ev->value); + } + } + } + } + + return expr; +} + +Tensor InlineTailCall(const Tensor& tensor) { + return op::TransformBody(tensor, InlineThisCall); +} + +class InlineTensorsMutator : public IRMutator { + public: + explicit InlineTensorsMutator(const Array& inlineable, bool inline_reductions = false) + : inline_reductions_(inline_reductions) { + for (const Tensor& tensor : inlineable) { + inlineable_.emplace(tensor->op.operator->(), tensor->value_index); + } + } + + Expr Mutate_(const Call* op, const Expr& e) { + if (op->call_type == Call::CallType::Halide) { + const ComputeOpNode* op_comp = op->func.as(); + if (inlineable_.empty() || inlineable_.count({op_comp, op->value_index})) { + if (op_comp && (inline_reductions_ || !op_comp->body[0].as())) { + Array tensor_axes; + for (const auto& var : op_comp->axis) { + tensor_axes.push_back(var->var); + } + + Stmt inlined = Inline(Evaluate::make(e), op->func, tensor_axes, + op_comp->body[op->value_index]); + if (const ir::Evaluate* ev = inlined.as()) { + // If it is a reduction, clone it + return Mutate(ev->value); + } + } + } + } + + return e; + } + + private: + std::set> inlineable_; + bool inline_reductions_; +}; + +Expr InlineTensors(const Expr& expr, const Array& inlineable, + bool inline_reductions) { + return InlineTensorsMutator(inlineable, inline_reductions).Mutate(expr); +} + +Tensor InlineTensors(const Tensor& tensor, const Array& inlineable, + bool inline_reductions) { + auto transformation = + [inlineable, inline_reductions](const Expr& e) { + return InlineTensorsMutator(inlineable, inline_reductions).Mutate(e); }; + return op::TransformBody(tensor, transformation); +} + + +struct NonzeronessConditionResult { + Expr cond; + Expr value; + + Expr to_expr() const { + return SelectElseZero(cond, value); + } +}; + +class NonzeronessConditionFunctor + : public ExprFunctor { + public: + NonzeronessConditionResult NonzeronessCondition(const Expr& e) { + return VisitExpr(e, e); + } + + result_type VisitExpr_(const Variable*, const Expr& e) final { return Default_(e); } + result_type VisitExpr_(const IntImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const UIntImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const FloatImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const StringImm*, const Expr& e) final { return Default_(e); } + result_type VisitExpr_(const Add* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Sub* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Mul* op, const Expr& e) final { return BinOpMulLike_(op, e); } + result_type VisitExpr_(const Div* op, const Expr& e) final { return BinOpDivLike_(op, e); } + result_type VisitExpr_(const Mod* op, const Expr& e) final { return BinOpDivLike_(op, e); } + result_type VisitExpr_(const Min* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Max* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const EQ* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const NE* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const LE* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const LT* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const GE* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const GT* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const Not* op, const Expr& e) final { return Bool_(op, e); } + + result_type VisitExpr_(const Cast* op, const Expr& e) final { + if (op->value.type().is_bool()) { + return {op->value, make_const(e.type(), 1)}; + } else { + auto nz_a = NonzeronessCondition(op->value); + + if (nz_a.value.same_as(op->value)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, Cast::make(op->type, nz_a.value)}; + } + } + } + + result_type VisitExpr_(const Select* op, const Expr& e) final { + return SelectLike_(e, op->condition, op->true_value, op->false_value, Select::make); + } + + result_type VisitExpr_(const Call* op, const Expr& e) final { + if (op->name == intrinsic::tvm_if_then_else) { + return SelectLike_(e, op->args[0], op->args[1], op->args[2], if_then_else); + } else { + return Default_(e); + } + } + + NonzeronessConditionResult Default_(const Expr& e) { + return {const_true(), e}; + } + + template + NonzeronessConditionResult Const_(const TNode* op, const Expr& e) { + if (op->value == 0) { + return {const_false(), e}; + } else { + return {const_true(), e}; + } + } + + template + NonzeronessConditionResult SelectLike_(const Expr& e, const Expr& cond, const Expr& true_val, + const Expr& false_val, make_select_type make_select) { + auto nz_a = NonzeronessCondition(true_val); + auto nz_b = NonzeronessCondition(false_val); + + if (is_const_value(nz_b.value, 0)) { + Expr new_cond = SuperSimplify(nz_a.cond && cond); + return {new_cond, nz_a.value}; + } + + if (is_const_value(nz_a.value, 0)) { + Expr new_cond = SuperSimplify(nz_b.cond && !cond); + return {new_cond, nz_b.value}; + } + + Expr new_cond = + SuperSimplify(Or::make(cond && nz_a.cond, + !cond && nz_b.cond)); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, e}; + } else { + return {new_cond, make_select(cond, nz_a.value, nz_b.value)}; + } + } + + template + NonzeronessConditionResult BinOpAddLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + auto nz_b = NonzeronessCondition(op->b); + + if (Equal(nz_a.cond, nz_b.cond)) { + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, TNode::make(nz_a.value, nz_b.value)}; + } + } else { + Expr new_cond = SuperSimplify(Or::make(nz_a.cond, nz_b.cond)); + Expr new_a = Equal(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); + Expr new_b = Equal(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); + Expr new_expr = TNode::make(new_a, new_b); + return {new_cond, new_expr}; + } + } + + template + NonzeronessConditionResult BinOpMulLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + auto nz_b = NonzeronessCondition(op->b); + + Expr new_cond = SuperSimplify(nz_a.cond && nz_b.cond); + + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {new_cond, e}; + } else { + return {new_cond, TNode::make(nz_a.value, nz_b.value)}; + } + } + + template + NonzeronessConditionResult BinOpDivLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + + if (nz_a.value.same_as(op->a)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, TNode::make(nz_a.value, op->b)}; + } + } + + template + NonzeronessConditionResult Bool_(const TNode* op, const Expr& e) { + return {e, make_const(e.type(), 1)}; + } +}; + +NonzeronessConditionResult NonzeronessCondition(const Expr& expr) { + return NonzeronessConditionFunctor().NonzeronessCondition(expr); +} + +Expr LiftNonzeronessCondition(const Expr& expr) { + return NonzeronessCondition(expr).to_expr(); +} + + +class NormalizeComparisonsMutator : public IRMutator { + public: + virtual Expr Mutate_(const EQ* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const NE* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const LT* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const LE* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const GT* op, const Expr& e) { return Make(op->b, op->a); } + virtual Expr Mutate_(const GE* op, const Expr& e) { return Make(op->b, op->a); } + + private: + template + Expr Make(const Expr& a, const Expr& b) { + // rewrite LT to LE for ints + if (std::is_same::value && (a.type().is_int() || a.type().is_uint())) { + return LE::make(SuperSimplify(a - b + 1), make_zero(a.type())); + } + return TNode::make(SuperSimplify(a - b), make_zero(a.type())); + } +}; + +// Rewrite every comparison into the form a == 0, a != 0, a <= 0, and sometimes for floats a < 0 +Expr NormalizeComparisons(const Expr& expr) { + return NormalizeComparisonsMutator().Mutate(expr); +} + + +struct FactorOutAtomicFormulasResult { + std::vector atomic_formulas; + Expr rest; + + Expr to_expr() const { + Expr res = rest; + for (const Expr& e : atomic_formulas) { + res = And::make(e, res); + } + return res; + } +}; + +class FactorOutAtomicFormulasFunctor + : public ExprFunctor { + public: + result_type Atomic_(const Expr& e) { + return {{e}, make_const(e.type(), 1)}; + } + + result_type VisitExpr_(const Variable*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const Call*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const IntImm*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const UIntImm*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const EQ*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const NE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const LE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const LT*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const GE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const GT*, const Expr& e) final { return Atomic_(e); } + + result_type VisitExpr_(const And* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + std::vector res; + res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + return {res, res_a.rest && res_b.rest}; + } + + result_type VisitExpr_(const Mul* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + std::vector res; + res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + return {res, res_a.rest * res_b.rest}; + } + + result_type VisitExpr_(const Or* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + std::vector res; + res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); + std::set_intersection(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + std::vector new_cond_a; + new_cond_a.reserve(res_a.atomic_formulas.size() - res.size()); + std::set_difference(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res.begin(), res.end(), + std::back_inserter(new_cond_a), + ExprLess()); + + std::vector new_cond_b; + new_cond_b.reserve(res_b.atomic_formulas.size() - res.size()); + std::set_difference(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + res.begin(), res.end(), + std::back_inserter(new_cond_b), + ExprLess()); + + res_a.atomic_formulas = std::move(new_cond_a); + res_b.atomic_formulas = std::move(new_cond_b); + + Expr new_rest = Or::make(res_a.to_expr(), res_b.to_expr()); + + return {res, new_rest}; + } +}; + +// Transform the given formula into an array of atomic formulas and a non-atomic residual. +FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const Expr& e) { + return FactorOutAtomicFormulasFunctor().VisitExpr(e, e); +} + + +struct EliminateDivModResult { + Expr expr; + Map substitution; + Array new_variables; + Array conditions; + Map ranges; +}; + +class EliminateDivModMutator : public IRMutator { + public: + Map substitution; + Array new_variables; + Array conditions; + Map ranges; + + explicit EliminateDivModMutator(Map ranges) + : ranges(ranges) {} + + virtual Expr Mutate_(const Div* op, const Expr& e) { + const IntImm* imm = op->b.as(); + if (imm && imm->value > 0) { + auto it = expr_to_vars_.find({op->a.get(), imm->value}); + if (it != expr_to_vars_.end()) { + return it->second.first; + } + + Expr mutated_a = Mutate(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { + return var_pair_opt.value().first; + } else { + return Div::make(mutated_a, Mutate(op->b)); + } + } + + return Div::make(Mutate(op->a), Mutate(op->b)); + } + + virtual Expr Mutate_(const Mod* op, const Expr& e) { + const IntImm* imm = op->b.as(); + if (imm && imm->value > 0) { + auto it = expr_to_vars_.find({op->a.get(), imm->value}); + if (it != expr_to_vars_.end()) { + return it->second.second; + } + + Expr mutated_a = Mutate(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { + return var_pair_opt.value().second; + } else { + return Mod::make(mutated_a, Mutate(op->b)); + } + } + + return Mod::make(Mutate(op->a), Mutate(op->b)); + } + + private: + dmlc::optional> AddNewVarPair(const Expr& e, const Expr& mut, int64_t val) { + using tresult = dmlc::optional>; + + Expr val_e = make_const(e.type(), val); + idx_ += 1; + + std::unordered_map var_intsets; + for (const auto& p : ranges) { + var_intsets[p.first.get()] = IntSet::range(p.second); + } + + Range div_range = EvalSet(mut / val_e, var_intsets).cover_range(Range()); + Range mod_range = EvalSet(mut % val_e, var_intsets).cover_range(Range()); + + if (!div_range.get() || !mod_range.get()) { + LOG(WARNING) << "EliminateDivMod: won't eliminate div or mod of expr " << e + << " because its bounds cannot be inferred"; + return tresult(); + } + + auto div = Var("div" + std::to_string(idx_), e.type()); + auto mod = Var("mod" + std::to_string(idx_), e.type()); + + new_variables.push_back(div); + new_variables.push_back(mod); + + substitution.Set(div, mut / val_e); + substitution.Set(mod, mut % val_e); + + ranges.Set(div, div_range); + ranges.Set(mod, mod_range); + + conditions.push_back(mut == div*val_e + mod); + + if (!CanProve(mod_range->extent <= val_e)) { + LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod of expr " << e + << " (probably it may change its sign)"; + // We cannot prove that mod is unique, so add additional condition + conditions.push_back(Select::make(e >= 0, mod >= 0, mod <= 0)); + } + + auto p = std::make_pair(div, mod); + expr_to_vars_[{e.get(), val}] = p; + return tresult(p); + } + + int idx_{0}; + std::map, std::pair> + expr_to_vars_; +}; + +// replace every subexpr of the form e/const and e % const with a new variable +EliminateDivModResult EliminateDivMod(const Expr& expr, Map ranges) { + EliminateDivModResult res; + EliminateDivModMutator mutator(ranges); + res.expr = mutator.Mutate(expr); + res.conditions = std::move(mutator.conditions); + res.new_variables = std::move(mutator.new_variables); + res.substitution = std::move(mutator.substitution); + res.ranges = std::move(mutator.ranges); + return res; +} + +// run EliminateDivMod from the condition of a reduction +Expr EliminateDivModFromReductionCondition(const Expr& expr, + Map vranges = Map()) { + if (const Reduce* red = expr.as()) { + for (const IterVar& iv : red->axis) { + vranges.Set(iv->var, iv->dom); + } + + auto elim_res = EliminateDivMod(red->condition, vranges); + + vranges = elim_res.ranges; + + Array new_axis = + Concat(red->axis, IterVarsFromMap(elim_res.new_variables, vranges, kCommReduce)); + + Expr new_cond = elim_res.expr && All(elim_res.conditions); + + return Reduce::make(red->combiner, red->source, new_axis, new_cond, red->value_index); + } else { + return expr; + } +} + + +VarBounds VarBounds::substitute(const Map& subst) const { + auto apply_fun = [&subst](const Expr& e) { return Substitute(e, subst); }; + return {Substitute(coef, subst), + UpdateArray(lower, apply_fun), + UpdateArray(equal, apply_fun), + UpdateArray(upper, apply_fun)}; +} + +Array SolveSystemOfInequalitiesResult::as_conditions() const { + Array res; + for (const Var& v : variables) { + auto it = bounds.find(v.get()); + CHECK(it != bounds.end()); + const VarBounds& bnds = it->second; + Expr lhs = bnds.coef * v; + for (const Expr& rhs : bnds.equal) { + res.push_back(EQ::make(lhs, rhs)); + } + for (const Expr& rhs : bnds.lower) { + res.push_back(GE::make(lhs, rhs)); + } + for (const Expr& rhs : bnds.upper) { + res.push_back(LE::make(lhs, rhs)); + } + } + for (const Expr& e : other_conditions) { + res.push_back(e); + } + return res; +} + +// Rewrite the system of inequalities using Fourier-Motzkin elimination +// Note that variable ranges help a lot, so this parameter is even non-optional +SolveSystemOfInequalitiesResult SolveSystemOfInequalities(const Array& inequalities, + const Array& variables, + const Map& vranges) { + SolveSystemOfInequalitiesResult res; + res.variables = variables; + + // The algorithm consists in doing the following things for each variable v + // - Take formulas from `current` and classify them according to polarity wrt v + // - Combine each formula of positive polarity (wrt v) with each formula of negative polarity + // - Put the resulting combinations into `new_current` along with unclassifiable formulas + // - Replace `current` with `new_current` and move to the next variable + + // current and new_current are sorted to enable some heuristics + std::set current; + std::set new_current; + // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_pos; + // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_neg; + + // formulas we don't know what to do with + std::vector rest; + + // A helper that adds an inequality to new_current if it's not obviously redundant + auto add_to_new_current = [&new_current, &vranges] (const Expr& new_ineq) { + if (CanProve(new_ineq, vranges)) { + // redundant: follows from the vranges + return; + } + if (const LE* new_le = new_ineq.as()) { + // A heuristic: check if the new inequality is a consequence of one + // of its future neighbors (in this case don't add it) or if a future neighbor is + // a consequence of the new ineq (in which case remove the neighbor) + auto it_neighbor = new_current.lower_bound(new_ineq); + if (it_neighbor != new_current.begin()) { + const LE* le = std::prev(it_neighbor)->as(); + if (le && CanProve(new_le->a - le->a <= 0, vranges)) { + return; + } else if (le && CanProve(le->a - new_le->a <= 0, vranges)) { + new_current.erase(std::prev(it_neighbor)); + } + } + // Check the other neighbor + if (it_neighbor != new_current.end()) { + const LE* le = it_neighbor->as(); + if (le && CanProve(new_le->a - le->a <= 0, vranges)) { + return; + } else if (le && CanProve(le->a - new_le->a <= 0, vranges)) { + it_neighbor = new_current.erase(it_neighbor); + } + } + + new_current.insert(it_neighbor, new_ineq); + } else { + new_current.insert(new_ineq); + } + }; + + // Simplify each inequality into the form `expr <= 0` and add to new_current formulas + for (const Expr& ineq : inequalities) { + add_to_new_current(NormalizeComparisons(SuperSimplify(ineq, vranges))); + } + + std::swap(current, new_current); + + for (const Var& v : variables) { + CHECK(!res.bounds.count(v.get())) << + "Variable " << v << " appears several times in the `variables` which might be a bug"; + + new_current.clear(); + coef_pos.clear(); + coef_neg.clear(); + + // Add bounds from vranges + if (vranges.count(v)) { + const Range& range = vranges[v]; + Expr range_lbound = SuperSimplify(range->min, vranges); + Expr range_ubound = SuperSimplify(range->min + range->extent - 1, vranges); + coef_neg.push_back({-1, range_lbound}); + coef_pos.push_back({1, -range_ubound}); + } + + // Take formulas from `current` and classify them according to polarity wrt v + for (const Expr& ineq : current) { + if (const LE* le = ineq.as()) { + Array coef = arith::DetectLinearEquation(le->a, {v}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to new_current + add_to_new_current(ineq); + } else if (coef0 > 0) { + coef_pos.push_back({coef0, coef[1]}); + } else if (coef0 < 0) { + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } else if (const EQ* eq = ineq.as()) { + Array coef = arith::DetectLinearEquation(eq->a, {v}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to new_current + add_to_new_current(ineq); + } else if (coef0 > 0) { + // Equalities may be considered as pairs of two inequalities + coef_pos.push_back({coef0, coef[1]}); + coef_neg.push_back({-coef0, -coef[1]}); + } else if (coef0 < 0) { + coef_pos.push_back({-coef0, -coef[1]}); + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } + + // if nothing worked, put it in rest + rest.push_back(ineq); + } + + // Combine each positive inequality with each negative one (by adding them together) + for (const auto& pos : coef_pos) { + for (const auto& neg : coef_neg) { + auto first_gcd = gcd(pos.first, -neg.first); + Expr c_pos = make_const(v.type(), neg.first/first_gcd); + Expr c_neg = make_const(v.type(), pos.first/first_gcd); + Expr new_lhs = c_neg*neg.second - c_pos*pos.second; + Expr new_ineq = LE::make(new_lhs, make_zero(pos.second.type())); + new_ineq = NormalizeComparisons(SuperSimplify(new_ineq, vranges)); + add_to_new_current(new_ineq); + } + } + + // Now we have to generate resulting (in)equalities for the variable v + + // Find the common denominator in a sense + // We will generate formulas of the form coef_lcm*v <= bound + int64_t coef_lcm = 1; + for (const auto& pos : coef_pos) { + coef_lcm = lcm(coef_lcm, pos.first); + } + for (const auto& neg : coef_neg) { + coef_lcm = lcm(coef_lcm, -neg.first); + } + + // The resulting lower and upper bounds stored in sorted vectors + std::vector upper_bounds; + std::vector lower_bounds; + upper_bounds.reserve(coef_pos.size()); + lower_bounds.reserve(coef_neg.size()); + + for (const auto& pos : coef_pos) { + Expr bound = make_const(v.type(), -coef_lcm/pos.first)*pos.second; + bound = SuperSimplify(bound, vranges); + // Don't add if any of the existing bounds is better + if (std::any_of(upper_bounds.begin(), upper_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound <= 0, + vranges); })) { + continue; + } + // Erase all worse bounds + upper_bounds.erase( + std::remove_if(upper_bounds.begin(), upper_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound >= 0, + vranges); }), + upper_bounds.end()); + // Add + upper_bounds.push_back(bound); + } + for (const auto& neg : coef_neg) { + Expr bound = make_const(v.type(), -coef_lcm/neg.first)*neg.second; + bound = SuperSimplify(bound, vranges); + // Don't add if any of the existing bounds is better + if (std::any_of(lower_bounds.begin(), lower_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound >= 0, + vranges); })) { + continue; + } + // Erase all worse bounds + lower_bounds.erase( + std::remove_if(lower_bounds.begin(), lower_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound <= 0, + vranges); }), + lower_bounds.end()); + // Add + lower_bounds.push_back(bound); + } + + // Sort the vectors and remove duplicates + for (std::vector* bounds : {&upper_bounds, &lower_bounds}) { + std::sort(bounds->begin(), bounds->end(), ExprLess()); + bounds->erase(std::unique(bounds->begin(), bounds->end(), ExprEq()), bounds->end()); + } + + // Bounds which are both lower and upper should go to equal... + std::vector equal; + equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); + std::set_intersection(upper_bounds.begin(), upper_bounds.end(), + lower_bounds.begin(), lower_bounds.end(), + std::back_inserter(equal), ExprLess()); + + // ...and be removed from upper bounds... + std::vector new_upper; + new_upper.reserve(upper_bounds.size() - equal.size()); + std::set_difference(upper_bounds.begin(), upper_bounds.end(), + equal.begin(), equal.end(), + std::back_inserter(new_upper), ExprLess()); + + // ...and from lower bounds. + std::vector new_lower; + new_lower.reserve(lower_bounds.size() - equal.size()); + std::set_difference(lower_bounds.begin(), lower_bounds.end(), + equal.begin(), equal.end(), + std::back_inserter(new_lower), ExprLess()); + + // Write it to the result. + auto& bnds = res.bounds[v.get()]; + bnds.coef = make_const(v.type(), coef_lcm); + bnds.equal = equal; + bnds.lower = new_lower; + bnds.upper = new_upper; + + std::swap(current, new_current); + } + + // Everything that is left goes to res.other_conditions + for (const Expr& e : current) { + Expr e_simp = SuperSimplify(e, vranges); + if (is_const_int(e_simp, 0)) { + // contradiction detected + res.other_conditions = {const_false()}; + return res; + } else if (is_const_int(e_simp, 1)) { + continue; + } else { + res.other_conditions.push_back(e_simp); + } + } + + for (const Expr& e : rest) + res.other_conditions.push_back(e); + + return res; +} + + +// Simplify an iteration domain. +DomainSimplificationResult SimplifyDomain(const Expr& cond, + const Array& axis, + Map vranges, + bool eliminate_div_mod) { + if (eliminate_div_mod) { + auto elim_res = EliminateDivMod(cond, vranges); + + Map new_vranges = elim_res.ranges; + Array new_axis = Concat(axis, elim_res.new_variables); + Expr new_cond = elim_res.expr && All(elim_res.conditions); + + auto res = SimplifyDomain(new_cond, new_axis, new_vranges, false); + + Map new_old_to_new; + for (const Var& v : axis) { + new_old_to_new.Set(v, res.old_to_new[v]); + } + + Map new_new_to_old; + for (const auto& pair : res.new_to_old) { + new_new_to_old.Set(pair.first, Substitute(pair.second, elim_res.substitution)); + } + + res.old_to_new = std::move(new_old_to_new); + res.new_to_old = std::move(new_new_to_old); + + return res; + } + + auto factoratomic_res = FactorOutAtomicFormulas(cond); + std::vector& atomic_formulas = factoratomic_res.atomic_formulas; + Expr rest_of_cond = factoratomic_res.rest; + + // Put rest_of_cond into the vector of atomic formulas so that we don't forget about it. + // Although rest_of_cond is not atomic, the subsequent functions won't complain about it. + atomic_formulas.push_back(rest_of_cond); + + // vars are variables from axis followed by all the other variables from vranges + Array vars = axis; + for (const auto& pair : vranges) { + bool already = false; + for (const Var& v : vars) { + already = already || v.same_as(pair.first); + } + if (!already) { + vars.push_back(pair.first); + } + } + + auto solved_system = SolveSystemOfInequalities(atomic_formulas, vars, vranges); + + DomainSimplificationResult res; + std::unordered_map new_var_intsets; + + // Initialize new_var_intsets with the old var intsets + for (const auto& pair : vranges) { + new_var_intsets[pair.first.get()] = IntSet::range(pair.second); + } + + // We process variables in the reverse direction to start with the most independent one. + // This order is needed to compute new ranges. + for (auto it = axis.rbegin(); it != axis.rend(); ++it) { + const Var& var = *it; + auto& bnd = solved_system.bounds[var.get()]; + // Note that we replace old vars with new ones + bnd = bnd.substitute(res.old_to_new); + if (is_one(bnd.coef) && !bnd.equal.empty()) { + // There is an equation of the form `v == expr`, so this variable can be completely removed. + // Note that we use the 0-th expression because they are ordered by complexity, so it must be + // the simplest one. + res.old_to_new.Set(var, bnd.equal[0]); + } else { + Array lowers = Concat(bnd.equal, bnd.lower); + Array uppers = Concat(bnd.equal, bnd.upper); + + // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the + // pair with the minimal difference between the upper and the lower. + // Note that the bounds are for v*coef, not for v (because we don't want complex expressions + // involving division). + + // The lower bound of the best pair so far + Expr best_lower = vranges[var]->min * bnd.coef; + // The difference between the upper and the lower of the best pair so far + Expr best_diff = (vranges[var]->extent - 1) * bnd.coef; + // The overapproximation of the best difference + Expr best_diff_over = best_diff; + + for (const Expr& low : lowers) { + for (const Expr& upp : uppers) { + Expr diff = SuperSimplify(upp - low, vranges); + // Since diff may depend on some other variables, we compute its overapproximation + Expr diff_over = EvalSet(diff, new_var_intsets).max(); + + if (diff_over.same_as(HalideIR::Internal::Interval::pos_inf)) { + continue; + } + + // If it is provable that the new one is strictly better than the current best one, + // then replace it. Note that we are biased towards earlier pairs which should be simpler. + if (CanProve(diff_over - best_diff_over < 0, vranges)) { + best_lower = low; + best_diff = diff; + best_diff_over = diff_over; + } + } + } + + if (is_const_int(best_diff, 0)) { + // In this case coef*iv = best_lower + // Don't create an itervar, just replace it everywhere with its min + res.old_to_new.Set(var, SuperSimplify(best_lower / bnd.coef, vranges)); + // To assure correctness, we have to add a condition that best_lower can be divided by coef + res.conditions.push_back(SuperSimplify(best_lower % bnd.coef == 0, vranges)); + } else { + std::string suffix = Equal(best_lower, vranges[var]->min * bnd.coef) ? "" : ".shifted"; + Var new_var = var.copy_with_suffix(suffix); + + // We will replace our iv with new_var + shift. + // We use rounding-up division to compute shift. Since we want to use a single formula + // without selects in as many cases as possible, we try to prove conditions manually. + Expr shift; + if (CanProve(best_lower <= 0, vranges)) { + shift = best_lower / bnd.coef; + } else if (CanProve(best_lower > -bnd.coef, vranges)) { + shift = (best_lower + bnd.coef - 1)/bnd.coef; + } else { + shift = Select::make(best_lower <= -bnd.coef, + best_lower / bnd.coef, + (best_lower + bnd.coef - 1)/bnd.coef); + } + shift = SuperSimplify(shift, vranges); + + Expr diff = SuperSimplify(best_diff_over / bnd.coef, vranges); + + if (is_const_int(diff, 0)) { + // Don't create an itervar, just replace it everywhere with its min + res.old_to_new.Set(var, shift); + } else { + res.old_to_new.Set(var, new_var + shift); + // Note that we are substituting old with new, so best_lower contains new var, + // that is we have to substitute new with old in best_lower here + res.new_to_old.Set(new_var, + SuperSimplify(var - Substitute(shift, res.new_to_old), vranges)); + + new_var_intsets[new_var.get()] = IntSet::interval(make_zero(new_var.type()), diff); + + // Add the new var to the resulting axis + auto range = Range(make_zero(new_var.type()), SuperSimplify(diff + 1, vranges)); + res.axis.push_back(new_var); + res.ranges.Set(new_var, range); + vranges.Set(new_var, range); + } + } + } + } + + // Add the original conditions (with variables substituted) to the resulting conditions + for (const Expr& old_cond : solved_system.as_conditions()) { + res.conditions.push_back(SuperSimplify(Substitute(old_cond, res.old_to_new), vranges)); + } + + return res; +} + +// Use the condition of a reduction op to simplify its domain (axis) +Expr SimplifyReductionDomain(const Expr& expr, const Map& outer_vranges) { + if (const Reduce* red = expr.as()) { + Map vranges = Merge(outer_vranges, IterVarsToMap(red->axis)); + auto res = SimplifyDomain(red->condition, IterVarsToVars(red->axis), + Merge(outer_vranges, IterVarsToMap(red->axis))); + + Array new_source; + for (const Expr& src : red->source) { + new_source.push_back(Substitute(src, res.old_to_new)); + } + + Array new_axis = IterVarsFromMap(res.axis, res.ranges, kCommReduce); + + // Perform simplification mainly to remove a possibly empty reduction. + return Simplify(Reduce::make(red->combiner, new_source, new_axis, + All(res.conditions), red->value_index)); + } else { + return expr; + } +} + +// Extract the given expr under the given condition as a separate tensor if the volume of the +// extracted tensor will be less than the volume of the outer_axis +Expr ExtractAsTensorMaybe(const Expr& e, const Expr& cond, + const Array& outer_axis, + const Map& vranges) { + // TODO(sgrechanik-h): We don't use divmod elimination here because of some performance problems + auto res = SimplifyDomain(cond, outer_axis, vranges, false); + + Expr new_expr = SuperSimplify(Substitute(e, res.old_to_new), vranges); + + // Keep only those variables of the new axis which are used in the new_expr + { + Array used_res_axis; + for (const Var& var : res.axis) { + if (ExprUseVar(new_expr, var)) { + used_res_axis.push_back(var); + } + } + + res.axis = std::move(used_res_axis); + } + + // Use the new axis to simplify the new expr, removing redundant inequalities + new_expr = SuperSimplify(new_expr, res.ranges); + + // If the expression does not use vars then it is probably better to keep it inlined + if (res.axis.empty()) { + return new_expr; + } + + // Compute volumes before and after + Expr old_volume = make_const(Int(64), 1); + for (const Var& var : outer_axis) { + old_volume = old_volume * vranges[var]->extent; + } + + Expr new_volume = make_const(Int(64), 1); + for (const Var& var : res.axis) { + new_volume = new_volume * res.ranges[var]->extent; + } + + // if we can prove that the old volume is not greater than the new volume then + // prefer the old expression. + if (CanProve(old_volume <= new_volume, vranges)) { + return e; + } + + Tensor tensor = op::TensorFromExpr(new_expr, IterVarsFromMap(res.axis, res.ranges), + "extracted_tensor"); + + Array args; + for (const Var& var : res.axis) { + args.push_back(res.new_to_old[var]); + } + + return Call::make(e.type(), tensor->op->name, args, + Call::CallType::Halide, tensor->op, tensor->value_index); +} + + +class RemoveRedundantInequalitiesMutator : public IRMutator { + public: + explicit RemoveRedundantInequalitiesMutator(Array known) { + for (const Expr& cond : known) { + known_.push_back(SuperSimplify(cond)); + } + } + + virtual Expr Mutate_(const Select* op, const Expr& e) { + bool has_side_effect = HasSideEffect(e); + Expr new_cond = SuperSimplify(Mutate(op->condition)); + if (is_one(new_cond) && !has_side_effect) { + return Mutate(op->true_value); + } else if (is_zero(new_cond) && !has_side_effect) { + return Mutate(op->false_value); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return Select::make(new_cond, new_mutator.Mutate(op->true_value), Mutate(op->false_value)); + } + } + + virtual Expr Mutate_(const Call* op, const Expr& e) { + if (op->name == intrinsic::tvm_if_then_else) { + Expr new_cond = SuperSimplify(Mutate(op->args[0])); + if (is_one(new_cond)) { + return Mutate(op->args[1]); + } else if (is_zero(new_cond)) { + return Mutate(op->args[2]); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return if_then_else(new_cond, new_mutator.Mutate(op->args[1]), Mutate(op->args[2])); + } + } else { + return IRMutator::Mutate_(op, e); + } + } + + virtual Expr Mutate_(const Reduce* op, const Expr& e) { + Array known_with_axes = known_; + for (const Expr& axis_cond : IterVarsToInequalities(op->axis)) { + known_with_axes.push_back(axis_cond); + } + RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes); + + Expr new_cond = mutator_with_axes.Mutate(op->condition); + + Array new_known = known_with_axes; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + + Array new_source; + for (const Expr& src : op->source) { + new_source.push_back(new_mutator.Mutate(src)); + } + + return Reduce::make(op->combiner, new_source, op->axis, new_cond, op->value_index); + } + + virtual Expr Mutate_(const EQ* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const NE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GE* op, const Expr& e) { return MutateAtomic_(e); } + + virtual Expr Mutate_(const And* op, const Expr& e) { + return Mutate(op->a) && Mutate(op->b); + } + + private: + Expr MutateAtomic_(const Expr& e) { + Expr simplified = SuperSimplify(e); + for (const Expr& other : known_) { + if (Equal(simplified, other)) { + return const_true(); + } + } + return simplified; + } + + Array known_; +}; + +// Propagate information from conditions and remove redundant inequalities +// TODO(sgrechanik-h): This should be merged into standard simplifiers +Expr RemoveRedundantInequalities(const Expr& expr, const Array& known) { + return RemoveRedundantInequalitiesMutator(known).Mutate(expr); +} + +// Extract from cond an implication of cond not containing vars +std::pair ImplicationNotContainingVars( + const Expr& cond, const std::unordered_set& vars) { + CHECK(cond.type().is_bool()) << "The type of cond must be bool"; + // TODO(sgrechanik-h): not + if (const And* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {pair_a.first && pair_b.first, + pair_a.second && pair_b.second}; + } else if (const Or* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {Or::make(pair_a.first, pair_b.first), cond}; + } else if (!ExprUseVar(cond, vars)) { + return {cond, const_true()}; + } else { + return {const_true(), cond}; + } +} + +// Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out +// (in)equalities which do not depend on the reduction variables. +std::pair LiftConditionsThroughReduction(const Expr& cond, + const Array& red_axis, + const Array& outer_axis) { + // Factor out atomics so that we can consider this as a system of inequalities + auto factoratomic_res = FactorOutAtomicFormulas(cond); + Array atomics = factoratomic_res.atomic_formulas; + const Expr& rest = factoratomic_res.rest; + + Array allvars; + for (const IterVar& v : red_axis) { + allvars.push_back(v->var); + } + for (const IterVar& v : outer_axis) { + allvars.push_back(v->var); + } + + auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis)); + // start from reduction vars, so that input vars don't depend on them + atomics = SolveSystemOfInequalities(atomics, allvars, vranges).as_conditions(); + + // Append the rest part + Expr rewritten_cond = All(atomics) && rest; + + std::unordered_set vset; + for (const IterVar& v : red_axis) { + vset.insert(v->var.get()); + } + + // The outer (first) condition does not contain reduction vars, + // the inner (second) condition is everything else + return ImplicationNotContainingVars(rewritten_cond, vset); +} + +class ExtractReductionsMutator : public IRMutator { + public: + explicit ExtractReductionsMutator(const Array& outer_axis, + Map vranges, + std::string name = "extracted_reduction") + : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {} + + Expr Mutate_(const Reduce* op, const Expr& e) { + ExtractReductionsMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_), + Merge(vranges_, IterVarsToMap(op->axis)), + name_); + + Array new_source; + for (const Expr& src : op->source) { + new_source.push_back(new_mutator.Mutate(src)); + } + + Expr new_reduce = + Reduce::make(op->combiner, new_source, op->axis, op->condition, op->value_index); + + ExprFreeVarsVisitor fv_visitor; + fv_visitor.Visit(new_reduce); + + // Vars of the tensor we are going to create for this reduction + Array vars; + for (const Var& v : outer_axis_) { + // We take variables from the outer_axis_ which are also present in the new reduction + if (fv_visitor.free.count(v.get())) { + vars.push_back(v); + } + } + + auto newaxis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_)); + Array new_axis = newaxis_vmap_pair.first; + new_reduce = SuperSimplify(Substitute(new_reduce, newaxis_vmap_pair.second), + IterVarsToMap(new_axis)); + + Tensor tensor = op::TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_); + + Array args; + for (const Var& v : vars) { + args.push_back(v); + } + + return Call::make(e.type(), tensor->op->name, args, + Call::CallType::Halide, tensor->op, tensor->value_index); + } + + private: + Array outer_axis_; + Map vranges_; + std::string name_; + std::string tag_; + Map attrs_; +}; + +// Extract reductions as separate tensors. +Expr ExtractReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges) { + return ExtractReductionsMutator(outer_axis, vranges).Mutate(expr); +} + +Expr ExtractNonTopReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges) { + if (const Reduce* red = expr.as()) { + Array new_outer_axis = Concat(IterVarsToVars(red->axis), outer_axis); + Map new_vranges = Merge(vranges, IterVarsToMap(red->axis)); + Array new_source; + for (const Expr& src : red->source) { + new_source.push_back(ExtractReductions(src, new_outer_axis, new_vranges)); + } + Expr new_condition = ExtractReductions(red->condition, new_outer_axis, new_vranges); + + return Reduce::make(red->combiner, new_source, red->axis, + new_condition, red->value_index); + } else { + return ExtractReductions(expr, outer_axis, vranges); + } +} + +Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, const Array& axis) { + Expr result; + + if (const Reduce* red = expr.as()) { + // TODO(sgrechanik-h): There are some other operations which behave like sum + bool is_sum = IsSumCombiner(red->combiner); + if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index)) { + Expr new_red = expr; + + // Here we simplify the reduction + { + Expr cond = red->condition; + Array source = red->source; + + // If it is a summation then we can lift nonzeroness conditions from the source + // and add them to the reduction conditions + if (is_sum) { + auto nz = NonzeronessCondition(red->source[red->value_index]); + cond = nz.cond && cond; + source.Set(0, nz.value); + } + + new_red = Reduce::make(red->combiner, source, red->axis, cond, red->value_index); + new_red = SimplifyReductionDomain(new_red, IterVarsToMap(axis)); + red = new_red.as(); + + // If the reduction disappears completely then transform the result as a non-reduction + if (!red) { + return OptimizeAndLiftNonzeronessConditionsImpl(new_red, axis); + } + } + + Expr new_outer_cond, new_reduce_cond; + Array new_source = red->source; + + // Partially lift conditions from the reduce condition + std::tie(new_outer_cond, new_reduce_cond) = + LiftConditionsThroughReduction(red->condition, red->axis, axis); + + // If it's not sum then we haven't yet lifted nonzeroness cond from the source + if (!is_sum) { + Expr outer_nz_cond, nz_cond, nz_source; + auto nz = NonzeronessCondition(red->source[red->value_index]); + // Append conditions from the reduction + nz_cond = new_reduce_cond && nz.cond; + nz_source = nz.value; + std::tie(outer_nz_cond, nz_cond) = + LiftConditionsThroughReduction(nz_cond, red->axis, axis); + new_outer_cond = new_outer_cond && outer_nz_cond; + new_source.Set(red->value_index, SelectElseZero(nz_cond, nz_source)); + } + + Expr new_reduce = Reduce::make(red->combiner, new_source, red->axis, + new_reduce_cond, red->value_index); + new_reduce = ExtractAsTensorMaybe(new_reduce, new_outer_cond, + IterVarsToVars(axis), IterVarsToMap(axis)); + result = SelectElseZero(new_outer_cond, new_reduce); + } else { + return SimplifyReductionDomain(expr, IterVarsToMap(axis)); + } + } else { + auto nz = NonzeronessCondition(expr); + Expr new_expr = ExtractAsTensorMaybe(nz.value, nz.cond, + IterVarsToVars(axis), IterVarsToMap(axis)); + result = SelectElseZero(nz.cond, new_expr); + } + + // Note that RemoveRedundantInequalities can sometimes propagate equalities which + // other simplifiers cannot, like (i % 3) == 0. + Array axis_conds = IterVarsToInequalities(axis); + result = RemoveRedundantInequalities(result, axis_conds); + + // Sometimes ExtractAsTensorMaybe doesn't perform extraction, so there may be some non-top + // reductions left, take care of them + Map vrange = IterVarsToMap(axis); + return SuperSimplify(ExtractReductions(result, IterVarsToVars(axis), vrange), + vrange); +} + +Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor) { + return op::TransformBody(tensor, OptimizeAndLiftNonzeronessConditionsImpl); +} + +TVM_REGISTER_API("ir_pass.IsSumCombiner") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IsSumCombiner(args[0]); + }); + +TVM_REGISTER_API("ir_pass.CanFactorZeroFromCombiner") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CanFactorZeroFromCombiner(args[0], args[1]); + }); + +TVM_REGISTER_API("ir_pass.LiftNonzeronessCondition") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LiftNonzeronessCondition(args[0]); + }); + +TVM_REGISTER_API("ir_pass.InlineTailCall") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = InlineTailCall(args[0]); + }); + +TVM_REGISTER_API("ir_pass.InlineTensors") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args[0].IsNodeType()) { + Expr e = args[0]; + if (args.size() == 1) { + *ret = InlineTensors(e); + } else if (args.size() == 2) { + *ret = InlineTensors(e, args[1]); + } else if (args.size() >= 3) { + *ret = InlineTensors(e, args[1], args[2]); + } + } else if (args[0].IsNodeType()) { + Tensor t = args[0]; + if (args.size() == 1) { + *ret = InlineTensors(t); + } else if (args.size() == 2) { + *ret = InlineTensors(t, args[1]); + } else if (args.size() >= 3) { + *ret = InlineTensors(t, args[1], args[2]); + } + } + }); + +TVM_REGISTER_API("ir_pass.SolveSystemOfInequalities") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SolveSystemOfInequalities(args[0], args[1], args[2]).as_conditions(); + }); + +TVM_REGISTER_API("ir_pass.SimplifyDomain") +.set_body([](TVMArgs args, TVMRetValue *ret) { + auto res = SimplifyDomain(args[0], args[1], args[2]); + Array axis = IterVarsFromMap(res.axis, res.ranges); + *ret = Array({All(res.conditions), axis, res.old_to_new, res.new_to_old}); + }); + +TVM_REGISTER_API("ir_pass.SimplifyReductionDomain") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SimplifyReductionDomain(args[0], args[1]); + }); + +TVM_REGISTER_API("ir_pass.ExtractAsTensorMaybe") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractAsTensorMaybe(args[0], args[1], args[2], args[3]); + }); + +TVM_REGISTER_API("ir_pass.ExtractReductions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractReductions(args[0], args[1], args[2]); + }); + +TVM_REGISTER_API("ir_pass.ExtractNonTopReductions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractNonTopReductions(args[0], args[1], args[2]); + }); + +TVM_REGISTER_API("ir_pass.OptimizeAndLiftNonzeronessConditions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = OptimizeAndLiftNonzeronessConditions(args[0]); + }); + +} // namespace ir +} // namespace tvm diff --git a/src/pass/zero_elimination.h b/src/pass/zero_elimination.h new file mode 100644 index 000000000000..1ac887bcb049 --- /dev/null +++ b/src/pass/zero_elimination.h @@ -0,0 +1,249 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file zero_elimination.h + * \brief Transform tensors in such a way as to eliminate summation over zeros. + */ +#ifndef TVM_PASS_ZERO_ELIMINATION_H_ +#define TVM_PASS_ZERO_ELIMINATION_H_ + +#include +#include + +#include + +namespace tvm { +namespace ir { + +/*! + * \brief Clone the reduction by cloning its iteration variables. + */ +Expr CloneReduction(const Expr& expr); + +/*! + * \brief Check if the given combiner represents summation. + */ +EXPORT bool IsSumCombiner(const CommReducer& combiner); + +/*! + * \brief Check if zero may be factored out of a reduction with this combiner when it is in + * the \p value_index position. + * + * For example, if the combiner works on tuples of two elements and `value_index = 1`, + * check that `(a, 0) combine (b, 0) = (c, 0)` for any a, b and some c. + * Note that all combiners generated by autodiff have this property. + */ +EXPORT bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index); + +/*! + * \brief Transform the expression into `c ? e : 0`, that is lift the condition of being + * possible to be non-zero to the top level. + */ +EXPORT Expr LiftNonzeronessCondition(const Expr& expr); + +/*! + * \brief If the body of the tensor consists of a single tensor call (indexing) expression, + * inline it. + */ +EXPORT Tensor InlineTailCall(const Tensor& tensor); + +/*! + * \brief Inline tensors recursively. + * + * This function will inline tensors recursively until it reaches a tensor which is impossible to + * inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is + * not from \p inlineable). It won't descend into non-inlinable tensors' bodies. + * + * \param expr The expression to transform. + * \param inlineable A list of tensors which are allowed to be inlined. If empty, try + * to inline all tensors. + * \param inline_reductions Whether to inline reductions (this may result in top-level reduction + * nodes). + */ +EXPORT Expr InlineTensors(const Expr& expr, + const Array& inlineable = Array(), + bool inline_reductions = false); + +/*! + * \brief Inline tensors recursively. + * + * This function will inline tensors recursively until it reaches a tensor which is impossible to + * inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is + * not from \p inlineable). It won't descend into non-inlinable tensors' bodies. + * + * \param tensor The tensor whose body to transform. + * \param inlineable A list of tensors which are allowed to be inlined. If empty, try + * to inline all tensors. + * \param inline_reductions Whether to inline reductions (this may result in top-level reduction + * nodes). + */ +EXPORT Tensor InlineTensors(const Tensor& tensor, + const Array& inlineable = Array(), + bool inline_reductions = false); + + +/*! + * \brief A struct representing a set of inequalities describing bounds of a variable. + * + * Given a variable x, this struct represents the following (in)equalities: + * - `coef*x >= low` for each `low` in `lower` + * - `coef*x == eq` for each `eq` in `equal` + * - `coef*x <= upp` for each `upp` in `upper` + * + * Note that every array is supposed to be sorted in the order of increasing expression + * complexity. + */ +struct VarBounds { + Expr coef; + Array lower; + Array equal; + Array upper; + + /*! + * \brief Perform substitution on all components of the struct. + */ + VarBounds substitute(const Map& subst) const; +}; + +/*! + * \brief A struct representing a system of inequalities resulted from Fourier-Motzkin elimination. + */ +struct SolveSystemOfInequalitiesResult { + Array variables; + std::unordered_map bounds; + Array other_conditions; + + /*! + * \brief Combine the information into an array of (in)equalities. + */ + Array as_conditions() const; +}; + +/*! + * \brief Rewrite the system of inequalities using Fourier-Motzkin elimination. + * + * This function takes an array of (in)equalities and an array of variables, and essentially + * rewrites the (in)equalities into an array of (in)equalities of the following form: + * + * x0 >= f0(x1, x2, ..., xn) + * x0 <= g0(x1, x2, ..., xn) + * x1 >= f1(x2, ..., xn) + * x1 <= g1(x2, ..., xn) + * ... + * xn >= fn() // just a constant + * xn <= gn() // just a constant + * + * This array is represented in a more structural way using SolveSystemOfInequalitiesResult. + * + * Note that the algorithm is extremely slow, it is super-exponential, so please provide variable + * ranges to aid the removal of redundant inequalities. + * + * \param inequalities The original (in)equalities. + * \param variables The variables x0, ..., xn + * \param vranges A map from variables to the corresponding value ranges. Extremely important for + * efficiency. + */ +EXPORT SolveSystemOfInequalitiesResult SolveSystemOfInequalities( + const Array& inequalities, const Array& variables, const Map& vranges); + +/*! + * \brief A struct representing a result of domain simplification. It is basically + * a new array of variables, the information about their ranges, and a new condition together with + * substitutions from the old variables to the new ones and from the new ones to the old ones. + */ +struct DomainSimplificationResult { + Array conditions; + Array axis; + Map ranges; + Map old_to_new; + Map new_to_old; +}; + +/*! + * \brief Simplify an iteration domain. + * + * An iteration domain is basically an array of variables and a condition. The function will do the + * following: + * - Replace div and mod operations with new variables (optional). + * - Extract (in)equalities from the condition. + * - Perform Fourier-Motzkin elimination. + * - Shear the domain of iteration (e.g. if `y <= x <= y + 2` then x will be replaced with `y + d` + * where `d` is a new variable such that `0 <= d <= 2`). + * - Remove redundant variables. + * - Infer new variable ranges (hopefully more precise). + * + * \param cond The condition of the original domain. + * \param axis The variables of the original domain. + * \param vranges A map from variables (both domain and outer) to their value ranges. + * \param eliminate_div_mod Whether to eliminate div and mod by introducing new variables. + */ +EXPORT DomainSimplificationResult SimplifyDomain(const Expr& cond, + const Array& axis, + Map vranges, + bool eliminate_div_mod = true); + + +/*! + * \brief Simplify the iteration domain of a reduction expression using SimplifyDomain. + */ +EXPORT Expr SimplifyReductionDomain(const Expr& expr, const Map& outer_vranges); + +/*! + * \brief Extract the given expression under the given condition as a separate tensor if the volume + * of the extracted tensor will be less than the volume of the \p outer_axis. + * + * \param expr The expression to extract. + * \param cond A condition which is assumed to be true. + * \param outer_axis Some variables, usually input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return Either a call to an extracted tensor or the original expression. + */ +EXPORT Expr ExtractAsTensorMaybe(const Expr& expr, const Expr& cond, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Extract reductions as separate tensors. This may be needed when non-top-level reductions + * are created. + * + * \param expr The expression from which to extract reductions. + * \param outer_axis Input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return An expression without non-top-level reductions. + */ +EXPORT Expr ExtractReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Extract reductions as separate tensors, but if the expr itself is a reduction, leave it + * intact. + * + * \param expr The expression from which to extract reductions. + * \param outer_axis Input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return An expression without non-top-level reductions. + */ +EXPORT Expr ExtractNonTopReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Perform lifting of conditions of being possible to be non-zero together with + * applying some transformations like simplifying the reduction domain. Works only with + * this particular tensor's body, i.e. doesn't perform inlining. + */ +EXPORT Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor); + +/*! + * \brief Pretty print the tensor with all its dependencies. + */ +EXPORT std::string PrintTensorRecursively(const Tensor& tensor); + +/*! + * \brief Pretty print the tensors with all their dependencies. + */ +EXPORT std::string PrintTensorsRecursively(const Array& tensor); + +} // namespace ir +} // namespace tvm +#endif // TVM_PASS_ZERO_ELIMINATION_H_ diff --git a/src/relay/op/autodiff_integration.cc b/src/relay/op/autodiff_integration.cc new file mode 100644 index 000000000000..a86d504e8798 --- /dev/null +++ b/src/relay/op/autodiff_integration.cc @@ -0,0 +1,209 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file autodiff_integration.cc + * \brief Integration with autodiff for TVM tensor expressions. + */ + +#include +#include +#include +#include +#include "./type_relations.h" +#include "./op_common.h" +#include "../../op/op_util.h" + +namespace tvm { +namespace relay { + +/*! \brief Attributes for the automatically generated gradient operation. */ +struct AutogeneratedGradientAttrs : public tvm::AttrsNode { + Op original_op; + Attrs original_attrs; + Type original_out_type; + + TVM_DECLARE_ATTRS(AutogeneratedGradientAttrs, "relay.attrs.AutogeneratedGradientAttrs") { + TVM_ATTR_FIELD(original_op) + .describe("The original operation."); + TVM_ATTR_FIELD(original_attrs) + .describe("The attributes of the original operation."); + TVM_ATTR_FIELD(original_out_type).set_default(Type(nullptr)) + .describe("The type of the original expression."); + } +}; + +bool AutogeneratedGradientRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // There are just two types: the type of the input tuple and the type of the output tuple. + CHECK(types.size() == 2) << "The size of the types array must be 2, not " << types.size(); + const auto* tuple_type = types[0].as(); + CHECK(tuple_type != nullptr) << "The input must be a tuple, not " << types[0]; + // The input tuple contains the original inputs and the last item is the adjoint + // for the output of the original operation. + Array input_types(tuple_type->fields.begin(), tuple_type->fields.end() + (-1)); + // The output of the gradient operation is a containing values of the same types as the + // original inputs. + reporter->Assign(types[1], TupleTypeNode::make(input_types)); + return true; +} + +Array AutogeneratedGradientCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + static auto fcompute = Op::GetAttr("FTVMCompute"); + + const AutogeneratedGradientAttrs* real_attrs = attrs.as(); + CHECK(real_attrs != nullptr); + + // We need the type of the original output to pass it to the + // FTVMCompute of the original operation. + Type original_out_type = real_attrs->original_out_type; + + // The `inputs` array contains both the original inputs and the adjoint, both in the + // flattened form. In general, the adjoint may consist of several tensors, so we need to know + // the number of the output tensors of the original operation. + size_t num_orig_outputs = 1; + // NOTE: Here we assume that there are no nested tuples + if (const auto* tuple_type = original_out_type.as()) { + num_orig_outputs = tuple_type->fields.size(); + } else if (const auto* tuple_type = out_type.as()) { + // Guess the number of outputs of the original op from the number of inputs of the original + // op (which is the same as the number of outputs of this gradient node). + num_orig_outputs = inputs.size() - tuple_type->fields.size(); + } + + CHECK(inputs.size() >= num_orig_outputs); + + // If the original output type hasn't been preserved, try to reconstruct it using the + // number of original outputs. + if (!original_out_type.defined()) { + Array fields; + for (auto it = inputs.end() + (-num_orig_outputs); it != inputs.end(); ++it) { + fields.push_back(TensorTypeNode::make((*it)->shape, (*it)->dtype)); + } + if (num_orig_outputs == 1) { + // If the number of the outputs is 1 then the output type is probably just a tensor, not + // a tuple of a single element. + original_out_type = fields[0]; + } else { + original_out_type = TupleTypeNode::make(fields); + } + } + + Array original_inputs(inputs.begin(), inputs.end() + (-num_orig_outputs)); + Array adjoints(inputs.end() + (-num_orig_outputs), inputs.end()); + + // In theory the inputs might contain duplicate entries which won't agree with the automatic + // differentiation, so we create new placeholders which we will replace with the inputs later. + Array input_placeholders; + std::unordered_map placeholders_to_inputs; + for (const Tensor& input : original_inputs) { + Tensor place = + tvm::PlaceholderOpNode::make(input->op->name, input->shape, input->dtype).output(0); + input_placeholders.push_back(place); + placeholders_to_inputs[place] = input; + } + + Array forward = + fcompute[real_attrs->original_op](real_attrs->original_attrs, input_placeholders, + original_out_type, target); + + CHECK(forward.size() == adjoints.size()); + + // If there are multiple outputs, we have to propagate gradients from all of them and + // add up the results. Note that there may be suboptimality, in the future we might want + // to make the Differentiate function accept arrays of outputs. + Array res; + for (size_t i = 0; i < forward.size(); ++i) { + Array part = + tvm::ir::Differentiate(forward[i], input_placeholders, adjoints[i])->result; + part = tvm::op::ReplaceTensorRecursively(part, placeholders_to_inputs); + + if (i == 0) { + res = part; + } else { + for (size_t j = 0; j < res.size(); ++j) { + res.Set(j, topi::add(res[j], part[j])); + } + } + } + + return res; +} + +RELAY_REGISTER_OP("autogenerated_gradient") +.describe(R"doc(Gradients for any specified operation generated using the automatic differentiation +for tensor expressions. + +- **input**: A tuple of the form `(x1, ..., xn, g)` where `x1, ..., xn` are the inputs of the + original operation, and g is the gradient of the loss with respect to the output + of the original operation. +- **out**: A tuple of the form `(g1, ..., gn)` containing the gradients of the loss with respect to + the inputs of the original operation. +)doc") +.set_num_inputs(1) +.add_argument("input", "Tuple", "A tuple containing the original inputs and the adjoint.") +.set_attrs_type_key("relay.attrs.AutogeneratedGradientAttrs") +.add_type_rel("AutogeneratedGradient", AutogeneratedGradientRel) +.set_attr("FTVMCompute", AutogeneratedGradientCompute) +.set_attr("TOpPattern", kOpaque) +.set_attr("FTVMSchedule", + [](const Attrs& attrs, const Array& outs, const Target& target) { + Array out_ops; + for (auto t : outs) + out_ops.push_back(t->op); + return create_schedule(out_ops); + }); + +FPrimalGradient AutogeneratedFPrimalGradient(const Op& op) { + return [op](const Expr& orig, const Expr& adjoint) -> Array { + const CallNode* call = orig.as(); + CHECK(call != nullptr); + + auto attrs = make_node(); + attrs->original_op = op; + attrs->original_attrs = call->attrs; + if (call->checked_type_.defined()) { + attrs->original_out_type = call->checked_type(); + } + + Array args_in_tuple = call->args; + args_in_tuple.push_back(adjoint); + Array args = {TupleNode::make(args_in_tuple)}; + auto grad_call = CallNode::make(Op::Get("autogenerated_gradient"), args, Attrs(attrs)); + + Array res; + for (size_t i = 0; i < call->args.size(); ++i) { + res.push_back(TupleGetItemNode::make(grad_call, i)); + } + return res; + }; +} + +/*! \brief Automatically generate primal gradient for the given operation. */ +void AutogeneratePrimalGradient(const std::string& op_name, int plevel = 100) { + OpRegistry& opreg = relay::OpRegistry::Registry()->__REGISTER_OR_GET__(op_name); + Op op = opreg.op(); + opreg.set_attr("FPrimalGradient", AutogeneratedFPrimalGradient(op), plevel); +} + +/*! \brief Automatically generate primal gradients for all operations in the registry. */ +void AutogeneratePrimalGradientForAll(int plevel = 5) { + for (const OpRegistry* opreg : relay::OpRegistry::Registry()->List()) { + AutogeneratePrimalGradient(opreg->op()->name, plevel); + } +} + +TVM_REGISTER_API("relay._ir_pass.AutogeneratePrimalGradient") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { + AutogeneratePrimalGradient(args[0]); + }); +TVM_REGISTER_API("relay._ir_pass.AutogeneratePrimalGradientForAll") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { + AutogeneratePrimalGradientForAll(); + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_primal_gradients.py b/tests/python/relay/test_primal_gradients.py new file mode 100644 index 000000000000..d23a920d578b --- /dev/null +++ b/tests/python/relay/test_primal_gradients.py @@ -0,0 +1,191 @@ +import tvm +import numpy as np + +from tvm import relay + +def to_int_array(arr, param_values=None): + if param_values is None: + param_values = {} + return [tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(s, param_values)).value + for s in arr] + +def tvm2numpy(something): + if isinstance(something, (tvm.ndarray.NDArray, relay.backend.interpreter.TensorValue)): + return something.asnumpy() + elif isinstance(something, list): + return [tvm2numpy(s) for s in something] + elif isinstance(something, (tuple, relay.backend.interpreter.TupleValue)): + return tuple(tvm2numpy(s) for s in something) + return something + +def check_relay_grad(expr, in_range=(-10,10), acceptable_fail_fraction=None): + expr = relay.ir_pass.infer_type(expr) + if len(expr.checked_type.shape) != 0: + expr = relay.op.sum(expr) + + input_vars = relay.ir_pass.free_vars(expr) + func = relay.Function(input_vars, expr) + func = relay.ir_pass.infer_type(func) + + gfunc = relay.ir_pass.gradient(func) + gfunc = relay.ir_pass.infer_type(gfunc) + + executor = relay.create_executor() + tvm_func = executor.evaluate(func) + tvm_gfunc = executor.evaluate(gfunc) + + np_func = lambda *a: tvm2numpy(tvm_func(*a)) + np_gfunc = lambda *a: tvm2numpy(tvm_gfunc(*a)) + + input_vals = [np.random.uniform(in_range[0], in_range[1], + size=to_int_array(a.type_annotation.shape)) + .astype(a.type_annotation.dtype) + for a in input_vars] + + tvm.testing.check_numerical_grads(np_func, input_vals, np_gfunc(*input_vals)[1], + acceptable_fail_fraction=acceptable_fail_fraction) + +def test_autogenerated_primal_gradients(): + relay._ir_pass.AutogeneratePrimalGradientForAll(100) + + x = relay.var("x", shape=(5,), dtype='float64') + y = relay.var("y", shape=(5,), dtype='float64') + + check_relay_grad(x*x) + check_relay_grad(x*y) + check_relay_grad(x*y + x + y) + + k = relay.var("k", shape=()) + x = relay.var("x", shape=(5,)) + y = relay.var("y", shape=(5,)) + ix = relay.var("ix", shape=(5,), dtype='int32') + iy = relay.var("iy", shape=(5,), dtype='int32') + ind = relay.var("ind", shape=(10,), dtype='int32') + x1 = relay.var("x1", shape=(5,)) + y1 = relay.var("y1", shape=(5,)) + X = relay.var("X", shape=(5, 7)) + Y = relay.var("Y", shape=(5, 10)) + Y1 = relay.var("Y1", shape=(1, 2, 5, 5)) + Y2 = relay.var("Y2", shape=(3, 2, 5, 5)) + W = relay.var("W", shape=(7, 10)) + A = relay.var("A", shape=(2, 5, 7, 7)) + w = relay.var("w", shape=(4, 5, 3, 3)) + w1 = relay.var("w1", shape=(5, 4, 3, 3)) + + check_relay_grad(x*x) + check_relay_grad(x*y) + check_relay_grad(x*y + x + y) + + check_relay_grad(relay.op.log(x), in_range=(0.1, 10)) + check_relay_grad(relay.op.sqrt(x), in_range=(0.1, 10)) + check_relay_grad(relay.op.exp(x)) + check_relay_grad(relay.op.sigmoid(x)) + check_relay_grad(relay.op.add(x, y)) + check_relay_grad(relay.op.subtract(x, y)) + check_relay_grad(relay.op.multiply(x, y)) + check_relay_grad(relay.op.divide(x, y)) + #check_relay_grad(relay.op.mod(x, y)) + check_relay_grad(relay.op.tanh(x)) + #check_relay_grad(relay.op.concatenate([x, y], 0)) + #check_relay_grad(relay.op.concatenate([X, Y], 1)) + check_relay_grad(relay.op.expand_dims(X, 1, 1)) + check_relay_grad(relay.op.expand_dims(X, 2, 3)) + check_relay_grad(relay.nn.softmax(x)) + check_relay_grad(relay.nn.log_softmax(X)) + check_relay_grad(relay.nn.relu(x)) + #check_relay_grad(relay.nn.dropout(x)) + #check_relay_grad(relay.nn.batch_norm(A, x, y, x1, y1)[0]) + check_relay_grad(relay.nn.bias_add(X, x, 0)) + + check_relay_grad(relay.nn.conv2d(A, w)) + check_relay_grad(relay.nn.conv2d(A, w, strides=(2, 1))) + check_relay_grad(relay.nn.conv2d(A, w, padding=(1, 0))) + check_relay_grad(relay.nn.conv2d(A, w, dilation=(1, 2))) + check_relay_grad(relay.nn.conv2d_transpose(A, w1)) + check_relay_grad(relay.nn.conv2d_transpose(A, w1, strides=(2, 1))) + check_relay_grad(relay.nn.conv2d_transpose(A, w1, padding=(1, 0))) + #check_relay_grad(relay.nn.conv2d_transpose(A, w1, dilation=(1, 2))) + check_relay_grad(relay.nn.dense(X, W)) + check_relay_grad(relay.nn.max_pool2d(A)) + check_relay_grad(relay.nn.max_pool2d(A, pool_size=(2, 2))) + check_relay_grad(relay.nn.max_pool2d(A, pool_size=(2, 2), strides=(2, 1))) + check_relay_grad(relay.nn.max_pool2d(A, pool_size=(3, 3), strides=(3, 2), padding=(1, 1))) + check_relay_grad(relay.nn.avg_pool2d(A)) + check_relay_grad(relay.nn.avg_pool2d(A, pool_size=(2, 2))) + check_relay_grad(relay.nn.avg_pool2d(A, pool_size=(2, 2), strides=(2, 1))) + check_relay_grad(relay.nn.avg_pool2d(A, pool_size=(3, 3), strides=(3, 2), padding=(1, 1))) + check_relay_grad(relay.nn.global_max_pool2d(A)) + check_relay_grad(relay.nn.global_avg_pool2d(A)) + check_relay_grad(relay.nn.upsampling(A, scale=2)) + check_relay_grad(relay.nn.batch_flatten(A)) + check_relay_grad(relay.nn.pad(A, ((0, 0), (0, 1), (1, 2), (2, 3)))) + check_relay_grad(relay.nn.lrn(A)) + check_relay_grad(relay.nn.l2_normalize(A, 0.01, axis=[1])) + #check_relay_grad(relay.nn.contrib_conv2d_winograd_without_weight_transform(...)) + check_relay_grad(relay.nn.contrib_conv2d_winograd_weight_transform(w, 4)) + + check_relay_grad(relay.nn.leaky_relu(x, 0.1)) + check_relay_grad(relay.nn.prelu(A, x)) + check_relay_grad(relay.reshape(Y, (2, 5, 5))) + check_relay_grad(relay.reshape_like(Y, Y1)) + check_relay_grad(relay.copy(x)) + check_relay_grad(relay.transpose(Y)) + check_relay_grad(relay.squeeze(Y1)) + check_relay_grad(relay.floor(x), acceptable_fail_fraction=0.2) + check_relay_grad(relay.ceil(x), acceptable_fail_fraction=0.2) + check_relay_grad(relay.trunc(x), acceptable_fail_fraction=0.2) + #check_relay_grad(relay.clip(x, -2, 2)) + check_relay_grad(relay.round(x), acceptable_fail_fraction=0.2) + check_relay_grad(relay.abs(x)) + check_relay_grad(relay.negative(x)) + #check_relay_grad(relay.take(x, ind), in_range=(0, 4)) + check_relay_grad(relay.zeros((5, 6), 'float32')) + check_relay_grad(relay.zeros_like(x)) + check_relay_grad(relay.ones((5, 6), 'float32')) + check_relay_grad(relay.ones_like(x)) + check_relay_grad(relay.full(k, (5, 6), 'float32')) + check_relay_grad(relay.full_like(x, k)) + #check_relay_grad(relay.cast(x, 'float64')) + #check_relay_grad(relay.split(x, (1, 3))) + + #check_relay_grad(relay.right_shift(ix, iy)) + #check_relay_grad(relay.left_shift(ix, iy)) + #check_relay_grad(relay.equal(ix, iy)) + #check_relay_grad(relay.not_equal(ix, iy)) + #check_relay_grad(relay.greater(x, y)) + #check_relay_grad(relay.greater_equal(x, y)) + #check_relay_grad(relay.less(x, y)) + #check_relay_grad(relay.less_equal(x, y)) + check_relay_grad(relay.maximum(x, y)) + check_relay_grad(relay.minimum(x, y)) + check_relay_grad(relay.power(relay.abs(x), y)) + #check_relay_grad(relay.where(ix, x, y)) + #check_relay_grad(relay.where(relay.greater(x, y), x, y)) + #check_relay_grad(relay.argmax(x)) + #check_relay_grad(relay.argmin(X, axis=1)) + check_relay_grad(relay.sum(x)) + check_relay_grad(relay.max(x)) + check_relay_grad(relay.min(x)) + check_relay_grad(relay.mean(x)) + check_relay_grad(relay.prod(x)) + check_relay_grad(relay.strided_slice(A, (0, 4, 2, 0), (1, 1, 5, 6), (1, -1, 2, 3))) + check_relay_grad(relay.broadcast_to(Y1, (3, 2, 5, 5))) + + check_relay_grad(relay.image.resize(A, (12, 10), method='BILINEAR')) + check_relay_grad(relay.image.resize(A, (12, 10), method='BILINEAR', align_corners=True)) + #check_relay_grad(relay.image.resize(A, (12, 10), method='NEAREST_NEIGHBOR')) + #check_relay_grad(relay.vision.multibox_prior(A)) + #check_relay_grad(relay.vision.multibox_transform_loc(...)) + #check_relay_grad(relay.vision.nms(...)) + + check_relay_grad(relay.broadcast_to_like(Y1, Y2)) + check_relay_grad(relay.collapse_sum_like(X, x)) + t1 = relay.var("t1", shape=(3, 4, 5)) + t2 = relay.var("t2", shape=(1, 2, 3)) + check_relay_grad(relay.slice_like(t1, t2)) + check_relay_grad(relay.layout_transform(w1, 'NCHW', 'NHCW2c')) + #check_relay_grad(relay.device_copy(...)) + #check_relay_grad(relay.annotation.on_device(...)) + +if __name__ == "__main__": + test_autogenerated_primal_gradients() diff --git a/tests/python/unittest/test_pass_autodiff.py b/tests/python/unittest/test_pass_autodiff.py new file mode 100644 index 000000000000..70d42dee6f14 --- /dev/null +++ b/tests/python/unittest/test_pass_autodiff.py @@ -0,0 +1,482 @@ +import tvm +import topi +import numpy as np +from tvm.testing import check_numerical_grads, estimate_performance, PerformanceEstimate +import time +import inspect +import sys + +# Whether to dump the generated code +verbose = False + +def get_shape(tensor, param_values=None): + if param_values is None: + param_values = {} + return [tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(s, param_values)).value + for s in tensor.shape] + +def check_equivalence(outputs1, outputs2, inputs, in_range=(-10, 10), iters=3): + outputs1 = list(outputs1) + outputs2 = list(outputs2) + sched1 = tvm.create_schedule([o.op for o in outputs1]) + mout1 = tvm.build(sched1, outputs1 + inputs) + + sched2 = tvm.create_schedule([o.op for o in outputs2]) + mout2 = tvm.build(sched2, outputs2 + inputs) + + arguments1 = [tvm.nd.empty(get_shape(t), t.dtype) for t in outputs1 + inputs] + arguments2 = [tvm.nd.empty(get_shape(t), t.dtype) for t in outputs1 + inputs] + + for i in range(iters): + arguments1 = [] + arguments2 = [] + for a in outputs1 + inputs: + val = np.random.uniform(in_range[0], in_range[1], size=get_shape(a)).astype(a.dtype) + arguments1.append(tvm.nd.array(val)) + arguments2.append(tvm.nd.array(val)) + mout1(*arguments1) + mout2(*arguments2) + + for j, _ in enumerate(outputs1): + tvm.testing.assert_allclose(arguments1[j].asnumpy(), arguments2[j].asnumpy()) + +def check_grad(out, inputs, args=[], in_range=(-10,10), perf=None, param_values=None, + acceptable_fail_fraction=None): + line = inspect.getframeinfo(inspect.stack()[1][0]).lineno + + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + if param_values is None: + param_values = {} + + if verbose: + print("\n" + 80*"=" + "\n") + print("Testing gradients, line {}\n".format(line)) + print("Original tensors:\n") + print(tvm.PrintTensorRecursively(out)) + print() + + sout = tvm.create_schedule(out.op) + mout = tvm.build(sout, [out] + inputs + args) + + ones = topi.full_like(out, 1.0) + + grads = list(tvm.differentiate(out, inputs, ones)) + + if verbose: + print("Gradients:\n") + print(tvm.PrintTensorsRecursively(grads)) + print() + + grads_sched = tvm.create_schedule([g.op for g in grads]) + mgrad = tvm.build(grads_sched, grads + inputs + args) + + lowered = tvm.lower(grads_sched, grads + inputs + args, simple_mode=True) + + if verbose: + print("Lowered gradients:\n") + print(lowered) + print() + + if perf != False: + est = estimate_performance(grads, param_values=param_values) + est_lowered = estimate_performance(lowered, param_values=param_values) + + if verbose: + print("Note: performance tuples are (iterations, multiplications, memory)") + print("Expected performance of grads: {}".format(perf)) + print("Estimated performance of grads: {}".format(est.as_tuple())) + print("Estimated performance of lowered grads: {}".format(est_lowered.as_tuple())) + print() + + if est_lowered.memory > est.memory: + print("WARNING: Line {}: The estimated memory consumption increased after lowering, " + "this may indicate that tensor bounds have been expanded too much".format(line)) + print("before: {} after: {}".format(est, est_lowered)) + + (iters, mults, mem) = est.as_tuple() + if perf is None or isinstance(perf, str): + print("WARNING: Line {}: No performance information, you may set it to {}" + .format(line, est.as_tuple())) + if isinstance(perf, str): + print("0,/{!r}/{{s/{!r}/{}/}}".format(perf, perf, (iters, mults, mem))) + elif perf != (iters, mults, mem): + (ref_iters, ref_mults, ref_mem) = perf + ref_est = PerformanceEstimate(*perf) + + if est <= ref_est: + print("WARNING: Line {}: Estimated performance {} is better than {}. " + "Use this with sed:" + .format(line, est.as_tuple(), ref_est.as_tuple())) + print("0,/{}/{{s/{}/{}/}}".format(perf, perf, (iters, mults, mem))) + elif est >= ref_est: + print("WARNING: Line {}: Estimated performance {} IS WORSE THAN {}" + .format(line, est.as_tuple(), ref_est.as_tuple())) + else: + print("WARNING: Line {}: Estimated performance {} does not match {}" + .format(line, est.as_tuple(), ref_est.as_tuple())) + + EST_RTOL = 1.5 + if iters > ref_iters*EST_RTOL or mults > ref_mults*EST_RTOL or mem > ref_mem*EST_RTOL: + raise AssertionError("Line {}: Some of the estimated performance metrics are much " + "worse than the reference ones (by {}): " + "estimated {}, expected {}" + .format(line, EST_RTOL, est.as_tuple(), ref_est.as_tuple())) + + input_vals = [tvm.nd.array(np.random.uniform(in_range[0], in_range[1], + size=get_shape(a, param_values)).astype(a.dtype)) + for a in inputs] + arg_vals = [tvm.nd.array(np.random.uniform(in_range[0], in_range[1], + size=get_shape(a, param_values)).astype(a.dtype)) + for a in args] + + def fun(*arguments): + arrays = [tvm.nd.empty(get_shape(out, param_values), out.dtype)] + \ + [tvm.nd.array(a) for a in list(arguments) + arg_vals] + mout(*arrays) + return arrays[0].asnumpy().sum() + + g_arg_vals = \ + [tvm.nd.empty(get_shape(i, param_values), g.dtype) for i, g in zip(inputs, grads)] + \ + input_vals + arg_vals + mgrad(*g_arg_vals) + g_res = [g_arg_vals[g].asnumpy() for g, _ in enumerate(grads)] + + check_numerical_grads(fun, [a.asnumpy() for a in input_vals], g_res, + acceptable_fail_fraction=acceptable_fail_fraction) + +def test_differentiate_function(): + x = tvm.placeholder((32, 3, 28, 28), name='x') + + w = tvm.placeholder((10, 3, 3, 3), name='w') + t1 = topi.nn.conv2d(x, w, 1, 0, 1) + + t2 = topi.nn.flatten(t1) + t3 = topi.sum(t2) + + [dx1, dw1] = tvm.differentiate(t3, [x, w]) + [dx2, dw2] = tvm.differentiate(t2, [x, w], topi.full_like(t2, 1.0)) + + check_equivalence([dx1, dw1], [dx2, dw2], [x, w]) + + def mydiff(out, inp, head, t1=t1, t2=t2): + assert out == t2 and inp == [t1] + return [tvm.compute(t1.shape, + lambda ax0, ax1, ax2, ax3: head[ax0, ax3 + ax2*26 + ax1*676])] + + res = tvm.differentiate(t3, [x, w], override={t2: ([t1], mydiff)}) + check_equivalence(res.result, [dx1, dw1], [x, w]) + + def mydiff2(out, inputs, head): + return tvm.differentiate(out, inputs, head) + + res = tvm.differentiate(t3, [x, w], override={t1: ([x, w], mydiff2)}) + check_equivalence(res.result, [dx1, dw1], [x, w]) + +# Test some simple expressions +def test_autodiff(): + x = tvm.var("x", dtype='float32') + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + A0 = tvm.placeholder((10, 10), name='A0') + A1 = tvm.placeholder((10, 10), name='A1') + + B = tvm.compute((10, 10), lambda i, j: A0[i, j] + A0[j, i], name='B') + check_grad(B, A0, perf=(10100, 10000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.floor(A0[i, j]), name='B') + check_grad(B, A0, perf=(100, 0, 100), acceptable_fail_fraction=0.05) + + B = tvm.compute((10, 10), lambda i, j: tvm.ceil(A0[i, j]), name='B') + check_grad(B, A0, perf=(100, 0, 100), acceptable_fail_fraction=0.05) + + B = tvm.compute((10, 10), lambda i, j: tvm.trunc(A0[i, j]), name='B') + check_grad(B, A0, perf=(100, 0, 100), acceptable_fail_fraction=0.05) + + B = tvm.compute((10, 10), lambda i, j: tvm.round(A0[i, j]), name='B') + check_grad(B, A0, perf=(100, 0, 100), acceptable_fail_fraction=0.05) + + B = tvm.compute((10, 10), lambda i, j: A0[i, j] + tvm.exp(A0[j, i]), name='B') + check_grad(B, A0, perf=(10100, 20000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.log(tvm.abs(A0[i, j] + tvm.exp(A0[j, i]))), name='B') + check_grad(B, A0, perf=(10100, 70000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.sigmoid(A0[i, j]*A0[i, j]*A0[j, i]), name='B') + check_grad(B, A0, perf=(10100, 120000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.tanh(A0[i, j]*A0[i, j]*A0[j, i]), name='B') + check_grad(B, A0, perf=(10100, 120000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.sqrt(A0[i, j]*A0[i, j]*A0[j, i]), name='B') + check_grad(B, A0, perf=(10100, 90000, 200), in_range=(0.1, 10)) + + B = tvm.compute((10, 10), lambda i, j: tvm.power(tvm.abs(A0[i, j]), A0[j, i]), name='B') + check_grad(B, A0, perf=(10100, 90000, 200)) + + B = tvm.compute((10, 10), lambda i, j: A0[i, j] * A0[j, i], name='B') + check_grad(B, A0, perf=(10100, 10000, 200)) + + # TODO: This one needs transforming Sum(a + b) -> Sum(a) + Sum(b) + B = tvm.compute((10,), lambda i: tvm.sum(A0[i, k]*A0[k, i], axis=k), name='B') + check_grad(B, A0, perf=(11010, 1000, 1110)) + + B = tvm.compute((10, 10), lambda i, j: tvm.sum(A0[i, k]*A0[k, i] + 5, axis=k), name='B') + check_grad(B, A0, perf=(20100, 10000, 1200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.max(A0[i, k]*A0[k, j] + 5, axis=k), name='B') + check_grad(B, A0, perf=(110100, 310000, 20200)) + + B = tvm.compute((10, 10), lambda i, j: A0[i, j] * (A1[j, i] + A0[j, i]), name='B') + check_grad(B, A0, [A1], perf=(10100, 10000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.sum(A0[k, k] - A0[tvm.min(j + k, 9), j]*A0[i, k], + axis=k), + name='B') + check_grad(B, A0, perf=(110100, 10000, 10200)) + + def fcombine(x, y): + return x*y + + def fidentity(t0): + return tvm.const(1, t0) + + prod = tvm.comm_reducer(fcombine, fidentity, name='prod') + B = tvm.compute((10, 10), lambda i, j: prod(A0[i, k] + A0[k, i], axis=k), name='B') + check_grad(B, A0, perf=(20100, 40000, 2200)) + + X = tvm.placeholder((10,), name='X') + A = tvm.compute((10,), lambda i: X[i] + X[9 - i]) + B = tvm.compute((10,), lambda i: X[i] * X[9 - i]) + Y = topi.tensordot(A, B, 1) + check_grad(Y, X, perf=(251, 230, 71)) + +def test_topi_autodiff(): + X = tvm.placeholder((1, 2, 4, 4), name='X') + W = tvm.placeholder((5, 2, 3, 3), name='W') + W1 = tvm.placeholder((2, 5, 3, 3), name='W1') + W2 = tvm.placeholder((1,), name='W2') + + R = topi.nn.conv2d(X, W, 1, 1, 1) + check_grad(R, [X, W], perf=(3410, 2880, 652)) + + R1 = topi.nn.conv2d(topi.nn.relu(R), W1, 1, 0, 1) + check_grad(R1, [X, W, W1], perf=(6198, 5320, 1250)) + + R = topi.broadcast_to(W2, (5, 2, 3, 3)) + check_grad(R, [W2], perf=(180, 0, 91)) + + R = topi.nn.conv2d(X, topi.broadcast_to(W2, (5, 2, 3, 3)), 1, 1, 1) + check_grad(R, [X, W2], perf=(3590, 2880, 743)) + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'avg') + check_grad(R, X, perf=(40, 224, 40)) + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') + check_grad(R, X, perf=(168, 1248, 104)) + + X = tvm.placeholder((1, 2, 5, 5), name='X') + R = topi.reshape(X, (1, 32)) + check_grad(R, [X], perf=(82, 1200, 82)) + + X = tvm.placeholder((1, 2, 5, 5), name='X') + W = tvm.placeholder((2, 2, 3, 3), name='W') + + S = topi.reshape(X, (1, 50)) + check_grad(S, [X], perf=(100, 700, 100)) + + R = X + topi.nn.conv2d(X + topi.nn.conv2d(X, W, 1, 1, 1), W, 1, 1, 1) + check_grad(R, [X, W], perf=(6854, 5400, 1726)) + + S = topi.nn.softmax(topi.reshape(R, (1, 50))) + check_grad(S, [X, W], perf=(10956, 14201, 2333)) + + S = topi.sigmoid(topi.reshape(R, (1, 50))) + check_grad(S, [X, W], perf=(8004, 8350, 2026)) + + S = topi.tanh(topi.reshape(R, (1, 50))) + check_grad(S, [X, W], perf=(8004, 8350, 2026)) + + S = topi.nn.log_softmax(topi.reshape(R, (1, 50))) + check_grad(S, [X, W], perf=(10906, 13601, 2283)) + check_grad(S, [W], [X], perf=(8920, 11101, 1997)) + + X = tvm.placeholder((1, 2, 3, 5), name='X') + Y = tvm.placeholder((1, 2, 7, 5), name='Y') + S = topi.concatenate((X, Y), 2) + check_grad(S, [X, Y], perf=(100, 0, 100)) + + X = tvm.placeholder((1, 2, 6, 5), name='X') + (S, R) = topi.split(X, 2, 2) + check_grad(S, [X], perf=(120, 0, 120)) + check_grad(R, [X], perf=(120, 0, 120)) + R1 = topi.concatenate((S, R), 2) + check_grad(R1, [X], perf=(300, 0, 300)) + R2 = topi.concatenate((R, S), 2) + check_grad(R2, [X], perf=(300, 0, 300)) + + X = tvm.placeholder((4, 5), name='X') + I = tvm.placeholder((100,), name='I', dtype='int32') + R = topi.take(X, topi.abs(I)) + check_grad(R, [X], [I], perf=(2200, 6000, 220)) + +def test_stride_dilation(): + X = tvm.placeholder((1, 2, 10, 10), name='X') + + W = tvm.placeholder((2, 2, 1, 1), name='W') + + Y = topi.nn.conv2d(X, W, 1, 0, 1) + check_grad(Y, [X, W], perf=(1404, 800, 808)) + Y = topi.nn.conv2d(X, W, 2, 0, 1) + check_grad(Y, [X, W], perf=(928, 1572, 670)) + Y = topi.nn.conv2d(X, W, 3, 0, 1) + check_grad(Y, [X, W], perf=(932, 1728, 672)) + Y = topi.nn.conv2d(X, W, 1, 0, 2) + check_grad(Y, [X, W], perf=(1404, 800, 808)) + Y = topi.nn.conv2d(X, W, 2, 0, 2) + check_grad(Y, [X, W], perf=(928, 1572, 670)) + Y = topi.nn.conv2d(X, W, 3, 0, 2) + check_grad(Y, [X, W], perf=(932, 1728, 672)) + Y = topi.nn.conv2d(X, W, 1, 0, 3) + check_grad(Y, [X, W], perf=(1404, 800, 808)) + Y = topi.nn.conv2d(X, W, 2, 0, 3) + check_grad(Y, [X, W], perf=(928, 1572, 670)) + Y = topi.nn.conv2d(X, W, 3, 0, 3) + check_grad(Y, [X, W], perf=(932, 1728, 672)) + + W = tvm.placeholder((2, 2, 2, 2), name='W') + + Y = topi.nn.conv2d(X, W, 1, 0, 1) + check_grad(Y, [X, W], perf=(3922, 2896, 1242)) + Y = topi.nn.conv2d(X, W, 2, 0, 1) + check_grad(Y, [X, W], perf=(1650, 2800, 1066)) + Y = topi.nn.conv2d(X, W, 3, 0, 1) + check_grad(Y, [X, W], perf=(1146, 2880, 890)) + Y = topi.nn.conv2d(X, W, 1, 0, 2) + check_grad(Y, [X, W], perf=(3500, 19720, 1092)) + Y = topi.nn.conv2d(X, W, 2, 0, 2) + check_grad(Y, [X, W], perf=(3408, 88848, 2034)) + Y = topi.nn.conv2d(X, W, 3, 0, 2) + check_grad(Y, [X, W], perf=(2254, 65232, 992)) + Y = topi.nn.conv2d(X, W, 1, 0, 3) + check_grad(Y, [X, W], perf=(3138, 17696, 970)) + Y = topi.nn.conv2d(X, W, 2, 0, 3) + check_grad(Y, [X, W], perf=(3816, 82368, 2176)) + Y = topi.nn.conv2d(X, W, 3, 0, 3) + check_grad(Y, [X, W], perf=(3834, 104432, 2306)) + + W = tvm.placeholder((2, 2, 3, 3), name='W') + + Y = topi.nn.conv2d(X, W, 1, 0, 1) + check_grad(Y, [X, W], perf=(7420, 5904, 1752)) + Y = topi.nn.conv2d(X, W, 2, 0, 1) + check_grad(Y, [X, W], perf=(3888, 58592, 2214)) + Y = topi.nn.conv2d(X, W, 3, 0, 1) + check_grad(Y, [X, W], perf=(1552, 2268, 1102)) + Y = topi.nn.conv2d(X, W, 1, 0, 2) + check_grad(Y, [X, W], perf=(5916, 42392, 1256)) + Y = topi.nn.conv2d(X, W, 2, 0, 2) + check_grad(Y, [X, W], perf=(6736, 21784, 3694)) + Y = topi.nn.conv2d(X, W, 3, 0, 2) + check_grad(Y, [X, W], perf=(2672, 146096, 1668)) + Y = topi.nn.conv2d(X, W, 1, 0, 3) + check_grad(Y, [X, W], perf=(2896, 89152, 956)) + Y = topi.nn.conv2d(X, W, 2, 0, 3) + check_grad(Y, [X, W], perf=(2280, 12856, 1992)) + Y = topi.nn.conv2d(X, W, 3, 0, 3) + check_grad(Y, [X, W], perf=(2224, 12032, 716)) + + Y = topi.nn.pool(X, [1, 1], [1, 1], [0, 0, 0, 0], 'max') + check_grad(Y, [X], perf=(200, 0, 200)) + Y = topi.nn.pool(X, [1, 1], [2, 2], [0, 0, 0, 0], 'max') + check_grad(Y, [X], perf=(412, 1124, 412)) + Y = topi.nn.pool(X, [1, 1], [3, 3], [0, 0, 0, 0], 'max') + check_grad(Y, [X], perf=(232, 1200, 232)) + Y = topi.nn.pool(X, [2, 2], [1, 1], [0, 0, 0, 0], 'max') + check_grad(Y, [X], perf=(4162, 7200, 1962)) + Y = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') + check_grad(Y, [X], perf=(1050, 7800, 650)) + Y = topi.nn.pool(X, [2, 2], [3, 3], [0, 0, 0, 0], 'max') + check_grad(Y, [X], perf=(858, 6304, 602)) + Y = topi.nn.pool(X, [3, 3], [1, 1], [0, 0, 0, 0], 'max') + check_grad(Y, [X], perf=(18128, 34200, 3928)) + Y = topi.nn.pool(X, [3, 3], [2, 2], [0, 0, 0, 0], 'max') + check_grad(Y, [X], perf=(6712, 131312, 1690)) + Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], 'max') + check_grad(Y, [X], perf=(1838, 12950, 704)) + +def test_some_conv2d_net(): + batch_size = 1 + num_classes = 10 + + features = 4 + dense_units = 16 + + x = tvm.placeholder((batch_size, 28, 14, 1)) + y = tvm.placeholder((batch_size, num_classes)) + + w1 = tvm.placeholder((features, 1, 3, 5)) + b1 = tvm.placeholder((features,)) + w2 = tvm.placeholder((features, features, 3, 5)) + b2 = tvm.placeholder((features,)) + b3 = tvm.placeholder((dense_units,)) + w4 = tvm.placeholder((num_classes, dense_units)) + b4 = tvm.placeholder((num_classes,)) + + t = topi.transpose(x, [0, 3, 1, 2]) + t = topi.nn.relu(topi.nn.conv2d(t, w1, 1, 0, 1) + topi.reshape(b1, (1, features, 1, 1))) + t = topi.nn.relu(topi.nn.conv2d(t, w2, 1, 0, 1) + topi.reshape(b2, (1, features, 1, 1))) + t = topi.nn.pool(t, [2, 2], [2, 2], [0, 0, 0, 0], 'avg') + t = topi.transpose(t, [0, 2, 3, 1]) + t = topi.nn.flatten(t) + w3 = tvm.placeholder((dense_units, get_shape(t)[1])) + t = topi.nn.relu(topi.nn.dense(t, w3, b3)) + t = topi.nn.dense(t, w4, b4) + + t = - topi.sum(y * topi.nn.log_softmax(t)) / batch_size + + weights = [w1, b1, w2, b2, w3, b3, w4, b4] + + check_grad(t, weights, [x, y], in_range=(-1.0, 1.0), perf=(194865, 179089, 28194)) + +def test_free_vars(): + m = tvm.var('m') + n = tvm.var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((n,), name='B') + + Y = topi.add(A, B) + check_grad(Y, [A, B], perf=(160, 0, 120), param_values={m: 5, n: 10}) + + param_values = {m: 10} + x = tvm.var("x", dtype='float32') + k = tvm.reduce_axis((0, m), name="k") + A0 = tvm.placeholder((m, m), name='A0') + A1 = tvm.placeholder((m, m), name='A1') + + B = tvm.compute((m, m), lambda i, j: A0[i, j] + A0[j, i], name='B') + check_grad(B, A0, perf=(10200, 10000, 300), param_values=param_values) + + B = tvm.compute((m,), lambda i: tvm.sum(A0[i, k]*A0[k, i], axis=k), name='B') + check_grad(B, A0, perf=(11110, 1000, 1210), param_values=param_values) + + X = tvm.placeholder((m, n, 4, 4), name='X') + param_values = {m: 1, n: 2} + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'avg') + check_grad(R, X, perf=(72, 224, 72), param_values=param_values) + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') + check_grad(R, X, perf=(200, 1248, 136), param_values=param_values) + +if __name__ == "__main__": + if "-v" in sys.argv: + verbose = True + + test_differentiate_function() + test_autodiff() + test_topi_autodiff() + test_stride_dilation() + test_some_conv2d_net() + test_free_vars() diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py new file mode 100644 index 000000000000..a1d4070a72f7 --- /dev/null +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -0,0 +1,464 @@ +import random +import sys +import numpy as np +import tvm +from tvm import comm_reducer +from tvm.testing import estimate_performance +from tvm.ir_pass import Simplify, Equal, LiftNonzeronessCondition, IsSumCombiner, \ + CanFactorZeroFromCombiner, InlineTailCall, InlineTensors, SolveSystemOfInequalities, \ + SimplifyDomain, SimplifyReductionDomain, ExtractAsTensorMaybe, ExtractReductions, \ + ExtractNonTopReductions, OptimizeAndLiftNonzeronessConditions + +def get_shape(tensor): + return [s.value for s in tensor.shape] + +def check_eq(t1, t2, args): + s1 = tvm.create_schedule(t1.op) + m1 = tvm.build(s1, [t1] + args) + + s2 = tvm.create_schedule(t2.op) + m2 = tvm.build(s2, [t2] + args) + + for _ in range(5): + arg_vals = [tvm.ndarray.array(np.random.uniform(-10, 10, size=get_shape(a)) + .astype(a.dtype)) + for a in [t1] + args] + m1(*arg_vals) + res1 = arg_vals[0].asnumpy() + m2(*arg_vals) + res2 = arg_vals[0].asnumpy() + + np.testing.assert_allclose(res1, res2, atol=1e-3, rtol=1e-2) + +def check_symeq(expr1, expr2): + expr1 = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr1)) + expr2 = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr2)) + + if tvm.ir_pass.Equal(expr1, expr2): + return + + diff = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr1 - expr2)) + if not Equal(diff, tvm.const(0, expr1.dtype)): + raise AssertionError("Expressions {} and {} are not equal, their diff is {}" + .format(expr1, expr2, diff)) + +def compute(shape, fcompute): + """Like tvm.compute but automatically extracts reductions.""" + return tvm.compute(shape, + lambda *vs: ExtractNonTopReductions( + fcompute(*vs), vs, {v: tvm.Range(0, s) for v, s in zip(vs, shape)})) + +def check_tensor_symeq(A, B): + if not isinstance(B, tvm.tensor.Tensor): + B = compute(A.shape, B) + vmap = {a.var: b.var for a, b in zip(A.op.axis, B.op.axis)} + expr_a = tvm.ir_pass.Substitute(A.op.body[A.value_index], vmap) + expr_b = B.op.body[B.value_index] + expr_a = tvm.ir_pass.CanonicalSimplify(InlineTensors(expr_a, [], True)) + expr_b = tvm.ir_pass.CanonicalSimplify(InlineTensors(expr_b, [], True)) + if not Equal(expr_a, expr_b): + print(expr_a) + print(expr_b) + raise AssertionError("The expressions are not equal") + +def check_eq_bruteforce(expr1, expr2, vranges): + def _compute_body(*us): + vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} + return tvm.ir_pass.Substitute(expr1 == expr2, vmap) + + A = compute([r.extent.value for v, r in vranges.items()], _compute_body) + args = [tvm.ndarray.empty(A.shape, A.dtype)] + sch = tvm.create_schedule(A.op) + mod = tvm.build(sch, [A]) + mod(*args) + res = args[0].asnumpy() + if not np.all(res): + indices = list(np.argwhere(res == 0)[0]) + counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] + counterex = ", ".join([v + " = " + str(i) for v, i in sorted(counterex)]) + raise AssertionError("Expressions {}\nand {}\nare not equal on {}\n" + "Counterexample: {}" + .format(expr1, expr2, vranges, counterex)) + +prod_combiner = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0)) +sum_combiner = comm_reducer(lambda x, y: x + y, lambda t0: tvm.const(0, t0)) +sum2_combiner = comm_reducer(lambda x, y: y + x, lambda t0: tvm.const(0, t0)) +sum_derivative_combiner = comm_reducer(lambda x, y: (x[0] + y[0], y[1] + x[1]), + lambda t0, t1: (tvm.const(0, t0), tvm.const(0, t1))) +prod_derivative_combiner = comm_reducer(lambda x, y: (x[0]*y[0], x[0]*y[1] + x[1]*y[0]), + lambda t0, t1: (tvm.const(1, t0), tvm.const(0, t1))) +sum_both_combiner = comm_reducer(lambda x, y: (x[0] + y[0], x[0] + y[0] + x[1] + y[1]), + lambda t0, t1: (tvm.const(0, t0), tvm.const(0, t1))) +xor_combiner = comm_reducer(lambda x, y: x ^ y, lambda t0: tvm.const(0, t0)) + +def test_is_sum_combiner(): + k = tvm.reduce_axis((0, 10), name="k") + i = tvm.const(0, "int32") + f = tvm.const(0.0, "float32") + assert IsSumCombiner(sum_combiner(i, k).combiner) + assert IsSumCombiner(sum_combiner(f, k).combiner) + assert IsSumCombiner(sum2_combiner(i, k).combiner) + assert IsSumCombiner(sum2_combiner(f, k).combiner) + assert not IsSumCombiner(sum_derivative_combiner((f, f), k)[0].combiner) + assert not IsSumCombiner(prod_combiner(f, k).combiner) + assert not IsSumCombiner(prod_derivative_combiner((f, f), k)[1].combiner) + +def test_can_factor_zero_from_combiner(): + k = tvm.reduce_axis((0, 10), name="k") + i = tvm.const(0, "int32") + f = tvm.const(0.0, "float32") + assert CanFactorZeroFromCombiner(sum_combiner(i, k).combiner, 0) + assert CanFactorZeroFromCombiner(sum2_combiner(f, k).combiner, 0) + assert CanFactorZeroFromCombiner(sum_derivative_combiner((f, f), k)[0].combiner, 0) + assert CanFactorZeroFromCombiner(sum_derivative_combiner((f, f), k)[0].combiner, 1) + assert not CanFactorZeroFromCombiner(prod_derivative_combiner((f, f), k)[0].combiner, 0) + assert CanFactorZeroFromCombiner(prod_derivative_combiner((f, f), k)[0].combiner, 1) + assert CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 0) + assert not CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 1) + +def test_lift_nonzeroness_condition(): + k = tvm.reduce_axis((0, 5), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 5), name="n") + A = tvm.placeholder((10,), name='A') + + def _check(shape, fun, A=A): + T1 = tvm.compute(shape, fun) + T2 = tvm.compute(shape, lambda *args: LiftNonzeronessCondition(fun(*args))) + check_eq(T1, T2, [A]) + assert isinstance(T2.op.body[0], tvm.expr.Select) + + _check((10,), lambda i: A[i]) + _check((10,), lambda i: A[i] + (i % 2 == 0)) + _check((10,), lambda i: A[i]*(i % 2 == 0) + (i % 2 == 0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), A[i], 0.0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), A[i], 0.0) + (i % 2 == 0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), 0.0, A[i]) + (i % 2 == 0)) + def e1(i): return tvm.expr.Select((i % 2 == 1), 0.0, A[i]) + def e2(i): return tvm.expr.Select((i % 2 == 0), A[(i + 1) % 10], 0.0) + def e3(i): return tvm.expr.Select((i % 2 == 1), A[i], 0.0) + _check((10,), lambda i: e1(i) + e2(i) + e3(i) + e1(i)*e2(i)) + _check((10,), lambda i: e1(i)*e3(i)) + _check((10,), lambda i: e1(i)*e2(i)) + _check((10,10), lambda i, j: A[i]*(i == j) + A[j]*(i == 2*j) + A[j]*(j == i)) + _check((10,10), lambda i, j: tvm.min(A[i]*(i == j), A[j]*(i == 2*j))) + _check((10,10), lambda i, j: tvm.max(A[i]*(i == j), A[j]*(i == 2*j))) + _check((10,10), lambda i, j: A[i]*(i == j) - A[j]*(i == 2*j)) + _check((10,10), lambda i, j: A[i]*(i == j) / (1 + tvm.abs(A[j]*(i == 2*j)))) + _check((10,10), lambda i, j: i*(i < j) + j*(i > j)) + _check((10,10), lambda i, j: i*(i < j) % (1 + j*(i > j))) + + def _check_symeq(expr1, expr2): + expr1 = LiftNonzeronessCondition(expr1) + expr2 = LiftNonzeronessCondition(expr2) + print(expr1) + print(expr2) + print() + check_symeq(expr1, expr2) + + _check_symeq(tvm.expr.Select(tvm.expr.EQ(k, l), 0.0, tvm.expr.Cast('float32', (k < n))), + tvm.expr.Select(tvm.expr.And((k < n), tvm.expr.NE(k, l)), 1.0, 0.0)) + _check_symeq(tvm.min(tvm.expr.Cast('int32', k < n)*l, tvm.expr.Select(k >= n, 0, 1)), + tvm.expr.Select(k < n, tvm.min(l, 1), 0)) + +def test_inline_tail_call(): + A = tvm.compute((10, 10), lambda i, j: i + j*j) + B = tvm.compute((5, 6), lambda k, l: A[k + l, k + 1]) + C = InlineTailCall(B) + resbody = lambda k, l: k + l + (k + 1)*(k + 1) + check_symeq(C.op.body[0], resbody(*[iv.var for iv in C.op.axis])) + +def test_inline_tensors(): + A = tvm.compute((10, 10), lambda i, j: i + j) + B = tvm.compute((10, 10), lambda i, j: i * j) + C = tvm.compute((10, 10), lambda i, j: A[i, j] + B[i, j]) + k = tvm.reduce_axis((0, 5), name="k") + D = tvm.compute((10, 10), lambda i, j: tvm.sum(A[i, k], k)) + E = tvm.compute((10, 10), lambda i, j: A[2, j] + C[i, 2] + D[i, j]) + + R = InlineTensors(E) + resbody = lambda i, j: 2 + j + i + 2 + i*2 + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [A]) + resbody = lambda i, j: 2 + j + C[i, 2] + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [A, C]) + resbody = lambda i, j: 2 + j + ((i + 2) + B[i, 2]) + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [B, C]) + resbody = lambda i, j: A[2, j] + (A[i, 2] + i*2) + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + +def test_solve_system_of_inequalities(): + seed = random.randrange(sys.maxsize) + print("\nseed: {}\n".format(seed)) + random.seed(seed) + + def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): + vs = [tvm.var("x" + str(i)) for i in range(variables)] + + fs = [] + for i in range(formulas): + s1 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s1 += random.randint(coef[0], coef[1]) + s2 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s2 += random.randint(coef[0], coef[1]) + op = random.choice([tvm.expr.EQ, tvm.expr.LE, tvm.expr.LT, tvm.expr.GE, tvm.expr.GT]) + fs.append(op(s1, s2)) + + vranges = {v: tvm.Range(bounds[0], bounds[1] + 1) for v in vs} + + before = tvm.all(*fs) + print(before) + after = tvm.all(*SolveSystemOfInequalities(fs, vs, vranges)) + print(after) + print() + + check_eq_bruteforce(before, after, vranges) + + for i in range(3): + _check(1, 1) + for i in range(3): + _check(1, 2) + + for i in range(3): + _check(2, 1) + for i in range(3): + _check(2, 2) + for i in range(3): + _check(2, 3) + + # Somewhere here coefficients in the results become too large, leading to overflow, + # so we use smaller initial coefficients + + for i in range(5): + _check(3, 3, coef=(-2,2)) + for i in range(5): + _check(3, 4, coef=(-2,2)) + + for i in range(5): + _check(4, 3, coef=(-1,1)) + + for i in range(5): + _check(10, 2, coef=(-1,1), bounds=(0, 4)) + for i in range(5): + _check(10, 3, coef=(0,1), bounds=(0, 4)) + +def test_simplify_domain(): + # Note that here we test both SimplifyDomain and SimplifyReductionDomain. + def _check(cond, axis, volume, vranges={}): + vranges_with_axis = dict(vranges) + vranges_with_axis.update({iv.var: iv.dom for iv in axis}) + variables = [iv.var for iv in axis] + new_cond, new_axis, old_to_new, new_to_old = SimplifyDomain(cond, variables, + vranges_with_axis) + + print("old", axis, cond) + print("new", new_axis, new_cond) + print("old_to_new", old_to_new) + print("new_to_old", new_to_old) + print() + + cond_subst = tvm.ir_pass.Substitute(cond, old_to_new) + new_vranges = vranges.copy() + new_vranges.update({v.var: v.dom for v in new_axis}) + # If new_cond is true in the new domain, then cond_subst must also be true in the new + # domain, but the reverse is not necessarily true + check_eq_bruteforce(tvm.all(new_cond, cond_subst), new_cond, new_vranges) + + new_cond_subst = tvm.ir_pass.Substitute(new_cond, new_to_old) + old_vranges = vranges.copy() + old_vranges.update({v.var: v.dom for v in axis}) + check_eq_bruteforce(cond, tvm.all(cond, new_cond_subst), old_vranges) + + # Also check SimplifyReductionDomain + reduction = xor_combiner(sum([v*(i + 1) for i, v in enumerate(axis)]), axis) + new_reduction = SimplifyReductionDomain(reduction, vranges) + check_eq_bruteforce(reduction, new_reduction, vranges) + + vol = np.prod([iv.dom.extent.value for iv in new_axis]) + if vol != volume: + raise AssertionError("New volume is {} != {}\n" + "Old domain {} where {}\nNew domain {} where {}" + .format(vol, volume, axis, cond, new_axis, new_cond)) + + k = tvm.reduce_axis((0, 5), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 5), name="n") + + _check((k <= l), [k, l, n], 125) + _check((k < l), [k, l, n], 80) + _check(tvm.expr.EQ(k, l), [k, l, n], 25) + _check(tvm.all(tvm.expr.EQ(k, l), (l < n)), [k, l, n], 16) + _check(tvm.expr.EQ(2*l, k), [k, l, n], 15) + # TODO: the result depends on the order of variables because we don't have a proper solver for + # systems of linear equations yet + _check(tvm.expr.EQ(2*l, k), [n, l, k], 25) + _check(tvm.all(l - k < 2, 2*n == k), [k, l, n], 15) + _check(tvm.all(l - k < 2, l >= k), [k, l, n], 50) + + some_var = tvm.var('some_var') + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 50, {some_var: tvm.Range(0, 3)}) + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 25, {some_var: tvm.Range(0, 2)}) + + + k = tvm.reduce_axis((-3, 2), name="k") + l = tvm.reduce_axis((-3, 2), name="l") + n = tvm.reduce_axis((-3, 2), name="n") + + _check((k < l), [k, l, n], 80) + _check(tvm.expr.EQ(k, l), [k, l, n], 25) + _check(tvm.all(tvm.expr.EQ(k, l), (l < n)), [k, l, n], 16) + # Now there are only two possible values for l: {l = -1, k = -2} and {l = 0, k = 0} + _check(tvm.expr.EQ(2*l, k), [k, l, n], 10) + # TODO: the result depends on the order of variables because we don't have a proper solver for + # systems of linear equations + _check(tvm.expr.EQ(2*l, k), [n, l, k], 25) + _check(tvm.all(l - k < 2, 2*n == k), [k, l, n], 10) + _check(tvm.all(l - k < 2, l >= k), [k, l, n], 50) + + some_var = tvm.var('some_var') + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 50, {some_var: tvm.Range(0, 3)}) + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 25, {some_var: tvm.Range(0, 2)}) + + + k = tvm.reduce_axis((0, 6), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 30), name="n") + + _check(tvm.all(k + l*6 == n), [k, l, n], 30) + _check(tvm.all(k + l*6 == n), [n, k, l], 30) + _check(tvm.all(k + l*6 == n), [n, l, k], 30) + + _check(tvm.all(n / 5 == k, n % 5 == l), [l, k, n], 30) + # TODO: Same thing with the order + _check(tvm.all(n / 5 == k, n % 5 == l), [n, l, k], 30) + + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + # TODO: This is not fully optimized because we don't have a solver + _check(tvm.all((l + k)%3 <= 1, (l + k)/3 <= 2), [l, k], 144) + +def test_extract_as_tensor_maybe(): + def _check(shape, fcompute, volume=None, vranges={}): + def fcompute_extracted(*variables): + vranges_updated = dict(vranges) + vranges_updated.update({v: tvm.Range(0, s) for v, s in zip(variables, shape)}) + expr = fcompute(*variables) + if isinstance(expr, tvm.expr.Select): + new_true_value = ExtractAsTensorMaybe(expr.true_value, + expr.condition, + variables, + vranges_updated) + expr = tvm.expr.Select(expr.condition, + new_true_value, + expr.false_value) + if volume is not None: + assert isinstance(new_true_value, tvm.expr.Call) + vol = np.prod([iv.dom.extent.value for iv in new_true_value.func.axis]) + if vol != volume: + raise AssertionError("New volume is {} != {}" + .format(vol, volume)) + return expr + + A = tvm.compute(shape, fcompute) + B = tvm.compute(shape, fcompute_extracted) + check_eq(A, B, []) + + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, i + j, 0), volume=30) + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, j, 0), volume=10) + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, i, 0), volume=3) + _check((10, 10), lambda i, j: tvm.expr.Select(tvm.all(i < j, j < 5), i + j, 0), volume=16) + # This one doesn't get extracted + _check((10, 10), lambda i, j: tvm.expr.Select(i <= j, i + j, 0)) + +def test_extract_reductions(): + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + n = tvm.reduce_axis((0, 10), name="n") + + A = tvm.compute((10, 10), + lambda i, j: + ExtractReductions(sum_combiner(i + k + xor_combiner(j*k + l, l), k), + [i, j], + {i: tvm.Range(0, 10), j: tvm.Range(0, 10)})) + B = tvm.compute((10, 10), lambda j, k: xor_combiner(j*k + l, l)) + C = tvm.compute((10, 10), lambda i, j: sum_combiner(i + k + B[j, k], k)) + check_eq(C, A, []) + + fcompute = lambda i, j: \ + ExtractReductions(sum_both_combiner((prod_derivative_combiner((i*n + 2*k, j + k), k)[1], + xor_combiner(j*n + l, l)), n)[1], + [i, j], + {i: tvm.Range(0, 10), j: tvm.Range(0, 10)}) + A = tvm.compute((10, 10), fcompute) + _, B = tvm.compute((10, 10, 10), + lambda i, j, n: prod_derivative_combiner((i*n + 2*k, j + k), k)) + C = tvm.compute((10, 10), lambda j, n: xor_combiner(j*n + l, l)) + _, D = tvm.compute((10, 10), lambda i, j: sum_both_combiner((B[i, j, n], C[j, n]), n)) + check_eq(A, D, []) + +def test_optimize_and_lift_nonzeroness(): + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + n = tvm.reduce_axis((0, 10), name="n") + A = tvm.placeholder((10, 10), name="A") + + zero = tvm.const(0, 'float32') + + B = compute((10, 10), lambda i, j: tvm.sum((i == j)*A[i, k] + A[k, j]*(i == j), k)) + B = OptimizeAndLiftNonzeronessConditions(B) + R = lambda i, j: tvm.expr.Select(i == j, + tvm.sum(A[j, k] + A[k, j], k), + zero) + check_tensor_symeq(B, R) + + # TODO: This test is unstable: sometimes the resulting condition looks like + # (i == j)*(j == i) instead of (i == j) + # B = compute((10, 10), lambda i, j: tvm.sum((i == j)*(i == k)*A[i, k] + + # (i == j)*A[k, j]*(i == k), k)) + # B = OptimizeAndLiftNonzeronessConditions(B) + # R = lambda i, j: tvm.expr.Select(i == j, A[j, j]*2.0, zero) + # check_tensor_symeq(B, R) + + B = compute((10, 10), lambda i, j: tvm.sum((i < j)*(j < k)*A[j, k], k)) + B = OptimizeAndLiftNonzeronessConditions(B) + k1 = tvm.reduce_axis((2, 10), name="k1") + R = compute((10, 10), lambda i, j: + tvm.expr.Select(tvm.all(i < j, j < 10), + tvm.sum(tvm.expr.Select(j < k1, A[j, k1], zero), k1), + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + + # TODO: This one needs the equation solver + # B = compute((10, 10), lambda i, j: tvm.sum((i <= j)*(j <= k)*A[j, k], k, where=(i >= k))) + # B = OptimizeAndLiftNonzeronessConditions(B) + # R = compute((10, 10), lambda i, j: tvm.expr.Select((i == j), A[i, i], zero)) + # check_eq(B, R, [A]) + # assert estimate_performance(B) <= estimate_performance(R) + + B = compute((10, 10), + lambda i, j: prod_derivative_combiner((A[j, k], (i <= j)*(j < k)*A[i, k]), k)[1]) + B = OptimizeAndLiftNonzeronessConditions(B) + R = compute((10, 10), lambda i, j: + tvm.expr.Select(tvm.all(i <= j, j < 10), + prod_derivative_combiner((A[j, k], (j < k)*A[i, k]), k)[1], + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + +if __name__ == "__main__": + test_is_sum_combiner() + test_can_factor_zero_from_combiner() + test_lift_nonzeroness_condition() + test_inline_tail_call() + test_inline_tensors() + test_solve_system_of_inequalities() + test_simplify_domain() + test_extract_as_tensor_maybe() + test_extract_reductions() + test_optimize_and_lift_nonzeroness() diff --git a/tutorials/language/autodiff_basics.py b/tutorials/language/autodiff_basics.py new file mode 100644 index 000000000000..32646a8952b2 --- /dev/null +++ b/tutorials/language/autodiff_basics.py @@ -0,0 +1,160 @@ +""" +Automatic Differentiation of Tensor Expressions +=============================================== +**Author**: `Sergei Grechanik `_ + +This tutorial describes how to use automatic differentiation of tensor expressions. + +Usually differentiation is done on the level of NNVM/Relay graphs. However there are some situations +when one might want to perform differentiation on the lower level of TVM tensor expressions, e.g.: + - When you are experimenting with a completely new kind of operations. + - When gradients for some operations haven't been implemented yet in NNVM/Relay. + - When you are implementing gradients for a new operation manually and need a starting point. + - When you want to train models in pure TVM without NNVM/Relay (if you really do, please tell us + why). + +.. note:: + + - Automatic differentiation is still work in progress. Some operations are not differentiated + very well yet. + - Automatic differentiation doesn't perform scheduling. The generated code should be scheduled + by hand or using some autoscheduling and autotuning methods (which may require manually + writing schedule templates). + +""" +from __future__ import absolute_import, print_function +import tvm +import topi + +###################################################################### +# How to use automatic differentiation +# ------------------------------------ +# +# Basically, all you need is the function :any:`tvm.differentiate` (also known as +# :any:`tvm.autodiff.differentiate`) which takes a tensor, differentiates it with respect to other +# given tensors using reverse accumulation, and applies certain optimizations. Let's consider an +# example: + +# inputs +X = tvm.placeholder((32, 100), name='X') +W = tvm.placeholder((10, 100), name='W') +B = tvm.placeholder((10,), name='B') + +# forward computation, basically topi.nn.dense(X, W, B) +k = tvm.reduce_axis((0, 100)) +T = tvm.compute((32, 10), lambda i, j: tvm.sum(X[i, k]*W[j, k], k)) +Y = topi.add(T, B) +L = topi.sum(Y) + +# gradients +[dL_dW, dL_dB] = tvm.differentiate(L, [W, B]) + +###################################################################### +# `L` is a scalar, so the results are gradients, however in general the result is a full Jacobian. +# :any:`tvm.differentiate` also accepts the third parameter if you want to multiply the Jacobian by +# another tensor. + +[dY_dW] = tvm.differentiate(Y, [W]) +print("Y.shape", Y.shape) +print("W.shape", W.shape) +print("dY_dW.shape", dY_dW.shape) + +[dL_dW] = tvm.differentiate(Y, [W], topi.full_like(Y, 1.0)) + +###################################################################### +# The result of :any:`tvm.differentiate` mimics a list, however it is an object that also contains +# all intermediate adjoints. Note also that the list of input tensors may be omitted, in which case +# the output will be differentiated with respect to all the inputs: + +res = tvm.differentiate(L) +dL_dX = res.adjoints[X] +dL_dT = res.adjoints[T] +dL_dY = res.adjoints[Y] + +###################################################################### +# Examples of generated gradients +# ------------------------------- +# +# Let's print out some generated code. We'll start with the simple matrix multiplication +# we've already differentiated. + +T1 = tvm.compute((32, 10), lambda i, j: tvm.sum(X[i, k]*W[j, k], k), name='matmul') +H1 = tvm.placeholder(T1.shape, name='H1') + +[dW] = tvm.differentiate(T1, [W], H1) +print(tvm.PrintTensorRecursively(dW)) + +###################################################################### +# (The only problem here is that an unnecessary intermediate tensor was extracted.) +# +# Now let's look at some problematic operations, like maxpool: + +X1 = tvm.placeholder((64, 32, 28, 28), name='X1') +W1 = tvm.placeholder((64, 64, 3, 3), name='W1') +Y1 = topi.nn.pool(X1, [2, 2], [2, 2], [0, 0, 0, 0], 'max') +H1 = tvm.placeholder(Y1.shape, name='H1') + +[dX1] = tvm.differentiate(Y1, [X1], H1) +print(tvm.PrintTensorRecursively(dX1)) + +###################################################################### +# Here the elements of the adjoint `H1` are multiplied by the elements of a mask (computed with +# the tensor called `extracted_tensor`). The mask represents whether an element is the maximum of +# its neighborhood. This is not the optimal solution. + +###################################################################### +# Overriding the differentiation function +# --------------------------------------- +# +# :any:`tvm.differentiate` internally calls a function which performs differentiation of a given +# tensor with respect to one of its inputs. This functions may be overridden for every tensor or for +# some particular tensors, which is useful when the default differentiation function does a poor job +# and we need to provide some gradients manually. Let's define our own naive version of this +# function: + +def custom_fdiff(out, inp, head): + return topi.tensordot(head, tvm.autodiff.Jacobian(out, inp, False), len(out.shape)) + +###################################################################### +# This function must take the tensors `out`, `inp` and `head` where `out` is the tensor that should +# be differentiated with respect to `inp`, `inp` is an immediate dependency of `out`, and `head` is +# the adjoint of `out` which should be multiplied by the result of differentiation. The +# differentiation itself is done using the function :any:`tvm.autodiff.Jacobian`, and the +# multiplication is done with :any:`topi.tensordot`. The default differentiation function +# :any:`tvm.autodiff.DiffBuildingBlock` does the same thing, but it also applies certain optimizing +# transformations. +# +# A custom differentiation function may be used like this: + +res = tvm.differentiate(L, fdiff=custom_fdiff) + +###################################################################### +# A custom differentiation function may be used to override differentiation for certain operations +# by checking if `out` is the operation we want to differentiate differently. However, there is an +# alternative way: using the `override` keyword argument. `override` should be a dict mapping +# tensors to their dependencies and custom differentiation functions. +# +# Let's consider the following scenario: we want to block gradient flow from `Y` to `X` and compute +# gradients of `Y` wrt `B` and `W` using the unoptimized differentiation function `custom_fdiff`. +# Note that `W` and `X` are not immediate dependencies of `Y`. + +def custom_fdiff_2(out, inputs, head): + assert out == Y + assert inputs == [X, W, B] + # block gradients to X + dX = topi.full(head.shape[:-len(out.shape)] + list(X.shape), head.dtype, 0) + # use the custom unoptimized differentiation function for the rest + return [dX] + list(tvm.differentiate(out, [W, B], head, fdiff=custom_fdiff)) + +res = tvm.differentiate(L, override={Y: ([X, W, B], custom_fdiff_2)}) + +###################################################################### +# There are several things to note: +# - For efficiency reasons the custom differentiation function used in `override` has a slightly +# different interface than the custom differentiation functions used for `fdiff`, namely it +# takes a list of inputs instead of a single input, and returns the list of the corresponding +# adjoints. +# - We had overridden the dependencies for `Y` (its immediate dependencies are `T` and `B`, but we +# used `X`, `W` and `B` instead), so we couldn't use :any:`tvm.autodiff.Jacobian` or +# `custom_fdiff` directly, since they expect the input to be an immediate dependency for the +# output. That's why we had to wrap them in the call to :any:`tvm.differentiate`.