Skip to content

Commit 93c3267

Browse files
rascanifacebook-github-bot
authored andcommitted
Use nullptr to represent fallback kernels
Summary: `KernelKey`s can be constructed as either a "fallback" or "specialized" key. The "fallback" type uses the default constructor, and the "specialized" takes a specially formatted string. Internally, this was represented as a pointer to the optional key string and a bool for whether it is a fallback kernel. As it would not make sense to construct a "specialized" kernel without a key string, this diff eliminates the bool `is_fallback_` in favor of using `kernel_key_data_ == nullptr` to represent fallback kernels. Each `KernelKey` is nested within the `Kernel` data structure, which makes up the list of registered kernels. As the default size of the registered_kernels array is 2000 kernel entries, this diff can reduce the size of the `registered_kernels` array by 8 KB. This diff also changes the backing storage buffer for `registered_kernels_data` to ensure that there is enough space for each Kernel element in the array to be aligned according to `alignas(Kernel)`. Differential Revision: D76201866
1 parent 8f05c35 commit 93c3267

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

runtime/kernel/operator_registry.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@ constexpr uint32_t kMaxRegisteredKernels = kMaxOperators * kMaxKernelsPerOp;
3535
// require constructing them at init time. Since we don't care about the values
3636
// until we add each entry to the table, allocate static zeroed memory instead
3737
// and point the table at it.
38+
struct alignas(Kernel) KernelBuffer {
39+
uint8_t data[sizeof(Kernel)];
40+
};
41+
3842
// @lint-ignore CLANGTIDY facebook-hte-CArray
39-
alignas(sizeof(Kernel)) uint8_t
40-
registered_kernels_data[kMaxRegisteredKernels * sizeof(Kernel)];
43+
KernelBuffer registered_kernels_data[kMaxRegisteredKernels];
4144

4245
/// Global table of registered kernels.
4346
Kernel* registered_kernels = reinterpret_cast<Kernel*>(registered_kernels_data);

runtime/kernel/operator_registry.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ struct KernelKey {
123123
* for all input tensor dtypes and dim orders if the specialized kernel is not
124124
* registered.
125125
*/
126-
KernelKey() : is_fallback_(true) {}
126+
KernelKey() = default;
127127

128128
/**
129129
* Creates a specialized (non-fallback) kernel key that matches a specific
130130
* set of input tensor dtypes and dim orders. See the class comment for the
131131
* expected format of `kernel_key_data`.
132132
*/
133133
/* implicit */ KernelKey(const char* kernel_key_data)
134-
: kernel_key_data_(kernel_key_data), is_fallback_(false) {}
134+
: kernel_key_data_(kernel_key_data) {}
135135

136136
bool operator==(const KernelKey& other) const {
137137
return this->equals(other);
@@ -142,17 +142,17 @@ struct KernelKey {
142142
}
143143

144144
bool equals(const KernelKey& other) const {
145-
if (is_fallback_ != other.is_fallback_) {
145+
if (is_fallback() != other.is_fallback()) {
146146
return false;
147147
}
148-
if (is_fallback_) {
148+
if (is_fallback()) {
149149
return true;
150150
}
151151
return strcmp(kernel_key_data_, other.kernel_key_data_) == 0;
152152
}
153153

154154
bool is_fallback() const {
155-
return is_fallback_;
155+
return kernel_key_data_ == nullptr;;
156156
}
157157

158158
const char* data() const {
@@ -168,7 +168,6 @@ struct KernelKey {
168168

169169
private:
170170
const char* kernel_key_data_ = nullptr;
171-
bool is_fallback_;
172171
};
173172

174173
/**

0 commit comments

Comments
 (0)