From 2eba09d4ff038ca8ffacfb5ca6e77de91c703346 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Wed, 31 Jul 2024 13:04:50 +0800 Subject: [PATCH] [Paddle.incubate.jit.inference] allow repeate call `my_layer = paddle.incubate.jit.inference(my_layer)` (#66689) * first commit --- .../paddle/incubate/jit/inference_decorator.py | 17 +++++++++++++++++ .../test_incubate_jit_inference.py | 2 ++ 2 files changed, 19 insertions(+) diff --git a/python/paddle/incubate/jit/inference_decorator.py b/python/paddle/incubate/jit/inference_decorator.py index acc446912d1b68..8011ab4d3ff1b8 100644 --- a/python/paddle/incubate/jit/inference_decorator.py +++ b/python/paddle/incubate/jit/inference_decorator.py @@ -236,6 +236,14 @@ def forward(self, args): ) input_specs.append(None) + for i in range(len(input_specs)): + if input_specs[i] is not None: + if isinstance(input_specs[i], list): + for j in range(len(input_specs[i])): + input_specs[i][j].stop_gradient = True + else: + input_specs[i].stop_gradient = True + # update the input_spec's shape for doing d2s d2s_shapes_id = 0 # initial the self.d2s_input_names! @@ -547,6 +555,15 @@ def inference( >>> decorator_result = mylayer(x) """ + # if function has already been decorated by @paddle.incubate.jit.inference(), then we just return it. + if ( + hasattr(function, "__name__") + and function.__name__ == "innermost_decorator" + ): + return function + elif isinstance(function, Layer): + if function.forward.__name__ == "innermost_decorator": + return function used_as_at_decorator = function is None diff --git a/test/dygraph_to_static/test_incubate_jit_inference.py b/test/dygraph_to_static/test_incubate_jit_inference.py index ac32b2cb1fb805..8266574a44e03e 100644 --- a/test/dygraph_to_static/test_incubate_jit_inference.py +++ b/test/dygraph_to_static/test_incubate_jit_inference.py @@ -101,6 +101,7 @@ def test_dygraph_static_same_result(self): result_x0 = my_layer(x).numpy() result_y0 = my_layer(y).numpy() + my_layer.func = paddle.incubate.jit.inference(my_layer.func) my_layer.func = paddle.incubate.jit.inference(my_layer.func) result_x1 = my_layer(x).numpy() @@ -119,6 +120,7 @@ def test_dygraph_static_same_result(self): my_layer = TestLayer2(hidd) result0 = my_layer([x, x]).numpy() my_static_layer = paddle.incubate.jit.inference(my_layer) + my_static_layer = paddle.incubate.jit.inference(my_layer) result1 = my_layer([x, x]).numpy() np.testing.assert_allclose(result0, result1, rtol=0.001, atol=1e-05)