Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relax][Training] Add automatic differentiation pass (tlc-pack#103)
This is the PR following tlc-pack#55 after source branch moved to personal repo. This PR is based on tlc-pack#98. This PR adds the new automatic differentiation API: - `Gradient(func: GlobalVar, require_grads: Optional[Union[Var, List[Var]]] = None) -> tvm.ir.transform.Pass` - transforms the given funcion in the IRModule, and adds a new function that calculates the gradient with regard to the function's output Now Gradient only supports differentiating a function in the IRModule with one dataflow block with respect to the only return value of the function, which needs to be scalar. This PR writes two files for unit test: - `tests/python/relax/test_transform_gradient.py` only contains `assert_structural_equal` assertions. - `tests/python/relax/test_transform_gradient_numeric.py` contains numeric checks, including manually derived gradients and the numerical differentiation method `check_numerical_grads`. Checkpoints: - [x] Refactor to use CopyWithNewParams and ExprFunctor - [x] Check int64/int32 tensors should not be differentiated (now only check in params) - [x] Rebase & migrate to StructInfo - [x] Refactor about Tuple - [x] Refactor about NestedMsg - [x] Support ops taking in tuple or returning tuple - [x] Eliminating collapse_sum_to (done in tlc-pack#98) Future: - (Not in this PR) Handle undefined gradient in add and return value - Now we handle them as zeros Co-authored-by: SiriusNEO <1713833595@qq.com>
- Loading branch information