Skip to content

Commit

Permalink
[Semi-Auto] Adapt layer_norm spmd rule to phi (PaddlePaddle#57374)
Browse files Browse the repository at this point in the history
* adapt layer_norm spmd rule to phi

* modify api in unit test

* bug fix

* fix bug in cpp unit test
  • Loading branch information
pkuzyc authored Sep 22, 2023
1 parent 5badfd7 commit 24128d7
Show file tree
Hide file tree
Showing 10 changed files with 457 additions and 382 deletions.

This file was deleted.

This file was deleted.

4 changes: 0 additions & 4 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
Expand All @@ -29,9 +28,6 @@ namespace paddle {
namespace distributed {
namespace auto_parallel {

// layer_norm rule
REGISTER_SPMD_RULE(layer_norm, LayerNormSPMDRule);

// replicated rule
REGISTER_SPMD_RULE(replicated, ReplicatedSPMDRule);

Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ AttrType InferSpmdContext::AttrAt(size_t idx) const {
}
}

template float InferSpmdContext::AttrAt(size_t idx) const;
template int InferSpmdContext::AttrAt(size_t idx) const;

template <>
bool InferSpmdContext::AttrAt(size_t idx) const {
try {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ struct InferSpmdFnImpl<Return (*)(Args...), infer_spmd_fn> {

// TODO(chenweihang): support other attr type later as needed
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(bool);
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(float);
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<int64_t>);
Expand Down
Loading

0 comments on commit 24128d7

Please sign in to comment.