From cec228ae315f78bd3a2cb7a26fd436bcaca211cf Mon Sep 17 00:00:00 2001 From: zeroRains Date: Sun, 17 Mar 2024 06:53:37 +0000 Subject: [PATCH 1/6] support dynamic shape for group_norm but it need to support dynamic shape for sqrt_decomp --- paddle/fluid/primitive/composite/composite.h | 73 ++++++++++++++----- .../test_prim_sub_graph_dynamic_shape.py | 63 ++++++++++++++++ 2 files changed, 116 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 8513dcc2839237..1ca78078cb7f24 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -792,21 +792,45 @@ std::tuple group_norm_decomp( if (need_cast) { x_cast = cast(x, DataType::FLOAT32); } - - auto x_dim = x.shape(); - std::vector one_axis(1, 1); - - std::vector x_shape{x_dim[0] * groups, -1}; - x_cast = reshape(x_cast, x_shape); - auto mean_ = mean_decomp(x_cast, IntArray(one_axis), true); - auto var_tmp_ = - mean_decomp(x_cast * x_cast, IntArray(one_axis), true) - mean_ * mean_; - auto var_ = - maximum(var_tmp_, full(var_tmp_.shape(), 0, var_tmp_.dtype())); - auto var_inv = 1 / sqrt_decomp(var_ + epsilon); - auto res = (x_cast - mean_) * var_inv; - auto out = reshape(res, x_dim); - + Tensor out, mean_, var_; + if (has_dynamic_shape(x.shape())) { + std::cout<<"step1_______________________________________________"<(x); + std::vector one_axis(1, 1); + Tensor x_shape = get_slice(x_dim, 0) * groups; + Tensor dim_1 = full({1}, -1, x_dim.type()); + std::cout<<"step2_______________________________________________"<({x_shape, dim_1}); + x_cast = backend::reshape(x_cast, x_shape); + std::cout<<"step3_______________________________________________"<(x_cast, IntArray(one_axis), true); + std::cout<<"step4_______________________________________________"<(x_cast * x_cast, IntArray(one_axis), true) - + mean_ * mean_; + std::cout<<"step5_______________________________________________"<( + var_tmp_, + backend::full_with_tensor(shape(var_tmp_), 0, var_tmp_.dtype())); + std::cout<<"step6_______________________________________________"<(var_ + epsilon); // TODO: support dynamic shape + std::cout<<"step7_______________________________________________"<(res, x_dim); + } else { + auto x_dim = x.shape(); + std::vector one_axis(1, 1); + + std::vector x_shape{x_dim[0] * groups, -1}; + x_cast = reshape(x_cast, x_shape); + mean_ = mean_decomp(x_cast, IntArray(one_axis), true); + auto var_tmp_ = mean_decomp(x_cast * x_cast, IntArray(one_axis), true) - + mean_ * mean_; + var_ = maximum(var_tmp_, full(var_tmp_.shape(), 0, var_tmp_.dtype())); + auto var_inv = 1 / sqrt_decomp(var_ + epsilon); + auto res = (x_cast - mean_) * var_inv; + out = reshape(res, x_dim); + } auto scale_ptr = scale.get_ptr(); auto bias_ptr = bias.get_ptr(); @@ -835,11 +859,20 @@ std::tuple group_norm_decomp( } out = out + bias_cast; } - - std::vector res_shape{x_dim[0], groups}; - auto mean_out = reshape(mean_, res_shape); - auto var_out = reshape(var_, res_shape); - + Tensor mean_out, var_out; + if (has_dynamic_shape(x.shape())) { + Tensor x_dim = shape(x); + Tensor x_shape = get_slice(x_dim, 0); + Tensor dim_1 = full({1}, groups, x_shape.type()); + x_shape = concat({x_shape, dim_1}); + mean_out = backend::reshape(mean_, x_shape); + var_out = backend::reshape(var_, x_shape); + } else { + auto x_dim = x.shape(); + std::vector res_shape{x_dim[0], groups}; + mean_out = reshape(mean_, res_shape); + var_out = reshape(var_, res_shape); + } if (need_cast) { out = cast(out, org_dtype); } diff --git a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py index 155cfbdeeb2683..dbc71acf96c8ef 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py @@ -92,6 +92,12 @@ def swiglu_net2(x): return paddle.incubate.nn.functional.swiglu(x) +def group_norm_net(x, weight, bias): + group_norm = paddle.nn.GroupNorm(num_channels=x.shape[1], num_groups=32) + paddle.assign(weight, group_norm.weight) + paddle.assign(bias, group_norm.bias) + return group_norm(x) + class TestPrimBase(unittest.TestCase): def setUp(self): np.random.seed(2023) @@ -305,5 +311,62 @@ def setUp(self): self.enable_cinn = False +class TestPrimThree(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [5, 640, 10, 20] + self.shape_y = [640] + self.shape_z = [640] + self.dtype_x = "float32" + self.dtype_y = "float32" + self.dtype_z = "float32" + self.init_x_shape = [None, 640, None, None] + self.init_y_shape = [640] + self.init_z_shape = [640] + self.x = np.random.random(self.shape_x).astype(self.dtype_x) + self.y = np.random.random(self.shape_y).astype(self.dtype_y) + self.z = np.random.random(self.shape_z).astype(self.dtype_z) + self.net = group_norm_net + self.necessary_ops = "pd_op.group_norm" + self.enable_cinn = False + + def base_net(self, flag=None): + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + z = paddle.to_tensor(self.z) + if flag == "prim": + core._set_prim_all_enabled(True) + fn = apply_to_static( + self.net, + use_cinn=self.enable_cinn, + input_spec=[ + InputSpec(shape=self.init_x_shape, dtype=self.dtype_x), + InputSpec(shape=self.init_y_shape, dtype=self.dtype_y), + InputSpec(shape=self.init_z_shape, dtype=self.dtype_z), + ], + ) + fn.eval() + else: + fn = self.net + res = fn(x, y, z) + + if flag == "prim": + ops = [ + op.name() + for op in fn.program_cache.last()[-1][-1] + .infer_program.program.global_block() + .ops + ] + assert self.necessary_ops not in ops + core._set_prim_all_enabled(False) + return res + + def test_prim_all_dynamic(self): + res_ref = self.base_net() + res = self.base_net("prim") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + if __name__ == "__main__": unittest.main() From feaa27ed1ab8371351c262aa5f2144abb6b367df Mon Sep 17 00:00:00 2001 From: zeroRains Date: Sun, 17 Mar 2024 06:55:20 +0000 Subject: [PATCH 2/6] fix code style --- paddle/fluid/primitive/composite/composite.h | 24 ++++++++++++------- .../test_prim_sub_graph_dynamic_shape.py | 1 + 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 1ca78078cb7f24..6027e14c61c0b3 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -794,27 +794,35 @@ std::tuple group_norm_decomp( } Tensor out, mean_, var_; if (has_dynamic_shape(x.shape())) { - std::cout<<"step1_______________________________________________"<(x); std::vector one_axis(1, 1); Tensor x_shape = get_slice(x_dim, 0) * groups; Tensor dim_1 = full({1}, -1, x_dim.type()); - std::cout<<"step2_______________________________________________"<({x_shape, dim_1}); x_cast = backend::reshape(x_cast, x_shape); - std::cout<<"step3_______________________________________________"<(x_cast, IntArray(one_axis), true); - std::cout<<"step4_______________________________________________"<(x_cast * x_cast, IntArray(one_axis), true) - mean_ * mean_; - std::cout<<"step5_______________________________________________"<( var_tmp_, backend::full_with_tensor(shape(var_tmp_), 0, var_tmp_.dtype())); - std::cout<<"step6_______________________________________________"<(var_ + epsilon); // TODO: support dynamic shape - std::cout<<"step7_______________________________________________"<(var_ + epsilon); // TODO: support dynamic shape + std::cout << "step7_______________________________________________" + << std::endl; Tensor res = (x_cast - mean_) * var_inv; out = backend::reshape(res, x_dim); } else { diff --git a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py index dbc71acf96c8ef..0bfd7b1d0681b2 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py @@ -98,6 +98,7 @@ def group_norm_net(x, weight, bias): paddle.assign(bias, group_norm.bias) return group_norm(x) + class TestPrimBase(unittest.TestCase): def setUp(self): np.random.seed(2023) From 233b730f48f5d9b2a4369387017f0e575ed101af Mon Sep 17 00:00:00 2001 From: zeroRains Date: Sun, 17 Mar 2024 08:50:05 +0000 Subject: [PATCH 3/6] remove todo --- paddle/fluid/primitive/composite/composite.h | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 6027e14c61c0b3..8b293b353a0de9 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -794,33 +794,22 @@ std::tuple group_norm_decomp( } Tensor out, mean_, var_; if (has_dynamic_shape(x.shape())) { - std::cout << "step1_______________________________________________" - << std::endl; Tensor x_dim = shape(x); std::vector one_axis(1, 1); Tensor x_shape = get_slice(x_dim, 0) * groups; Tensor dim_1 = full({1}, -1, x_dim.type()); - std::cout << "step2_______________________________________________" - << std::endl; x_shape = concat({x_shape, dim_1}); x_cast = backend::reshape(x_cast, x_shape); - std::cout << "step3_______________________________________________" - << std::endl; mean_ = mean_decomp(x_cast, IntArray(one_axis), true); - std::cout << "step4_______________________________________________" - << std::endl; Tensor var_tmp_ = mean_decomp(x_cast * x_cast, IntArray(one_axis), true) - mean_ * mean_; - std::cout << "step5_______________________________________________" - << std::endl; var_ = maximum( var_tmp_, backend::full_with_tensor(shape(var_tmp_), 0, var_tmp_.dtype())); std::cout << "step6_______________________________________________" << std::endl; - Tensor var_inv = - 1 / sqrt_decomp(var_ + epsilon); // TODO: support dynamic shape + Tensor var_inv = 1 / sqrt_decomp(var_ + epsilon); std::cout << "step7_______________________________________________" << std::endl; Tensor res = (x_cast - mean_) * var_inv; From 9fa990da634193f38f90c06d89458b79ae7635d6 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Tue, 19 Mar 2024 08:21:40 +0000 Subject: [PATCH 4/6] modify the test --- .../test_prim_sub_graph_dynamic_shape.py | 119 ++++++++++-------- 1 file changed, 68 insertions(+), 51 deletions(-) diff --git a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py index f1705d8c603b0e..67390df01d03a8 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py @@ -92,10 +92,32 @@ def swiglu_net2(x): return paddle.incubate.nn.functional.swiglu(x) -def group_norm_net(x, weight, bias): +def group_norm_net1(x): group_norm = paddle.nn.GroupNorm(num_channels=x.shape[1], num_groups=32) - paddle.assign(weight, group_norm.weight) - paddle.assign(bias, group_norm.bias) + return group_norm(x) + + +def group_norm_net2(x): + group_norm = paddle.nn.GroupNorm( + num_channels=x.shape[1], num_groups=32, weight_attr=False + ) + return group_norm(x) + + +def group_norm_net3(x): + group_norm = paddle.nn.GroupNorm( + num_channels=x.shape[1], num_groups=32, bias_attr=False + ) + return group_norm(x) + + +def group_norm_net4(x): + group_norm = paddle.nn.GroupNorm( + num_channels=x.shape[1], + num_groups=32, + weight_attr=False, + bias_attr=False, + ) return group_norm(x) @@ -372,61 +394,56 @@ def setUp(self): self.tol = 1e-6 -class TestPrimThree(unittest.TestCase): +class TestPrimGroupNorm1(unittest.TestCase): def setUp(self): np.random.seed(2023) - self.shape_x = [5, 640, 10, 20] - self.shape_y = [640] - self.shape_z = [640] - self.dtype_x = "float32" - self.dtype_y = "float32" - self.dtype_z = "float32" + self.dtype = "float32" + self.x_shape = [50, 640, 10, 20] self.init_x_shape = [None, 640, None, None] - self.init_y_shape = [640] - self.init_z_shape = [640] - self.x = np.random.random(self.shape_x).astype(self.dtype_x) - self.y = np.random.random(self.shape_y).astype(self.dtype_y) - self.z = np.random.random(self.shape_z).astype(self.dtype_z) - self.net = group_norm_net - self.necessary_ops = "pd_op.group_norm" + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net1 + self.necessary_ops = "pd_op.flatten" self.enable_cinn = False + self.tol = 1e-6 - def base_net(self, flag=None): - x = paddle.to_tensor(self.x) - y = paddle.to_tensor(self.y) - z = paddle.to_tensor(self.z) - if flag == "prim": - core._set_prim_all_enabled(True) - fn = apply_to_static( - self.net, - use_cinn=self.enable_cinn, - input_spec=[ - InputSpec(shape=self.init_x_shape, dtype=self.dtype_x), - InputSpec(shape=self.init_y_shape, dtype=self.dtype_y), - InputSpec(shape=self.init_z_shape, dtype=self.dtype_z), - ], - ) - fn.eval() - else: - fn = self.net - res = fn(x, y, z) - if flag == "prim": - ops = [ - op.name() - for op in fn.program_cache.last()[-1][-1] - .infer_program.program.global_block() - .ops - ] - assert self.necessary_ops not in ops - core._set_prim_all_enabled(False) - return res +class TestPrimGroupNorm2(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.x_shape = [50, 640, 10, 20] + self.init_x_shape = [None, 640, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net2 + self.necessary_ops = "pd_op.flatten" + self.enable_cinn = False + self.tol = 1e-6 - def test_prim_all_dynamic(self): - res_ref = self.base_net() - res = self.base_net("prim") - for ref, actual in zip(res_ref, res): - np.testing.assert_allclose(ref, actual, rtol=1e-6) + +class TestPrimGroupNorm3(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.x_shape = [50, 640, 10, 20] + self.init_x_shape = [None, 640, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net3 + self.necessary_ops = "pd_op.flatten" + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimGroupNorm4(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.x_shape = [50, 640, 10, 20] + self.init_x_shape = [None, 640, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net4 + self.necessary_ops = "pd_op.flatten" + self.enable_cinn = False + self.tol = 1e-6 if __name__ == "__main__": From 451eb16fea8e17dd025f1d68286eb9557befa3de Mon Sep 17 00:00:00 2001 From: zeroRains Date: Tue, 19 Mar 2024 08:23:53 +0000 Subject: [PATCH 5/6] remote debug tag --- paddle/fluid/primitive/composite/composite.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index c815a87bf8f29a..93ed4074b31f20 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -905,11 +905,7 @@ std::tuple group_norm_decomp( var_ = maximum( var_tmp_, backend::full_with_tensor(shape(var_tmp_), 0, var_tmp_.dtype())); - std::cout << "step6_______________________________________________" - << std::endl; Tensor var_inv = 1 / sqrt_decomp(var_ + epsilon); - std::cout << "step7_______________________________________________" - << std::endl; Tensor res = (x_cast - mean_) * var_inv; out = backend::reshape(res, x_dim); } else { From c0eb649d98308894add8bd022fa8b581249a33c7 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Tue, 19 Mar 2024 09:48:14 +0000 Subject: [PATCH 6/6] fix a typo --- test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py index 67390df01d03a8..54fc95319b9094 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py @@ -402,7 +402,7 @@ def setUp(self): self.init_x_shape = [None, 640, None, None] self.x = np.random.random(self.x_shape).astype(self.dtype) self.net = group_norm_net1 - self.necessary_ops = "pd_op.flatten" + self.necessary_ops = "pd_op.group_norm" self.enable_cinn = False self.tol = 1e-6 @@ -415,7 +415,7 @@ def setUp(self): self.init_x_shape = [None, 640, None, None] self.x = np.random.random(self.x_shape).astype(self.dtype) self.net = group_norm_net2 - self.necessary_ops = "pd_op.flatten" + self.necessary_ops = "pd_op.group_norm" self.enable_cinn = False self.tol = 1e-6 @@ -428,7 +428,7 @@ def setUp(self): self.init_x_shape = [None, 640, None, None] self.x = np.random.random(self.x_shape).astype(self.dtype) self.net = group_norm_net3 - self.necessary_ops = "pd_op.flatten" + self.necessary_ops = "pd_op.group_norm" self.enable_cinn = False self.tol = 1e-6 @@ -441,7 +441,7 @@ def setUp(self): self.init_x_shape = [None, 640, None, None] self.x = np.random.random(self.x_shape).astype(self.dtype) self.net = group_norm_net4 - self.necessary_ops = "pd_op.flatten" + self.necessary_ops = "pd_op.group_norm" self.enable_cinn = False self.tol = 1e-6