Skip to content

Commit b054646

Browse files
authored
Askhade/implement erf (#137)
* erf implementation for op9 * enable erf node tests + review comment fixes * update CMAKE flag * plus erf to execution provider
1 parent 7d79bfe commit b054646

File tree

6 files changed

+38
-1
lines changed

6 files changed

+38
-1
lines changed

cmake/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ set_target_properties(onnx_proto PROPERTIES FOLDER "External/ONNX")
285285
# fix a warning in onnx code we can't do anything about
286286
if (MSVC)
287287
target_compile_options(onnx_proto PRIVATE /wd4146) # unary minus operator applied to unsigned type
288+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DEIGEN_HAS_C99_MATH") # required to be set explicitly to enable Eigen-Unsupported SpecialFunctions
288289
endif()
289290
set(onnxruntime_EXTERNAL_DEPENDENCIES gsl onnx_proto)
290291

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Con
195195
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);
196196
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN);
197197
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
198+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf);
198199

199200
void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
200201
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Clip)>());
@@ -382,6 +383,7 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
382383
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>());
383384
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN)>());
384385
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>());
386+
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf)>());
385387
}
386388

387389
// Forward declarations of ml op kernels

onnxruntime/core/providers/cpu/math/element_wise_ops.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "core/providers/cpu/math/element_wise_ops.h"
5+
#include <unsupported/Eigen/SpecialFunctions>
56

67
namespace onnxruntime {
78

@@ -311,6 +312,12 @@ ONNX_CPU_OPERATOR_KERNEL(
311312
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
312313
Scale<float>);
313314

315+
ONNX_CPU_OPERATOR_KERNEL(
316+
Erf,
317+
9,
318+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
319+
Erf<float>);
320+
314321
template <typename T>
315322
Status Add<T>::Compute(OpKernelContext* context) const {
316323
return BroadcastTwo<T, T>(
@@ -874,4 +881,16 @@ Status Scale<float>::Compute(OpKernelContext* ctx) const {
874881
return Status::OK();
875882
}
876883

884+
template <>
885+
Status Erf<float>::Compute(OpKernelContext* context) const {
886+
auto X_ptr = context->Input<Tensor>(0);
887+
ONNXRUNTIME_ENFORCE(X_ptr != nullptr);
888+
auto& X = *X_ptr;
889+
auto& Y = *context->Output(0, X.Shape());
890+
891+
EigenMap<float>(Y) = EigenMap<float>(X).array().erf();
892+
893+
return Status::OK();
894+
}
895+
877896
} // namespace onnxruntime

onnxruntime/core/providers/cpu/math/element_wise_ops.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,15 @@ class Scale final : public OpKernel {
317317
float scale_;
318318
};
319319

320+
template <typename T>
321+
class Erf final : public OpKernel {
322+
public:
323+
Erf(const OpKernelInfo& info) : OpKernel(info) {
324+
}
325+
326+
Status Compute(OpKernelContext* context) const override;
327+
};
328+
320329
template <typename T>
321330
auto MakeEigenArrayMap(Tensor& t) { return EigenVectorArrayMap<T>(t.template MutableData<T>(), t.Shape().Size()); }
322331
template <typename T>

onnxruntime/test/onnx/main.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ int real_main(int argc, char* argv[]) {
326326
{"acosh_example", "opset 9 not supported yet"},
327327
{"atanh_example", "opset 9 not supported yet"},
328328
{"sign_model", "opset 9 not supported yet"},
329-
{"erf", "opset 9 not supported yet"},
330329
{"sign", "opset 9 not supported yet"},
331330
{"scatter_with_axis", "opset 9 not supported yet"},
332331
{"scatter_without_axis", "opset 9 not supported yet"},

onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,13 @@ TEST(MathOpTest, Scale_Default) {
869869
test.Run();
870870
}
871871

872+
TEST(MathOpTest, Erf) {
873+
OpTester test("Erf", 9);
874+
std::vector<int64_t> dims{2, 2};
875+
test.AddInput<float>("A", dims, {0.5f, 1.0f, 0.7f, 2.0f});
876+
test.AddOutput<float>("B", dims, {0.5204999f, 0.8427008f, 0.6778012f, 0.9953223f});
877+
test.Run();
878+
}
872879
} // namespace test
873880

874881
} // namespace onnxruntime

0 commit comments

Comments
 (0)