Skip to content

Commit 7d79bfe

Browse files
authored
Move isnan out of contrib_ops and add float16 support for it as per the spec. (#141)
* Move isnan out of contrib_ops and add float16 support for it as per the spec. * Remove isnan from list of broken tests
1 parent 9bf78e1 commit 7d79bfe

File tree

8 files changed

+101
-91
lines changed

8 files changed

+101
-91
lines changed

onnxruntime/contrib_ops/contrib_ops.cc

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,6 @@ Sample echo operator.)DOC");
7878
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema);
7979
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema);
8080

81-
ONNX_CONTRIB_OPERATOR_SCHEMA(IsNaN)
82-
.SetDomain(kMSDomain)
83-
.SinceVersion(1)
84-
.Input(0, "X", "input", "T1")
85-
.Output(0, "Y", "output", "T2")
86-
.TypeConstraint(
87-
"T1",
88-
ONNX_NAMESPACE::OpSchema::numeric_types_for_math_reduction(),
89-
"Constrain to any numeric tensor type. If the dtype attribute is not provided this must be a valid output type.")
90-
.TypeConstraint(
91-
"T2",
92-
{"tensor(bool)"},
93-
"Constrain outputs to boolean tensor")
94-
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
95-
.SetDoc(R"DOC(Returns which elements of the input are NaN.)DOC");
96-
9781
ONNX_CONTRIB_OPERATOR_SCHEMA(Tokenizer)
9882
.SetDomain(kMSDomain)
9983
.SinceVersion(1)
@@ -203,8 +187,8 @@ should be equal to the number of columns of input 'b'.)DOC")
203187
.SetDomain(kMSDomain)
204188
.SinceVersion(1)
205189
.SetDoc(R"DOC(
206-
The convolution operator consumes a quantized input tensor, its scale and zero point,
207-
a quantized filter, its scale and zero point, and output's scale and zero point,
190+
The convolution operator consumes a quantized input tensor, its scale and zero point,
191+
a quantized filter, its scale and zero point, and output's scale and zero point,
208192
and computes the quantized output. Each scale and zero point pair must have same shape.
209193
It means they must be either scalars (per tensor) or 1-D tensors (per channel).)DOC")
210194
.Input(
@@ -522,7 +506,6 @@ The bounding box coordinates corresponding to the selected indices can then be o
522506
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
523507
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
524508
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
525-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, IsNaN);
526509
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
527510
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
528511
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
@@ -538,7 +521,6 @@ void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
538521

539522
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>());
540523
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>());
541-
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, IsNaN)>());
542524
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>());
543525
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>());
544526
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>());

onnxruntime/contrib_ops/cpu/isnan.cc

Lines changed: 0 additions & 37 deletions
This file was deleted.

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, And
7373
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Or);
7474
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Xor);
7575
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Less);
76-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less);
7776
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Greater);
78-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
7977
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, bool, Equal);
8078
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
8179
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
@@ -155,7 +153,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Dro
155153
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Identity);
156154
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ImageScaler);
157155
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization);
158-
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization);
159156
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Pad);
160157
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, Reshape_1);
161158
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape);
@@ -174,7 +171,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
174171
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice);
175172
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice);
176173
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice);
177-
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress);
178174
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth);
179175
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace);
180176
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Split);
@@ -189,10 +185,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sca
189185
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
190186
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If);
191187
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Loop);
192-
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);
193188

194189
// Opset 9
190+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress);
191+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization);
192+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
193+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less);
195194
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantLike);
195+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);
196+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN);
197+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
196198

197199
void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
198200
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Clip)>());
@@ -258,9 +260,7 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
258260
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Or)>());
259261
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Xor)>());
260262
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Less)>());
261-
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less)>());
262263
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Greater)>());
263-
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>());
264264
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, bool, Equal)>());
265265
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal)>());
266266
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal)>());
@@ -340,7 +340,6 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
340340
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Identity)>());
341341
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ImageScaler)>());
342342
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization)>());
343-
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization)>());
344343
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Pad)>());
345344
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, Reshape_1)>());
346345
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape)>());
@@ -359,7 +358,6 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
359358
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice)>());
360359
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice)>());
361360
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice)>());
362-
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress)>());
363361
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth)>());
364362
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace)>());
365363
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Split)>());
@@ -374,10 +372,16 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
374372
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>());
375373
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If)>());
376374
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Loop)>());
377-
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>());
378375

379376
// Opset 9
377+
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress)>());
378+
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization)>());
379+
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>());
380+
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less)>());
380381
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantLike)>());
382+
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>());
383+
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN)>());
384+
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>());
381385
}
382386

383387
// Forward declarations of ml op kernels
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "isnan.h"
5+
#include "core/util/math_cpuonly.h"
6+
#include "core/common/common.h"
7+
#include "core/framework/tensor.h"
8+
#include "Eigen/src/Core/arch/CUDA/Half.h"
9+
10+
namespace onnxruntime {
11+
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#IsNaN
12+
#define ADD_TYPED_ISNAN_OP(data_type) \
13+
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
14+
IsNaN, \
15+
9, \
16+
data_type, \
17+
KernelDefBuilder() \
18+
.TypeConstraint("T1", DataTypeImpl::GetTensorType<data_type>()) \
19+
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()), \
20+
IsNaN<data_type>);
21+
22+
ADD_TYPED_ISNAN_OP(float);
23+
ADD_TYPED_ISNAN_OP(MLFloat16);
24+
25+
template <>
26+
Status IsNaN<float>::Compute(OpKernelContext* context) const {
27+
const Tensor* X_ptr = context->Input<Tensor>(0);
28+
if (!X_ptr) {
29+
return Status(common::ONNXRUNTIME, common::FAIL, "Null input ptr");
30+
}
31+
auto& X = *X_ptr;
32+
auto& dims = X.Shape();
33+
auto& Y = *context->Output(0, dims);
34+
35+
EigenMap<bool>(Y) = EigenMap<float>(X).array().isNaN();
36+
37+
return Status::OK();
38+
}
39+
40+
template <>
41+
Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {
42+
const Tensor* X_ptr = context->Input<Tensor>(0);
43+
if (!X_ptr) {
44+
return Status(common::ONNXRUNTIME, common::FAIL, "Null input ptr");
45+
}
46+
auto X_data = X_ptr->template Data<MLFloat16>();
47+
auto& dims = X_ptr->Shape();
48+
auto shape_size = dims.Size();
49+
auto& Y = *context->Output(0, dims);
50+
51+
EigenMap<bool>(Y) = ConstEigenVectorMap<Eigen::half>(static_cast<const Eigen::half*>(static_cast<const void*>(X_data)), shape_size).array().isNaN();
52+
53+
return Status::OK();
54+
}
55+
} // namespace onnxruntime

onnxruntime/contrib_ops/cpu/isnan.h renamed to onnxruntime/core/providers/cpu/tensor/isnan.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66
#include "core/framework/op_kernel.h"
77

88
namespace onnxruntime {
9-
namespace contrib {
109
template <typename T>
1110
class IsNaN : public OpKernel {
1211
public:
1312
explicit IsNaN(const OpKernelInfo& info) : OpKernel(info) {}
1413
Status Compute(OpKernelContext* context) const override;
1514
};
16-
} // namespace contrib
1715
} // namespace onnxruntime

onnxruntime/test/contrib_ops/isnan_test.cc

Lines changed: 0 additions & 20 deletions
This file was deleted.

onnxruntime/test/onnx/main.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,7 @@ int real_main(int argc, char* argv[]) {
330330
{"sign", "opset 9 not supported yet"},
331331
{"scatter_with_axis", "opset 9 not supported yet"},
332332
{"scatter_without_axis", "opset 9 not supported yet"},
333-
{"scan_sum", "opset 9 not supported yet"},
334-
{"isnan", "opset 9 not supported yet"}};
333+
{"scan_sum", "opset 9 not supported yet"}};
335334

336335
#ifdef USE_CUDA
337336
broken_tests["maxpool_2d_default"] = "cudnn pooling only support input dimension >= 3";
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "gtest/gtest.h"
5+
#include "test/providers/provider_test_utils.h"
6+
#include <cmath> // NAN
7+
#include "core/util/math.h"
8+
9+
namespace onnxruntime {
10+
namespace test {
11+
12+
TEST(IsNaNOpTest, IsNaNFloat) {
13+
OpTester test("IsNaN", 9, kOnnxDomain);
14+
std::vector<int64_t> dims{2, 2};
15+
test.AddInput<float>("X", dims, {1.0f, NAN, 2.0f, NAN});
16+
test.AddOutput<bool>("Y", dims, {false, true, false, true});
17+
test.Run();
18+
}
19+
20+
TEST(IsNaNOpTest, IsNaNFloat16) {
21+
OpTester test("IsNaN", 9, kOnnxDomain);
22+
std::vector<int64_t> dims{2, 2};
23+
test.AddInput<MLFloat16>("X", dims, std::initializer_list<MLFloat16>({MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(NAN)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(NAN))}));
24+
test.AddOutput<bool>("Y", dims, {false, true, false, true});
25+
test.Run();
26+
}
27+
28+
} // namespace test
29+
} // namespace onnxruntime

0 commit comments

Comments
 (0)