-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
136 additions
and
0 deletions.
There are no files selected for viewing
58 changes: 58 additions & 0 deletions
58
src/frontends/onnx/frontend/src/op/com.microsoft/quick_gelu.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#include "core/operator_set.hpp" | ||
#include "exceptions.hpp" | ||
#include "openvino/frontend/exception.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/sigmoid.hpp" | ||
|
||
|
||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace onnx { | ||
namespace com_microsoft { | ||
namespace opset_1 { | ||
ov::OutputVector quickgelu(const ov::frontend::onnx::Node& node) { | ||
// Original Documentation: | ||
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QuickGelu | ||
// Goal: Compute x * Sigmoid(alpha * x) | ||
|
||
const auto inputs = node.get_ov_inputs(); | ||
|
||
// Only one input (x) so give a check | ||
auto num_inputs = inputs.size(); | ||
FRONT_END_GENERAL_CHECK(num_inputs == 1, | ||
"QuickGelu takes only 1 input but was provided " + std::to_string(num_inputs)); | ||
const auto& x = inputs[0]; | ||
|
||
// Constrain input type to float16, float, double (f64), bfloat16 | ||
auto element_type = x.get_element_type(); | ||
CHECK_VALID_NODE(node, | ||
element_type == ov::element::f16 || element_type == ov::element::f32 || | ||
element_type == ov::element::f64 || element_type == ov::element::bf16, | ||
"Unsupported input x type, accepted FP16, FP32, FP64, BFP16 but got: ", | ||
element_type); | ||
|
||
// Get attribute from node | ||
const float alpha = node.get_attribute_value<float>("alpha"); | ||
|
||
// Numpy broadcasting rule is automatically applied with mismatched shapes according to: | ||
// https://docs.openvino.ai/2022.3/openvino_docs_ops_arithmetic_Multiply_1.html "Tensor with dimension of size 1 | ||
// will be implicitly broadcasted to match the size of the second tensor." Convert alpha to tensor with size 1 | ||
const auto alpha_tensor = std::make_shared<v0::Constant>(ov::element::f32, Shape{1}, alpha); | ||
|
||
auto alpha_x = std::make_shared<v1::Multiply>(alpha_tensor, x); | ||
auto sig_alpha_x = std::make_shared<v0::Sigmoid>(alpha_x); | ||
auto result = std::make_shared<v1::Multiply>(x, sig_alpha_x); | ||
|
||
return {result}; | ||
} // func end | ||
|
||
ONNX_OP("QuickGelu", OPSET_SINCE(1), com_microsoft::opset_1::quickgelu, MICROSOFT_DOMAIN); | ||
|
||
} // namespace opset_1 | ||
} // namespace com_microsoft | ||
} // namespace onnx | ||
} // namespace frontend | ||
} // namespace ov |
52 changes: 52 additions & 0 deletions
52
src/frontends/onnx/tests/models/com.microsoft/quick_gelu.prototxt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
ir_version: 3 | ||
producer_name: "OpenVINO ONNX Frontend" | ||
graph { | ||
name: "test_quick_gelu" | ||
node { | ||
input: "X" | ||
output: "Y" | ||
op_type: "QuickGelu" | ||
attribute { | ||
name: "alpha" | ||
f: 0.9974269270896912 | ||
type: FLOAT | ||
} | ||
domain: "com.microsoft" | ||
} | ||
input { | ||
name: "X" | ||
type { | ||
tensor_type { | ||
elem_type: 1 | ||
shape { | ||
dim { | ||
dim_value: 2 | ||
} | ||
dim { | ||
dim_value: 5 | ||
} | ||
} | ||
} | ||
} | ||
} | ||
output { | ||
name: "Y" | ||
type { | ||
tensor_type { | ||
elem_type: 1 | ||
shape { | ||
dim { | ||
dim_value: 2 | ||
} | ||
dim { | ||
dim_value: 5 | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
opset_import { | ||
domain: "com.microsoft" | ||
version: 1 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters