Skip to content

Commit

Permalink
[Phi] add temporal_shift yaml (#44409)
Browse files Browse the repository at this point in the history
* add temporal_shift yaml and unittest
  • Loading branch information
ccrrong authored Jul 21, 2022
1 parent 438ca7f commit 0243c6c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 2 deletions.
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2154,6 +2154,16 @@
func : tanh_shrink
backward : tanh_shrink_grad

# temporal_shift
- api : temporal_shift
args : (Tensor x, int seg_num, float shift_ratio, str data_format_str)
output : Tensor
infer_meta :
func : TemporalShiftInferMeta
kernel :
func : temporal_shift
backward : temporal_shift_grad

# thresholded_relu
- api : thresholded_relu
args : (Tensor x, float threshold)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,16 @@
func : tanh_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)

- backward_api : temporal_shift_grad
forward : temporal_shift(Tensor x, int seg_num, float shift_ratio, str data_format_str) -> Tensor(out)
args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format_str)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : temporal_shift_grad

- backward_api : thresholded_relu_grad
forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float threshold)
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/fluid/tests/unittests/test_temporal_shift_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TestTemporalShift(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'temporal_shift'
self.python_api = paddle.nn.functional.temporal_shift
x = np.random.random(self.x_shape).astype(self.dtype)

self.attrs = {
Expand All @@ -61,12 +62,13 @@ def setUp(self):
output = temporal_shift(x, self.seg_num, self.shift_ratio,
self.data_format)
self.outputs = {"Out": output}
self.python_out_sig = ["Out"]

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad_ignore_uv(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)

def initTestCase(self):
self.x_shape = (6, 4, 4, 4)
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/functional/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
if in_dygraph_mode():
return _C_ops.final_state_temporal_shift(x, seg_num, shift_ratio,
data_format)
if _non_static_mode():
return _C_ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio, 'data_format', data_format)
Expand Down

0 comments on commit 0243c6c

Please sign in to comment.