Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVM] Automatic differentiation for tensor expressions #2498

Closed
wants to merge 11 commits into from
5 changes: 5 additions & 0 deletions docs/api/python/dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,8 @@ tvm.make
~~~~~~~~
.. automodule:: tvm.make
:members:

tvm.testing
~~~~~~~~~~~
.. automodule:: tvm.testing
:members:
1 change: 1 addition & 0 deletions docs/api/python/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Python API
container
function
autotvm
autodiff
graph_runtime
rpc
bridge
Expand Down
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ List of operators
topi.not_equal
topi.greater_equal
topi.less_equal
topi.tensordot
topi.image.resize


Expand Down Expand Up @@ -123,6 +124,7 @@ topi
.. autofunction:: topi.power
.. autofunction:: topi.greater
.. autofunction:: topi.less
.. autofunction:: topi.tensordot

topi.nn
~~~~~~~
Expand Down
152 changes: 152 additions & 0 deletions include/tvm/autodiff.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*!
* Copyright (c) 2018 by Contributors
* \file autodiff.h
* \brief Automatic differentiation of IR Expr.
sgrechanik-h marked this conversation as resolved.
Show resolved Hide resolved
*/
#ifndef TVM_AUTODIFF_H_
#define TVM_AUTODIFF_H_

#include <tvm/ir.h>
#include <tvm/tensor.h>

