Skip to content

Commit 5a9a249

Browse files
committed
feat:支持layernorm融合算子
1 parent 28f6fb0 commit 5a9a249

File tree

8 files changed

+363
-8
lines changed

8 files changed

+363
-8
lines changed

scripts/onnx/make_serialize.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from refactor_graph.onnx import make_compiler
22
from onnx import load
33
import argparse
4+
from onnx.external_data_helper import load_external_data_for_model
5+
46

57
def parse_args():
68
parser = argparse.ArgumentParser(
@@ -9,17 +11,27 @@ def parse_args():
911
parser.add_argument(
1012
"--model", type=str, required=True, help="Path to the model file file."
1113
)
12-
parser.add_argument("--output", type=str, default="./", help="Path to save the output file.")
14+
parser.add_argument(
15+
"--output", type=str, default="./", help="Path to save the output file."
16+
)
1317
args = parser.parse_args()
1418
return (
1519
args.model,
1620
args.output,
1721
)
1822

23+
1924
def main():
2025
model_path, output_path = parse_args()
21-
compiler = make_compiler(load(model_path))
26+
model = load(model_path)
27+
# model = load(model_path, load_external_data=False)
28+
# load_external_data_for_model(
29+
# model,
30+
# "/home/zhangyunze/workspace/RefactorGraph/scripts/onnx/bert_bs1.pb",
31+
# )
32+
compiler = make_compiler(model)
2233
compiler.serialize(output_path)
2334

35+
2436
if __name__ == "__main__":
25-
main()
37+
main()

scripts/onnx/to_onnx.py

+48-5
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
121121
),
122122
[],
123123
)
124-
if self.type == "Relu":
124+
if self.type in ["Relu", "Tanh"]:
125125
return (
126126
make_node(
127127
self.type,
@@ -166,6 +166,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
166166
"Log",
167167
"Neg",
168168
"Sigmoid",
169+
"Where",
169170
]:
170171
return (
171172
make_node(
@@ -235,14 +236,14 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
235236
),
236237
[shape],
237238
)
238-
if self.type in ["Gather", "Concat", "Softmax"]:
239+
if self.type in ["Gather", "Concat", "Softmax", "Split"]:
239240
meta = self.meta.split(b"/")
240241
axis = int(meta[0])
241242
return (
242243
make_node(
243244
self.type,
244245
[tensors[i].name for i in self.topo.inputs],
245-
[tensors[self.topo.outputs[0]].name],
246+
[tensors[i].name for i in self.topo.outputs],
246247
self.name,
247248
domain=DEFAULT_DOMAIN,
248249
axis=axis,
@@ -251,7 +252,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
251252
)
252253
if self.type == "ReduceMean":
253254
meta = self.meta.split(b",")
254-
keepDims = meta[2] == b"true"
255+
keepDims = meta[2] == b" true"
255256
axes = [int(x) for x in split_array(meta[0])]
256257
return (
257258
make_node(
@@ -311,7 +312,35 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
311312
[tensors[i].name for i in self.topo.outputs],
312313
self.name,
313314
domain="refactor",
314-
epsilon=1e-5,
315+
epsilon=float(self.meta.split(b"=")[0]),
316+
),
317+
[],
318+
)
319+
if self.type == "LayerNormalization":
320+
meta = self.meta.split(b",")
321+
epsilon = float(meta[0].split(b"=")[0].strip())
322+
axis = int(meta[1])
323+
return (
324+
make_node(
325+
self.type,
326+
[tensors[i].name for i in self.topo.inputs],
327+
[tensors[i].name for i in self.topo.outputs],
328+
self.name,
329+
domain="refactor",
330+
epsilon=epsilon,
331+
axis=axis,
332+
),
333+
[],
334+
)
335+
if self.type == "RotaryPositionEmbedding":
336+
return (
337+
make_node(
338+
self.type,
339+
[tensors[i].name for i in self.topo.inputs],
340+
[tensors[i].name for i in self.topo.outputs],
341+
self.name,
342+
domain="refactor",
343+
theta=float(self.meta.split(b"=")[0]),
315344
),
316345
[],
317346
)
@@ -364,7 +393,14 @@ def main():
364393
with open(data_path, "r") as f:
365394
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as m:
366395
nodes = []
396+
# for t in tensors:
397+
# if t.size != 0:
398+
# print(f"tensor_name is {t.name}")
367399
initializer = [
400+
# (
401+
# ,
402+
# print(f"tensor_name is {t.name}"),
403+
# )
368404
make_tensor(
369405
t.name,
370406
t.dt,
@@ -391,6 +427,13 @@ def main():
391427
for t in (tensors[i] for i in graph.outputs)
392428
],
393429
initializer,
430+
value_info=[
431+
make_tensor_value_info(t.name, t.dt, t.shape)
432+
for t in tensors
433+
if t.size == 0
434+
and t.name not in graph.inputs
435+
and t.name not in graph.outputs
436+
],
394437
)
395438
# model = make_model(
396439
# graph, opset_imports=[make_opsetid(domain="", version=13)]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef COMPUTATION_LAYER_NORMALIZATION_H
2+
#define COMPUTATION_LAYER_NORMALIZATION_H
3+
4+
#include "../operator.h"
5+
6+
namespace refactor::computation {
7+
8+
struct LayerNormalization final : public Operator {
9+
float epsilon;
10+
int axis;
11+
12+
constexpr explicit LayerNormalization(float epsilon_, int axis_) noexcept
13+
: Operator(), epsilon(epsilon_), axis(axis_) {}
14+
15+
static size_t typeId() noexcept;
16+
size_t opTypeId() const noexcept final;
17+
std::string_view name() const noexcept final;
18+
// kernel::CollectorBox candidateKernels(Target) const final;
19+
std::string serialize() const noexcept final;
20+
};
21+
22+
}// namespace refactor::computation
23+
24+
#endif// COMPUTATION_LAYER_NORMALIZATION_H
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#ifndef COMPUTATION_LAYERNORM_FUSE_H
2+
#define COMPUTATION_LAYERNORM_FUSE_H
3+
4+
#include "../graph.h"
5+
#include "computation/operators/layernorm.h"
6+
#include "computation/operators/reduce.h"
7+
#include "computation/operators/simple_binary.h"
8+
#include "computation/operators/simple_unary.h"
9+
#include "computation/pass/converter.h"
10+
11+
namespace refactor::computation {
12+
13+
class LayernormFuse : public Converter {
14+
public:
15+
virtual bool execute(const std::shared_ptr<GraphMutant> &g) const override {
16+
auto nodesList = g->internal().nodes();
17+
size_t count = 0;
18+
for (auto opMatch : nodesList) {
19+
if (opMatch->info().op == nullptr) {
20+
continue;
21+
}
22+
size_t optype = opMatch->info().op->opTypeId();
23+
if (optype != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) {
24+
continue;
25+
}
26+
if (opMatch->successors().size() < 2) {
27+
continue;
28+
}
29+
auto input = opMatch->inputs()[0]->info().tensor;
30+
auto targets = opMatch->outputs()[0]->targets();
31+
auto ReduceMeanOp = *targets.begin();
32+
auto SubOp1 = *(std::next(targets.begin()));
33+
if (ReduceMeanOp == nullptr || SubOp1 == nullptr ||
34+
ReduceMeanOp->info().op->opTypeId() != Reduce::typeId(refactor::kernel::ReduceType::Mean) ||
35+
SubOp1->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Sub)) {
36+
continue;
37+
}
38+
auto reduceOp = dynamic_cast<Reduce *>(ReduceMeanOp->info().op.get());
39+
auto axes = reduceOp->axes;
40+
if (axes.size() != 1) {
41+
continue;
42+
}
43+
auto keepDims = reduceOp->keepDims;
44+
if (ReduceMeanOp->successors().size() != 1 || *(ReduceMeanOp->outputs()[0]->targets().begin()) != SubOp1) {
45+
continue;
46+
}
47+
if (SubOp1->successors().size() != 2) {
48+
continue;
49+
}
50+
auto targets1 = SubOp1->outputs()[0]->targets();
51+
auto PowOp = *targets1.begin();
52+
auto DivOp = *(std::next(targets1.begin()));
53+
if (PowOp == nullptr || DivOp == nullptr ||
54+
PowOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Pow) ||
55+
DivOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Div)) {
56+
continue;
57+
}
58+
if (PowOp->successors().size() != 1 || DivOp->successors().size() != 1) {
59+
continue;
60+
}
61+
auto ReduceMeanOp1 = *(PowOp->outputs()[0]->targets().begin());
62+
auto MulOp = *(DivOp->outputs()[0]->targets().begin());
63+
if (ReduceMeanOp1 == nullptr || MulOp == nullptr ||
64+
ReduceMeanOp1->info().op->opTypeId() != Reduce::typeId(refactor::kernel::ReduceType::Mean) ||
65+
MulOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Mul)) {
66+
continue;
67+
}
68+
auto reduce1Op = dynamic_cast<Reduce *>(ReduceMeanOp1->info().op.get());
69+
auto axes1 = reduce1Op->axes;
70+
if (axes != axes1) {
71+
continue;
72+
}
73+
if (auto keepDims1 = reduce1Op->keepDims; keepDims != keepDims1) {
74+
continue;
75+
}
76+
if (MulOp->successors().size() != 1 || ReduceMeanOp1->successors().size() != 1) {
77+
continue;
78+
}
79+
auto AddOrSqrtOp = *(ReduceMeanOp1->outputs()[0]->targets().begin());
80+
auto AddOp2 = *(MulOp->outputs()[0]->targets().begin());
81+
if (AddOrSqrtOp == nullptr || AddOp2 == nullptr ||
82+
AddOp2->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) {
83+
continue;
84+
}
85+
if (AddOrSqrtOp->successors().size() != 1) {
86+
continue;
87+
}
88+
float epsilon = 0.0;
89+
if (auto AddOp = AddOrSqrtOp; AddOp->info().op->opTypeId() == SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) {
90+
auto SqrtOp = *(AddOp->outputs()[0]->targets().begin());
91+
if (SqrtOp == nullptr || SqrtOp->info().op->opTypeId() != SimpleUnary::typeId(refactor::kernel::SimpleUnaryType::Sqrt)) {
92+
continue;
93+
}
94+
if (SqrtOp->successors().size() != 1 || *(SqrtOp->outputs()[0]->targets().begin()) != DivOp) {
95+
continue;
96+
}
97+
// start replace with LayernormOp
98+
if (auto t = AddOp->inputs()[1]->info().tensor->data; t) {
99+
epsilon = *t->get<float>();
100+
}
101+
} else if (auto SqrtOp = AddOrSqrtOp; SqrtOp->info().op->opTypeId() == SimpleUnary::typeId(refactor::kernel::SimpleUnaryType::Sqrt)) {
102+
if (*(SqrtOp->outputs()[0]->targets().begin()) != DivOp) {
103+
continue;
104+
}
105+
} else {
106+
continue;
107+
}
108+
109+
int axis = axes[0];
110+
auto layernormOp = g->internal().pushNode(
111+
{std::make_unique<LayerNormalization>(epsilon, axis), fmt::format("Layernorm_{}", count)},
112+
{g->internal().shareEdge({Tensor::share(input->dataType, input->shape), fmt::format("Layernorm_{}_out", count)})});
113+
layernormOp->connect(0, opMatch->outputs()[0]);
114+
layernormOp->connect(1, MulOp->inputs()[1]);
115+
layernormOp->connect(2, AddOp2->inputs()[1]);
116+
if (AddOp2->outputs()[0]->targets().size() == 0) {//global output
117+
g->internal().replaceOutput(AddOp2->outputs()[0], layernormOp->outputs()[0]);
118+
} else {
119+
for (auto node : AddOp2->outputs()[0]->targets()) {
120+
auto it = std::find(node->inputs().begin(), node->inputs().end(), AddOp2->outputs()[0]);
121+
node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], layernormOp->outputs()[0]);
122+
}
123+
}
124+
count++;
125+
g->internal().cleanup();
126+
}
127+
return true;
128+
};
129+
};
130+
131+
132+
}// namespace refactor::computation
133+
134+
#endif// COMPUTATION_LAYERNORM_FUSE_H

src/05computation/include/computation/pass_register.h

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define COMPUTATION_PASS_REGISTER_H
33
#include "pass/conv_to_matmul.h"
44
#include "pass/converter.h"
5+
#include "pass/layernorm_fuse.h"
56
#include "pass/matmul_transpose.h"
67

78
namespace refactor::computation {
@@ -10,6 +11,7 @@ namespace refactor::computation {
1011
#define REGISTER(PASS, NAME) static ConverterRegister<PASS> NAME("" #NAME);
1112
REGISTER(MatMulTransposeFuse, MatMulTransposeFuse)
1213
REGISTER(ConvToMatmul, ConvToMatmul)
14+
REGISTER(LayernormFuse, LayernormFuse)
1315
};
1416

1517

src/05computation/src/graph.cc

+1
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ namespace refactor::computation {
220220
void Graph::optimize() {
221221
auto graphMutant = GraphMutant(*this);
222222
std::vector<std::string_view> passes = {
223+
"LayernormFuse",
223224
// "MatMulTransposeFuse",
224225
// "ConvToMatmul",
225226
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include "computation/operators/layernorm.h"
2+
3+
namespace refactor::computation {
4+
using Op = LayerNormalization;
5+
6+
auto Op::typeId() noexcept -> size_t {
7+
static uint8_t ID = 1;
8+
return reinterpret_cast<size_t>(&ID);
9+
}
10+
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
11+
auto Op::name() const noexcept -> std::string_view { return "LayerNormalization"; }
12+
auto Op::serialize() const noexcept -> std::string {
13+
union code {
14+
float f;
15+
int32_t i;
16+
};
17+
return fmt::format(("{}({:e}={:#010x},{})"),
18+
name(), epsilon,
19+
code{epsilon}.i, axis);
20+
}
21+
22+
}// namespace refactor::computation

0 commit comments

Comments
 (0)