Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVM] Automatic differentiation for tensor expressions #2498

Closed
wants to merge 11 commits into from
30 changes: 30 additions & 0 deletions include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ inline const uint64_t* as_const_uint(const Expr& x) {
*/
inline bool is_const_int(const Expr& x, int64_t value);

/*!
* \brief Check if the given expr is a const of any type equal to the given integer value.
* \param e The expression.
* \param value The value to compare to.
* \return Whether the expression is a const equal to the value.
* \tparam ValueType The value type
*/
template <typename ValueType>
inline bool is_const_value(const Expr& e, ValueType value);

/*!
* \brief Check whether stmt is nop.
* \param stmt The input statement
Expand Down Expand Up @@ -515,6 +525,26 @@ inline bool is_const_int(const Expr& x, int64_t value) {
return false;
}

template <typename ValueType>
inline bool is_const_value(const Expr& e, ValueType value) {
static_assert(std::is_integral<ValueType>::value,
"Comparison to non-integer values is forbidden.");
// This implementation was copy-pasted from HalideIR
if (const ir::IntImm* i = e.as<ir::IntImm>()) {
return i->value == value;
} else if (const ir::UIntImm* i = e.as<ir::UIntImm>()) {
return (value >= 0) && (i->value == (uint64_t)value);
} else if (const ir::FloatImm* i = e.as<ir::FloatImm>()) {
return i->value == value;
} else if (const ir::Cast* c = e.as<ir::Cast>()) {
return is_const_value(c->value, value);
} else if (const ir::Broadcast* b = e.as<ir::Broadcast>()) {
return is_const_value(b->value, value);
} else {
return false;
}
}

inline bool is_no_op(const Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<ir::Evaluate>()) {
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from . import generic
from . import hybrid
from . import testing
from . import autodiff

from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
Expand All @@ -36,6 +37,7 @@
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
from .autodiff import differentiate

# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
130 changes: 130 additions & 0 deletions python/tvm/autodiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Namespace of autodiff-related functions.

The functions are automatically exported from C++ side via PackedFunc.
You can read "include/tvm/autodiff.h" for the function signature of these functions.
"""
import logging

from ._ffi.function import _init_api
from ._ffi.node import NodeBase, register_node

_init_api("tvm.autodiff")

@register_node
class DifferentiationResult(NodeBase):
"""Result of differentiation.

Parameters
----------
result : list of Tensor
The requested adjoints, i.e. the jacobians or gradients of the given output
wrt to the given inputs.

adjoints : dict from Tensor to Tensor
A map from tensors to the corresponding adjoints (including internal nodes).

adjoint_summands : dict from Tensor to dict from Tensor to Tensor
Single summands of the adjoints.
"""
def __getattr__(self, name):
# Here we convert tvm Maps to dicts because Map compares keys by reference which is
# wrong for Tensors. Hopefully, in the future Map gets fixed somehow, and this function
# may be removed then.
res = NodeBase.__getattr__(self, name)
if name == 'adjoints':
return dict(res.items())
if name == 'adjoint_summands':
return {k: dict(v.items()) for k, v in res.items()}
return res

def __getitem__(self, i):
return self.result[i]

def __len__(self):
return len(self.result)


def differentiate(output, inputs=None, head=None, manual=None, fdiff=None):
"""Perform reverse-mode automatic differentiation.

Example::

x = tvm.placeholder((32, 3, 28, 28), name='x')
w1 = tvm.placeholder((10, 3, 3, 3), name='w1')
w2 = tvm.placeholder((10, 10, 3, 3), name='w2')
y = topi.sum(topi.nn.conv2d(topi.nn.conv2d(x, w1, 1, 0), w2, 1, 0))

[dw1, dw2] = tvm.differentiate(y, [w1, w2])

Parameters
----------
output : Tensor
The tensor to differentiate.

inputs : list of Tensor
The list of input tensors. When the list is empty or None, will perform
differentiation wrt all tensors the output depends on (i.e. will compute all
adjoints and populate the corresponding dict, but the list of results
will be empty).

head : Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be understood as the gradient of the output tensor w.r.t. the final output of a math expression? Could you explain why its shape is output.shape+output.shape when None is passed? My understanding is that it should default to ones(output.shape). Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please disregard my question above. After reading your implementation of DiffBuildingBlock, I think I got the meaning of defaulting head to shape output.shape+output.shape.

The adjoint of the output, in other words, some tensor, by which the Jacobians
will be multiplied. Its shape must be of the form `prefix + output.shape`.
If `None` is passed, the identity tensor of shape `output.shape + output.shape`
will be used.

manual : dict (Tensor, Tensor) -> function
A dict providing custom multiplication-differentiation functions (see `fdiff`)
for certain pairs of tensors. Each pair consists of an output and an input tensor,
the input one being an immediate dependency of the output one. Pairs of the form
`(None, tensor)` and `(tensor, None)` are allowed, `None` working as a wildcard.

fdiff : function (Tensor, Tensor, Tensor) -> Tensor
The default function performing differentiation and multiplication, by default
`tvm.autodiff.FDiffBuildingBlock` is used. The function must accept three
parameters:
- `output` - an output tensor
- `input` - an input tensor
- `head` - the adjoint of the output tensor
The result should be `head` multiplied by the jacobian of `output` wrt `input`

Returns
-------
differentiation_result : DifferentiationResult
"""
if inputs is None:
inputs = []

if fdiff is None:
fdiff = DiffBuildingBlock

if manual is not None:
if not isinstance(manual, dict):
manual = dict(manual)

# pylint: disable=dangerous-default-value
used_items = set()

def _modified_fdiff(out, inp, head, manual=manual, old_fdiff=fdiff, used_items=used_items):
if (out, inp) in manual:
used_items.add((out, inp))
return manual[(out, inp)](out, inp, head)
if (out, None) in manual:
used_items.add((out, None))
return manual[(out, None)](out, inp, head)
if (None, inp) in manual:
used_items.add((None, inp))
return manual[(None, inp)](out, inp, head)
return old_fdiff(out, inp, head)

fdiff = _modified_fdiff

res = Differentiate(output, inputs, head, fdiff)

if manual is not None:
for k in manual:
if k not in used_items:
logging.warning("The manually specified differentiation function "
"for %s hasn't been used", k)

return res
172 changes: 172 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" TVM testing utilities """
import logging
import numpy as np
import tvm

def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
""" Version of np.testing.assert_allclose with `atol` and `rtol` fields set
Expand Down Expand Up @@ -145,3 +146,174 @@ def compare_derivative(j, n_der, grad):
logging.info("Numerical grad test wrt '%s' of shape %s passes, "
"dist = %f, max_diff = %f, avg_diff = %f",
x_name, grad.shape, dist, max_diff, avg_diff)


class PerformanceEstimate:
"""A result of static performance estimation.

Parameters
----------
iterations : int
The total number of iterations of all the loops.

multiplications : int
The total number of expensive operations like multiplications.

memory : int
The amount of memory to allocate.
"""
def __init__(self, iterations=0, multiplications=0, memory=0):
self.iterations = iterations
self.multiplications = multiplications
self.memory = memory

def as_tuple(self):
return (self.iterations, self.multiplications, self.memory)

def __add__(self, other):
return PerformanceEstimate(iterations=self.iterations + other.iterations,
multiplications=self.multiplications + other.multiplications,
memory=self.memory + other.memory)

def max(self, other):
return PerformanceEstimate(
iterations=max(self.iterations, other.iterations),
multiplications=max(self.multiplications, other.multiplications),
memory=max(self.memory, other.memory))

def times(self, iters):
return PerformanceEstimate(iterations=self.iterations*iters,
multiplications=self.multiplications*iters,
memory=self.memory)

def __repr__(self):
return "PerformanceEstimate(iterations={}, multiplications={}, memory={})".format(
self.iterations, self.multiplications, self.memory)

def __le__(self, other):
return \
self.iterations <= other.iterations and \
self.multiplications <= other.multiplications and \
self.memory <= other.memory


def estimate_performance(s, processed_ops=None):
"""Statically estimate performance of statements, expressions and tensors. Note that the
estimate is very rough, it mustn't be used to predict future performance, its only purpose is
to detect possible performance regressions.

Parameters:
-----------
s
A statement, an expression, a tensor, an operation, or a list
of any of the above.

Returns
-------
estimate : PerformanceEstimate
"""
from tvm import stmt
from tvm import expr

if processed_ops is None:
processed_ops = {}
res = estimate_performance(s, processed_ops)
for op_est in processed_ops.values():
res += op_est
return res

est = lambda e, processed_ops=processed_ops: estimate_performance(e, processed_ops)

def _prod(elems):
res = 1
for x in elems:
res *= x
return res

if s is None or isinstance(s, (stmt.AssertStmt, stmt.Free, stmt.Prefetch,
expr.ConstExpr, expr.Var, tvm.tensor.PlaceholderOp)):
return PerformanceEstimate()
elif isinstance(s, list):
res = PerformanceEstimate()
for item in s:
res += est(item)
return res
elif s in processed_ops:
return PerformanceEstimate()
elif isinstance(s, stmt.Allocate):
mem = _prod([e.value for e in s.extents])
return est(s.condition) + est(s.body) + PerformanceEstimate(memory=mem)
elif isinstance(s, stmt.Block):
return est(s.first) + est(s.rest)
elif isinstance(s, stmt.Evaluate):
return est(s.value)
elif isinstance(s, stmt.For):
body_est = est(s.body)
body_est.iterations = max(1, body_est.iterations)
return body_est.times(s.extent.value)
elif isinstance(s, stmt.IfThenElse):
return est(s.condition) + est(s.then_case) + est(s.else_case)
elif isinstance(s, stmt.LetStmt):
return est(s.value) + est(s.body)
elif isinstance(s, (stmt.ProducerConsumer, stmt.AttrStmt)):
return est(s.body)
elif isinstance(s, stmt.Provide):
return est(s.value)
elif isinstance(s, stmt.Realize):
return est(s.condition) + est(s.body)
elif isinstance(s, stmt.Store):
return est(s.value) + est(s.index) + est(s.predicate)
elif isinstance(s, (expr.Mul, expr.Div, expr.Mod)):
return est(s.a) + est(s.b) + PerformanceEstimate(multiplications=1)
elif isinstance(s, (expr.BinaryOpExpr, expr.CmpExpr, expr.LogicalExpr)):
if not hasattr(s, 'b'):
return est(s.a)
return est(s.a) + est(s.b)
elif isinstance(s, expr.Call):
res = PerformanceEstimate()
for a in s.args:
res += est(a)
if s.call_type == expr.Call.Halide:
# The estimate is added to processed_ops, we don't need the result here
est(s.func)
elif s.name == "tvm_if_then_else":
pass
else:
# expr.If it is a non-halide call (e.g. exp or log), consider it a mul
res += PerformanceEstimate(multiplications=1)
return res
elif isinstance(s, expr.Cast):
return est(s.value)
elif isinstance(s, expr.Load):
return est(s.index) + est(s.predicate)
elif isinstance(s, expr.Select):
return est(s.condition) + est(s.true_value) + est(s.false_value)
elif isinstance(s, expr.Reduce):
iterations = _prod([iv.dom.extent.value for iv in s.axis])
res = PerformanceEstimate()
for id_elem in s.combiner.identity_element:
res += est(id_elem)
on_each_iter = est(s.condition)
for src in s.source:
on_each_iter += est(src)
for comb_res in s.combiner.result:
on_each_iter += est(comb_res)
on_each_iter.iterations = max(1, on_each_iter.iterations)
return res + on_each_iter.times(iterations)
elif isinstance(s, tvm.tensor.Tensor):
return est(s.op)
elif isinstance(s, tvm.tensor.ComputeOp):
iterations = _prod([iv.dom.extent.value for iv in s.axis])
if s.reduce_axis:
res = est(s.body[0])
else:
res = PerformanceEstimate()
for b in s.body:
res += est(b)
res.iterations = max(1, res.iterations)
res = res.times(iterations) + PerformanceEstimate(memory=iterations*len(s.body))
processed_ops[s] = res
return PerformanceEstimate()

raise ValueError("Don't know how to estimate performance of {} of type {}"
.format(s, type(s)))
Loading