diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 5e33ee2bb88d5..ab7a3654e582c 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -295,6 +295,12 @@ def test_predict_without_inputs(self): np.testing.assert_equal(output[0].shape[0], len(self.test_dataset)) fluid.disable_dygraph() + def test_summary_gpu(self): + paddle.disable_static(self.device) + rnn = paddle.nn.LSTM(16, 32, 2) + params_info = paddle.summary( + rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))]) + class MyModel(paddle.nn.Layer): def __init__(self):