-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Prim][PIR] group_norm decomposite rule support dynamic shape #62793
Changes from 6 commits
cec228a
feaa27e
233b730
845eaee
9fa990d
451eb16
c0eb649
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -890,21 +890,38 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp( | |
if (need_cast) { | ||
x_cast = cast<T>(x, DataType::FLOAT32); | ||
} | ||
|
||
auto x_dim = x.shape(); | ||
std::vector<int64_t> one_axis(1, 1); | ||
|
||
std::vector<int64_t> x_shape{x_dim[0] * groups, -1}; | ||
x_cast = reshape<T>(x_cast, x_shape); | ||
auto mean_ = mean_decomp<T>(x_cast, IntArray(one_axis), true); | ||
auto var_tmp_ = | ||
mean_decomp<T>(x_cast * x_cast, IntArray(one_axis), true) - mean_ * mean_; | ||
auto var_ = | ||
maximum<T>(var_tmp_, full<T>(var_tmp_.shape(), 0, var_tmp_.dtype())); | ||
auto var_inv = 1 / sqrt_decomp<T>(var_ + epsilon); | ||
auto res = (x_cast - mean_) * var_inv; | ||
auto out = reshape<T>(res, x_dim); | ||
|
||
Tensor out, mean_, var_; | ||
if (has_dynamic_shape(x.shape())) { | ||
Tensor x_dim = shape<T>(x); | ||
std::vector<int64_t> one_axis(1, 1); | ||
Tensor x_shape = get_slice<T>(x_dim, 0) * groups; | ||
Tensor dim_1 = full<T>({1}, -1, x_dim.type()); | ||
x_shape = concat<T>({x_shape, dim_1}); | ||
x_cast = backend::reshape<T>(x_cast, x_shape); | ||
mean_ = mean_decomp<T>(x_cast, IntArray(one_axis), true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove IntArray |
||
Tensor var_tmp_ = | ||
mean_decomp<T>(x_cast * x_cast, IntArray(one_axis), true) - | ||
mean_ * mean_; | ||
var_ = maximum<T>( | ||
var_tmp_, | ||
backend::full_with_tensor<T>(shape<T>(var_tmp_), 0, var_tmp_.dtype())); | ||
Tensor var_inv = 1 / sqrt_decomp<T>(var_ + epsilon); | ||
Tensor res = (x_cast - mean_) * var_inv; | ||
out = backend::reshape<T>(res, x_dim); | ||
} else { | ||
auto x_dim = x.shape(); | ||
std::vector<int64_t> one_axis(1, 1); | ||
|
||
std::vector<int64_t> x_shape{x_dim[0] * groups, -1}; | ||
x_cast = reshape<T>(x_cast, x_shape); | ||
mean_ = mean_decomp<T>(x_cast, IntArray(one_axis), true); | ||
auto var_tmp_ = mean_decomp<T>(x_cast * x_cast, IntArray(one_axis), true) - | ||
mean_ * mean_; | ||
var_ = maximum<T>(var_tmp_, full<T>(var_tmp_.shape(), 0, var_tmp_.dtype())); | ||
auto var_inv = 1 / sqrt_decomp<T>(var_ + epsilon); | ||
auto res = (x_cast - mean_) * var_inv; | ||
out = reshape<T>(res, x_dim); | ||
} | ||
auto scale_ptr = scale.get_ptr(); | ||
auto bias_ptr = bias.get_ptr(); | ||
|
||
|
@@ -933,11 +950,20 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp( | |
} | ||
out = out + bias_cast; | ||
} | ||
|
||
std::vector<int64_t> res_shape{x_dim[0], groups}; | ||
auto mean_out = reshape<T>(mean_, res_shape); | ||
auto var_out = reshape<T>(var_, res_shape); | ||
|
||
Tensor mean_out, var_out; | ||
if (has_dynamic_shape(x.shape())) { | ||
Tensor x_dim = shape<T>(x); | ||
Tensor x_shape = get_slice<T>(x_dim, 0); | ||
Tensor dim_1 = full<T>({1}, groups, x_shape.type()); | ||
x_shape = concat<T>({x_shape, dim_1}); | ||
mean_out = backend::reshape<T>(mean_, x_shape); | ||
var_out = backend::reshape<T>(var_, x_shape); | ||
} else { | ||
auto x_dim = x.shape(); | ||
std::vector<int64_t> res_shape{x_dim[0], groups}; | ||
mean_out = reshape<T>(mean_, res_shape); | ||
var_out = reshape<T>(var_, res_shape); | ||
} | ||
if (need_cast) { | ||
out = cast<T>(out, org_dtype); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,6 +92,35 @@ def swiglu_net2(x): | |
return paddle.incubate.nn.functional.swiglu(x) | ||
|
||
|
||
def group_norm_net1(x): | ||
group_norm = paddle.nn.GroupNorm(num_channels=x.shape[1], num_groups=32) | ||
return group_norm(x) | ||
|
||
|
||
def group_norm_net2(x): | ||
group_norm = paddle.nn.GroupNorm( | ||
num_channels=x.shape[1], num_groups=32, weight_attr=False | ||
) | ||
return group_norm(x) | ||
|
||
|
||
def group_norm_net3(x): | ||
group_norm = paddle.nn.GroupNorm( | ||
num_channels=x.shape[1], num_groups=32, bias_attr=False | ||
) | ||
return group_norm(x) | ||
|
||
|
||
def group_norm_net4(x): | ||
group_norm = paddle.nn.GroupNorm( | ||
num_channels=x.shape[1], | ||
num_groups=32, | ||
weight_attr=False, | ||
bias_attr=False, | ||
) | ||
return group_norm(x) | ||
|
||
|
||
def layer_norm_net1(x): | ||
return paddle.nn.functional.layer_norm(x, x.shape[1:]) | ||
|
||
|
@@ -365,5 +394,57 @@ def setUp(self): | |
self.tol = 1e-6 | ||
|
||
|
||
class TestPrimGroupNorm1(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2023) | ||
self.dtype = "float32" | ||
self.x_shape = [50, 640, 10, 20] | ||
self.init_x_shape = [None, 640, None, None] | ||
self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
self.net = group_norm_net1 | ||
self.necessary_ops = "pd_op.flatten" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. group_norm There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
self.enable_cinn = False | ||
self.tol = 1e-6 | ||
|
||
|
||
class TestPrimGroupNorm2(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2023) | ||
self.dtype = "float32" | ||
self.x_shape = [50, 640, 10, 20] | ||
self.init_x_shape = [None, 640, None, None] | ||
self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
self.net = group_norm_net2 | ||
self.necessary_ops = "pd_op.flatten" | ||
self.enable_cinn = False | ||
self.tol = 1e-6 | ||
|
||
|
||
class TestPrimGroupNorm3(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2023) | ||
self.dtype = "float32" | ||
self.x_shape = [50, 640, 10, 20] | ||
self.init_x_shape = [None, 640, None, None] | ||
self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
self.net = group_norm_net3 | ||
self.necessary_ops = "pd_op.flatten" | ||
self.enable_cinn = False | ||
self.tol = 1e-6 | ||
|
||
|
||
class TestPrimGroupNorm4(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2023) | ||
self.dtype = "float32" | ||
self.x_shape = [50, 640, 10, 20] | ||
self.init_x_shape = [None, 640, None, None] | ||
self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
self.net = group_norm_net4 | ||
self.necessary_ops = "pd_op.flatten" | ||
self.enable_cinn = False | ||
self.tol = 1e-6 | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.