Skip to content

Commit

Permalink
fused_rms_norm infer spmd
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 committed Jan 10, 2024
1 parent 74bc22d commit 5b46fe6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/spmd_rules/rms_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ SpmdInfo RmsNormInferSpmd(const DistMetaTensor& x,
return {{x_dist_attr_dst, scale_dist_attr_dst}, {out, invvar}};
}

SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x,
SpmdInfo RmsNormInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& out,
const DistMetaTensor& invvar,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/spmd_rules/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SpmdInfo RmsNormInferSpmd(const DistMetaTensor& x,
SpmdInfo RmsNormGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& invvar,
const DistMetaTensor& dy,
const DistMetaTensor& out_grad,
float epsilon);

SpmdInfo RmsNormInferSpmdReverse(const DistMetaTensor& x,
Expand Down
35 changes: 35 additions & 0 deletions test/cpp/auto_parallel/fused_rms_norm_spmd_rule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ TEST(FusedRmsNormSPMDRule, test_fused_rms_norm) {
// build input data class
std::vector<int64_t> x_shape = {64, 32, 1024};
std::vector<int64_t> scale_shape = {1024};
std::vector<int64_t> variance_shape = {64, 32};

std::vector<int64_t> mesh_shape = {2, 3};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5};
Expand Down Expand Up @@ -68,6 +69,40 @@ TEST(FusedRmsNormSPMDRule, test_fused_rms_norm) {
check_dim_mapping(infered_dist_attrs.second[0], {1, 0, -1});
check_dim_mapping(infered_dist_attrs.second[1], {1, 0});
VLOG(4) << "test2 done.";


TensorDistAttr out_dist_attr = TensorDistAttr();
out_dist_attr.set_process_mesh(process_mesh);
out_dist_attr.set_dims_mapping(std::vector<int64_t>({0, 1, -1}));
out_dist_attr.set_dynamic_dims(std::vector<bool>({false, false, false}));
phi::distributed::DistMetaTensor out(common::make_ddim(x_shape), out_dist_attr);


TensorDistAttr invvar_dist_attr = TensorDistAttr();
invvar_dist_attr.set_process_mesh(process_mesh);
invvar_dist_attr.set_dims_mapping(std::vector<int64_t>({0, -1}));
invvar_dist_attr.set_dynamic_dims(std::vector<bool>({false, false}));
phi::distributed::DistMetaTensor invvar(common::make_ddim(variance_shape),invvar_dist_attr);

infered_dist_attrs = phi::distributed::RmsNormInferSpmdReverse(x, scale, out, invvar, 0.5);
check_dim_mapping(infered_dist_attrs.first[0], {0, 1, -1});
check_dim_mapping(infered_dist_attrs.first[1], {-1});
check_dim_mapping(infered_dist_attrs.second[0], {0, 1, -1});
check_dim_mapping(infered_dist_attrs.second[1], {0, 1});
VLOG(4) << "test3 done.";

x_dist_attr.set_dims_mapping({0, 1, -1});
x = phi::distributed::DistMetaTensor(common::make_ddim(x_shape), x_dist_attr);
infered_dist_attrs = phi::distributed::RmsNormGradInferSpmd(x, scale, invvar, out, 0.5);

check_dim_mapping(infered_dist_attrs.first[0], {0, 1, -1});
check_dim_mapping(infered_dist_attrs.first[1], {-1});
check_dim_mapping(infered_dist_attrs.first[2], {0, 1});
check_dim_mapping(infered_dist_attrs.first[3], {0, 1, -1});
check_dim_mapping(infered_dist_attrs.second[0], {0, 1, -1});
check_dim_mapping(infered_dist_attrs.second[1], {-1});
check_partial_dims(infered_dist_attrs.second[1], {0, 1});

}
} // namespace auto_parallel
} // namespace distributed
Expand Down

0 comments on commit 5b46fe6

Please sign in to comment.