Skip to content

Commit 458c6a4

Browse files
TfLite Round missing datatype support (#32) (#73)
* TfLite Round missing datatype support -Adds bf16, f16 support for round -Adds bf16, f16 round unit tests
1 parent e0aaea1 commit 458c6a4

File tree

4 files changed

+110
-13
lines changed

4 files changed

+110
-13
lines changed

tensorflow/lite/kernels/internal/reference/round.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,16 @@ inline float RoundToNearest(float value) {
3434
}
3535
}
3636

37-
inline void Round(const RuntimeShape& input_shape, const float* input_data,
38-
const RuntimeShape& output_shape, float* output_data) {
37+
template <typename Scalar>
38+
inline void Round(const RuntimeShape& input_shape, const Scalar* input_data,
39+
const RuntimeShape& output_shape, Scalar* output_data) {
3940
const int flat_size = MatchingFlatSize(input_shape, output_shape);
4041
for (int i = 0; i < flat_size; ++i) {
4142
// Note that this implementation matches that of tensorFlow tf.round
4243
// and corresponds to the bankers rounding method.
4344
// cfenv (for fesetround) is not yet supported universally on Android, so
4445
// using a work around.
45-
output_data[i] = RoundToNearest(input_data[i]);
46+
output_data[i] = static_cast<Scalar>(RoundToNearest(input_data[i]));
4647
}
4748
}
4849

tensorflow/lite/kernels/round.cc

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "tensorflow/lite/kernels/internal/reference/round.h"
1717

18+
#include "Eigen/Core"
1819
#include "tensorflow/lite/core/c/common.h"
1920
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
2021
#include "tensorflow/lite/kernels/internal/tensor.h"
@@ -37,7 +38,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
3738
GetOutputSafe(context, node, kOutputTensor, &output));
3839
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
3940
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
40-
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
41+
if (input->type != kTfLiteFloat32 && input->type != kTfLiteFloat16 &&
42+
input->type != kTfLiteBFloat16) {
43+
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by round.",
44+
TfLiteTypeGetName(input->type));
45+
return kTfLiteError;
46+
}
4147
output->type = input->type;
4248
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
4349
return context->ResizeTensor(context, output, output_size);
@@ -49,9 +55,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
4955
TfLiteTensor* output;
5056
TF_LITE_ENSURE_OK(context,
5157
GetOutputSafe(context, node, kOutputTensor, &output));
52-
53-
optimized_ops::Round(GetTensorShape(input), GetTensorData<float>(input),
54-
GetTensorShape(output), GetTensorData<float>(output));
58+
switch (output->type) {
59+
case kTfLiteFloat32: {
60+
optimized_ops::Round<float>(
61+
GetTensorShape(input), GetTensorData<float>(input),
62+
GetTensorShape(output), GetTensorData<float>(output));
63+
break;
64+
}
65+
case kTfLiteFloat16: {
66+
optimized_ops::Round<Eigen::half>(
67+
GetTensorShape(input), GetTensorData<Eigen::half>(input),
68+
GetTensorShape(output), GetTensorData<Eigen::half>(output));
69+
break;
70+
}
71+
case kTfLiteBFloat16: {
72+
optimized_ops::Round<Eigen::bfloat16>(
73+
GetTensorShape(input), GetTensorData<Eigen::bfloat16>(input),
74+
GetTensorShape(output), GetTensorData<Eigen::bfloat16>(output));
75+
break;
76+
}
77+
default: {
78+
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by round.",
79+
TfLiteTypeGetName(output->type));
80+
return kTfLiteError;
81+
}
82+
}
5583

5684
return kTfLiteOk;
5785
}

tensorflow/lite/kernels/round_test.cc

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#include <vector>
1818

1919
#include <gtest/gtest.h>
20+
#include "Eigen/Core"
2021
#include "tensorflow/lite/kernels/test_util.h"
2122
#include "tensorflow/lite/schema/schema_generated.h"
2223

@@ -25,11 +26,12 @@ namespace {
2526

2627
using ::testing::ElementsAreArray;
2728

29+
template <typename T>
2830
class RoundOpModel : public SingleOpModel {
2931
public:
30-
RoundOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
31-
input_ = AddInput(TensorType_FLOAT32);
32-
output_ = AddOutput(TensorType_FLOAT32);
32+
RoundOpModel(std::initializer_list<int> input_shape) {
33+
input_ = AddInput(GetTensorType<T>());
34+
output_ = AddOutput(GetTensorType<T>());
3335
SetBuiltinOp(BuiltinOperator_ROUND, BuiltinOptions_NONE, 0);
3436
BuildInterpreter({
3537
input_shape,
@@ -38,7 +40,7 @@ class RoundOpModel : public SingleOpModel {
3840

3941
int input() { return input_; }
4042

41-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
43+
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
4244
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
4345

4446
private:
@@ -47,15 +49,15 @@ class RoundOpModel : public SingleOpModel {
4749
};
4850

4951
TEST(RoundOpTest, SingleDim) {
50-
RoundOpModel model({6}, TensorType_FLOAT32);
52+
RoundOpModel<float> model({6});
5153
model.PopulateTensor<float>(model.input(), {8.5, 0.0, 3.5, 4.2, -3.5, -4.5});
5254
ASSERT_EQ(model.Invoke(), kTfLiteOk);
5355
EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0, 4, 4, -4, -4}));
5456
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({6}));
5557
}
5658

5759
TEST(RoundOpTest, MultiDims) {
58-
RoundOpModel model({2, 1, 1, 6}, TensorType_FLOAT32);
60+
RoundOpModel<float> model({2, 1, 1, 6});
5961
model.PopulateTensor<float>(
6062
model.input(), {0.0001, 8.0001, 0.9999, 9.9999, 0.5, -0.0001, -8.0001,
6163
-0.9999, -9.9999, -0.5, -2.5, 1.5});
@@ -65,5 +67,70 @@ TEST(RoundOpTest, MultiDims) {
6567
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 6}));
6668
}
6769

