Skip to content

Commit

Permalink
Add support for forward and reverse high-order automatic differentiat…
Browse files Browse the repository at this point in the history
…ion mechanism (#41919)

* Updated triple_grad_check func

* add todo for gradient checker and refine some comments

* remove additional code

* add test for warnging in backward.py

* format python code

* support multi input in triple gradient checker

* Add matmul triple grad kernel

* Updated comments of TODO

* Supported some special tests

* Change code-format to follow CI std

* Updated gradient_checker.py

* Fix conflicts

* Removed unnecessary printing log

* Change code style to follow CI std

* merge upstream

* add priops.py

* add_p

* rm useless files

* add sub_p mul_p div_p

* add sqrt_p and tanh_p

* add reshape_p

* add broadcast_p

* Add python primitive wrappers.

* Jvp rules updated.

* JVP rules done for all the 17 primops.

* quick check and fixes.

* add jvp(op, *args)

* add broadcast_p fill_constant_p matmul_p reduce_p reshape_p transpose_p

* add split_p and concat_p

* add gather_p and scatter_add_p

* add slice_select_p and slice_assign_p

* Add transpose rules.

* add multi input check for add_p, sub_p, mul_p, div_p

* update concat_p

* Linearize and transpose in progress..

* refine gather_p and scatter_add_p

* updated.

* update transpose.

* refine slice_assign_p and slice_select_p

* init commit for lower

* Merged with primitive ops.

* small update

* add rules for orig2prim and prim2orig

* add 9 test for prim ops

* add more test and fix some bug

* add more test

* register proto

* Adding primops test.

* add shape valid check for broadcast_p op, and add keepdim attr into reduce_p op proto

* support multi input and multi output for split_p and concat_p

* Test updated.

* update

* fix slice bug for slice_select_p and slice_assign_p

* updated.

* Ops updated.

* Refactor and bug fixes.

* updated.

* finish orig2prim and prim2orig rules

* dtype for axis attr should be long int

* update dtype for axis attr int64_t

* update for iscan CI

* Update primx.

* Refactor vars in primx.

* update for lower transform

* add more shape and dtype check

* update primx.py

* change IndexTensor into int32 dtype

* update

* Fix linearize and transpose.

* Update is_dot

* Update is_dot

* Update is_dot

* add gradient aggregation, fix add_transpose.

* pass first linearize+transpose test.

* update test

* refactor op registration and primx.

* update rule for slice_assign

* try test lower

* update orig2prim and prim2orig

* pass simple lower pass

* update

* Update input types in the unit test.

* orig2prim segfault.

* 50% for adam.minimize

* test updated.

* temp fix erros in removing vars.

* primx updated.

* update for matmul_v2 and reshape2 orig2prim

* update for minimize

* Refine primrules

* Remove some code

* supporting unused and unreachable vars.

* update for use prim2orig in minimize

* fix gather and scatter_add transpose.

* Add rules UT

* update scatter_add

* Refine UT code

* fix nonetype check in topo

* Update gather_p pywrapper.

* remove useless print

* Merge tongxin PR and refine code

* readd some test

* rm useless print

* polish code.

* fix bug in minimize

* add get_input_var_list and get_output_var_list and use it in lower

* Fix scatter_add_p prim2orig

* Update code and fix orig2prim/prim2orig UT

* delete vars after block.desc._remove

* Improve ops and vars clean up logics.

* fix some bug in linearize and lower

* update tanh transpose.

* use set instead of list for var2remove

* test updated.

* polish code.

* fix dot2bar delete.

* merge tx/ad

* add indextensor_dot for gather and scatter_add

* add sorted for set

* Fix scale_orig2prim params

* fix some syntax bug

* add golbal_lower_update list

* Better handling of unused vars.

* update tests.

* Fix elementwise_sub orig2prim

* support none for transpose rule

* Merge and add transform UT

* fix a bug in transpose

* Fix transpose and UT

* a hacky fix for cancat op

* Fix exector place

* Refine variable name

* Add elementwise_mul orig2prim and support p_norm when p=1

* Add sqrt orig2prim rule and UT

* merge wz test

* rename files, add enable_prim, disable_prim, prim_enabled, delete global_lower_update

* fix a bug in test_ad_transform_trans

* revert modify in framework.py

* add paddle.fluid.incubate.ad_transform to  python/setup.py.in

* Fix remove vars error

* Fix p_norm_orig2prim

* merge wz

* Modify the code directory

* Add utils.py and remove get_input/output_vars functions

* Update maolin code

* Rename UT and refine test_ad_transform_primops

* Fix div_p jvp rule

* Add higher derivatives UT

* Remove UT to autograd dir

* Fix comments

* import paddle in primops.py

* Add some error message for assert

* Refine UT class name and refine some comments in primreg.py

* update minimize of paddle/optimizer for supporting new autograd

* resolve cicular importing between backward.py and optimizer.py

* fill gradients and minimize unittest

* Replace `assert isinstance` with `raise TypeError`

* Add some assert message for primx.py

* Polish variable name

* Add some assert message

* add some docstring

* refine some name

* update the format of english documents

* Split test_transform.py to two files to avoid ci error

* fix the document format of enable_prim/disable_prim/prim2orig/prim_enabled

* polish test_gradients_and_minimize

* add default value for prim_enabled api doc

* Remove some UT to avoid windows ci error

* Enlarge test_gradients_and_minimize limit time

* Fix ut limit time

Co-authored-by: veyron95 <veyron_wu@163.com>
Co-authored-by: Jiabin Yang <360788950@qq.com>
Co-authored-by: levi131 <limaolin01@baidu.com>
Co-authored-by: Tongxin Bai <waffle.bai@gmail.com>
Co-authored-by: Xiaoxu Chen <chenxx_id@163.com>
Co-authored-by: levi131 <83750468+levi131@users.noreply.github.com>
  • Loading branch information
7 people authored May 18, 2022
1 parent b9342a8 commit f6ee202
Show file tree
Hide file tree
Showing 17 changed files with 3,803 additions and 78 deletions.
54 changes: 0 additions & 54 deletions python/paddle/autograd/primreg.py

This file was deleted.

6 changes: 6 additions & 0 deletions python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from collections.abc import Sequence
except:
from collections import Sequence

__all__ = [
'append_backward',
'gradients',
Expand Down Expand Up @@ -2113,6 +2114,11 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
check_type(target_gradients, 'target_gradients', (
framework.Variable, list, tuple, type(None)), 'paddle.static.gradients')

from ..incubate.autograd.primx import _gradients
from ..incubate.autograd.utils import prim_enabled
if prim_enabled():
return _gradients(targets, inputs, target_gradients)

outs = calc_gradient(targets, inputs, target_gradients, no_grad_set)
return _as_list(outs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ endforeach(TEST_OP)

set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 160)
set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160)
set_tests_properties(test_gradients_and_minimize PROPERTIES TIMEOUT 60)
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np

import paddle
from paddle.incubate.autograd.primx import prim2orig
from paddle.incubate.autograd.utils import enable_prim, disable_prim, prim_enabled

paddle.enable_static()


class TestGradients(unittest.TestCase):
def test_third_order(self):
enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)

