Skip to content

Commit

Permalink
[Semi-Auto] Move and adapt elementwise rule to phi (#57197)
Browse files Browse the repository at this point in the history
* adapt general spmd rule

* polish details

* add new rules

* bugfix

---------

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: liangjianzhong <liangjianzhong@baidu.com>
  • Loading branch information
3 people authored Sep 20, 2023
1 parent 4fd9767 commit a98e997
Show file tree
Hide file tree
Showing 9 changed files with 829 additions and 431 deletions.

This file was deleted.

This file was deleted.

89 changes: 0 additions & 89 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h"
Expand All @@ -43,94 +42,6 @@ REGISTER_SPMD_RULE(min, ReductionSPMDRule);
REGISTER_SPMD_RULE(prod, ReductionSPMDRule);
REGISTER_SPMD_RULE(sum, ReductionSPMDRule);

// elementwise rule
REGISTER_SPMD_RULE(add, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(assign, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(assign_out_, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(divide, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(elementwise_pow, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(exponential_, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(floor_divide, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(fmin, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(hardswish, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(heaviside, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(maximum, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(minimum, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(mish, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(multiply, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(relu6, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(remainder, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(subtract, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(swish, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(acos, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(acosh, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(asin, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(asinh, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(atan, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(atanh, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(bernoulli, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(bitwise_and, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(bitwise_not, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(bitwise_or, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(bitwise_xor, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(ceil, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(celu, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(clip, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(conj, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(cos, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(cosh, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(det, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(digamma, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(elu, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(erf, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(erfinv, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(exp, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(expm1, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(fill, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(floor, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(fmax, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(gelu, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(hardshrink, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(hardsigmoid, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(hardtanh, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(label_smooth, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(leaky_relu, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(lgamma, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(log, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(log10, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(log1p, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(log2, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(logical_and, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(logical_not, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(logical_or, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(logical_xor, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(logit, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(logsigmoid, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(poisson, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(pow, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(reciprocal, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(relu, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(round, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(rsqrt, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(scale, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(selu, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(sigmoid, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(sign, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(silu, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(sin, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(sinh, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(softplus, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(softshrink, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(softsign, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(sqrt, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(square, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(stanh, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(tan, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(tanh, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(tanh_shrink, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(thresholded_relu, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(trunc, ElementwiseSPMDRule);

// layer_norm rule
REGISTER_SPMD_RULE(layer_norm, LayerNormSPMDRule);

Expand Down
Loading

0 comments on commit a98e997

Please sign in to comment.