Skip to content

Commit 2230e5b

Browse files
authored
Conv+Activation fusion for CPU (#105)
* Add conv+activation fusion. * Adding tests * Adding activation LeakyRelu. * Refactoring the code to use a fusedConv custom op instead of changing the original conv op at runtime. * fix build issue. * fix build issue. * In order to reduce binary size: 1. reuse onnx shape inference for conv 2. remove most doc. * Accomodating PR comments. * Accomodating PR comments * Remove unused variables
1 parent 7f0e526 commit 2230e5b

File tree

14 files changed

+537
-155
lines changed

14 files changed

+537
-155
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "fused_conv.h"
5+
6+
namespace onnxruntime {
7+
namespace contrib {
8+
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
9+
FusedConv,
10+
1,
11+
float,
12+
KernelDefBuilder()
13+
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
14+
FusedConv<float>);
15+
} // namespace contrib
16+
} // namespace onnxruntime
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/cpu/nn/conv_impl.h"
7+
8+
namespace onnxruntime {
9+
namespace contrib {
10+
11+
template <typename T>
12+
class FusedConv : public Conv<T> {
13+
public:
14+
FusedConv(const OpKernelInfo& info) : Conv<T>(info) {
15+
Conv<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
16+
Conv<T>::alpha_ = info.GetAttrOrDefault("alpha", 0.01f);
17+
}
18+
19+
Status Compute(OpKernelContext* context) const override {
20+
return Conv<T>::Compute(context);
21+
}
22+
};
23+
} // namespace contrib
24+
} // namespace onnxruntime

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
#include "core/graph/contrib_ops/contrib_defs.h"
77
#include "core/graph/contrib_ops/range_schema_defs.h"
88
#include "core/graph/op.h"
9+
#include "onnx/defs/shape_inference.h"
910

11+
namespace ONNX_NAMESPACE {
12+
void convPoolTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape);
13+
}
1014
namespace onnxruntime {
1115
namespace contrib {
1216
using ::ONNX_NAMESPACE::AttributeProto;
@@ -28,6 +32,62 @@ void RegisterContribSchemas() {
2832
Sample echo operator.)DOC");
2933

3034
// register schemas for more operators here
35+
ONNX_CONTRIB_OPERATOR_SCHEMA(FusedConv)
36+
.SetDomain(kMSDomain)
37+
.SinceVersion(1)
38+
.SetDoc(R"DOC(
39+
The fused convolution operator schema is the same as Conv besides it includes an attribute
40+
activation.)DOC")
41+
.Attr(
42+
"auto_pad",
43+
"",
44+
AttributeProto::STRING,
45+
std::string("NOTSET"))
46+
.Attr(
47+
"kernel_shape",
48+
"",
49+
AttributeProto::INTS,
50+
OPTIONAL)
51+
.Attr(
52+
"dilations",
53+
"",
54+
AttributeProto::INTS,
55+
OPTIONAL)
56+
.Attr(
57+
"strides", "", AttributeProto::INTS, OPTIONAL)
58+
.Attr("pads",
59+
"",
60+
AttributeProto::INTS, OPTIONAL)
61+
.Attr(
62+
"group",
63+
"",
64+
AttributeProto::INT,
65+
static_cast<int64_t>(1))
66+
.Attr(
67+
"activation",
68+
"",
69+
AttributeProto::STRING,
70+
OPTIONAL)
71+
.Input(
72+
0,
73+
"X",
74+
"",
75+
"T")
76+
.Input(
77+
1,
78+
"W",
79+
"",
80+
"T")
81+
.Input(2, "B", "", "T", OpSchema::Optional)
82+
.Output(
83+
0,
84+
"Y",
85+
"",
86+
"T")
87+
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors")
88+
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
89+
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
90+
});
3191

