@@ -17,6 +17,7 @@ limitations under the License.
1717#include < vector>
1818
1919#include < gtest/gtest.h>
20+ #include " Eigen/Core"
2021#include " tensorflow/lite/kernels/test_util.h"
2122#include " tensorflow/lite/schema/schema_generated.h"
2223
@@ -25,11 +26,12 @@ namespace {
2526
2627using ::testing::ElementsAreArray;
2728
29+ template <typename T>
2830class RoundOpModel : public SingleOpModel {
2931 public:
30- RoundOpModel (std::initializer_list<int > input_shape, TensorType input_type ) {
31- input_ = AddInput (TensorType_FLOAT32 );
32- output_ = AddOutput (TensorType_FLOAT32 );
32+ RoundOpModel (std::initializer_list<int > input_shape) {
33+ input_ = AddInput (GetTensorType<T>() );
34+ output_ = AddOutput (GetTensorType<T>() );
3335 SetBuiltinOp (BuiltinOperator_ROUND, BuiltinOptions_NONE, 0 );
3436 BuildInterpreter ({
3537 input_shape,
@@ -38,7 +40,7 @@ class RoundOpModel : public SingleOpModel {
3840
3941 int input () { return input_; }
4042
41- std::vector<float > GetOutput () { return ExtractVector<float >(output_); }
43+ std::vector<T > GetOutput () { return ExtractVector<T >(output_); }
4244 std::vector<int > GetOutputShape () { return GetTensorShape (output_); }
4345
4446 private:
@@ -47,15 +49,15 @@ class RoundOpModel : public SingleOpModel {
4749};
4850
4951TEST (RoundOpTest, SingleDim) {
50- RoundOpModel model ({6 }, TensorType_FLOAT32 );
52+ RoundOpModel< float > model ({6 });
5153 model.PopulateTensor <float >(model.input (), {8.5 , 0.0 , 3.5 , 4.2 , -3.5 , -4.5 });
5254 ASSERT_EQ (model.Invoke (), kTfLiteOk );
5355 EXPECT_THAT (model.GetOutput (), ElementsAreArray ({8 , 0 , 4 , 4 , -4 , -4 }));
5456 EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({6 }));
5557}
5658
5759TEST (RoundOpTest, MultiDims) {
58- RoundOpModel model ({2 , 1 , 1 , 6 }, TensorType_FLOAT32 );
60+ RoundOpModel< float > model ({2 , 1 , 1 , 6 });
5961 model.PopulateTensor <float >(
6062 model.input (), {0.0001 , 8.0001 , 0.9999 , 9.9999 , 0.5 , -0.0001 , -8.0001 ,
6163 -0.9999 , -9.9999 , -0.5 , -2.5 , 1.5 });
@@ -65,5 +67,70 @@ TEST(RoundOpTest, MultiDims) {
6567 EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({2 , 1 , 1 , 6 }));
6668}
6769
70+ TEST (RoundOpTest, Float16SingleDim) {
71+ RoundOpModel<Eigen::half> model ({6 });
72+ model.PopulateTensor <Eigen::half>(
73+ model.input (), {Eigen::half (8.5 ), Eigen::half (0.0 ), Eigen::half (3.5 ),
74+ Eigen::half (4.2 ), Eigen::half (-3.5 ), Eigen::half (-4.5 )});
75+ ASSERT_EQ (model.Invoke (), kTfLiteOk );
76+ EXPECT_THAT (
77+ model.GetOutput (),
78+ ElementsAreArray ({Eigen::half (8 ), Eigen::half (0 ), Eigen::half (4 ),
79+ Eigen::half (4 ), Eigen::half (-4 ), Eigen::half (-4 )}));
80+ EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({6 }));
81+ }
82+
83+ TEST (RoundOpTest, Float16MultiDims) {
84+ RoundOpModel<Eigen::half> model ({2 , 1 , 1 , 6 });
85+ model.PopulateTensor <Eigen::half>(
86+ model.input (),
87+ {Eigen::half (0.0001 ), Eigen::half (8.0001 ), Eigen::half (0.9999 ),
88+ Eigen::half (9.9999 ), Eigen::half (0.5 ), Eigen::half (-0.0001 ),
89+ Eigen::half (-8.0001 ), Eigen::half (-0.9999 ), Eigen::half (-9.9999 ),
90+ Eigen::half (-0.5 ), Eigen::half (-2.5 ), Eigen::half (1.5 )});
91+ ASSERT_EQ (model.Invoke (), kTfLiteOk );
92+ EXPECT_THAT (
93+ model.GetOutput (),
94+ ElementsAreArray ({Eigen::half (0 ), Eigen::half (8 ), Eigen::half (1 ),
95+ Eigen::half (10 ), Eigen::half (0 ), Eigen::half (0 ),
96+ Eigen::half (-8 ), Eigen::half (-1 ), Eigen::half (-10 ),
97+ Eigen::half (-0 ), Eigen::half (-2 ), Eigen::half (2 )}));
98+ EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({2 , 1 , 1 , 6 }));
99+ }
100+
101+ TEST (RoundOpTest, BFloat16SingleDim) {
102+ RoundOpModel<Eigen::bfloat16> model ({6 });
103+ model.PopulateTensor <Eigen::bfloat16>(
104+ model.input (),
105+ {Eigen::bfloat16 (8.5 ), Eigen::bfloat16 (0.0 ), Eigen::bfloat16 (3.5 ),
106+ Eigen::bfloat16 (4.2 ), Eigen::bfloat16 (-3.5 ), Eigen::bfloat16 (-4.5 )});
107+ ASSERT_EQ (model.Invoke (), kTfLiteOk );
108+ EXPECT_THAT (model.GetOutput (),
109+ ElementsAreArray ({Eigen::bfloat16 (8 ), Eigen::bfloat16 (0 ),
110+ Eigen::bfloat16 (4 ), Eigen::bfloat16 (4 ),
111+ Eigen::bfloat16 (-4 ), Eigen::bfloat16 (-4 )}));
112+ EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({6 }));
113+ }
114+
115+ TEST (RoundOpTest, BFloat16MultiDims) {
116+ RoundOpModel<Eigen::bfloat16> model ({2 , 1 , 1 , 6 });
117+ model.PopulateTensor <Eigen::bfloat16>(
118+ model.input (),
119+ {Eigen::bfloat16 (0.0001 ), Eigen::bfloat16 (8.0001 ),
120+ Eigen::bfloat16 (0.9999 ), Eigen::bfloat16 (9.9999 ), Eigen::bfloat16 (0.5 ),
121+ Eigen::bfloat16 (-0.0001 ), Eigen::bfloat16 (-8.0001 ),
122+ Eigen::bfloat16 (-0.9999 ), Eigen::bfloat16 (-9.9999 ),
123+ Eigen::bfloat16 (-0.5 ), Eigen::bfloat16 (-2.5 ), Eigen::bfloat16 (1.5 )});
124+ ASSERT_EQ (model.Invoke (), kTfLiteOk );
125+ EXPECT_THAT (
126+ model.GetOutput (),
127+ ElementsAreArray (
128+ {Eigen::bfloat16 (0 ), Eigen::bfloat16 (8 ), Eigen::bfloat16 (1 ),
129+ Eigen::bfloat16 (10 ), Eigen::bfloat16 (0 ), Eigen::bfloat16 (0 ),
130+ Eigen::bfloat16 (-8 ), Eigen::bfloat16 (-1 ), Eigen::bfloat16 (-10 ),
131+ Eigen::bfloat16 (-0 ), Eigen::bfloat16 (-2 ), Eigen::bfloat16 (2 )}));
132+ EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({2 , 1 , 1 , 6 }));
133+ }
134+
68135} // namespace
69136} // namespace tflite
0 commit comments