Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT/DO NOT REVIEW] Add Rotary Embedding from ONNX Opset 23 #23507

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbeddingONNX);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbeddingONNX);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
Expand Down Expand Up @@ -301,6 +303,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbeddingONNX)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbeddingONNX)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>,
Expand Down
139 changes: 139 additions & 0 deletions onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#include "contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h"
#include "contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h"

#include "core/mlas/inc/mlas.h"
#include "core/platform/threadpool.h"

using onnxruntime::concurrency::ThreadPool;
using namespace onnxruntime::contrib::rotary_embedding_onnx_helper;

Check warning on line 11 in onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc:11: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {

// These ops are internal-only, so register outside of onnx
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
RotaryEmbeddingONNX, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), \
RotaryEmbeddingONNX<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
RotaryEmbeddingONNX<T>::RotaryEmbeddingONNX(const OpKernelInfo& info) : OpKernel(info) {
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));

if (rotary_embedding_dim > 0) {
ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
}
}

// TODO: rotary embedding in place

Check warning on line 43 in onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc:43: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
template <typename T>
Status RunRotaryEmbeddingONNX(concurrency::ThreadPool* tp, RotaryParameters parameters, const T* input,
const T* cos_cache, const T* sin_cache, const int64_t* position_ids, T* output,
bool interleaved) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int n_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int head_stride = parameters.head_stride;
const int seq_stride = parameters.seq_stride;
const int batch_stride = parameters.batch_stride;
const int position_ids_format = parameters.position_ids_format;
const int rotary_emb_dim = parameters.rotary_embedding_dim;
const int half_rotary_emb_dim = rotary_emb_dim / 2;

const int loop_len = batch_size * sequence_length * n_heads;
const double cost = static_cast<double>(rotary_emb_dim);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / n_heads) / sequence_length);
const int s = static_cast<int>((ptr / n_heads) % sequence_length);
const int n = static_cast<int>(ptr % n_heads);

const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;

const T* input_data = input + block_offset;
T* output_data = output + block_offset;

// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = (position_ids_format == 0)
? static_cast<int>(position_ids[0]) + s
: static_cast<int>(position_ids[b * sequence_length + s]);
const int cache_offset = position_id * half_rotary_emb_dim;
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;

MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data);

if (rotary_emb_dim < head_size) {
std::memcpy(output_data + rotary_emb_dim,
input_data + rotary_emb_dim,
(head_size - rotary_emb_dim) * sizeof(T));

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).
}
}
});

return Status::OK();
}

template Status RunRotaryEmbeddingONNX<float>(concurrency::ThreadPool* tp, RotaryParameters parameters, const float* input,
const float* cos_cache, const float* sin_cache, const int64_t* position_ids, float* output,
bool interleaved);

template Status RunRotaryEmbeddingONNX<MLFloat16>(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input,
const MLFloat16* cos_cache, const MLFloat16* sin_cache, const int64_t* position_ids,
MLFloat16* output, bool interleaved);

template <typename T>
Status RotaryEmbeddingONNX<T>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* cos_cache = context->Input<Tensor>(1);
const Tensor* sin_cache = context->Input<Tensor>(2);
const Tensor* position_ids = context->Input<Tensor>(3);

RotaryParameters parameters = {};
ORT_RETURN_IF_ERROR(rotary_embedding_onnx_helper::CheckInputs<Tensor>(input,
cos_cache,
sin_cache,
position_ids,
num_heads,
rotary_embedding_dim,
&parameters));

Tensor* output = context->Output(0, input->Shape());

if (parameters.sequence_length > parameters.max_sequence_length) {
// Launch update_cos_sin_cache kernel with scale
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
}

const T* input_src = input->Data<T>();
const T* cos_cache_data = cos_cache->Data<T>();
const T* sin_cache_data = sin_cache->Data<T>();
const int64_t* pos_ids_data = position_ids->Data<int64_t>();
T* output_dest = output->MutableData<T>();

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();

return RunRotaryEmbeddingONNX<T>(tp, parameters, input_src, cos_cache_data, sin_cache_data, pos_ids_data, output_dest,
interleaved);
}

} // namespace contrib
} // namespace onnxruntime
30 changes: 30 additions & 0 deletions onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "rotary_embedding_onnx_helper.h"

Check warning on line 7 in onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h:7: Include the directory when naming header files [build/include_subdir] [4]

