Skip to content

Commit

Permalink
[PIR] Migrate rnn into pir (#60180)
Browse files Browse the repository at this point in the history
* pir183

* use abspath to support run by python

* use fill_constant instead of create Tensor

* test_rnn_op passed

* fix bug

* fix ci bug

---------

Co-authored-by: ooooo <3164076421@qq.com>
Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
3 people authored Dec 21, 2023
1 parent 5f9721d commit 6c574bd
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 68 deletions.
29 changes: 3 additions & 26 deletions python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from paddle.base.dygraph.base import NON_PERSISTABLE_VAR_NAME_SUFFIX
from paddle.base.framework import (
default_startup_program,
in_dygraph_mode,
in_dynamic_or_pir_mode,
program_guard,
)
from paddle.common_ops_import import Variable
Expand Down Expand Up @@ -106,7 +106,7 @@ def rnn(
"""

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
return _rnn_dynamic_graph(
cell,
inputs,
Expand Down Expand Up @@ -1590,7 +1590,7 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length):
if not self.time_major:
inputs = paddle.tensor.transpose(inputs, [1, 0, 2])

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
out, _, state = _C_ops.rnn(
inputs,
initial_states,
Expand All @@ -1606,29 +1606,6 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length):
0,
not self.training,
)
elif in_dynamic_mode():
_, _, out, state = _legacy_C_ops.rnn(
inputs,
initial_states,
self._all_weights,
sequence_length,
self._dropout_state,
self.state_components,
'dropout_prob',
self.dropout,
'is_bidirec',
self.num_directions == 2,
'input_size',
self.input_size,
'hidden_size',
self.hidden_size,
'num_layers',
self.num_layers,
'mode',
self.mode,
'is_test',
not self.training,
)
else:
out = self._helper.create_variable_for_type_inference(inputs.dtype)
state = [
Expand Down
97 changes: 62 additions & 35 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@
from prim_op_test import OpTestUtils, PrimForwardChecker, PrimGradChecker
from testsuite import append_input_output, append_loss_ops, create_op, set_input

sys.path.append("..")
# Add test/legacy and test to sys.path
legacy_test_dir = pathlib.Path(__file__).parent # test/legacy_test
test_dir = legacy_test_dir.parent # test
sys.path.append(str(legacy_test_dir.absolute()))
sys.path.append(str(test_dir.absolute()))

from utils import static_guard
from white_list import (
check_shape_white_list,
Expand All @@ -66,8 +71,6 @@
)
from paddle.base.wrapped_decorator import signature_safe_contextmanager

sys.path.append(os.path.abspath(os.path.dirname(__file__)))


@signature_safe_contextmanager
def paddle_static_guard():
Expand Down Expand Up @@ -1385,7 +1388,8 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
fetch_list = getattr(self, "fetch_list", [])
# if the fetch_list is customized by user, we use it directly.
# if not, fill the fetch_list by the user configured outputs in test.

# filter ret_tuple
ret_to_check = []
if len(fetch_list) == 0:
if isinstance(ret_tuple, (tuple, list)):
assert len(ret_tuple) == len(outputs_sig)
Expand All @@ -1395,14 +1399,17 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
if not self._need_fetch(sig_name):
continue
if isinstance(var, list):
ret_to_check.append(var)
for v in var:
fetch_list.append(v)
else:
ret_to_check.append(var)
fetch_list.append(var)
elif isinstance(
ret_tuple, paddle.base.libpaddle.pir.OpResult
):
fetch_list.append(ret_tuple)
ret_to_check = ret_tuple
elif ret_tuple is None:
pass
else:
Expand All @@ -1415,19 +1422,27 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
outs = executor.run(
ir_program, feed=feed, fetch_list=[fetch_list]
)

outputs_sig = [
sig_name
for sig_name in outputs_sig
if self._need_fetch(sig_name)
]

if paddle.utils.is_sequence(
ret_to_check
) and paddle.utils.is_sequence(outs):
outs = paddle.utils.pack_sequence_as(ret_to_check, outs)

result = construct_output_dict_by_kernel_sig(outs, outputs_sig)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][
i
].name = self.python_out_sig_sub_name[key][i]
result[key][0] = {
a: [b]
for a, b in zip(
self.python_out_sig_sub_name[key],
result[key][0],
)
}
return result

def _check_ir_output(self, place, program, feed_map, fetch_list, outs):
Expand Down Expand Up @@ -2435,12 +2450,24 @@ def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
expect_np = convert_uint16_to_float(expect_np)
return actual_np, expect_np

def find_imperative_actual(target_name, pir_outs, place):
def find_pir_actual(self, target_name, pir_outs, place):
for name in pir_outs:
if name == target_name:
return pir_outs[name][0]

var_list = pir_outs[name]
sub_dict = pir_outs[name][0]
if isinstance(sub_dict, dict):
for key, value in sub_dict.items():
if key == target_name:
return value[0]

raise AssertionError("No pir output named " + target_name)

def find_pir_expect(self, target_name, dygraph_outs, place):
for name in dygraph_outs:
if name == target_name:
return dygraph_outs[name][0]
var_list = dygraph_outs[name]
for i, var in enumerate(var_list):
if isinstance(var, list):
for tensor in var:
Expand All @@ -2450,26 +2477,14 @@ def find_imperative_actual(target_name, pir_outs, place):
isinstance(var, paddle.Tensor)
and var.name == target_name
):
return pir_outs[name][i]
self.assertTrue(
False,
f"Found failed {pir_outs.keys()} {target_name}",
)

def find_imperative_expect(self, target_name, pir_outs, place):
for name in pir_outs:
if name == target_name:
return pir_outs[name][0]
self.assertTrue(
False,
f"Found failed {pir_outs.keys()} {target_name}",
)
return dygraph_outs[name][i]
raise AssertionError("No pir ref_output named " + target_name)

def find_actual_value(self, target_name):
with paddle.pir.core.program_guard(
paddle.pir.core.default_main_program()
):
actual = find_imperative_actual(
actual = self.find_pir_actual(
target_name, self.outputs, place
)
actual_t = np.array(actual)
Expand All @@ -2479,7 +2494,7 @@ def find_expect_value(self, target_name):
with paddle.pir.core.program_guard(
paddle.pir.core.default_main_program()
):
expect = self.find_imperative_expect(
expect = self.find_pir_expect(
target_name, self.ref_outputs, place
)
expect_t = np.array(expect)
Expand Down Expand Up @@ -3674,10 +3689,19 @@ def _get_gradient(

return res

def _find_var_in_pir(self, output_vars, name):
if name in output_vars:
return output_vars[name]
raise AssertionError(name, " not in outputs:", output_vars.keys())
def _find_var_in_pir(self, output_vars, target_name):
for name in output_vars:
if name == target_name:
return output_vars[name]

sub_dict = output_vars[name][0]
if isinstance(sub_dict, dict):
for key, value in sub_dict.items():
if key == target_name:
return value
raise AssertionError(
target_name, " not in outputs:", output_vars.keys()
)

def _get_ir_gradient(
self,
Expand Down Expand Up @@ -3751,10 +3775,13 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
outputs[key][0][
i
].name = self.python_out_sig_sub_name[key][i]
outputs[key][0] = {
a: [b]
for a, b in zip(
self.python_out_sig_sub_name[key],
outputs[key][0],
)
}
fetch_list = getattr(self, "fetch_list", [])

# cast outputs
Expand Down
27 changes: 20 additions & 7 deletions test/legacy_test/test_rnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
import random
import sys
import unittest
from pathlib import Path

import numpy as np
from op_test import OpTest

import paddle
from paddle.base import core

sys.path.append("../../test/rnn")
# Add test/rnn to sys.path
legacy_test_dir = Path(__file__).resolve().parents[1]
sys.path.append(str(legacy_test_dir / "rnn"))
from convert import get_params_for_net
from rnn_numpy import LSTM

Expand All @@ -45,7 +48,7 @@ def rnn_wrapper(
seed=0,
is_test=False,
):
dropout_state_in = paddle.Tensor()
dropout_state_in = paddle.tensor.fill_constant([], "float32", 0.0)
return paddle._C_ops.rnn(
Input,
PreState,
Expand Down Expand Up @@ -168,7 +171,9 @@ def rocm_rnn_get_place():
}

def test_output(self):
self.check_output(no_check_set=['Reserve', 'DropoutState'])
self.check_output(
no_check_set=['Reserve', 'DropoutState'], check_pir=True
)

def set_attrs(self):
pass
Expand All @@ -179,7 +184,9 @@ def test_grad(self):
grad_check_list = ['Input', 'init_h', 'init_c']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
check_pir=True,
)

def test_grad_only_input(self):
Expand All @@ -188,7 +195,9 @@ def test_grad_only_input(self):
grad_check_list = ['Input']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
check_pir=True,
)

def test_grad_only_h(self):
Expand All @@ -197,7 +206,9 @@ def test_grad_only_h(self):
grad_check_list = ['init_h']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
check_pir=True,
)

def test_grad_only_c(self):
Expand All @@ -206,7 +217,9 @@ def test_grad_only_c(self):
grad_check_list = ['init_c']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
check_pir=True,
)


Expand Down

0 comments on commit 6c574bd

Please sign in to comment.