Skip to content

Commit

Permalink
[Relax][Training] Add automatic differentiation pass (tlc-pack#103)
Browse files Browse the repository at this point in the history
This is the PR following tlc-pack#55 after source branch moved to personal repo.

This PR is based on tlc-pack#98.

This PR adds the new automatic differentiation API:
- `Gradient(func: GlobalVar, require_grads: Optional[Union[Var,
List[Var]]] = None) -> tvm.ir.transform.Pass`
- transforms the given funcion in the IRModule, and adds a new function
that calculates the gradient with regard to the function's output

Now Gradient only supports differentiating a function in the IRModule
with one dataflow block with respect to the only return value of the
function, which needs to be scalar.

This PR writes two files for unit test:
- `tests/python/relax/test_transform_gradient.py` only contains
`assert_structural_equal` assertions.
- `tests/python/relax/test_transform_gradient_numeric.py` contains
numeric checks, including manually derived gradients and the numerical
differentiation method `check_numerical_grads`.

Checkpoints:
- [x] Refactor to use CopyWithNewParams and ExprFunctor
- [x] Check int64/int32 tensors should not be differentiated (now only
check in params)
- [x] Rebase & migrate to StructInfo
- [x] Refactor about Tuple
- [x] Refactor about NestedMsg
- [x] Support ops taking in tuple or returning tuple
- [x] Eliminating collapse_sum_to (done in tlc-pack#98)

Future:
- (Not in this PR) Handle undefined gradient in add and return value
	- Now we handle them as zeros

Co-authored-by: SiriusNEO <1713833595@qq.com>
  • Loading branch information
2 people authored and MasterJH5574 committed Jan 31, 2023
1 parent b2ae2c3 commit 8fc0ffb
Show file tree
Hide file tree
Showing 7 changed files with 1,784 additions and 10 deletions.
20 changes: 20 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,26 @@ TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> target_opt
*/
TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype);

/*!
* \brief Reverse-mode automatic differentiation.
*
* Now only supports differentiating one function in the IRModule with one dataflow block
* with respect to the only return value of the function, which needs to be scalar.
*
* For a given function specified by the input global var, it generates a new function with the name
* `[name of original function] + "_adjoint"`. The new function computes the adjoints of the
* specified arguments of the original function with respect to the only one return value of the
* original function.
*
* For examples, see the MLP examples in `tests/python/relax/test_transform_gradient.py` and
* `tests/python/relax/test_transform_gradient_numeric.py`.
*
* \param global_var The GlobalVar of the specified function.
* \param require_grads The relax variables whose adjoints are needed. Must be parameters of the
* given function. If it is not specified, adjoints of all arguments would be computed.
* \return The Pass.
*/
TVM_DLL Pass Gradient(GlobalVar global_var, Optional<Array<Var>> require_grads = NullOpt);
} // namespace transform
} // namespace relax
} // namespace tvm
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tvm.ir
from tvm.runtime import NDArray
from . import _ffi_api
from ..expr import Var, GlobalVar


@tvm._ffi.register_object("relax.FunctionPass")
Expand Down Expand Up @@ -410,6 +411,43 @@ def ToMixedPrecision(out_dtype="float32") -> tvm.ir.transform.Pass:
return _ffi_api.ToMixedPrecision(out_dtype) # type: ignore


def Gradient(
global_var: GlobalVar, require_grads: Optional[Union[Var, List[Var]]] = None
) -> tvm.ir.transform.Pass:
"""Reverse-mode automatic differentiation.
Now only supports differentiating one function in the IRModule with one dataflow block
with respect to the only return value of the function, which needs to be scalar.
For a given function specified by the input global var, it generates a new function with the
name `[name of original function] + "_adjoint"`. The new function computes the adjoints of the
specified arguments of the original function with respect to the only one return value of the
original function.
For examples, see the MLP examples in tests/python/relax/test_transform_gradient.py and
tests/python/relax/test_transform_gradient_numeric.py.
Parameters
----------
global_var : relax.GlobalVar
The GlobalVar of the specific function.
require_grads : Optional[Union[relax.Var, List[relax.Var]]]
The relax variables whose adjoints is needed. Must be parameters of the given function and
should not be duplicate. If it is not specified, adjoints of all arguments would be
computed.
Returns
-------
ret : tvm.ir.transform.Pass
The Pass.
"""
if require_grads is not None and not isinstance(require_grads, list):
require_grads = [require_grads]

return _ffi_api.Gradient(global_var, require_grads) # type: ignore


def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""

Expand Down
Loading

0 comments on commit 8fc0ffb

Please sign in to comment.