-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Prim】Support Relu Custom VJP #51742
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,3 +33,4 @@ | |
- put_along_axis | ||
- greater_than | ||
- less_equal | ||
- where |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
from utils import TOLERANCE | ||
|
||
import paddle | ||
import paddle.nn.functional as F | ||
from paddle.fluid import core | ||
|
||
|
||
def generate_data(shape, dtype="float32"): | ||
np_data = np.random.random(shape).astype(dtype) | ||
return np_data | ||
|
||
|
||
class Attr: | ||
def __init__(self) -> None: | ||
self.dtype = None | ||
self.shape = None | ||
|
||
def set_dtype(self, dtype) -> None: | ||
self.dtype = dtype | ||
return | ||
|
||
def set_shape(self, shape) -> None: | ||
self.shape = shape | ||
return | ||
|
||
def get_rtol(self, flag): | ||
rtol = TOLERANCE[self.dtype][flag].get("rtol") | ||
return rtol | ||
|
||
def get_atol(self, flag): | ||
atol = TOLERANCE[self.dtype][flag].get("atol") | ||
return atol | ||
|
||
|
||
attrs = Attr() | ||
|
||
|
||
def fn(x): | ||
return F.relu(x) | ||
|
||
|
||
def expect_grad(inputs): | ||
paddle.disable_static() | ||
inputs.stop_gradient = False | ||
res = fn(inputs) | ||
|
||
gradients = paddle.grad(res, inputs) | ||
return gradients | ||
|
||
|
||
class TestCompositeSoftmaxPrimBackward(unittest.TestCase): | ||
"test composite softmax and prim backward" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change softmax as relu There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in PR 51838 |
||
|
||
def setUp(self): | ||
core._set_prim_backward_enabled(True) | ||
self.dtypes = ["float16", "float32", "float64"] | ||
self.shapes = [[2, 3, 4], [2, 3]] | ||
|
||
def cal_composite_grad(self, inputs): | ||
paddle.enable_static() | ||
core._set_prim_all_enabled(True) | ||
startup_program = paddle.static.Program() | ||
main_program = paddle.static.Program() | ||
with paddle.static.program_guard(main_program, startup_program): | ||
x = paddle.static.data( | ||
'x', shape=inputs.shape, dtype=str(inputs.dtype) | ||
) | ||
x.stop_gradient = False | ||
y = fn(x) | ||
blocks = main_program.blocks | ||
z = paddle.static.gradients([y], x) | ||
paddle.incubate.autograd.primapi.to_prim(blocks) | ||
|
||
exe = paddle.static.Executor() | ||
exe.run(startup_program) | ||
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) | ||
paddle.disable_static() | ||
core._set_prim_all_enabled(False) | ||
return res | ||
|
||
def compare_backward(self): | ||
np_data = generate_data(attrs.shape) | ||
tensor_data = paddle.to_tensor(np_data) | ||
|
||
expect = expect_grad(tensor_data)[0].numpy() | ||
actual = self.cal_composite_grad(np_data)[0] | ||
|
||
assert expect.dtype == actual.dtype | ||
np.testing.assert_allclose( | ||
expect, | ||
actual, | ||
rtol=attrs.get_rtol("prim_backward"), | ||
atol=attrs.get_rtol("prim_backward"), | ||
) | ||
|
||
def test_prim_backward(self): | ||
for j in self.dtypes: | ||
for t in self.shapes: | ||
attrs.set_dtype(j) | ||
attrs.set_shape(t) | ||
self.compare_backward() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,7 @@ def composite_batchnorm( | |
batch_mean = zeros(run_mean.shape, run_mean.dtype) | ||
batch_var = zeros(run_var.shape, run_var.dtype) | ||
if not use_run_stat: | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unnecessary change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pre-commit hook did this |
||
batch_mean = mean(x, reduce_axes, keepdim=True) | ||
temp = mean(x * x, reduce_axes, keepdim=True) | ||
batch_var = temp - batch_mean * batch_mean | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's better to delete this if op test with prim can cover this test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since activation op test only test simple case, add this case can be more safe to use.