-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT/DO NOT REVIEW] Add Rotary Embedding from ONNX Opset 23 #23507
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h" | ||
#include "contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h" | ||
|
||
#include "core/mlas/inc/mlas.h" | ||
#include "core/platform/threadpool.h" | ||
|
||
using onnxruntime::concurrency::ThreadPool; | ||
using namespace onnxruntime::contrib::rotary_embedding_onnx_helper; | ||
Check warning on line 11 in onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc
|
||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
// These ops are internal-only, so register outside of onnx | ||
#define REGISTER_KERNEL_TYPED(T) \ | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( \ | ||
RotaryEmbeddingONNX, \ | ||
kMSDomain, \ | ||
1, \ | ||
T, \ | ||
kCpuExecutionProvider, \ | ||
KernelDefBuilder() \ | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \ | ||
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), \ | ||
RotaryEmbeddingONNX<T>); | ||
|
||
REGISTER_KERNEL_TYPED(float) | ||
REGISTER_KERNEL_TYPED(MLFloat16) | ||
|
||
template <typename T> | ||
RotaryEmbeddingONNX<T>::RotaryEmbeddingONNX(const OpKernelInfo& info) : OpKernel(info) { | ||
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); | ||
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0)); | ||
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0)); | ||
|
||
if (rotary_embedding_dim > 0) { | ||
ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); | ||
} | ||
} | ||
|
||
// TODO: rotary embedding in place | ||
Check warning on line 43 in onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc
|
||
template <typename T> | ||
Status RunRotaryEmbeddingONNX(concurrency::ThreadPool* tp, RotaryParameters parameters, const T* input, | ||
const T* cos_cache, const T* sin_cache, const int64_t* position_ids, T* output, | ||
bool interleaved) { | ||
const int batch_size = parameters.batch_size; | ||
const int sequence_length = parameters.sequence_length; | ||
const int n_heads = parameters.num_heads; | ||
const int head_size = parameters.head_size; | ||
const int head_stride = parameters.head_stride; | ||
const int seq_stride = parameters.seq_stride; | ||
const int batch_stride = parameters.batch_stride; | ||
const int position_ids_format = parameters.position_ids_format; | ||
const int rotary_emb_dim = parameters.rotary_embedding_dim; | ||
const int half_rotary_emb_dim = rotary_emb_dim / 2; | ||
|
||
const int loop_len = batch_size * sequence_length * n_heads; | ||
const double cost = static_cast<double>(rotary_emb_dim); | ||
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { | ||
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { | ||
const int b = static_cast<int>((ptr / n_heads) / sequence_length); | ||
const int s = static_cast<int>((ptr / n_heads) % sequence_length); | ||
const int n = static_cast<int>(ptr % n_heads); | ||
|
||
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; | ||
|
||
const T* input_data = input + block_offset; | ||
T* output_data = output + block_offset; | ||
|
||
// Cache is (M, H/2) or (M, rotary_embedding_dim/2) | ||
const int position_id = (position_ids_format == 0) | ||
? static_cast<int>(position_ids[0]) + s | ||
: static_cast<int>(position_ids[b * sequence_length + s]); | ||
const int cache_offset = position_id * half_rotary_emb_dim; | ||
const T* cos_data = cos_cache + cache_offset; | ||
const T* sin_data = sin_cache + cache_offset; | ||
|
||
MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data); | ||
|
||
if (rotary_emb_dim < head_size) { | ||
std::memcpy(output_data + rotary_emb_dim, | ||
input_data + rotary_emb_dim, | ||
(head_size - rotary_emb_dim) * sizeof(T)); | ||
Check warning Code scanning / PREfast Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).
Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).
Check warning Code scanning / PREfast Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).
Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).
|
||
} | ||
} | ||
}); | ||
|
||
return Status::OK(); | ||
} | ||
|
||
template Status RunRotaryEmbeddingONNX<float>(concurrency::ThreadPool* tp, RotaryParameters parameters, const float* input, | ||
const float* cos_cache, const float* sin_cache, const int64_t* position_ids, float* output, | ||
bool interleaved); | ||
|
||
template Status RunRotaryEmbeddingONNX<MLFloat16>(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input, | ||
const MLFloat16* cos_cache, const MLFloat16* sin_cache, const int64_t* position_ids, | ||
MLFloat16* output, bool interleaved); | ||
|
||
template <typename T> | ||
Status RotaryEmbeddingONNX<T>::Compute(OpKernelContext* context) const { | ||
const Tensor* input = context->Input<Tensor>(0); | ||
const Tensor* cos_cache = context->Input<Tensor>(1); | ||
const Tensor* sin_cache = context->Input<Tensor>(2); | ||
const Tensor* position_ids = context->Input<Tensor>(3); | ||
|
||
RotaryParameters parameters = {}; | ||
ORT_RETURN_IF_ERROR(rotary_embedding_onnx_helper::CheckInputs<Tensor>(input, | ||
cos_cache, | ||
sin_cache, | ||
position_ids, | ||
num_heads, | ||
rotary_embedding_dim, | ||
¶meters)); | ||
|
||
Tensor* output = context->Output(0, input->Shape()); | ||
|
||
if (parameters.sequence_length > parameters.max_sequence_length) { | ||
// Launch update_cos_sin_cache kernel with scale | ||
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); | ||
} | ||
|
||
const T* input_src = input->Data<T>(); | ||
const T* cos_cache_data = cos_cache->Data<T>(); | ||
const T* sin_cache_data = sin_cache->Data<T>(); | ||
const int64_t* pos_ids_data = position_ids->Data<int64_t>(); | ||
T* output_dest = output->MutableData<T>(); | ||
|
||
AllocatorPtr allocator; | ||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); | ||
auto* tp = context->GetOperatorThreadPool(); | ||
|
||
return RunRotaryEmbeddingONNX<T>(tp, parameters, input_src, cos_cache_data, sin_cache_data, pos_ids_data, output_dest, | ||
interleaved); | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
Check warning Code scanning / lintrunner CLANGFORMAT/format Warning
See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch. |
||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/common/common.h" | ||
#include "core/framework/op_kernel.h" | ||
#include "rotary_embedding_onnx_helper.h" | ||
Check warning on line 7 in onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h
|
||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
template <typename T> | ||
Status RunRotaryEmbeddingONNX(onnxruntime::concurrency::ThreadPool* tp, rotary_embedding_onnx_helper::RotaryParameters parameters, const T* input, | ||
const T* cos_cache, const T* sin_cache, const int64_t* position_ids, T* output, | ||
bool interleaved); | ||
|
||
template <typename T> | ||
class RotaryEmbeddingONNX final : public OpKernel { | ||
public: | ||
RotaryEmbeddingONNX(const OpKernelInfo& info); | ||
Status Compute(OpKernelContext* context) const override; | ||
|
||
protected: | ||
bool interleaved; | ||
int rotary_embedding_dim; | ||
int num_heads; | ||
}; | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
Check warning Code scanning / lintrunner CLANGFORMAT/format Warning
See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch. |
||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/common/common.h" | ||
#include "core/providers/common.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace rotary_embedding_onnx_helper { | ||
|
||
// Parameters deduced from node attributes and inputs/outputs. | ||
struct RotaryParameters { | ||
int batch_size; // Batch size used by input | ||
int sequence_length; // Sequence length used by input | ||
int hidden_size; // Hidden size used by input | ||
int head_size; // Head size | ||
int rotary_embedding_dim; // Rotary embedding dimension. | ||
int num_heads; // num_heads = hidden_size / head_size | ||
int max_sequence_length; // Sequence length used by cos/sin cache | ||
int head_stride; // Head stride | ||
int seq_stride; // Sequence stride | ||
int batch_stride; // Batch stride | ||
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) | ||
bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) | ||
}; | ||
|
||
template <typename T> | ||
Status CheckInputs(const T* input, | ||
const T* cos_cache, | ||
const T* sin_cache, | ||
const T* position_ids, | ||
int num_heads, | ||
int rotary_embedding_dim, | ||
void* parameters) { | ||
// input : (batch_size, sequence_length, hidden_size) | ||
// cos cache : (max_sequence_length, rotary_embedding_dim / 2) | ||
// sin cache : (max_sequence_length, rotary_embedding_dim / 2) | ||
// position ids : (1) or (batch_size, sequence_length) | ||
|
||
// Check input | ||
const auto& input_dims = input->Shape().GetDims(); | ||
if (input_dims.size() != 3 && input_dims.size() != 4) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", | ||
input_dims.size()); | ||
} | ||
// Check cos_cache and sin_cache | ||
const auto& cos_cache_dims = cos_cache->Shape().GetDims(); | ||
if (cos_cache_dims.size() != 2) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", | ||
cos_cache_dims.size()); | ||
} | ||
const auto& sin_cache_dims = sin_cache->Shape().GetDims(); | ||
if (sin_cache_dims.size() != 2) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", | ||
sin_cache_dims.size()); | ||
} | ||
if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", | ||
"the same shape"); | ||
} | ||
// Check position_ids | ||
const auto& position_ids_dims = position_ids->Shape().GetDims(); | ||
if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ", | ||
"dimensions, got ", position_ids_dims.size()); | ||
} | ||
|
||
// Check num_heads and rotary_embedding_dim | ||
if (rotary_embedding_dim > 0 && num_heads == 0) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ", | ||
"specified"); | ||
} | ||
|
||
// Get attributes from inputs | ||
int batch_size = static_cast<int>(input_dims[0]); | ||
int sequence_length = static_cast<int>(input_dims[1]); | ||
int hidden_size = static_cast<int>(input_dims[2]); | ||
|
||
bool transposed = false; | ||
if (input_dims.size() == 4) { | ||
// input is [batch, seq, num_heads, head_size] | ||
hidden_size = static_cast<int>(input_dims[2]) * static_cast<int>(input_dims[3]); | ||
transposed = true; | ||
} | ||
int max_sequence_length = static_cast<int>(cos_cache_dims[0]); | ||
int head_size = rotary_embedding_dim == 0 ? static_cast<int>(cos_cache_dims[1]) * 2 | ||
: static_cast<int>(hidden_size / num_heads); | ||
if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", | ||
"head_size"); | ||
} | ||
|
||
int position_ids_format = -1; | ||
|
||
// Check position_ids input shapes | ||
if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { | ||
if (batch_size != static_cast<int>(position_ids_dims[0])) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", | ||
"batch_size, got ", position_ids_dims[0]); | ||
} | ||
if (sequence_length != static_cast<int>(position_ids_dims[1])) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", | ||
"sequence_length, got ", position_ids_dims[1]); | ||
} | ||
position_ids_format = 1; | ||
} else { | ||
position_ids_format = 0; | ||
} | ||
|
||
// Check cos_cache input shapes | ||
if (max_sequence_length != static_cast<int>(cos_cache_dims[0])) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", | ||
"max_sequence_length, got ", cos_cache_dims[0]); | ||
} | ||
if ((head_size / 2) != static_cast<int>(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast<int>(cos_cache_dims[1]))) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", | ||
"head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); | ||
} | ||
|
||
num_heads = num_heads > 0 ? num_heads : static_cast<int>(hidden_size / head_size); | ||
// Calculate stride values | ||
int head_stride; | ||
int seq_stride; | ||
int batch_stride; | ||
if (transposed) { | ||
// Transposed input tensor shape is [batch, n_heads, seq_len, head_size] | ||
seq_stride = head_size; | ||
head_stride = sequence_length * seq_stride; | ||
batch_stride = num_heads * head_stride; | ||
} else { | ||
// Default input tensor shape is [batch, seq_len, hidden_size] | ||
head_stride = head_size; | ||
seq_stride = num_heads * head_stride; | ||
batch_stride = sequence_length * seq_stride; | ||
} | ||
|
||
// Set rotary parameters | ||
if (parameters != nullptr) { | ||
RotaryParameters* output_parameters = reinterpret_cast<RotaryParameters*>(parameters); | ||
output_parameters->batch_size = batch_size; | ||
output_parameters->sequence_length = sequence_length; | ||
output_parameters->hidden_size = hidden_size; | ||
output_parameters->head_size = head_size; | ||
output_parameters->num_heads = num_heads; | ||
output_parameters->max_sequence_length = max_sequence_length; | ||
output_parameters->head_stride = head_stride; | ||
output_parameters->seq_stride = seq_stride; | ||
output_parameters->batch_stride = batch_stride; | ||
output_parameters->position_ids_format = position_ids_format; | ||
output_parameters->transposed = transposed; | ||
output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
|
||
} // namespace rotary_embedding_onnx_helper | ||
} // namespace contrib | ||
} // namespace onnxruntime |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning