Skip to content

Commit

Permalink
Revert fuse conv fix err (#6859)
Browse files Browse the repository at this point in the history
* merge fuse cuda conv revert

* resolve merge conflict revert exclude unsupported type

* add Stream for slicing

* remove file

* add Stream

Co-authored-by: RandySheriffH <randysheriff@hotmail.com>
  • Loading branch information
oliviajain and RandySheriffH authored Mar 2, 2021
1 parent 29b30bb commit 40b0929
Show file tree
Hide file tree
Showing 12 changed files with 299 additions and 720 deletions.
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv);

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
Expand Down Expand Up @@ -175,7 +174,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
126 changes: 0 additions & 126 deletions onnxruntime/contrib_ops/cuda/fused_conv.cc

This file was deleted.

6 changes: 0 additions & 6 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1273,12 +1273,6 @@ activation.)DOC")
"",
"T",
OpSchema::Optional)
.Input(
3,
"Z",
"",
"T",
OpSchema::Optional)
.Output(
0,
"Y",
Expand Down
124 changes: 34 additions & 90 deletions onnxruntime/core/optimizer/conv_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,106 +100,50 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
continue;
}

if (node->GetExecutionProviderType() == onnxruntime::kCudaExecutionProvider) {
if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() !=
ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
continue;
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13})) {
Node& conv_node = *node;
Node& act_node = *graph.GetNode(next_node.Index());
auto node_name = graph.GenerateNodeName(conv_node.Name() + "_" + act_node.Name());
Node& fused_conv = graph.AddNode(node_name,
"FusedConv",
node_name,
conv_node.MutableInputDefs(),
{},
&conv_node.GetAttributes(),
onnxruntime::kMSDomain);
fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType());
fused_conv.AddAttribute("activation", "Relu");
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);
modified = true;
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {6, 7, 13})) {
const auto& last_node = *(next_node.OutputNodesBegin());
if (last_node.GetExecutionProviderType() != node->GetExecutionProviderType()) {
continue;
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13}) &&
next_node.GetOutputEdgesCount() == 1) {
Node& conv_node = *node;
Node& add_node = *graph.GetNode(next_node.Index());
Node& act_node = *graph.GetNode(last_node.Index());
auto conv_inputs = conv_node.MutableInputDefs();
auto conv_outputs = conv_node.MutableOutputDefs();
auto add_inputs = add_node.MutableInputDefs();
for (auto add_input : add_inputs) {
if (add_input->Name() != conv_outputs[0]->Name()) {
conv_inputs.push_back(add_input);
break;
}
}
auto node_name = graph.GenerateNodeName(conv_node.Name() + "_" +
add_node.Name() + "_" +
act_node.Name());
Node& fused_conv = graph.AddNode(node_name,
"FusedConv",
node_name,
conv_inputs,
{}, &conv_node.GetAttributes(),
onnxruntime::kMSDomain);
fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType());
fused_conv.AddAttribute("activation", "Relu");
graph_utils::FinalizeNodeFusion(graph, {conv_node, add_node, act_node}, fused_conv);
modified = true;
}
}
} else {
// Test if this is an activation that can be fused and also extract the
// activation's parameters.
std::vector<float> activation_params;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) {
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f());
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) {
float min, max;
if (GetClipConstantMinMax(graph, next_node, min, max)) {
activation_params.push_back(min);
activation_params.push_back(max);
} else {
continue;
}
// Test if this is an activation that can be fused and also extract the
// activation's parameters.
std::vector<float> activation_params;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) {
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f());
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) {
float min, max;
if (GetClipConstantMinMax(graph, next_node, min, max)) {
activation_params.push_back(min);
activation_params.push_back(max);
} else {
continue;
}
} else {
continue;
}
}

Node& conv_node = *node;
Node& act_node = *graph.GetNode(next_node.Index());
Node& conv_node = *node;
Node& act_node = *graph.GetNode(next_node.Index());

Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node.Name()), "FusedConv",
"fused Conv " + conv_node.Name() + "with activation " + act_node.OpType(),
conv_node.MutableInputDefs(),
{},
&conv_node.GetAttributes(),
"com.microsoft");
Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node.Name()), "FusedConv",
"fused Conv " + conv_node.Name() + "with activation " + act_node.OpType(),
conv_node.MutableInputDefs(),
{},
&conv_node.GetAttributes(),
"com.microsoft");

// Assign provider to this new node. Provider should be same as the provider for old node.
fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType());
// Assign provider to this new node. Provider should be same as the provider for old node.
fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType());

// Add attributes to specify the activation type and parameters.
fused_conv.AddAttribute("activation", next_node.OpType());
if (activation_params.size() > 0) {
fused_conv.AddAttribute("activation_params", activation_params);
}
// Add attributes to specify the activation type and parameters.
fused_conv.AddAttribute("activation", next_node.OpType());
if (activation_params.size() > 0) {
fused_conv.AddAttribute("activation_params", activation_params);
}

// move output definitions and edges from act_node to fused_conv. delete conv_node and act_node.
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);
// move output definitions and edges from act_node to fused_conv. delete conv_node and act_node.
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);

modified = true;
}
modified = true;
}

return Status::OK();
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
transformers.emplace_back(onnxruntime::make_unique<DynamicQuantizeMatMulFusion>(cpu_execution_providers));

std::unordered_set<std::string> cpu_acl_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider};
std::unordered_set<std::string> cpu_cuda_acl_armnn_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider};
std::unordered_set<std::string> cpu_acl_armnn_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider};

transformers.emplace_back(onnxruntime::make_unique<ConvActivationFusion>(cpu_cuda_acl_armnn_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<ConvActivationFusion>(cpu_acl_armnn_execution_providers));

std::unordered_set<std::string> cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider};
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(cpu_cuda_execution_providers));
Expand Down
Loading

0 comments on commit 40b0929

Please sign in to comment.