Skip to content

Commit

Permalink
[XPU] fuse swiglu in fc by using pass (#70221) (#70141)
Browse files Browse the repository at this point in the history
  • Loading branch information
linkk08 authored Dec 19, 2024
1 parent 96ac4d1 commit 1bc8f7d
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 181 deletions.
323 changes: 292 additions & 31 deletions paddle/fluid/pir/transforms/xpu/fc_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,6 @@
#include "paddle/pir/include/pass/pass_registry.h"
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"

/*
fuse malmul + add to fc_xpu
For example:
graph:
x w
\ /
|
mul
|
|
bias --- add
|
|
output
------------------------------------------------------
After the pass is applied:
x w
\ /
|
bias--- fc_xpu
|
|
Output
*/

namespace {

int ConvertActivationType(const std::string &act_type) {
Expand All @@ -76,20 +50,47 @@ int ConvertActivationType(const std::string &act_type) {
return static_cast<int>(xpu::Activation_t::SWISH);
} else if (act_type == "relu6") {
return static_cast<int>(xpu::Activation_t::RELU6);
} else if (act_type == "swish_glu") {
return static_cast<int>(xpu::Activation_t::SWISH_GLU);
} else {
PADDLE_THROW(common::errors::Unimplemented(
"Not support convert activation_type(%s).", act_type));
}
return -1;
}

class FCXpuFusePattern : public paddle::drr::DrrPatternBase {
/*
fuse malmul + add to fc_xpu
For example:
graph:
x w
\ /
|
mul
|
|
bias --- add
|
|
output
------------------------------------------------------
After the pass is applied:
x w
\ /
|
bias--- fc_xpu
|
|
Output
*/
class FcXpuFuseAddPattern : public paddle::drr::DrrPatternBase {
private:
bool transpose_w_;

public:
explicit FCXpuFusePattern(bool transpose_w) : transpose_w_(transpose_w) {}
std::string name() const override { return "FCXpuFusePattern"; }
explicit FcXpuFuseAddPattern(bool transpose_w) : transpose_w_(transpose_w) {}
std::string name() const override { return "FcXpuFuseAddPattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
Expand Down Expand Up @@ -185,14 +186,274 @@ class FCXpuFusePattern : public paddle::drr::DrrPatternBase {
}
};

/*
fuse malmul + add + act to fc_xpu
For example:
graph:
x w
\ /
|
mul
|
|
bias --- add
|
|
act
|
|
output
------------------------------------------------------
After the pass is applied:
x w
\ /
|
bias--- fc_xpu
|
|
Output
*/
class FcXpuFuseAddActPattern : public paddle::drr::DrrPatternBase {
private:
bool transpose_w_;

public:
explicit FcXpuFuseAddActPattern(bool transpose_w)
: transpose_w_(transpose_w) {}
std::string name() const override { return "FcXpuFuseAddActPattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &mul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x")},
{"transpose_y", pat.Attr("transpose_y")}});
mul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("mul_out")});

const auto &add = pat.Op(paddle::dialect::AddOp::name());
add({&pat.Tensor("mul_out"), &pat.Tensor("bias")},
{&pat.Tensor("add_out")});
const auto &swiglu = pat.Op(paddle::dialect::SwigluOp::name());
swiglu({&pat.Tensor("add_out"), &pat.InputNoneTensor()},
{&pat.Tensor("act_out")});

// Constraints
pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) {
auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w"));
auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x"));
auto bias_shape = pir::GetShapeFromValue(match_ctx.Tensor("bias"));
if (transpose_w_ != match_ctx.Attr<bool>("transpose_y")) {
return false;
}
return (w_shape.size() == 2 && x_shape.size() >= 2 &&
bias_shape.size() == 1);
});

// Result pattern
paddle::drr::ResultPattern res = pat.ResultPattern();

const auto &in_num_col_dims_attr =
res.ComputeAttr([&](const paddle::drr::MatchContext &match_ctx) -> int {
auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x"));
return x_shape.size() - 1;
});

if (!transpose_w_) {
// prepare weight, transpose it if necessary
const auto &perm_attr = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int> {
auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w"));
if (w_shape.size() == 2) {
return {1, 0};
} else {
PADDLE_THROW(common::errors::Unimplemented(
"Not support convert w_shape.size()(%d).", w_shape.size()));
}
});
const auto &transpose_op =
res.Op(paddle::dialect::TransposeOp::name(), {{"perm", perm_attr}});
res.Tensor("w_trans") = transpose_op(res.Tensor("w"));
VLOG(3) << "transpose weight for fc_xpu op";
}

const auto &out_dtype_attr = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> phi::DataType {
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
// 目前仅支持以下几种非量化的情况
if (x_dtype.isa<pir::Float32Type>()) {
return phi::DataType::FLOAT32;
} else if (x_dtype.isa<pir::Float16Type>()) {
return phi::DataType::FLOAT16;
} else if (x_dtype.isa<pir::BFloat16Type>()) {
return phi::DataType::BFLOAT16;
} else {
return phi::DataType::UNDEFINED;
}
});
// only support float32 bias now
const auto &cast_op = res.Op(paddle::dialect::CastOp::name(),
{{"dtype", res.DataTypeAttr("float32")}});
res.Tensor("bias_fp32") = cast_op(res.Tensor("bias"));