grad1, = paddle.static.gradients([x4], [x])
grad2, = paddle.static.gradients([grad1], [x])
grad3, = paddle.static.gradients([grad2], [x])

prim2orig(main.block(0))

feed = {x.name: np.array([2.]).astype('float32')}
fetch_list = [grad3.name]
result = [np.array([48.])]

place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.allclose(outs, result)
disable_prim()

def test_fourth_order(self):
enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)
x5 = paddle.multiply(x4, x)
out = paddle.sqrt(x5 + x4)

grad1, = paddle.static.gradients([out], [x])
grad2, = paddle.static.gradients([grad1], [x])
grad3, = paddle.static.gradients([grad2], [x])
grad4, = paddle.static.gradients([grad3], [x])

prim2orig(main.block(0))

feed = {x.name: np.array([2.]).astype('float32'), }
fetch_list = [grad4.name]
# (3*(-5*x^2-16*x-16))/(16*(x+1)^3.5)
result = [np.array([-0.27263762711])]

place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.allclose(outs, result)
disable_prim()


class TestMinimize(unittest.TestCase):
def model(self, x, w, bias, opt):
paddle.seed(0)
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
input_x = paddle.static.data('x', x.shape, dtype=x.dtype)
input_x.stop_gradient = False
params_w = paddle.static.create_parameter(
shape=w.shape, dtype=w.dtype, is_bias=False)
params_bias = paddle.static.create_parameter(
shape=bias.shape, dtype=bias.dtype, is_bias=True)
y = paddle.tanh(paddle.matmul(input_x, params_w) + params_bias)
loss = paddle.norm(y, p=2)
opt = opt
_, grads = opt.minimize(loss)
if prim_enabled():
prim2orig(main.block(0))
exe.run(startup)
grads = exe.run(main,
feed={'x': x,
'w': w,
'bias': bias},
fetch_list=grads)
return grads

def test_adam(self):
x = np.random.rand(2, 20)
w = np.random.rand(20, 2)
bias = np.random.rand(2)
enable_prim()
prim_grads = self.model(x, w, bias, paddle.optimizer.Adam(0.01))
disable_prim()
orig_grads = self.model(x, w, bias, paddle.optimizer.Adam(0.01))
for orig, prim in zip(orig_grads, prim_grads):
np.testing.assert_allclose(orig, prim)

def test_sgd(self):
x = np.random.rand(2, 20)
w = np.random.rand(20, 2)
bias = np.random.rand(2)
enable_prim()
prim_grads = self.model(x, w, bias, paddle.optimizer.SGD(0.01))
disable_prim()
orig_grads = self.model(x, w, bias, paddle.optimizer.SGD(0.01))
for orig, prim in zip(orig_grads, prim_grads):
np.testing.assert_allclose(orig, prim)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit f6ee202

Please sign in to comment.