3292
ONNX_CONTRIB_OPERATOR_SCHEMA(ExpandDims)
3393
.SetDomain(kMSDomain)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/graph/initializer.h"
5+
#include "core/graph/conv_activation_fusion.h"
6+
#include "core/graph/graph_utils.h"
7+
8+
using namespace onnx;
9+
using namespace ::onnxruntime::common;
10+
namespace onnxruntime {
11+
12+
namespace {
13+
bool IsFusableActivation(const Node& node) {
14+
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
15+
}
16+
} // namespace
17+
18+
Status ConvActivationFusion::Apply(Graph& graph, bool& modified) const {
19+
GraphViewer graph_viewer(graph);
20+
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
21+
22+
std::vector<onnxruntime::NodeIndex> removed_nodes;
23+
for (auto index : order) {
24+
auto node = graph.GetNode(index);
25+
if (!utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", 1) || node->GetOutputEdgesCount() != 1) {
26+
continue;
27+
}
28+
const Node& next_node = *(node->OutputNodesBegin());
29+
if (!IsFusableActivation(next_node) || graph.IsNodeOutputsInGraphOutputs(next_node)) {
30+
continue;
31+
}
32+
33+
Node* conv_node = node;
34+
const Node& act_node = next_node;
35+
std::vector<NodeArg> input_args, output_args;
36+
37+
Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node->Name()), "FusedConv",
38+
"fused Conv " + conv_node->Name() + "with activation " + act_node.OpType(),
39+
conv_node->MutableInputDefs(),
40+
conv_node->MutableOutputDefs(),
41+
&conv_node->GetAttributes(),
42+
"com.microsoft");
43+
44+
//Add a new attribute to specify the activation type
45+
fused_conv.AddAttribute("activation", "string");
46+
47+
//Add optional attributes for activations
48+
if (act_node.OpType() == "LeakyRelu") {
49+
const NodeAttributes attrs = act_node.GetAttributes();
50+
for (auto it = attrs.begin(); it != attrs.end(); ++it) {
51+
fused_conv.AddAttribute(it->first, it->second);
52+
}
53+
}
54+
55+
// Replace the input of the node following activation node
56+
const NodeArg* act_output_def = act_node.OutputDefs()[0];
57+
NodeArg* fused_conv_output_def = fused_conv.MutableOutputDefs()[0];
58+
for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) {
59+
auto output_node = graph.GetNode((*it).Index());
60+
if (!output_node) {
61+
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
62+
}
63+
64+
auto& input_defs = output_node->MutableInputDefs();
65+
for (auto& def : input_defs) {
66+
if (def == act_output_def) {
67+
def = fused_conv_output_def;
68+
}
69+
}
70+
}
71+
72+
removed_nodes.push_back(act_node.Index());
73+
removed_nodes.push_back(conv_node->Index());
74+
}
75+
76+
for (auto i : removed_nodes) {
77+
graph.RemoveNode(i);
78+
}
79+
80+
if (!removed_nodes.empty()) {
81+
modified = true;
82+
ONNXRUNTIME_RETURN_IF_ERROR(graph.Resolve());
83+
}
84+
return Status::OK();
85+
}
86+
} // namespace onnxruntime
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/graph/graph_transformer.h"
7+
8+
namespace onnxruntime {
9+
10+
class ConvActivationFusion : public onnxruntime::GraphTransformer {
11+
public:
12+
ConvActivationFusion() noexcept : onnxruntime::GraphTransformer("ConvActivationFusion", "Fusing Activation into Conv") {}
13+
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
14+
};
15+
16+
} // namespace onnxruntime

onnxruntime/core/providers/cpu/nn/conv.cc

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,159 @@
44
#include "core/providers/cpu/nn/conv_impl.h"
55

