-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TVM] Automatic differentiation for tensor expressions #2498
Closed
Closed
Changes from 1 commit
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
a0822d3
[TVM] Automatic differentiation for tensor expressions
sgrechanik-h 0de13b9
Fix the failing tests
sgrechanik-h 9f99cd7
Fix topi.take
sgrechanik-h 0775a4e
Fix a python2 syntax error
sgrechanik-h d90323e
Move autodiff.h to include/tvm
sgrechanik-h 204f2f7
Reduce the probability of test failure because of integer overflow
sgrechanik-h 952629b
Fix a problem with free vars
sgrechanik-h 75418c7
[AD] More intrinsics; Fixed treatment of ints
sgrechanik-h cf5083a
[AD] Autodiff/relay integration
sgrechanik-h 5ae4ac1
Simplified overriding of gradients; Tutorial and docs
sgrechanik-h a342e5e
Several fixes
sgrechanik-h File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
"""Namespace of autodiff-related functions. | ||
|
||
The functions are automatically exported from C++ side via PackedFunc. | ||
You can read "include/tvm/autodiff.h" for the function signature of these functions. | ||
""" | ||
import logging | ||
|
||
from ._ffi.function import _init_api | ||
from ._ffi.node import NodeBase, register_node | ||
|
||
_init_api("tvm.autodiff") | ||
|
||
@register_node | ||
class DifferentiationResult(NodeBase): | ||
"""Result of differentiation. | ||
|
||
Parameters | ||
---------- | ||
result : list of Tensor | ||
The requested adjoints, i.e. the jacobians or gradients of the given output | ||
wrt to the given inputs. | ||
|
||
adjoints : dict from Tensor to Tensor | ||
A map from tensors to the corresponding adjoints (including internal nodes). | ||
|
||
adjoint_summands : dict from Tensor to dict from Tensor to Tensor | ||
Single summands of the adjoints. | ||
""" | ||
def __getattr__(self, name): | ||
# Here we convert tvm Maps to dicts because Map compares keys by reference which is | ||
# wrong for Tensors. Hopefully, in the future Map gets fixed somehow, and this function | ||
# may be removed then. | ||
res = NodeBase.__getattr__(self, name) | ||
if name == 'adjoints': | ||
return dict(res.items()) | ||
if name == 'adjoint_summands': | ||
return {k: dict(v.items()) for k, v in res.items()} | ||
return res | ||
|
||
def __getitem__(self, i): | ||
return self.result[i] | ||
|
||
def __len__(self): | ||
return len(self.result) | ||
|
||
|
||
def differentiate(output, inputs=None, head=None, manual=None, fdiff=None): | ||
"""Perform reverse-mode automatic differentiation. | ||
|
||
Example:: | ||
|
||
x = tvm.placeholder((32, 3, 28, 28), name='x') | ||
w1 = tvm.placeholder((10, 3, 3, 3), name='w1') | ||
w2 = tvm.placeholder((10, 10, 3, 3), name='w2') | ||
y = topi.sum(topi.nn.conv2d(topi.nn.conv2d(x, w1, 1, 0), w2, 1, 0)) | ||
|
||
[dw1, dw2] = tvm.differentiate(y, [w1, w2]) | ||
|
||
Parameters | ||
---------- | ||
output : Tensor | ||
The tensor to differentiate. | ||
|
||
inputs : list of Tensor | ||
The list of input tensors. When the list is empty or None, will perform | ||
differentiation wrt all tensors the output depends on (i.e. will compute all | ||
adjoints and populate the corresponding dict, but the list of results | ||
will be empty). | ||
|
||
head : Tensor | ||
The adjoint of the output, in other words, some tensor, by which the Jacobians | ||
will be multiplied. Its shape must be of the form `prefix + output.shape`. | ||
If `None` is passed, the identity tensor of shape `output.shape + output.shape` | ||
will be used. | ||
|
||
manual : dict (Tensor, Tensor) -> function | ||
A dict providing custom multiplication-differentiation functions (see `fdiff`) | ||
for certain pairs of tensors. Each pair consists of an output and an input tensor, | ||
the input one being an immediate dependency of the output one. Pairs of the form | ||
`(None, tensor)` and `(tensor, None)` are allowed, `None` working as a wildcard. | ||
|
||
fdiff : function (Tensor, Tensor, Tensor) -> Tensor | ||
The default function performing differentiation and multiplication, by default | ||
`tvm.autodiff.FDiffBuildingBlock` is used. The function must accept three | ||
parameters: | ||
- `output` - an output tensor | ||
- `input` - an input tensor | ||
- `head` - the adjoint of the output tensor | ||
The result should be `head` multiplied by the jacobian of `output` wrt `input` | ||
|
||
Returns | ||
------- | ||
differentiation_result : DifferentiationResult | ||
""" | ||
if inputs is None: | ||
inputs = [] | ||
|
||
if fdiff is None: | ||
fdiff = DiffBuildingBlock | ||
|
||
if manual is not None: | ||
if not isinstance(manual, dict): | ||
manual = dict(manual) | ||
|
||
# pylint: disable=dangerous-default-value | ||
used_items = set() | ||
|
||
def _modified_fdiff(out, inp, head, manual=manual, old_fdiff=fdiff, used_items=used_items): | ||
if (out, inp) in manual: | ||
used_items.add((out, inp)) | ||
return manual[(out, inp)](out, inp, head) | ||
if (out, None) in manual: | ||
used_items.add((out, None)) | ||
return manual[(out, None)](out, inp, head) | ||
if (None, inp) in manual: | ||
used_items.add((None, inp)) | ||
return manual[(None, inp)](out, inp, head) | ||
return old_fdiff(out, inp, head) | ||
|
||
fdiff = _modified_fdiff | ||
|
||
res = Differentiate(output, inputs, head, fdiff) | ||
|
||
if manual is not None: | ||
for k in manual: | ||
if k not in used_items: | ||
logging.warning("The manually specified differentiation function " | ||
"for %s hasn't been used", k) | ||
|
||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be understood as the gradient of the output tensor w.r.t. the final output of a math expression? Could you explain why its shape is
output.shape+output.shape
whenNone
is passed? My understanding is that it should default toones(output.shape)
. Thanks.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please disregard my question above. After reading your implementation of
DiffBuildingBlock
, I think I got the meaning of defaultinghead
to shapeoutput.shape+output.shape
.