70+
TEST(RoundOpTest, Float16SingleDim) {
71+
RoundOpModel<Eigen::half> model({6});
72+
model.PopulateTensor<Eigen::half>(
73+
model.input(), {Eigen::half(8.5), Eigen::half(0.0), Eigen::half(3.5),
74+
Eigen::half(4.2), Eigen::half(-3.5), Eigen::half(-4.5)});
75+
ASSERT_EQ(model.Invoke(), kTfLiteOk);
76+
EXPECT_THAT(
77+
model.GetOutput(),
78+
ElementsAreArray({Eigen::half(8), Eigen::half(0), Eigen::half(4),
79+
Eigen::half(4), Eigen::half(-4), Eigen::half(-4)}));
80+
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({6}));
81+
}
82+
83+
TEST(RoundOpTest, Float16MultiDims) {
84+
RoundOpModel<Eigen::half> model({2, 1, 1, 6});
85+
model.PopulateTensor<Eigen::half>(
86+
model.input(),
87+
{Eigen::half(0.0001), Eigen::half(8.0001), Eigen::half(0.9999),
88+
Eigen::half(9.9999), Eigen::half(0.5), Eigen::half(-0.0001),
89+
Eigen::half(-8.0001), Eigen::half(-0.9999), Eigen::half(-9.9999),
90+
Eigen::half(-0.5), Eigen::half(-2.5), Eigen::half(1.5)});
91+
ASSERT_EQ(model.Invoke(), kTfLiteOk);
92+
EXPECT_THAT(
93+
model.GetOutput(),
94+
ElementsAreArray({Eigen::half(0), Eigen::half(8), Eigen::half(1),
95+
Eigen::half(10), Eigen::half(0), Eigen::half(0),
96+
Eigen::half(-8), Eigen::half(-1), Eigen::half(-10),
97+
Eigen::half(-0), Eigen::half(-2), Eigen::half(2)}));
98+
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 6}));
99+
}
100+
101+
TEST(RoundOpTest, BFloat16SingleDim) {
102+
RoundOpModel<Eigen::bfloat16> model({6});
103+
model.PopulateTensor<Eigen::bfloat16>(
104+
model.input(),
105+
{Eigen::bfloat16(8.5), Eigen::bfloat16(0.0), Eigen::bfloat16(3.5),
106+
Eigen::bfloat16(4.2), Eigen::bfloat16(-3.5), Eigen::bfloat16(-4.5)});
107+
ASSERT_EQ(model.Invoke(), kTfLiteOk);
108+
EXPECT_THAT(model.GetOutput(),
109+
ElementsAreArray({Eigen::bfloat16(8), Eigen::bfloat16(0),
110+
Eigen::bfloat16(4), Eigen::bfloat16(4),
111+
Eigen::bfloat16(-4), Eigen::bfloat16(-4)}));
112+
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({6}));
113+
}
114+
115+
TEST(RoundOpTest, BFloat16MultiDims) {
116+
RoundOpModel<Eigen::bfloat16> model({2, 1, 1, 6});
117+
model.PopulateTensor<Eigen::bfloat16>(
118+
model.input(),
119+
{Eigen::bfloat16(0.0001), Eigen::bfloat16(8.0001),
120+
Eigen::bfloat16(0.9999), Eigen::bfloat16(9.9999), Eigen::bfloat16(0.5),
121+
Eigen::bfloat16(-0.0001), Eigen::bfloat16(-8.0001),
122+
Eigen::bfloat16(-0.9999), Eigen::bfloat16(-9.9999),
123+
Eigen::bfloat16(-0.5), Eigen::bfloat16(-2.5), Eigen::bfloat16(1.5)});
124+
ASSERT_EQ(model.Invoke(), kTfLiteOk);
125+
EXPECT_THAT(
126+
model.GetOutput(),
127+
ElementsAreArray(
128+
{Eigen::bfloat16(0), Eigen::bfloat16(8), Eigen::bfloat16(1),
129+
Eigen::bfloat16(10), Eigen::bfloat16(0), Eigen::bfloat16(0),
130+
Eigen::bfloat16(-8), Eigen::bfloat16(-1), Eigen::bfloat16(-10),
131+
Eigen::bfloat16(-0), Eigen::bfloat16(-2), Eigen::bfloat16(2)}));
132+
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 6}));
133+
}
134+
68135
} // namespace
69136
} // namespace tflite

tensorflow/lite/kernels/test_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,7 @@ TFLITE_TENSOR_TYPE_ASSOC(Eigen::bfloat16, TensorType_BFLOAT16);
11721172
TFLITE_TENSOR_TYPE_ASSOC(float, TensorType_FLOAT32);
11731173
TFLITE_TENSOR_TYPE_ASSOC(double, TensorType_FLOAT64);
11741174
TFLITE_TENSOR_TYPE_ASSOC(std::string, TensorType_STRING);
1175+
TFLITE_TENSOR_TYPE_ASSOC(Eigen::bfloat16, TensorType_BFLOAT16);
11751176

11761177
#undef TFLITE_TENSOR_TYPE_ASSOC
11771178

0 commit comments

Comments
 (0)