diff --git a/runtime/kernel/operator_registry.cpp b/runtime/kernel/operator_registry.cpp index d7e7b298c10..d5c9a982d6d 100644 --- a/runtime/kernel/operator_registry.cpp +++ b/runtime/kernel/operator_registry.cpp @@ -35,9 +35,12 @@ constexpr uint32_t kMaxRegisteredKernels = kMaxOperators * kMaxKernelsPerOp; // require constructing them at init time. Since we don't care about the values // until we add each entry to the table, allocate static zeroed memory instead // and point the table at it. +struct alignas(Kernel) KernelBuffer { + uint8_t data[sizeof(Kernel)]; +}; + // @lint-ignore CLANGTIDY facebook-hte-CArray -alignas(sizeof(Kernel)) uint8_t - registered_kernels_data[kMaxRegisteredKernels * sizeof(Kernel)]; +KernelBuffer registered_kernels_data[kMaxRegisteredKernels]; /// Global table of registered kernels. Kernel* registered_kernels = reinterpret_cast(registered_kernels_data); diff --git a/runtime/kernel/operator_registry.h b/runtime/kernel/operator_registry.h index f7a62208dd8..9bd6318676c 100644 --- a/runtime/kernel/operator_registry.h +++ b/runtime/kernel/operator_registry.h @@ -123,7 +123,7 @@ struct KernelKey { * for all input tensor dtypes and dim orders if the specialized kernel is not * registered. */ - KernelKey() : is_fallback_(true) {} + KernelKey() = default; /** * Creates a specialized (non-fallback) kernel key that matches a specific @@ -131,7 +131,7 @@ struct KernelKey { * expected format of `kernel_key_data`. */ /* implicit */ KernelKey(const char* kernel_key_data) - : kernel_key_data_(kernel_key_data), is_fallback_(false) {} + : kernel_key_data_(kernel_key_data) {} bool operator==(const KernelKey& other) const { return this->equals(other); @@ -142,17 +142,17 @@ struct KernelKey { } bool equals(const KernelKey& other) const { - if (is_fallback_ != other.is_fallback_) { + if (is_fallback() != other.is_fallback()) { return false; } - if (is_fallback_) { + if (is_fallback()) { return true; } return strcmp(kernel_key_data_, other.kernel_key_data_) == 0; } bool is_fallback() const { - return is_fallback_; + return kernel_key_data_ == nullptr; } const char* data() const { @@ -168,7 +168,6 @@ struct KernelKey { private: const char* kernel_key_data_ = nullptr; - bool is_fallback_; }; /**