Skip to content

Commit

Permalink
Merge pull request #2 from levi131/lml/code_refine
Browse files Browse the repository at this point in the history
Lml/code refine
  • Loading branch information
levi131 authored May 15, 2022
2 parents 28e812e + d3efc1d commit 06dd62d
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 151 deletions.
4 changes: 3 additions & 1 deletion python/paddle/incubate/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.autograd.functional import Hessian, Jacobian, jvp, vjp
from .primx import prim2orig
from .utils import enable_prim, disable_prim, prim_enabled

__all__ = [ # noqa
'vjp', 'jvp', 'Jacobian', 'Hessian'
'vjp', 'jvp', 'Jacobian', 'Hessian', 'prim2orig', 'enable_prim', 'disable_prim', 'prim_enabled'
]
4 changes: 2 additions & 2 deletions python/paddle/incubate/autograd/primreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def REGISTER_ORIG2PRIM(op_type):
.. code-block:: python
@REGISTER_ORIG2PRIM('tanh')
def tanh_orig2prim(op):
x = get_input_var_list(op)
x, = get_input_var_list(op)
return primops.tanh(x)
"""
Expand Down Expand Up @@ -209,7 +209,7 @@ def REGISTER_PRIM2ORIG(op_type):
.. code-block:: python
@REGISTER_PRIM2ORIG('tanh_p')
def tanh_prim2orig(op):
x = get_input_var_list(op)
x, = get_input_var_list(op)
return paddle.tanh(x)
"""
Expand Down
94 changes: 50 additions & 44 deletions python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def linear_jvp(op, *args, **kwargs):
These original ops are fully supported:
elementwise_add
elementwise_sub
elementwise_mul
tanh
fill_zeros_like
sum
index_select
elementwise_sub
scale
assign
elementwise_mul
sqrt
These original ops are partially supported:
Expand All @@ -79,16 +79,58 @@ def elementwise_add_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
scale_x = fill_const(
shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, scale_x)
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
scale_y = fill_const(
shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, scale_y)
z = add(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
scale_out = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, scale_out)
return z


@REGISTER_ORIG2PRIM('elementwise_sub')
def elementwise_sub_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
scale_x = fill_const(
shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, scale_x)
if op.attr('Scale_y') - 1.0 > 1e-5:
scale_y = fill_const(
shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, scale_y)
z = sub(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
scale_out = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, scale_out)
return z


@REGISTER_ORIG2PRIM('elementwise_mul')
def elementwise_mul_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
scale_x = fill_const(
shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, scale_x)
if op.attr('Scale_y') - 1.0 > 1e-5:
scale_y = fill_const(
shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, scale_y)
z = mul(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
scale_out = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
z = mul(z, scale_out)
return z


Expand All @@ -115,24 +157,6 @@ def index_select_orig2prim(op, index_t, x):
return gather(x, indextensor=index_t, axis=op.attr('dim'))


@REGISTER_ORIG2PRIM('elementwise_sub')
def elementwise_sub_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
z = sub(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
return z


@REGISTER_ORIG2PRIM('scale')
def scale_orig2prim(op, scale_t, x):
if scale_t is None:
Expand All @@ -151,24 +175,6 @@ def assign_orig2prim(op, x):
return add(x, zero_t)


@REGISTER_ORIG2PRIM('elementwise_mul')
def elementwise_mul_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
z = mul(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
return z


@REGISTER_ORIG2PRIM('sqrt')
def sqrt_orig2prim(op, x):
return sqrt(x)
Expand Down
Loading

0 comments on commit 06dd62d

Please sign in to comment.