namespace onnxruntime {
namespace contrib {

template <typename T>
Status RunRotaryEmbeddingONNX(onnxruntime::concurrency::ThreadPool* tp, rotary_embedding_onnx_helper::RotaryParameters parameters, const T* input,
const T* cos_cache, const T* sin_cache, const int64_t* position_ids, T* output,
bool interleaved);

template <typename T>
class RotaryEmbeddingONNX final : public OpKernel {
public:
RotaryEmbeddingONNX(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

protected:
bool interleaved;
int rotary_embedding_dim;
int num_heads;
};

} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/providers/common.h"

namespace onnxruntime {
namespace contrib {
namespace rotary_embedding_onnx_helper {

// Parameters deduced from node attributes and inputs/outputs.
struct RotaryParameters {
int batch_size; // Batch size used by input
int sequence_length; // Sequence length used by input
int hidden_size; // Hidden size used by input
int head_size; // Head size
int rotary_embedding_dim; // Rotary embedding dimension.
int num_heads; // num_heads = hidden_size / head_size
int max_sequence_length; // Sequence length used by cos/sin cache
int head_stride; // Head stride
int seq_stride; // Sequence stride
int batch_stride; // Batch stride
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};

template <typename T>
Status CheckInputs(const T* input,
const T* cos_cache,
const T* sin_cache,
const T* position_ids,
int num_heads,
int rotary_embedding_dim,
void* parameters) {
// input : (batch_size, sequence_length, hidden_size)
// cos cache : (max_sequence_length, rotary_embedding_dim / 2)
// sin cache : (max_sequence_length, rotary_embedding_dim / 2)
// position ids : (1) or (batch_size, sequence_length)

// Check input
const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 3 && input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ",
input_dims.size());
}
// Check cos_cache and sin_cache
const auto& cos_cache_dims = cos_cache->Shape().GetDims();
if (cos_cache_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ",
cos_cache_dims.size());
}
const auto& sin_cache_dims = sin_cache->Shape().GetDims();
if (sin_cache_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ",
sin_cache_dims.size());
}
if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ",
"the same shape");
}
// Check position_ids
const auto& position_ids_dims = position_ids->Shape().GetDims();
if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ",
"dimensions, got ", position_ids_dims.size());
}

// Check num_heads and rotary_embedding_dim
if (rotary_embedding_dim > 0 && num_heads == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ",
"specified");
}

// Get attributes from inputs
int batch_size = static_cast<int>(input_dims[0]);
int sequence_length = static_cast<int>(input_dims[1]);
int hidden_size = static_cast<int>(input_dims[2]);

bool transposed = false;
if (input_dims.size() == 4) {
// input is [batch, seq, num_heads, head_size]
hidden_size = static_cast<int>(input_dims[2]) * static_cast<int>(input_dims[3]);
transposed = true;
}
int max_sequence_length = static_cast<int>(cos_cache_dims[0]);
int head_size = rotary_embedding_dim == 0 ? static_cast<int>(cos_cache_dims[1]) * 2
: static_cast<int>(hidden_size / num_heads);
if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
"head_size");
}

int position_ids_format = -1;

// Check position_ids input shapes
if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) {
if (batch_size != static_cast<int>(position_ids_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ",
"batch_size, got ", position_ids_dims[0]);
}
if (sequence_length != static_cast<int>(position_ids_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ",
"sequence_length, got ", position_ids_dims[1]);
}
position_ids_format = 1;
} else {
position_ids_format = 0;
}

// Check cos_cache input shapes
if (max_sequence_length != static_cast<int>(cos_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
"max_sequence_length, got ", cos_cache_dims[0]);
}
if ((head_size / 2) != static_cast<int>(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast<int>(cos_cache_dims[1]))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
"head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
}

num_heads = num_heads > 0 ? num_heads : static_cast<int>(hidden_size / head_size);
// Calculate stride values
int head_stride;
int seq_stride;
int batch_stride;
if (transposed) {
// Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = num_heads * head_stride;
} else {
// Default input tensor shape is [batch, seq_len, hidden_size]
head_stride = head_size;
seq_stride = num_heads * head_stride;
batch_stride = sequence_length * seq_stride;
}

// Set rotary parameters
if (parameters != nullptr) {
RotaryParameters* output_parameters = reinterpret_cast<RotaryParameters*>(parameters);
output_parameters->batch_size = batch_size;
output_parameters->sequence_length = sequence_length;
output_parameters->hidden_size = hidden_size;
output_parameters->head_size = head_size;
output_parameters->num_heads = num_heads;
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->head_stride = head_stride;
output_parameters->seq_stride = seq_stride;
output_parameters->batch_stride = batch_stride;
output_parameters->position_ids_format = position_ids_format;
output_parameters->transposed = transposed;
output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size;
}

return Status::OK();
}

} // namespace rotary_embedding_onnx_helper
} // namespace contrib
} // namespace onnxruntime
Loading
Loading