Skip to content

Commit

Permalink
[PIR]Migrate dropout into pir (#57319)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Sep 20, 2023
1 parent 113cd81 commit 3e8da40
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
19 changes: 12 additions & 7 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
import numpy

import paddle
from paddle import _C_ops
from paddle import _C_ops, ir
from paddle.base.layer_helper import LayerHelper
from paddle.common_ops_import import Variable, default_main_program
from paddle.framework import core, in_dynamic_mode, in_pir_mode
from paddle.framework import (
core,
in_dynamic_mode,
in_dynamic_or_pir_mode,
in_pir_mode,
)
from paddle.tensor.creation import full

from ...base.data_feeder import (
Expand Down Expand Up @@ -1090,7 +1095,7 @@ def dropout(
[[0., 0., 6.],
[0., 0., 0.]])
"""
if not isinstance(p, (float, int, Variable)):
if not isinstance(p, (float, int, Variable, ir.OpResult)):
raise TypeError("p argument should be a number or Variable")

if isinstance(p, (int, float)):
Expand All @@ -1112,7 +1117,7 @@ def dropout(
'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
) # semantic transfer

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed

Expand Down Expand Up @@ -1176,7 +1181,7 @@ def get_attrs(prog, dropout_prob, is_test, seed):
dtype = x.dtype
keep_prob = 1 - p
if training:
if in_dynamic_mode() and p == 1.0:
if in_dynamic_or_pir_mode() and p == 1.0:
return paddle.scale(x, scale=0.0)

scale_input = (
Expand All @@ -1187,7 +1192,7 @@ def get_attrs(prog, dropout_prob, is_test, seed):

# get mask shape
input_shape = x.shape
if not in_dynamic_mode():
if not in_dynamic_or_pir_mode():
input_shape_tensor = paddle.shape(x)
drop_axes = [axis] if isinstance(axis, int) else list(axis)
if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1:
Expand All @@ -1203,7 +1208,7 @@ def get_attrs(prog, dropout_prob, is_test, seed):
)
)
mask_shape = [1] * len(input_shape)
if not in_dynamic_mode():
if not in_dynamic_or_pir_mode():
for i in drop_axes:
mask_shape[i] = input_shape_tensor[i]
else:
Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def setUp(self):
self.enable_check_static_comp = False

def test_check_output(self):
self.check_output(check_prim=True)
self.check_output(check_prim=True, check_new_ir=True)

def test_check_grad_normal(self):
# Now in dy2st mode x_grad = [], so set check_prim=False
self.check_grad(['X'], 'Out', check_prim=False)
self.check_grad(['X'], 'Out', check_prim=False, check_new_ir=True)


class TestDropoutOp_ZeroDim(TestDropoutOp):
Expand Down

0 comments on commit 3e8da40

Please sign in to comment.