Skip to content

Commit bfc15f7

Browse files
committed
[slimtensor] Add all required dtype support (Int8/16/32/64, Bool, BFloat16)
This diff adds support for all required scalar types in SlimTensor to support ExecuTorch aoti-driven backend usage: Int8 (Char), Int16 (Short), Int32 (Int), Int64 (Long), Bool, and BFloat16. **Key changes:** 1. **`c10/core/ScalarType.h`** - Extended with all required types: - Added enum values matching PyTorch's c10::ScalarType for compatibility - Added type alias constants (kChar, kShort, kInt, kLong, kBool, kBFloat16) - Extended `elementSize()` to return correct sizes for all types - Extended `toString()` for all types - Fixed `isFloatingType()` to include BFloat16 - Fixed `isIntegralType()` to properly handle all integral types and Bool - Added `isBoolType()` helper function - Imported BFloat16 from ExecuTorch's portable_type Differential Revision: [D89821402](https://our.internmc.facebook.com/intern/diff/D89821402/) ghstack-source-id: 331195238 Pull Request resolved: #16399
1 parent bb8580d commit bfc15f7

File tree

5 files changed

+538
-43
lines changed

5 files changed

+538
-43
lines changed

backends/aoti/slim/c10/core/ScalarType.h

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,64 @@
1212
#include <cstdint>
1313
#include <ostream>
1414

15+
#include <executorch/runtime/core/portable_type/bfloat16.h>
1516
#include <executorch/runtime/platform/assert.h>
1617

1718
namespace executorch::backends::aoti::slim::c10 {
1819

20+
// Import BFloat16 from ExecuTorch's portable_type
21+
using BFloat16 = ::executorch::runtime::etensor::BFloat16;
22+
1923
/// Enum representing the scalar type (dtype) of tensor elements.
2024
/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility.
2125
enum class ScalarType : int8_t {
22-
// Byte = 0,
23-
// Char = 1,
24-
// Short = 2,
25-
// Int = 3,
26-
// Long = 4,
27-
Float = 6,
28-
// Bool = 11,
29-
// BFloat16 = 15,
26+
// Byte = 0, // uint8_t - not currently needed
27+
Char = 1, // int8_t
28+
Short = 2, // int16_t
29+
Int = 3, // int32_t
30+
Long = 4, // int64_t
31+
// Half = 5, // float16 - not currently needed
32+
Float = 6, // float
33+
// Double = 7, // double - not currently needed
34+
// ComplexHalf = 8,
35+
// ComplexFloat = 9,
36+
// ComplexDouble = 10,
37+
Bool = 11, // bool
38+
// QInt8 = 12,
39+
// QUInt8 = 13,
40+
// QInt32 = 14,
41+
BFloat16 = 15, // bfloat16
3042
Undefined = -1,
31-
NumOptions = 7,
3243
};
3344

34-
/// Constant for Float scalar type.
45+
// Type alias constants for convenience
46+
constexpr ScalarType kChar = ScalarType::Char;
47+
constexpr ScalarType kShort = ScalarType::Short;
48+
constexpr ScalarType kInt = ScalarType::Int;
49+
constexpr ScalarType kLong = ScalarType::Long;
3550
constexpr ScalarType kFloat = ScalarType::Float;
51+
constexpr ScalarType kBool = ScalarType::Bool;
52+
constexpr ScalarType kBFloat16 = ScalarType::BFloat16;
3653

3754
/// Returns the size in bytes of a single element of the given scalar type.
3855
/// @param t The scalar type.
3956
/// @return The size in bytes of a single element.
4057
inline size_t elementSize(ScalarType t) {
4158
switch (t) {
59+
case ScalarType::Char:
60+
return sizeof(int8_t);
61+
case ScalarType::Short:
62+
return sizeof(int16_t);
63+
case ScalarType::Int:
64+
return sizeof(int32_t);
65+
case ScalarType::Long:
66+
return sizeof(int64_t);
4267
case ScalarType::Float:
4368
return sizeof(float);
69+
case ScalarType::Bool:
70+
return sizeof(bool);
71+
case ScalarType::BFloat16:
72+
return sizeof(BFloat16);
4473
default:
4574
ET_CHECK_MSG(false, "Unknown ScalarType: %d", static_cast<int>(t));
4675
}
@@ -51,8 +80,20 @@ inline size_t elementSize(ScalarType t) {
5180
/// @return The name of the scalar type.
5281
inline const char* toString(ScalarType t) {
5382
switch (t) {
83+
case ScalarType::Char:
84+
return "Char";
85+
case ScalarType::Short:
86+
return "Short";
87+
case ScalarType::Int:
88+
return "Int";
89+
case ScalarType::Long:
90+
return "Long";
5491
case ScalarType::Float:
5592
return "Float";
93+
case ScalarType::Bool:
94+
return "Bool";
95+
case ScalarType::BFloat16:
96+
return "BFloat16";
5697
case ScalarType::Undefined:
5798
return "Undefined";
5899
default:
@@ -64,16 +105,32 @@ inline const char* toString(ScalarType t) {
64105
/// @param t The scalar type to check.
65106
/// @return true if the scalar type is floating point, false otherwise.
66107
inline bool isFloatingType(ScalarType t) {
67-
return t == ScalarType::Float;
108+
return t == ScalarType::Float || t == ScalarType::BFloat16;
68109
}
69110

70-
/// Checks if the scalar type is an integral type (including bool).
111+
/// Checks if the scalar type is an integral type (including bool optionally).
71112
/// @param t The scalar type to check.
72113
/// @param includeBool Whether to consider Bool as integral.
73114
/// @return true if the scalar type is integral, false otherwise.
74-
inline bool isIntegralType(ScalarType t, bool /*includeBool*/) {
75-
(void)t;
76-
return false;
115+
inline bool isIntegralType(ScalarType t, bool includeBool) {
116+
switch (t) {
117+
case ScalarType::Char:
118+
case ScalarType::Short:
119+
case ScalarType::Int:
120+
case ScalarType::Long:
121+
return true;
122+
case ScalarType::Bool:
123+
return includeBool;
124+
default:
125+
return false;
126+
}
127+
}
128+
129+
/// Checks if the scalar type is a boolean type.
130+
/// @param t The scalar type to check.
131+
/// @return true if the scalar type is Bool, false otherwise.
132+
inline bool isBoolType(ScalarType t) {
133+
return t == ScalarType::Bool;
77134
}
78135

79136
inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) {

backends/aoti/slim/c10/core/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def define_common_targets():
3636
],
3737
visibility = ["@EXECUTORCH_CLIENTS"],
3838
exported_deps = [
39+
"//executorch/runtime/core/portable_type:portable_type",
3940
"//executorch/runtime/platform:platform",
4041
],
4142
)

backends/aoti/slim/c10/core/test/test_scalar_type.cpp

Lines changed: 165 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,186 @@
1313

1414
using namespace executorch::backends::aoti::slim::c10;
1515

16-
class ScalarTypeTest : public ::testing::Test {};
16+
// =============================================================================
17+
// Test Data Structures for Parameterized Tests
18+
// =============================================================================
1719

18-
TEST_F(ScalarTypeTest, FloatEnumValue) {
19-
// Verify Float has the correct enum value (6) to match PyTorch
20-
EXPECT_EQ(static_cast<int>(ScalarType::Float), 6);
20+
struct ScalarTypeTestData {
21+
ScalarType dtype;
22+
int expected_enum_value;
23+
size_t expected_element_size;
24+
const char* expected_name;
25+
bool is_floating;
26+
bool is_integral;
27+
bool is_integral_with_bool;
28+
bool is_bool;
29+
};
30+
31+
// All supported scalar types with their expected properties
32+
const std::vector<ScalarTypeTestData> kAllScalarTypes = {
33+
// dtype, enum_value, element_size, name, is_float, is_int, is_int_w_bool,
34+
// is_bool
35+
{ScalarType::Char, 1, 1, "Char", false, true, true, false},
36+
{ScalarType::Short, 2, 2, "Short", false, true, true, false},
37+
{ScalarType::Int, 3, 4, "Int", false, true, true, false},
38+
{ScalarType::Long, 4, 8, "Long", false, true, true, false},
39+
{ScalarType::Float, 6, 4, "Float", true, false, false, false},
40+
{ScalarType::Bool, 11, 1, "Bool", false, false, true, true},
41+
{ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false},
42+
};
43+
44+
// =============================================================================
45+
// Parameterized Test Fixture
46+
// =============================================================================
47+
48+
class ScalarTypeParamTest
49+
: public ::testing::TestWithParam<ScalarTypeTestData> {};
50+
51+
TEST_P(ScalarTypeParamTest, EnumValue) {
52+
const auto& data = GetParam();
53+
EXPECT_EQ(static_cast<int>(data.dtype), data.expected_enum_value)
54+
<< "Failed for dtype: " << toString(data.dtype);
55+
}
56+
57+
TEST_P(ScalarTypeParamTest, ElementSize) {
58+
const auto& data = GetParam();
59+
EXPECT_EQ(elementSize(data.dtype), data.expected_element_size)
60+
<< "Failed for dtype: " << toString(data.dtype);
61+
}
62+
63+
TEST_P(ScalarTypeParamTest, ToString) {
64+
const auto& data = GetParam();
65+
EXPECT_STREQ(toString(data.dtype), data.expected_name)
66+
<< "Failed for dtype: " << toString(data.dtype);
67+
}
68+
69+
TEST_P(ScalarTypeParamTest, IsFloatingType) {
70+
const auto& data = GetParam();
71+
EXPECT_EQ(isFloatingType(data.dtype), data.is_floating)
72+
<< "Failed for dtype: " << toString(data.dtype);
73+
}
74+
75+
TEST_P(ScalarTypeParamTest, IsIntegralTypeWithoutBool) {
76+
const auto& data = GetParam();
77+
EXPECT_EQ(isIntegralType(data.dtype, false), data.is_integral)
78+
<< "Failed for dtype: " << toString(data.dtype);
79+
}
80+
81+
TEST_P(ScalarTypeParamTest, IsIntegralTypeWithBool) {
82+
const auto& data = GetParam();
83+
EXPECT_EQ(isIntegralType(data.dtype, true), data.is_integral_with_bool)
84+
<< "Failed for dtype: " << toString(data.dtype);
85+
}
86+
87+
TEST_P(ScalarTypeParamTest, IsBoolType) {
88+
const auto& data = GetParam();
89+
EXPECT_EQ(isBoolType(data.dtype), data.is_bool)
90+
<< "Failed for dtype: " << toString(data.dtype);
91+
}
92+
93+
TEST_P(ScalarTypeParamTest, StreamOperator) {
94+
const auto& data = GetParam();
95+
std::ostringstream oss;
96+
oss << data.dtype;
97+
EXPECT_EQ(oss.str(), data.expected_name)
98+
<< "Failed for dtype: " << toString(data.dtype);
99+
}
100+
101+
INSTANTIATE_TEST_SUITE_P(
102+
AllTypes,
103+
ScalarTypeParamTest,
104+
::testing::ValuesIn(kAllScalarTypes),
105+
[](const ::testing::TestParamInfo<ScalarTypeTestData>& info) {
106+
return std::string(info.param.expected_name);
107+
});
108+
109+
// =============================================================================
110+
// Type Constant Tests
111+
// =============================================================================
112+
113+
class ScalarTypeConstantsTest : public ::testing::Test {};
114+
115+
TEST_F(ScalarTypeConstantsTest, KCharConstant) {
116+
EXPECT_EQ(kChar, ScalarType::Char);
117+
}
118+
119+
TEST_F(ScalarTypeConstantsTest, KShortConstant) {
120+
EXPECT_EQ(kShort, ScalarType::Short);
21121
}
22122

23-
TEST_F(ScalarTypeTest, KFloatConstant) {
24-
// Verify kFloat constant
123+
TEST_F(ScalarTypeConstantsTest, KIntConstant) {
124+
EXPECT_EQ(kInt, ScalarType::Int);
125+
}
126+
127+
TEST_F(ScalarTypeConstantsTest, KLongConstant) {
128+
EXPECT_EQ(kLong, ScalarType::Long);
129+
}
130+
131+
TEST_F(ScalarTypeConstantsTest, KFloatConstant) {
25132
EXPECT_EQ(kFloat, ScalarType::Float);
26133
}
27134

28-
TEST_F(ScalarTypeTest, ElementSizeFloat) {
29-
// Verify elementSize returns correct size for Float (4 bytes)
30-
EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float));
31-
EXPECT_EQ(elementSize(ScalarType::Float), 4);
135+
TEST_F(ScalarTypeConstantsTest, KBoolConstant) {
136+
EXPECT_EQ(kBool, ScalarType::Bool);
32137
}
33138

34-
TEST_F(ScalarTypeTest, ToStringFloat) {
35-
// Verify toString returns correct string for Float
36-
EXPECT_STREQ(toString(ScalarType::Float), "Float");
139+
TEST_F(ScalarTypeConstantsTest, KBFloat16Constant) {
140+
EXPECT_EQ(kBFloat16, ScalarType::BFloat16);
37141
}
38142

39-
TEST_F(ScalarTypeTest, ToStringUndefined) {
40-
// Verify toString returns correct string for Undefined
143+
// =============================================================================
144+
// Edge Cases and Special Values
145+
// =============================================================================
146+
147+
class ScalarTypeEdgeCasesTest : public ::testing::Test {};
148+
149+
TEST_F(ScalarTypeEdgeCasesTest, UndefinedToString) {
41150
EXPECT_STREQ(toString(ScalarType::Undefined), "Undefined");
42151
}
43152

44-
TEST_F(ScalarTypeTest, IsFloatingType) {
45-
// Verify isFloatingType works correctly
46-
EXPECT_TRUE(isFloatingType(ScalarType::Float));
153+
TEST_F(ScalarTypeEdgeCasesTest, UndefinedIsNotFloating) {
154+
EXPECT_FALSE(isFloatingType(ScalarType::Undefined));
47155
}
48156

49-
TEST_F(ScalarTypeTest, IsIntegralType) {
50-
// Verify isIntegralType works correctly
51-
// Currently no integral types are supported, so Float should return false
52-
EXPECT_FALSE(isIntegralType(ScalarType::Float, false));
53-
EXPECT_FALSE(isIntegralType(ScalarType::Float, true));
157+
TEST_F(ScalarTypeEdgeCasesTest, UndefinedIsNotIntegral) {
158+
EXPECT_FALSE(isIntegralType(ScalarType::Undefined, false));
159+
EXPECT_FALSE(isIntegralType(ScalarType::Undefined, true));
54160
}
55161

56-
TEST_F(ScalarTypeTest, StreamOperator) {
57-
// Verify stream operator works
58-
std::ostringstream oss;
59-
oss << ScalarType::Float;
60-
EXPECT_EQ(oss.str(), "Float");
162+
TEST_F(ScalarTypeEdgeCasesTest, UndefinedIsNotBool) {
163+
EXPECT_FALSE(isBoolType(ScalarType::Undefined));
164+
}
165+
166+
// =============================================================================
167+
// Element Size Consistency Tests
168+
// =============================================================================
169+
170+
class ElementSizeConsistencyTest : public ::testing::Test {};
171+
172+
TEST_F(ElementSizeConsistencyTest, CharMatchesSizeofInt8) {
173+
EXPECT_EQ(elementSize(ScalarType::Char), sizeof(int8_t));
174+
}
175+
176+
TEST_F(ElementSizeConsistencyTest, ShortMatchesSizeofInt16) {
177+
EXPECT_EQ(elementSize(ScalarType::Short), sizeof(int16_t));
178+
}
179+
180+
TEST_F(ElementSizeConsistencyTest, IntMatchesSizeofInt32) {
181+
EXPECT_EQ(elementSize(ScalarType::Int), sizeof(int32_t));
182+
}
183+
184+
TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) {
185+
EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t));
186+
}
187+
188+
TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) {
189+
EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float));
190+
}
191+
192+
TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) {
193+
EXPECT_EQ(elementSize(ScalarType::Bool), sizeof(bool));
194+
}
195+
196+
TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) {
197+
EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16));
61198
}

backends/aoti/slim/core/test/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,13 @@ def define_common_targets():
3434
"//executorch/backends/aoti/slim/core:storage",
3535
],
3636
)
37+
38+
runtime.cxx_test(
39+
name = "test_slimtensor_dtypes",
40+
srcs = [
41+
"test_slimtensor_dtypes.cpp",
42+
],
43+
deps = [
44+
"//executorch/backends/aoti/slim/factory:empty",
45+
],
46+
)

0 commit comments

Comments
 (0)