66
namespace onnxruntime {
7+
8+
template <>
9+
Status Conv<float>::Compute(OpKernelContext* context) const {
10+
size_t num_inputs = OpKernel::Node().InputDefs().size();
11+
const Tensor* X = context->Input<Tensor>(0);
12+
const Tensor* W = context->Input<Tensor>(1);
13+
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
14+
const int64_t N = X->Shape()[0];
15+
const int64_t C = X->Shape()[1];
16+
const int64_t M = W->Shape()[0];
17+
ONNXRUNTIME_RETURN_IF_ERROR(ValidateInputShape(X, W));
18+
19+
std::vector<int64_t> kernel_shape = ComputeKernelShape(W->Shape());
20+
21+
if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) {
22+
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.",
23+
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
24+
" W: ", W->Shape().ToString().c_str());
25+
}
26+
27+
for (size_t i = 0; i < kernel_shape.size(); ++i) {
28+
if (kernel_shape[i] != W->Shape()[i + 2]) {
29+
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.",
30+
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
31+
" W: ", W->Shape().ToString().c_str());
32+
}
33+
}
34+
35+
std::vector<int64_t> pads(pads_);
36+
if (pads.empty()) {
37+
pads.resize(kernel_shape.size() * 2, 0);
38+
}
39+
std::vector<int64_t> dilations(dilations_);
40+
if (dilations.empty()) {
41+
dilations.resize(kernel_shape.size(), 1);
42+
}
43+
std::vector<int64_t> strides(strides_);
44+
if (strides.empty()) {
45+
strides.resize(kernel_shape.size(), 1);
46+
}
47+
48+
std::vector<int64_t> Y_dims;
49+
Y_dims.insert(Y_dims.begin(), {N, M});
50+
TensorShape input_shape = X->Shape().Slice(2);
51+
ONNXRUNTIME_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
52+
Tensor* Y = context->Output(0, TensorShape(Y_dims));
53+
TensorShape output_shape = Y->Shape().Slice(2);
54+
55+
AllocatorPtr alloc;
56+
ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
57+
58+
const float* Xdata = X->template Data<float>();
59+
float* Ydata = Y->template MutableData<float>();
60+
61+
const size_t kernel_rank = kernel_shape.size();
62+
63+
if (kernel_rank == 2 || kernel_rank == 3) {
64+
MLAS_CONV_PARAMETERS Parameters;
65+
size_t WorkingBufferSize;
66+
MlasConvPrepare(&Parameters,
67+
kernel_rank,
68+
static_cast<size_t>(N),
69+
static_cast<size_t>(group_),
70+
static_cast<size_t>(C / group_),
71+
input_shape.GetDims().data(),
72+
kernel_shape.data(),
73+
dilations.data(),
74+
pads.data(),
75+
strides.data(),
76+
output_shape.GetDims().data(),
77+
static_cast<size_t>(M / group_),
78+
&WorkingBufferSize);
79+
80+
auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr;
81+
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
82+
83+
MlasConv(&Parameters,
84+
Xdata,
85+
W->template Data<float>(),
86+
B != nullptr ? B->template Data<float>() : nullptr,
87+
static_cast<float*>(working_buffer.get()),
88+
Ydata);
89+
90+
//TODO: this will be replaced with Tracy's changes.
91+
fuse_activation(activation_, Ydata, Y->Shape().Size(), alpha_);
92+
93+
} else {
94+
const int64_t input_image_size = input_shape.Size();
95+
const int64_t output_image_size = output_shape.Size();
96+
const int64_t kernel_size = TensorShape(kernel_shape).Size();
97+
const int64_t X_offset = C / group_ * input_image_size;
98+
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
99+
const int64_t W_offset = W->Shape().Size() / group_;
100+
const int64_t kernel_dim = C / group_ * kernel_size;
101+
const int64_t col_buffer_size = kernel_dim * output_image_size;
102+
103+
auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size);
104+
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
105+
float* col_buffer_data = static_cast<float*>(col_buffer.get());
106+
107+
TensorShape image_shape = X->Shape().Slice(1);
108+
std::vector<int64_t> col_buffer_shape{kernel_dim};
109+
col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(),
110+
output_shape.GetDims().end());
111+
112+
for (int image_id = 0; image_id < N; ++image_id) {
113+
for (int group_id = 0; group_id < group_; ++group_id) {
114+
math::Im2colNd<float, CPUMathUtil, StorageOrder::NCHW>(
115+
Xdata + group_id * X_offset,
116+
image_shape.GetDims().data(),
117+
col_buffer_shape.data(),
118+
C * input_image_size,
119+
col_buffer_size,
120+
kernel_shape.data(),
121+
strides.data(),
122+
dilations.data(),
123+
pads.data(),
124+
static_cast<int>(kernel_shape.size()),
125+
col_buffer_data,
126+
&CPUMathUtil::Instance());
127+
math::Gemm<float, CPUMathUtil>(
128+
CblasNoTrans,
129+
CblasNoTrans,
130+
M / group_,
131+
output_image_size,
132+
kernel_dim,
133+
1,
134+
W->template Data<float>() + group_id * W_offset,
135+
col_buffer_data,
136+
0,
137+
Ydata + group_id * Y_offset,
138+
&CPUMathUtil::Instance());
139+
}
140+
141+
if (B != nullptr) {
142+
auto Ymatrix = EigenMatrixMap<float>(Ydata, output_image_size, M);
143+
auto Bvec = ConstEigenVectorMap<float>(B->template Data<float>(), M);
144+
Ymatrix.rowwise() += Bvec.transpose();
145+
}
146+
147+
fuse_activation(activation_, Ydata, Y_offset * group_, alpha_);
148+
149+
Xdata += X_offset * group_;
150+
Ydata += Y_offset * group_;
151+
}
152+
}
153+
154+
return Status::OK();
155+
}
156+
7157
ONNX_CPU_OPERATOR_KERNEL(
8158
Conv,
9159
1,
10160
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
11161
Conv<float>);
12-
}
162+
} // namespace onnxruntime

0 commit comments

Comments
 (0)