namespace tvm {
namespace ir {

class DifferentiationResultNode;

/*!
* \brief A result of differentiation.
*/
class DifferentiationResult : public NodeRef {
public:
/*! \brief default constructor, used internally */
DifferentiationResult() {}
explicit DifferentiationResult(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const DifferentiationResultNode* operator->() const;
/*! \brief specify container node */
using ContainerType = DifferentiationResultNode;
};

/*! \brief Node to represent a differentiation result */
class DifferentiationResultNode : public Node {
public:
/*! \brief The requested adjoints, i.e. Jacobians or gradients wrt to the given inputs */
Array<Tensor> result;
/*! \brief A map from tensors to the corresponding adjoints (including internal nodes) */
Map<Tensor, Tensor> adjoints;
/*! \brief Single summands of the adjoints*/
Map<Tensor, Map<Tensor, Tensor>> adjoint_summands;
/*! \brief constructor */
DifferentiationResultNode() {}

void VisitAttrs(AttrVisitor* v) final {
v->Visit("result", &result);
v->Visit("adjoints", &adjoints);
v->Visit("adjoint_summands", &adjoint_summands);
}
TVM_DLL static DifferentiationResult make(Array<Tensor> result,
Map<Tensor, Tensor> adjoints,
Map<Tensor, Map<Tensor, Tensor>> adjoint_summands);

static constexpr const char* _type_key = "DifferentiationResult";
TVM_DECLARE_NODE_TYPE_INFO(DifferentiationResultNode, Node);
};

inline const DifferentiationResultNode* DifferentiationResult::operator->() const {
sgrechanik-h marked this conversation as resolved.
Show resolved Hide resolved
return static_cast<const DifferentiationResultNode*>(node_.get());
}


/*! \brief A type of a "local" differentiation function for reverse mode AD
*
* A function of this type is a building block for reverse-mode automatic differentiation. It
* should take three tensors: `output`, `input` and `head`, `head` being the adjoint corresponding
* to the `output`, and return (a summand of) the adjoint corresponding to the input. In other
* words, it should differentiate `output` wrt `input` and multiply the result by `head` with
* tensor dot product (`head` should be on the left of the multiplication). `input` should be an
* immediate dependency of `output` (should be called from within the body of `output`).
*
* See also ::DiffBuildingBlock, which might be considered the reference implementation.
*/
using FDiffBuildingBlock = std::function<Tensor(const Tensor& output,
const Tensor& input,
const Tensor& head)>;

/*!
* \brief Take the derivative of the expression with respect to the given variable.
* \param expr The expression to differentiate.
* \param var The variable to differentiate with respect to.
* \return The expression for the derivative.
*/
EXPORT Expr Derivative(const Expr& expr, const VarExpr& var);

/*!
* \brief Get the tensor representing the Jacobian of the output with respect to the input.
*
* Note that if \p output depends on \p input indirectly (by using some other tensor
* depending on \p input), this dependency won't contribute to the resulting Jacobian.
* For such cases use the function ::Differentiate.
*
* \param output The tensor to differentiate.
* \param input The input tensor, which \p output should directly use.
* \param optimize Whether to perform optimizations like lifting of nonzeroness conditions.
* \return The tensor representing the Jacobian of shape `output.shape + input.shape`.
*/
EXPORT Tensor Jacobian(const Tensor& output, const Tensor& input, bool optimize = true);

/*!
* \brief The building block for reverse-mode AD.
*
* Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor
* dot product. \p input must be an immediate dependency of \p output (must be called from within
* the body of \p output). That is, the function will compute a summand of the adjoint for \p input
* given the adjoint for \p output (which is called \p head here).
*
* \param output The tensor to differentiate.
* \param input The input tensor, which \p output should directly use.
* \param head The adjoint of \p output. Must be of shape `prefix + output.shape`
* \return The tensor representing the adjoint of \p input of shape `prefix + input.shape`.
*/
EXPORT Tensor DiffBuildingBlock(const Tensor& output, const Tensor& input, const Tensor& head);

/*!
* \brief Perform reverse mode automatic differentiation.
*
* Each item of the `result` field of the result is an adjoint for the corresponding item of
* \p inputs, i.e. \p head multiplied by the Jacobian of \p output with respect to the
* corresponding item of \p inputs.
*
* \param output The tensor to differentiate.
* \param inputs The array of input tensors. When the array is empty, will perform differentiation
* wrt all tensors the output depends on.
* \param head The adjoint of the output, in other words, some tensor, by which the Jacobians
* will be multiplied. Its shape must be of the form `prefix + output.shape`. If the
* null pointer is provided, the identity tensor of shape
* `output.shape + output.shape` will be used.
* \param fdiff The function performing differentiation and multiplication, see
* ::FDiffBuildingBlock.
* \param override_deps A map from tensors to their dependencies (`InputTensors()` are used by
* default). Overriding dependencies may be useful to treat a group of tensors
* as a single supertensor. In this case the fdiff functions should also be
* modified accordingly.
* \return An object of type DifferentiationResult which contains three fields:
* - `result` An array of adjoints corresponding to \p inputs.
* - `adjoints` A map from tensors to the corresponding adjoints (includes intermediate
* tensors).
* - `adjoint_summands` A map from tensors to maps from parent tensors to individual
* summands of the adjoint.
*/
EXPORT DifferentiationResult Differentiate(
const Tensor& output,
const Array<Tensor>& inputs = Array<Tensor>(),
const Tensor& head = Tensor(),
const FDiffBuildingBlock& fdiff = DiffBuildingBlock,
const Map<Tensor, Array<Tensor>>& override_deps = Map<Tensor, Array<Tensor>>());

} // namespace ir
} // namespace tvm
#endif // TVM_AUTODIFF_H_
30 changes: 30 additions & 0 deletions include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ inline const uint64_t* as_const_uint(const Expr& x) {
*/
inline bool is_const_int(const Expr& x, int64_t value);

/*!
* \brief Check if the given expr is a const of any type equal to the given integer value.
* \param e The expression.
* \param value The value to compare to.
* \return Whether the expression is a const equal to the value.
* \tparam ValueType The value type
*/
template <typename ValueType>
inline bool is_const_value(const Expr& e, ValueType value);

/*!
* \brief Check whether stmt is nop.
* \param stmt The input statement
Expand Down Expand Up @@ -515,6 +525,26 @@ inline bool is_const_int(const Expr& x, int64_t value) {
return false;
}

template <typename ValueType>
inline bool is_const_value(const Expr& e, ValueType value) {
static_assert(std::is_integral<ValueType>::value,
"Comparison to non-integer values is forbidden.");
// This implementation was copy-pasted from HalideIR
if (const ir::IntImm* i = e.as<ir::IntImm>()) {
return i->value == value;
} else if (const ir::UIntImm* i = e.as<ir::UIntImm>()) {
return (value >= 0) && (i->value == (uint64_t)value);
} else if (const ir::FloatImm* i = e.as<ir::FloatImm>()) {
return i->value == value;
} else if (const ir::Cast* c = e.as<ir::Cast>()) {
return is_const_value(c->value, value);
} else if (const ir::Broadcast* b = e.as<ir::Broadcast>()) {
return is_const_value(b->value, value);
} else {
return false;
}
}

inline bool is_no_op(const Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<ir::Evaluate>()) {
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from . import generic
from . import hybrid
from . import testing
from . import autodiff

from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
Expand All @@ -36,6 +37,7 @@
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
from .autodiff import differentiate

# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
Loading