const auto &fc_xpu = res.Op(
paddle::dialect::FcXpuOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"transpose_x", pat.Attr("transpose_x")},
{"alpha", res.Float32Attr(1.0f)},
{"beta", res.Float32Attr(0.f)},
{"act_type", res.Int32Attr(ConvertActivationType("swish_glu"))},
{"act_alpha", res.Float32Attr(0.0f)},
{"out_dtype", out_dtype_attr},
}});
fc_xpu(
{
&res.Tensor("x"),
&res.InputNoneTensor(),
transpose_w_ ? &res.Tensor("w") : &res.Tensor("w_trans"),
&res.InputNoneTensor(),
&res.Tensor("bias_fp32"),
&res.InputNoneTensor(),
&res.InputNoneTensor(),
},
{&res.Tensor("act_out"), &res.Tensor("out_max")});
}
};

/*
fuse malmul + act to fc_xpu
For example:
graph:
x w
\ /
|
mul
|
|
act
|
|
output
------------------------------------------------------
After the pass is applied:
x w
\ /
|
bias--- fc_xpu
|
|
Output
*/

class FcXpuFuseActPattern : public paddle::drr::DrrPatternBase {
private:
bool transpose_w_;

public:
explicit FcXpuFuseActPattern(bool transpose_w) : transpose_w_(transpose_w) {}
std::string name() const override { return "FcXpuFuseActPattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &mul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x")},
{"transpose_y", pat.Attr("transpose_y")}});
mul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("mul_out")});

const auto &swiglu = pat.Op(paddle::dialect::SwigluOp::name());
swiglu({&pat.Tensor("mul_out"), &pat.InputNoneTensor()},
{&pat.Tensor("act_out")});

// Constraints
pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) {
auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w"));
auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x"));
if (transpose_w_ != match_ctx.Attr<bool>("transpose_y")) {
return false;
}
return (w_shape.size() == 2 && x_shape.size() >= 2);
});

// Result pattern
paddle::drr::ResultPattern res = pat.ResultPattern();

const auto &in_num_col_dims_attr =
res.ComputeAttr([&](const paddle::drr::MatchContext &match_ctx) -> int {
auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x"));
return x_shape.size() - 1;
});

if (!transpose_w_) {
// prepare weight, transpose it if necessary
const auto &perm_attr = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int> {
auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w"));
if (w_shape.size() == 2) {
return {1, 0};
} else {
PADDLE_THROW(common::errors::Unimplemented(
"Not support convert w_shape.size()(%d).", w_shape.size()));
}
});
const auto &transpose_op =
res.Op(paddle::dialect::TransposeOp::name(), {{"perm", perm_attr}});
res.Tensor("w_trans") = transpose_op(res.Tensor("w"));
VLOG(3) << "transpose weight for fc_xpu op";
}

const auto &out_dtype_attr = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> phi::DataType {
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
// 目前仅支持以下几种非量化的情况
if (x_dtype.isa<pir::Float32Type>()) {
return phi::DataType::FLOAT32;
} else if (x_dtype.isa<pir::Float16Type>()) {
return phi::DataType::FLOAT16;
} else if (x_dtype.isa<pir::BFloat16Type>()) {
return phi::DataType::BFLOAT16;
} else {
return phi::DataType::UNDEFINED;
}
});

const auto &fc_xpu = res.Op(
paddle::dialect::FcXpuOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"transpose_x", pat.Attr("transpose_x")},
{"alpha", res.Float32Attr(1.0f)},
{"beta", res.Float32Attr(0.f)},
{"act_type", res.Int32Attr(ConvertActivationType("swish_glu"))},
{"act_alpha", res.Float32Attr(0.0f)},
{"out_dtype", out_dtype_attr},
}});
fc_xpu(
{
&res.Tensor("x"),
&res.InputNoneTensor(),
transpose_w_ ? &res.Tensor("w") : &res.Tensor("w_trans"),
&res.InputNoneTensor(),
&res.InputNoneTensor(),
&res.InputNoneTensor(),
&res.InputNoneTensor(),
},
{&res.Tensor("act_out"), &res.Tensor("out_max")});
}
};

class FCXpuFusePass : public pir::PatternRewritePass {
public:
FCXpuFusePass() : pir::PatternRewritePass("fc_xpu_fuse_pass", 2) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(paddle::drr::Create<FCXpuFusePattern>(context, false));
ps.Add(paddle::drr::Create<FCXpuFusePattern>(context, true));
ps.Add(paddle::drr::Create<FcXpuFuseAddActPattern>(context, false));
ps.Add(paddle::drr::Create<FcXpuFuseAddActPattern>(context, true));
ps.Add(paddle::drr::Create<FcXpuFuseActPattern>(context, false));
ps.Add(paddle::drr::Create<FcXpuFuseActPattern>(context, true));
ps.Add(paddle::drr::Create<FcXpuFuseAddPattern>(context, false));
ps.Add(paddle::drr::Create<FcXpuFuseAddPattern>(context, true));
return ps;
}
};
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,9 @@ void FcXPUInferMeta(const MetaTensor& x,
out_shape[i] = static_cast<int>(x.dims()[i]);
}
out_shape[in_num_col_dims] = static_cast<int>(w.dims()[0]);
if (act_type == 23 /*phi::backends::xpu::Activation_t::SWISH_GLU*/) {
out_shape[in_num_col_dims] = out_shape[in_num_col_dims] / 2;
}
out->set_dims(DDim(out_shape.data(), static_cast<int>(out_shape.size())));
out->set_dtype(out_dtype);
out->set_layout(x.layout());
Expand Down
Loading

0 comments on commit 1bc8f7d

Please sign in to comment.