diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 8946674cc58f92..cbde8a7f10597b 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -24,7 +24,7 @@ from paddle.base.dygraph.base import NON_PERSISTABLE_VAR_NAME_SUFFIX from paddle.base.framework import ( default_startup_program, - in_dygraph_mode, + in_dynamic_or_pir_mode, program_guard, ) from paddle.common_ops_import import Variable @@ -106,7 +106,7 @@ def rnn( """ - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): return _rnn_dynamic_graph( cell, inputs, @@ -1590,7 +1590,7 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length): if not self.time_major: inputs = paddle.tensor.transpose(inputs, [1, 0, 2]) - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): out, _, state = _C_ops.rnn( inputs, initial_states, @@ -1606,29 +1606,6 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length): 0, not self.training, ) - elif in_dynamic_mode(): - _, _, out, state = _legacy_C_ops.rnn( - inputs, - initial_states, - self._all_weights, - sequence_length, - self._dropout_state, - self.state_components, - 'dropout_prob', - self.dropout, - 'is_bidirec', - self.num_directions == 2, - 'input_size', - self.input_size, - 'hidden_size', - self.hidden_size, - 'num_layers', - self.num_layers, - 'mode', - self.mode, - 'is_test', - not self.training, - ) else: out = self._helper.create_variable_for_type_inference(inputs.dtype) state = [ diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index 127b5636ea230f..2222638fb339bd 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -39,7 +39,12 @@ from prim_op_test import OpTestUtils, PrimForwardChecker, PrimGradChecker from testsuite import append_input_output, append_loss_ops, create_op, set_input -sys.path.append("..") +# Add test/legacy and test to sys.path +legacy_test_dir = pathlib.Path(__file__).parent # test/legacy_test +test_dir = legacy_test_dir.parent # test +sys.path.append(str(legacy_test_dir.absolute())) +sys.path.append(str(test_dir.absolute())) + from utils import static_guard from white_list import ( check_shape_white_list, @@ -66,8 +71,6 @@ ) from paddle.base.wrapped_decorator import signature_safe_contextmanager -sys.path.append(os.path.abspath(os.path.dirname(__file__))) - @signature_safe_contextmanager def paddle_static_guard(): @@ -1385,7 +1388,8 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): fetch_list = getattr(self, "fetch_list", []) # if the fetch_list is customized by user, we use it directly. # if not, fill the fetch_list by the user configured outputs in test. - + # filter ret_tuple + ret_to_check = [] if len(fetch_list) == 0: if isinstance(ret_tuple, (tuple, list)): assert len(ret_tuple) == len(outputs_sig) @@ -1395,14 +1399,17 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): if not self._need_fetch(sig_name): continue if isinstance(var, list): + ret_to_check.append(var) for v in var: fetch_list.append(v) else: + ret_to_check.append(var) fetch_list.append(var) elif isinstance( ret_tuple, paddle.base.libpaddle.pir.OpResult ): fetch_list.append(ret_tuple) + ret_to_check = ret_tuple elif ret_tuple is None: pass else: @@ -1415,19 +1422,27 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): outs = executor.run( ir_program, feed=feed, fetch_list=[fetch_list] ) - outputs_sig = [ sig_name for sig_name in outputs_sig if self._need_fetch(sig_name) ] + + if paddle.utils.is_sequence( + ret_to_check + ) and paddle.utils.is_sequence(outs): + outs = paddle.utils.pack_sequence_as(ret_to_check, outs) + result = construct_output_dict_by_kernel_sig(outs, outputs_sig) if hasattr(self, "python_out_sig_sub_name"): for key in self.python_out_sig_sub_name.keys(): - for i in range(len(self.python_out_sig_sub_name[key])): - result[key][0][ - i - ].name = self.python_out_sig_sub_name[key][i] + result[key][0] = { + a: [b] + for a, b in zip( + self.python_out_sig_sub_name[key], + result[key][0], + ) + } return result def _check_ir_output(self, place, program, feed_map, fetch_list, outs): @@ -2435,12 +2450,24 @@ def convert_uint16_to_float_ifneed(self, actual_np, expect_np): expect_np = convert_uint16_to_float(expect_np) return actual_np, expect_np - def find_imperative_actual(target_name, pir_outs, place): + def find_pir_actual(self, target_name, pir_outs, place): for name in pir_outs: if name == target_name: return pir_outs[name][0] - var_list = pir_outs[name] + sub_dict = pir_outs[name][0] + if isinstance(sub_dict, dict): + for key, value in sub_dict.items(): + if key == target_name: + return value[0] + + raise AssertionError("No pir output named " + target_name) + + def find_pir_expect(self, target_name, dygraph_outs, place): + for name in dygraph_outs: + if name == target_name: + return dygraph_outs[name][0] + var_list = dygraph_outs[name] for i, var in enumerate(var_list): if isinstance(var, list): for tensor in var: @@ -2450,26 +2477,14 @@ def find_imperative_actual(target_name, pir_outs, place): isinstance(var, paddle.Tensor) and var.name == target_name ): - return pir_outs[name][i] - self.assertTrue( - False, - f"Found failed {pir_outs.keys()} {target_name}", - ) - - def find_imperative_expect(self, target_name, pir_outs, place): - for name in pir_outs: - if name == target_name: - return pir_outs[name][0] - self.assertTrue( - False, - f"Found failed {pir_outs.keys()} {target_name}", - ) + return dygraph_outs[name][i] + raise AssertionError("No pir ref_output named " + target_name) def find_actual_value(self, target_name): with paddle.pir.core.program_guard( paddle.pir.core.default_main_program() ): - actual = find_imperative_actual( + actual = self.find_pir_actual( target_name, self.outputs, place ) actual_t = np.array(actual) @@ -2479,7 +2494,7 @@ def find_expect_value(self, target_name): with paddle.pir.core.program_guard( paddle.pir.core.default_main_program() ): - expect = self.find_imperative_expect( + expect = self.find_pir_expect( target_name, self.ref_outputs, place ) expect_t = np.array(expect) @@ -3674,10 +3689,19 @@ def _get_gradient( return res - def _find_var_in_pir(self, output_vars, name): - if name in output_vars: - return output_vars[name] - raise AssertionError(name, " not in outputs:", output_vars.keys()) + def _find_var_in_pir(self, output_vars, target_name): + for name in output_vars: + if name == target_name: + return output_vars[name] + + sub_dict = output_vars[name][0] + if isinstance(sub_dict, dict): + for key, value in sub_dict.items(): + if key == target_name: + return value + raise AssertionError( + target_name, " not in outputs:", output_vars.keys() + ) def _get_ir_gradient( self, @@ -3751,10 +3775,13 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): ) if hasattr(self, "python_out_sig_sub_name"): for key in self.python_out_sig_sub_name.keys(): - for i in range(len(self.python_out_sig_sub_name[key])): - outputs[key][0][ - i - ].name = self.python_out_sig_sub_name[key][i] + outputs[key][0] = { + a: [b] + for a, b in zip( + self.python_out_sig_sub_name[key], + outputs[key][0], + ) + } fetch_list = getattr(self, "fetch_list", []) # cast outputs diff --git a/test/legacy_test/test_rnn_op.py b/test/legacy_test/test_rnn_op.py index 4eb2d8332d9eca..06409237276adc 100644 --- a/test/legacy_test/test_rnn_op.py +++ b/test/legacy_test/test_rnn_op.py @@ -15,6 +15,7 @@ import random import sys import unittest +from pathlib import Path import numpy as np from op_test import OpTest @@ -22,7 +23,9 @@ import paddle from paddle.base import core -sys.path.append("../../test/rnn") +# Add test/rnn to sys.path +legacy_test_dir = Path(__file__).resolve().parents[1] +sys.path.append(str(legacy_test_dir / "rnn")) from convert import get_params_for_net from rnn_numpy import LSTM @@ -45,7 +48,7 @@ def rnn_wrapper( seed=0, is_test=False, ): - dropout_state_in = paddle.Tensor() + dropout_state_in = paddle.tensor.fill_constant([], "float32", 0.0) return paddle._C_ops.rnn( Input, PreState, @@ -168,7 +171,9 @@ def rocm_rnn_get_place(): } def test_output(self): - self.check_output(no_check_set=['Reserve', 'DropoutState']) + self.check_output( + no_check_set=['Reserve', 'DropoutState'], check_pir=True + ) def set_attrs(self): pass @@ -179,7 +184,9 @@ def test_grad(self): grad_check_list = ['Input', 'init_h', 'init_c'] grad_check_list.extend(var_name_list) self.check_grad( - set(grad_check_list), ['Out', 'last_hidden', 'last_cell'] + set(grad_check_list), + ['Out', 'last_hidden', 'last_cell'], + check_pir=True, ) def test_grad_only_input(self): @@ -188,7 +195,9 @@ def test_grad_only_input(self): grad_check_list = ['Input'] grad_check_list.extend(var_name_list) self.check_grad( - set(grad_check_list), ['Out', 'last_hidden', 'last_cell'] + set(grad_check_list), + ['Out', 'last_hidden', 'last_cell'], + check_pir=True, ) def test_grad_only_h(self): @@ -197,7 +206,9 @@ def test_grad_only_h(self): grad_check_list = ['init_h'] grad_check_list.extend(var_name_list) self.check_grad( - set(grad_check_list), ['Out', 'last_hidden', 'last_cell'] + set(grad_check_list), + ['Out', 'last_hidden', 'last_cell'], + check_pir=True, ) def test_grad_only_c(self): @@ -206,7 +217,9 @@ def test_grad_only_c(self): grad_check_list = ['init_c'] grad_check_list.extend(var_name_list) self.check_grad( - set(grad_check_list), ['Out', 'last_hidden', 'last_cell'] + set(grad_check_list), + ['Out', 'last_hidden', 'last_cell'], + check_pir=True, )