From 491a9eda73e06fd49ff88be17574793ad61d6137 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Sat, 19 Aug 2023 14:15:22 +0800 Subject: [PATCH 01/10] =?UTF-8?q?[LSTM/RNNBase]=E6=98=93=E7=94=A8=E6=80=A7?= =?UTF-8?q?=E6=8F=90=E5=8D=87=20No.10=2011?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/nn/layer/rnn.py | 72 +++++++++++++++++++++-------- test/dygraph_to_static/test_lstm.py | 25 ++++++---- test/rnn/rnn_numpy.py | 22 +++++---- test/rnn/test_rnn_nets.py | 47 +++++++++++++++++++ 4 files changed, 129 insertions(+), 37 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index a2122c2dab3b7d..8bea81ff4b9a9a 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -554,7 +554,13 @@ class RNNCellBase(Layer): """ def get_initial_states( - self, batch_ref, shape=None, dtype=None, init_value=0.0, batch_dim_idx=0 + self, + batch_ref, + shape=None, + dtype=None, + init_value=0.0, + batch_dim_idx=0, + proj_size=None, ): r""" Generate initialized states according to provided shape, data type and @@ -865,6 +871,8 @@ class LSTMCell(RNNCellBase): Parameters: input_size (int): The input size. hidden_size (int): The hidden size. + proj_size (int, optional): If specified, the output hidden state + will be projected to `proj_size`. weight_ih_attr(ParamAttr, optional): The parameter attribute for `weight_ih`. Default: None. weight_hh_attr(ParamAttr, optional): The parameter attribute for @@ -879,6 +887,7 @@ class LSTMCell(RNNCellBase): Variables: - **weight_ih** (Parameter): shape (4 * hidden_size, input_size), input to hidden weight, which corresponds to the concatenation of :math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula. - **weight_hh** (Parameter): shape (4 * hidden_size, hidden_size), hidden to hidden weight, which corresponds to the concatenation of :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula. + - **weight_ho** (Parameter, optional): shape (hidden_size, proj_size), project the hidden state. - **bias_ih** (Parameter): shape (4 * hidden_size, ), input to hidden bias, which corresponds to the concatenation of :math:`b_{ii}, b_{if}, b_{ig}, b_{io}` in the formula. - **bias_hh** (Parameter): shape (4 * hidden_size, ), hidden to hidden bias, swhich corresponds to the concatenation of :math:`b_{hi}, b_{hf}, b_{hg}, b_{ho}` in the formula. @@ -888,7 +897,8 @@ class LSTMCell(RNNCellBase): Returns: - **outputs** (Tensor): shape `[batch_size, hidden_size]`, the output, corresponding to :math:`h_{t}` in the formula. - - **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula. + - **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, if proj_size is specified, output shape of the first element will be `[batch_size, proj_size]` + the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula. Notes: All the weights and bias are initialized with `Uniform(-std, std)` by @@ -921,6 +931,7 @@ def __init__( self, input_size, hidden_size, + proj_size=None, weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, @@ -941,7 +952,7 @@ def __init__( default_initializer=I.Uniform(-std, std), ) self.weight_hh = self.create_parameter( - (4 * hidden_size, hidden_size), + (4 * hidden_size, proj_size or hidden_size), weight_hh_attr, default_initializer=I.Uniform(-std, std), ) @@ -957,6 +968,13 @@ def __init__( is_bias=True, default_initializer=I.Uniform(-std, std), ) + self.proj_size = proj_size + if proj_size: + self.weight_ho = self.create_parameter( + (proj_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std), + ) self.hidden_size = hidden_size self.input_size = input_size @@ -966,6 +984,7 @@ def __init__( def forward(self, inputs, states=None): if states is None: states = self.get_initial_states(inputs, self.state_shape) + pre_hidden, pre_cell = states gates = paddle.matmul(inputs, self.weight_ih, transpose_y=True) if self.bias_ih is not None: @@ -981,6 +1000,8 @@ def forward(self, inputs, states=None): o = self._gate_activation(chunked_gates[3]) c = f * pre_cell + i * self._activation(chunked_gates[2]) h = o * self._activation(c) + if self.proj_size: + h = paddle.matmul(h, self.weight_ho, transpose_y=True) return h, (h, c) @@ -992,7 +1013,7 @@ def state_shape(self): automatically inserted into shape). These two shapes correspond to :math:`h_{t-1}` and :math:`c_{t-1}` separately. """ - return ((self.hidden_size,), (self.hidden_size,)) + return ((self.hidden_size,), (self.proj_size or self.hidden_size,)) def extra_repr(self): return '{input_size}, {hidden_size}'.format(**self.__dict__) @@ -1329,6 +1350,7 @@ def __init__( hidden_size, num_layers=1, direction="forward", + proj_size=None, time_major=False, dropout=0.0, weight_ih_attr=None, @@ -1354,28 +1376,37 @@ def __init__( "bias_hh_attr": bias_hh_attr, } + self.proj_size = proj_size + if proj_size: + assert mode == 'LSTM' + if mode == "LSTM": rnn_cls = LSTMCell + kwargs["proj_size"] = proj_size elif mode == "GRU": rnn_cls = GRUCell - else: + elif mode == "RNN_RELU": rnn_cls = SimpleRNNCell - kwargs["activation"] = self.activation + kwargs["activation"] = 'relu' + elif mode == "RNN_TANH": + rnn_cls = SimpleRNNCell + kwargs["activation"] = 'tanh' + in_size = proj_size or hidden_size if direction in ["forward"]: is_reverse = False cell = rnn_cls(input_size, hidden_size, **kwargs) self.append(RNN(cell, is_reverse, time_major)) for i in range(1, num_layers): - cell = rnn_cls(hidden_size, hidden_size, **kwargs) + cell = rnn_cls(in_size, hidden_size, **kwargs) self.append(RNN(cell, is_reverse, time_major)) elif direction in bidirectional_list: cell_fw = rnn_cls(input_size, hidden_size, **kwargs) cell_bw = rnn_cls(input_size, hidden_size, **kwargs) self.append(BiRNN(cell_fw, cell_bw, time_major)) for i in range(1, num_layers): - cell_fw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) - cell_bw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) + cell_fw = rnn_cls(2 * in_size, hidden_size, **kwargs) + cell_bw = rnn_cls(2 * in_size, hidden_size, **kwargs) self.append(BiRNN(cell_fw, cell_bw, time_major)) else: raise ValueError( @@ -1569,12 +1600,8 @@ def forward(self, inputs, initial_states=None, sequence_length=None): batch_index = 1 if self.time_major else 0 dtype = inputs.dtype if initial_states is None: - state_shape = ( - self.num_layers * self.num_directions, - -1, - self.hidden_size, - ) - + state_shape = (self.num_layers * self.num_directions, -1) + dims = ([self.proj_size or self.hidden_size], [self.hidden_size]) fill_shape = list(state_shape) if inputs.shape[batch_index] > 0: fill_shape[1] = inputs.shape[batch_index] @@ -1582,8 +1609,10 @@ def forward(self, inputs, initial_states=None, sequence_length=None): fill_shape[1] = paddle.shape(inputs)[batch_index].item() initial_states = tuple( [ - paddle.full(shape=fill_shape, fill_value=0, dtype=dtype) - for _ in range(self.state_components) + paddle.full( + shape=fill_shape + dims[i], fill_value=0, dtype=dtype + ) + for i in range(self.state_components) ] ) else: @@ -1745,6 +1774,7 @@ def __init__( hidden_size, num_layers, direction, + None, time_major, dropout, weight_ih_attr, @@ -1793,6 +1823,8 @@ class LSTM(RNNBase): direction (str, optional): The direction of the network. It can be "forward" or "bidirect"(or "bidirectional"). When "bidirect", the way to merge outputs of forward and backward is concatenating. Defaults to "forward". + proj_size (int, optional): If specified, the output hidden state of each layer + will be projected to `proj_size`. time_major (bool, optional): Whether the first dimension of the input means the time steps. If time_major is True, the shape of Tensor is [time_steps,batch_size,input_size], otherwise [batch_size, time_steps,input_size]. @@ -1820,7 +1852,8 @@ class LSTM(RNNBase): - **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence. - - **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1. + - **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If `proj_size` is specified, the last dimension of h will be proj_size. + Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1. Variables: - **weight_ih_l[k]**: the learnable input-hidden weights of the k-th layer. If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise, the shape is `[hidden_size, num_directions * hidden_size]`. @@ -1857,6 +1890,7 @@ def __init__( hidden_size, num_layers=1, direction="forward", + proj_size=None, time_major=False, dropout=0.0, weight_ih_attr=None, @@ -1871,6 +1905,7 @@ def __init__( hidden_size, num_layers, direction, + proj_size, time_major, dropout, weight_ih_attr, @@ -1990,6 +2025,7 @@ def __init__( hidden_size, num_layers, direction, + None, time_major, dropout, weight_ih_attr, diff --git a/test/dygraph_to_static/test_lstm.py b/test/dygraph_to_static/test_lstm.py index 4dc5b5a0fba75e..2b517fccbac40b 100644 --- a/test/dygraph_to_static/test_lstm.py +++ b/test/dygraph_to_static/test_lstm.py @@ -24,7 +24,7 @@ class LSTMLayer(nn.Layer): - def __init__(self, in_channels, hidden_size): + def __init__(self, in_channels, hidden_size, proj_size=None): super().__init__() self.cell = nn.LSTM( in_channels, hidden_size, direction='bidirectional', num_layers=2 @@ -36,9 +36,9 @@ def forward(self, x): class Net(nn.Layer): - def __init__(self, in_channels, hidden_size): + def __init__(self, in_channels, hidden_size, proj_size=None): super().__init__() - self.lstm = LSTMLayer(in_channels, hidden_size) + self.lstm = LSTMLayer(in_channels, hidden_size, proj_size=proj_size) def forward(self, x): x = self.lstm(x) @@ -49,6 +49,8 @@ def forward(self, x): class TestLstm(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() + self.net = Net(12, 2) + self.inputs = paddle.zeros((2, 10, 12)) def tearDown(self): self.temp_dir.cleanup() @@ -60,10 +62,8 @@ def run_lstm(self, to_static): paddle.static.default_main_program().random_seed = 1001 paddle.static.default_startup_program().random_seed = 1001 - net = Net(12, 2) - net = paddle.jit.to_static(net) - x = paddle.zeros((2, 10, 12)) - y = net(paddle.to_tensor(x)) + net = paddle.jit.to_static(self.net) + y = net(paddle.to_tensor(self.inputs)) return y.numpy() def test_lstm_to_static(self): @@ -74,8 +74,8 @@ def test_lstm_to_static(self): @ast_only_test def test_save_in_eval(self, with_training=True): paddle.jit.enable_to_static(True) - net = Net(12, 2) - x = paddle.randn((2, 10, 12)) + net = self.net + x = self.inputs if with_training: x.stop_gradient = False dygraph_out = net(x) @@ -123,6 +123,13 @@ def test_save_without_training(self): self.test_save_in_eval(with_training=False) +class TestLstmWithProjsize(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.net = Net(12, 2, 4) + self.inputs = paddle.zeros((2, 10, 12)) + + class LinearNet(nn.Layer): def __init__(self): super().__init__() diff --git a/test/rnn/rnn_numpy.py b/test/rnn/rnn_numpy.py index 2c8cb26e4b4c24..467760d0d4897a 100644 --- a/test/rnn/rnn_numpy.py +++ b/test/rnn/rnn_numpy.py @@ -401,18 +401,17 @@ def forward(self, inputs, initial_states=None, sequence_length=None): batch_size = inputs.shape[batch_index] dtype = inputs.dtype if initial_states is None: - state_shape = ( - self.num_layers * self.num_directions, - batch_size, - self.hidden_size, - ) + state_shape = (self.num_layers * self.num_directions, batch_size) + proj_size = self.proj_size if hasattr(self, 'proj_size') else None + + dims = ([proj_size or self.hidden_size], [self.hidden_size]) if self.state_components == 1: initial_states = np.zeros(state_shape, dtype) else: initial_states = tuple( [ - np.zeros(state_shape, dtype) - for _ in range(self.state_components) + np.zeros(state_shape + dims[i], dtype) + for i in range(self.state_components) ] ) @@ -502,6 +501,7 @@ def __init__( hidden_size, num_layers=1, direction="forward", + proj_size=None, dropout=0.0, time_major=False, dtype="float64", @@ -509,20 +509,21 @@ def __init__( super().__init__() bidirectional_list = ["bidirectional", "bidirect"] + in_size = proj_size or hidden_size if direction in ["forward"]: is_reverse = False cell = LSTMCell(input_size, hidden_size, dtype=dtype) self.append(RNN(cell, is_reverse, time_major)) for i in range(1, num_layers): - cell = LSTMCell(hidden_size, hidden_size, dtype=dtype) + cell = LSTMCell(in_size, hidden_size, dtype=dtype) self.append(RNN(cell, is_reverse, time_major)) elif direction in bidirectional_list: cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype) cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype) self.append(BiRNN(cell_fw, cell_bw, time_major)) for i in range(1, num_layers): - cell_fw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype) - cell_bw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype) + cell_fw = LSTMCell(2 * in_size, hidden_size, dtype=dtype) + cell_bw = LSTMCell(2 * in_size, hidden_size, dtype=dtype) self.append(BiRNN(cell_fw, cell_bw, time_major)) else: raise ValueError( @@ -537,6 +538,7 @@ def __init__( self.time_major = time_major self.num_layers = num_layers self.state_components = 2 + self.proj_size = proj_size class GRU(RNNMixin): diff --git a/test/rnn/test_rnn_nets.py b/test/rnn/test_rnn_nets.py index 0ac68bdbf30d62..817cab023784eb 100644 --- a/test/rnn/test_rnn_nets.py +++ b/test/rnn/test_rnn_nets.py @@ -289,6 +289,53 @@ def runTest(self): self.test_predict() +class TestLSTMWithProjSize(unittest.TestCase): + def setUp(self): + # Since `set_device` is global, set `set_device` in `setUp` rather than + # `__init__` to avoid using an error device set by another test case. + place = paddle.set_device(self.place) + paddle.disable_static(place) + rnn1 = LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + proj_size=8, + ) + rnn2 = paddle.nn.LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + proj_size=8, + ) + convert_params_for_net(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 8) + prev_c = np.random.randn(2 * self.num_directions, 4, 32) + + y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) + y2, (h2, c2) = rnn2( + paddle.to_tensor(x), + (paddle.to_tensor(prev_h), paddle.to_tensor(prev_c)), + ) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def predict_test_util(place, mode, stop_gradient=True): place = paddle.set_device(place) paddle.seed(123) From 606f0fab2a834d96152fd917c39b871a98a6ba55 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 21 Aug 2023 13:14:32 +0800 Subject: [PATCH 02/10] update --- python/paddle/nn/layer/rnn.py | 1 - test/rnn/rnn_numpy.py | 23 +++++++++++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 8bea81ff4b9a9a..07e73c360a1f1b 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -560,7 +560,6 @@ def get_initial_states( dtype=None, init_value=0.0, batch_dim_idx=0, - proj_size=None, ): r""" Generate initialized states according to provided shape, data type and diff --git a/test/rnn/rnn_numpy.py b/test/rnn/rnn_numpy.py index 467760d0d4897a..97611193bc8d76 100644 --- a/test/rnn/rnn_numpy.py +++ b/test/rnn/rnn_numpy.py @@ -144,7 +144,14 @@ def forward(self, inputs, hx=None): class LSTMCell(LayerMixin): - def __init__(self, input_size, hidden_size, bias=True, dtype="float64"): + def __init__( + self, + input_size, + hidden_size, + bias=True, + dtype="float64", + proj_size=None, + ): self.input_size = input_size self.hidden_size = hidden_size self.bias = bias @@ -154,10 +161,16 @@ def __init__(self, input_size, hidden_size, bias=True, dtype="float64"): -std, std, (4 * hidden_size, input_size) ).astype(dtype) self.weight_hh = np.random.uniform( - -std, std, (4 * hidden_size, hidden_size) + -std, std, (4 * hidden_size, proj_size or hidden_size) ).astype(dtype) self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh + self.proj_size = proj_size + if proj_size: + self.weight_ho = np.random.uniform( + -std, std, (proj_size, hidden_size) + ).astype(dtype) + self.parameters['weight_hh'] = self.weight_ho if bias: self.bias_ih = np.random.uniform( -std, std, (4 * hidden_size) @@ -195,6 +208,8 @@ def forward(self, inputs, hx=None): o = 1.0 / (1.0 + np.exp(-chunked_gates[3])) c = f * pre_cell + i * np.tanh(chunked_gates[2]) h = o * np.tanh(c) + if self.proj_size: + h = np.matmul(h, self.weight_ho.T) return h, (h, c) @@ -401,10 +416,10 @@ def forward(self, inputs, initial_states=None, sequence_length=None): batch_size = inputs.shape[batch_index] dtype = inputs.dtype if initial_states is None: - state_shape = (self.num_layers * self.num_directions, batch_size) + state_shape = (self.wum_layers * self.num_directions, batch_size) proj_size = self.proj_size if hasattr(self, 'proj_size') else None - dims = ([proj_size or self.hidden_size], [self.hidden_size]) + dims = ((proj_size or self.hidden_size), (self.hidden_size)) if self.state_components == 1: initial_states = np.zeros(state_shape, dtype) else: From a1424325aef2d06e1485fd6b69f1b2ef165f43e1 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 16 Oct 2023 22:09:15 +0800 Subject: [PATCH 03/10] update docstring --- python/paddle/nn/layer/rnn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 07e73c360a1f1b..9ccce46042cf44 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -560,6 +560,7 @@ def get_initial_states( dtype=None, init_value=0.0, batch_dim_idx=0, + proj_size=None, ): r""" Generate initialized states according to provided shape, data type and @@ -885,7 +886,7 @@ class LSTMCell(RNNCellBase): Variables: - **weight_ih** (Parameter): shape (4 * hidden_size, input_size), input to hidden weight, which corresponds to the concatenation of :math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula. - - **weight_hh** (Parameter): shape (4 * hidden_size, hidden_size), hidden to hidden weight, which corresponds to the concatenation of :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula. + - **weight_hh** (Parameter): shape (4 * hidden_size, hidden_size), hidden to hidden weight, which corresponds to the concatenation of :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula. If proj_size was specified, the shape will be (4 * hidden_size, proj_size). - **weight_ho** (Parameter, optional): shape (hidden_size, proj_size), project the hidden state. - **bias_ih** (Parameter): shape (4 * hidden_size, ), input to hidden bias, which corresponds to the concatenation of :math:`b_{ii}, b_{if}, b_{ig}, b_{io}` in the formula. - **bias_hh** (Parameter): shape (4 * hidden_size, ), hidden to hidden bias, swhich corresponds to the concatenation of :math:`b_{hi}, b_{hf}, b_{hg}, b_{ho}` in the formula. From 784c72a2796e75a5d17818e704a44ffd0325b90b Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 18 Oct 2023 19:12:16 +0800 Subject: [PATCH 04/10] fix unittest --- test/rnn/rnn_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rnn/rnn_numpy.py b/test/rnn/rnn_numpy.py index e09a8ff7be333d..c98b62e3a600a9 100644 --- a/test/rnn/rnn_numpy.py +++ b/test/rnn/rnn_numpy.py @@ -415,7 +415,7 @@ def forward(self, inputs, initial_states=None, sequence_length=None): batch_size = inputs.shape[batch_index] dtype = inputs.dtype if initial_states is None: - state_shape = (self.wum_layers * self.num_directions, batch_size) + state_shape = (self.num_layers * self.num_directions, batch_size) proj_size = self.proj_size if hasattr(self, 'proj_size') else None dims = ((proj_size or self.hidden_size), (self.hidden_size)) From 8eb60e6f9f3bbe77f25b7b2b8120af3362bb766a Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 8 Nov 2023 15:53:19 +0800 Subject: [PATCH 05/10] update --- python/paddle/nn/layer/rnn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 3f56960e94c4cd..d16d3dfc5ec478 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -560,7 +560,6 @@ def get_initial_states( dtype=None, init_value=0.0, batch_dim_idx=0, - proj_size=None, ): r""" Generate initialized states according to provided shape, data type and From 0ea08ceb0328686e0ab117d156ad777ef92469c3 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 6 Dec 2023 19:19:23 +0800 Subject: [PATCH 06/10] fix unittest --- python/paddle/nn/layer/rnn.py | 18 +++++++++++++++--- test/rnn/rnn_numpy.py | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 38f7e41e42aa29..39664ee567d6b1 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -910,7 +910,8 @@ class LSTMCell(RNNCellBase): input_size (int): The input size. hidden_size (int): The hidden size. proj_size (int, optional): If specified, the output hidden state - will be projected to `proj_size`. + will be projected to `proj_size`. `proj_size` must be smaller than + `hidden_size`. Default: None. weight_ih_attr(ParamAttr, optional): The parameter attribute for `weight_ih`. Default: None. weight_hh_attr(ParamAttr, optional): The parameter attribute for @@ -983,6 +984,16 @@ def __init__( self.__class__.__name__, hidden_size ) ) + if proj_size and proj_size <= 0: + raise ValueError( + "proj_size of {} must be greater than 0, but now equals to {}".format( + self.__class__.__name__, hidden_size + ) + ) + + if proj_size and proj_size >= hidden_size: + raise ValueError("proj_size must be smaller than hidden_size") + std = 1.0 / math.sqrt(hidden_size) if weight_ih_attr is not False: self.weight_ih = self.create_parameter( @@ -1912,7 +1923,7 @@ class LSTM(RNNBase): \widetilde{c}_{t} & = \tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg}) - c_{t} & = f_{t} * c_{t-1} + i_{t} * \widetilde{c}_{t} + c_{t} & = f_{t} * c_{t-2} + i_{t} * \widetilde{c}_{t} h_{t} & = o_{t} * \tanh(c_{t}) @@ -1931,7 +1942,8 @@ class LSTM(RNNBase): or "bidirect"(or "bidirectional"). When "bidirect", the way to merge outputs of forward and backward is concatenating. Defaults to "forward". proj_size (int, optional): If specified, the output hidden state of each layer - will be projected to `proj_size`. + will be projected to `proj_size`. `proj_size` must be smaller than `hidden_size`. + Default to None. time_major (bool, optional): Whether the first dimension of the input means the time steps. If time_major is True, the shape of Tensor is [time_steps,batch_size,input_size], otherwise [batch_size, time_steps,input_size]. diff --git a/test/rnn/rnn_numpy.py b/test/rnn/rnn_numpy.py index 5371f05bbb0408..501dde9ad5ba5b 100644 --- a/test/rnn/rnn_numpy.py +++ b/test/rnn/rnn_numpy.py @@ -445,7 +445,7 @@ def forward(self, inputs, initial_states=None, sequence_length=None): state_shape = (self.num_layers * self.num_directions, batch_size) proj_size = self.proj_size if hasattr(self, 'proj_size') else None - dims = ((proj_size or self.hidden_size), (self.hidden_size)) + dims = ((proj_size or self.hidden_size,), (self.hidden_size,)) if self.state_components == 1: initial_states = np.zeros(state_shape, dtype) else: From d11450a9b7a2ead1b8115bede39790258451fa2e Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 7 Dec 2023 12:18:24 +0800 Subject: [PATCH 07/10] fix shape of intital states --- python/paddle/nn/layer/rnn.py | 2 +- test/dygraph_to_static/test_lstm.py | 8 ++++++-- test/rnn/rnn_numpy.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 39664ee567d6b1..074906999a32ab 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -1923,7 +1923,7 @@ class LSTM(RNNBase): \widetilde{c}_{t} & = \tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg}) - c_{t} & = f_{t} * c_{t-2} + i_{t} * \widetilde{c}_{t} + c_{t} & = f_{t} * c_{t-1} + i_{t} * \widetilde{c}_{t} h_{t} & = o_{t} * \tanh(c_{t}) diff --git a/test/dygraph_to_static/test_lstm.py b/test/dygraph_to_static/test_lstm.py index 789b76e38dbe64..b8cd805aa74dfc 100644 --- a/test/dygraph_to_static/test_lstm.py +++ b/test/dygraph_to_static/test_lstm.py @@ -27,7 +27,11 @@ class LSTMLayer(nn.Layer): def __init__(self, in_channels, hidden_size, proj_size=None): super().__init__() self.cell = nn.LSTM( - in_channels, hidden_size, direction='bidirectional', num_layers=2 + in_channels, + hidden_size, + direction='bidirectional', + num_layers=2, + proj_size=proj_size, ) def forward(self, x): @@ -122,7 +126,7 @@ def test_save_with_training(self): self.save_in_eval(with_training=True) -class TestLstmWithProjsize(unittest.TestCase): +class TestLstmWithProjsize(TestLstm): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.net = Net(12, 2, 4) diff --git a/test/rnn/rnn_numpy.py b/test/rnn/rnn_numpy.py index 501dde9ad5ba5b..7ecba65134bd91 100644 --- a/test/rnn/rnn_numpy.py +++ b/test/rnn/rnn_numpy.py @@ -447,7 +447,7 @@ def forward(self, inputs, initial_states=None, sequence_length=None): dims = ((proj_size or self.hidden_size,), (self.hidden_size,)) if self.state_components == 1: - initial_states = np.zeros(state_shape, dtype) + initial_states = np.zeros(state_shape + dims[0], dtype) else: initial_states = tuple( [ From aa674c9741e6af35a889ccb21def8c6cbd88f8cc Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Sun, 10 Dec 2023 19:31:38 +0800 Subject: [PATCH 08/10] fix --- test/dygraph_to_static/test_lstm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dygraph_to_static/test_lstm.py b/test/dygraph_to_static/test_lstm.py index b8cd805aa74dfc..221f3b2ed6880c 100644 --- a/test/dygraph_to_static/test_lstm.py +++ b/test/dygraph_to_static/test_lstm.py @@ -129,7 +129,7 @@ def test_save_with_training(self): class TestLstmWithProjsize(TestLstm): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() - self.net = Net(12, 2, 4) + self.net = Net(12, 8, 4) self.inputs = paddle.zeros((2, 10, 12)) From 48390edc07bba7538b0026ce6b75879f3f25d07e Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 11 Dec 2023 23:56:18 +0800 Subject: [PATCH 09/10] add unittests --- test/rnn/test_rnn_nets.py | 28 +++++------------- test/rnn/test_rnn_nets_static.py | 50 +++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 22 deletions(-) diff --git a/test/rnn/test_rnn_nets.py b/test/rnn/test_rnn_nets.py index 8e16142bd5a640..241aefa677c1f7 100644 --- a/test/rnn/test_rnn_nets.py +++ b/test/rnn/test_rnn_nets.py @@ -289,7 +289,7 @@ def runTest(self): self.test_predict() -class TestLSTMWithProjSize(unittest.TestCase): +class TestLSTMWithProjSize(TestLSTM): def setUp(self): # Since `set_device` is global, set `set_device` in `setUp` rather than # `__init__` to avoid using an error device set by another test case. @@ -316,25 +316,6 @@ def setUp(self): self.rnn1 = rnn1 self.rnn2 = rnn2 - def test_with_initial_state(self): - rnn1 = self.rnn1 - rnn2 = self.rnn2 - - x = np.random.randn(12, 4, 16) - if not self.time_major: - x = np.transpose(x, [1, 0, 2]) - prev_h = np.random.randn(2 * self.num_directions, 4, 8) - prev_c = np.random.randn(2 * self.num_directions, 4, 32) - - y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) - y2, (h2, c2) = rnn2( - paddle.to_tensor(x), - (paddle.to_tensor(prev_h), paddle.to_tensor(prev_c)), - ) - np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) - np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) - np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) - def predict_test_util(place, mode, stop_gradient=True): place = paddle.set_device(place) @@ -410,7 +391,12 @@ def load_tests(loader, tests, pattern): for direction in ["forward", "bidirectional", "bidirect"]: for time_major in [True, False]: for device in devices: - for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: + for test_class in [ + TestSimpleRNN, + TestLSTM, + TestGRU, + TestLSTMWithProjSize, + ]: suite.addTest(test_class(time_major, direction, device)) return suite diff --git a/test/rnn/test_rnn_nets_static.py b/test/rnn/test_rnn_nets_static.py index 20b8a7975e8c21..79d4da786de8df 100644 --- a/test/rnn/test_rnn_nets_static.py +++ b/test/rnn/test_rnn_nets_static.py @@ -505,13 +505,61 @@ def runTest(self): self.test_with_input_lengths() +class TestLSTMWithProjSize(TestLSTM): + def setUp(self): + # Since `set_device` is global, set `set_device` in `setUp` rather than + # `__init__` to avoid using an error device set by another test case. + place = paddle.set_device(self.place) + rnn1 = LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + proj_size=8, + ) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.base.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction, + proj_size=8, + ) + + exe = paddle.static.Executor(place) + scope = paddle.base.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_net_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def load_tests(loader, tests, pattern): suite = unittest.TestSuite() devices = ["cpu", "gpu"] if paddle.base.is_compiled_with_cuda() else ["cpu"] for direction in ["forward", "bidirectional", "bidirect"]: for time_major in [True, False]: for device in devices: - for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: + for test_class in [ + TestSimpleRNN, + TestLSTM, + TestGRU, + TestLSTMWithProjSize, + ]: suite.addTest(test_class(time_major, direction, device)) return suite From d301a0cb9e53c96367ccbdb5a980dff28f6b07ab Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 18 Dec 2023 16:40:26 +0800 Subject: [PATCH 10/10] try --- test/dygraph_to_static/test_lstm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/dygraph_to_static/test_lstm.py b/test/dygraph_to_static/test_lstm.py index 232508bcb340b7..27fb6367ab4fd1 100644 --- a/test/dygraph_to_static/test_lstm.py +++ b/test/dygraph_to_static/test_lstm.py @@ -65,13 +65,13 @@ def tearDown(self): self.temp_dir.cleanup() def run_lstm(self, to_static): - with enable_to_static_guard(to_static): + with enable_to_static_guard(False): paddle.static.default_main_program().random_seed = 1001 paddle.static.default_startup_program().random_seed = 1001 - net = paddle.jit.to_static(self.net) - y = net(paddle.to_tensor(self.inputs)) - return y.numpy() + net = paddle.jit.to_static(self.net) + y = net(paddle.to_tensor(self.inputs)) + return y.numpy() def test_lstm_to_static(self): dygraph_out = self.run_lstm(to_static=False)