From 78ac6dfd4d68ab6f399557340a0ff0decd98a8ba Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Tue, 28 Jan 2025 00:56:12 +0000 Subject: [PATCH 1/2] add cpu kernel --- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 4 + .../cpu/onnx_std_exp/rotary_embedding_onnx.cc | 139 ++++ .../cpu/onnx_std_exp/rotary_embedding_onnx.h | 30 + .../rotary_embedding_onnx_helper.h | 160 +++++ .../core/graph/contrib_ops/bert_defs.cc | 293 ++++++++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 4 +- .../rotary_embedding_onnx_op_test.cc | 633 ++++++++++++++++++ 7 files changed, 1262 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc create mode 100644 onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h create mode 100644 onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h create mode 100644 onnxruntime/test/contrib_ops/rotary_embedding_onnx_op_test.cc diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index c742cd1e95bdd..a0039f039806d 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -26,6 +26,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbeddingONNX); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbeddingONNX); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -301,6 +303,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc new file mode 100644 index 0000000000000..f37dcb6ea7179 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h" +#include "contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h" + +#include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" + +using onnxruntime::concurrency::ThreadPool; +using namespace onnxruntime::contrib::rotary_embedding_onnx_helper; + +namespace onnxruntime { +namespace contrib { + +// These ops are internal-only, so register outside of onnx +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbeddingONNX, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbeddingONNX); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +RotaryEmbeddingONNX::RotaryEmbeddingONNX(const OpKernelInfo& info) : OpKernel(info) { + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + + if (rotary_embedding_dim > 0) { + ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } +} + +// TODO: rotary embedding in place +template +Status RunRotaryEmbeddingONNX(concurrency::ThreadPool* tp, RotaryParameters parameters, const T* input, + const T* cos_cache, const T* sin_cache, const int64_t* position_ids, T* output, + bool interleaved) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int n_heads = parameters.num_heads; + const int head_size = parameters.head_size; + const int head_stride = parameters.head_stride; + const int seq_stride = parameters.seq_stride; + const int batch_stride = parameters.batch_stride; + const int position_ids_format = parameters.position_ids_format; + const int rotary_emb_dim = parameters.rotary_embedding_dim; + const int half_rotary_emb_dim = rotary_emb_dim / 2; + + const int loop_len = batch_size * sequence_length * n_heads; + const double cost = static_cast(rotary_emb_dim); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / n_heads) / sequence_length); + const int s = static_cast((ptr / n_heads) % sequence_length); + const int n = static_cast(ptr % n_heads); + + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; + + const T* input_data = input + block_offset; + T* output_data = output + block_offset; + + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) + const int position_id = (position_ids_format == 0) + ? static_cast(position_ids[0]) + s + : static_cast(position_ids[b * sequence_length + s]); + const int cache_offset = position_id * half_rotary_emb_dim; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + MlasRotaryEmbedOneRow(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data); + + if (rotary_emb_dim < head_size) { + std::memcpy(output_data + rotary_emb_dim, + input_data + rotary_emb_dim, + (head_size - rotary_emb_dim) * sizeof(T)); + } + } + }); + + return Status::OK(); +} + +template Status RunRotaryEmbeddingONNX(concurrency::ThreadPool* tp, RotaryParameters parameters, const float* input, + const float* cos_cache, const float* sin_cache, const int64_t* position_ids, float* output, + bool interleaved); + +template Status RunRotaryEmbeddingONNX(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input, + const MLFloat16* cos_cache, const MLFloat16* sin_cache, const int64_t* position_ids, + MLFloat16* output, bool interleaved); + +template +Status RotaryEmbeddingONNX::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* cos_cache = context->Input(1); + const Tensor* sin_cache = context->Input(2); + const Tensor* position_ids = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_onnx_helper::CheckInputs(input, + cos_cache, + sin_cache, + position_ids, + num_heads, + rotary_embedding_dim, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const T* input_src = input->Data(); + const T* cos_cache_data = cos_cache->Data(); + const T* sin_cache_data = sin_cache->Data(); + const int64_t* pos_ids_data = position_ids->Data(); + T* output_dest = output->MutableData(); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + auto* tp = context->GetOperatorThreadPool(); + + return RunRotaryEmbeddingONNX(tp, parameters, input_src, cos_cache_data, sin_cache_data, pos_ids_data, output_dest, + interleaved); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h new file mode 100644 index 0000000000000..2cc08fbb3fcbf --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "rotary_embedding_onnx_helper.h" + +namespace onnxruntime { +namespace contrib { + +template +Status RunRotaryEmbeddingONNX(onnxruntime::concurrency::ThreadPool* tp, rotary_embedding_onnx_helper::RotaryParameters parameters, const T* input, + const T* cos_cache, const T* sin_cache, const int64_t* position_ids, T* output, + bool interleaved); + +template +class RotaryEmbeddingONNX final : public OpKernel { + public: + RotaryEmbeddingONNX(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + protected: + bool interleaved; + int rotary_embedding_dim; + int num_heads; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h new file mode 100644 index 0000000000000..f826f78c5262d --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace contrib { +namespace rotary_embedding_onnx_helper { + +// Parameters deduced from node attributes and inputs/outputs. +struct RotaryParameters { + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size + int rotary_embedding_dim; // Rotary embedding dimension. + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int head_stride; // Head stride + int seq_stride; // Sequence stride + int batch_stride; // Batch stride + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) +}; + +template +Status CheckInputs(const T* input, + const T* cos_cache, + const T* sin_cache, + const T* position_ids, + int num_heads, + int rotary_embedding_dim, + void* parameters) { + // input : (batch_size, sequence_length, hidden_size) + // cos cache : (max_sequence_length, rotary_embedding_dim / 2) + // sin cache : (max_sequence_length, rotary_embedding_dim / 2) + // position ids : (1) or (batch_size, sequence_length) + + // Check input + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3 && input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", + input_dims.size()); + } + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", + cos_cache_dims.size()); + } + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + // Check position_ids + const auto& position_ids_dims = position_ids->Shape().GetDims(); + if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ", + "dimensions, got ", position_ids_dims.size()); + } + + // Check num_heads and rotary_embedding_dim + if (rotary_embedding_dim > 0 && num_heads == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ", + "specified"); + } + + // Get attributes from inputs + int batch_size = static_cast(input_dims[0]); + int sequence_length = static_cast(input_dims[1]); + int hidden_size = static_cast(input_dims[2]); + + bool transposed = false; + if (input_dims.size() == 4) { + // input is [batch, seq, num_heads, head_size] + hidden_size = static_cast(input_dims[2]) * static_cast(input_dims[3]); + transposed = true; + } + int max_sequence_length = static_cast(cos_cache_dims[0]); + int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2 + : static_cast(hidden_size / num_heads); + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + + int position_ids_format = -1; + + // Check position_ids input shapes + if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { + if (batch_size != static_cast(position_ids_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", + "batch_size, got ", position_ids_dims[0]); + } + if (sequence_length != static_cast(position_ids_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", + "sequence_length, got ", position_ids_dims[1]); + } + position_ids_format = 1; + } else { + position_ids_format = 0; + } + + // Check cos_cache input shapes + if (max_sequence_length != static_cast(cos_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", + "max_sequence_length, got ", cos_cache_dims[0]); + } + if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); + } + + num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); + // Calculate stride values + int head_stride; + int seq_stride; + int batch_stride; + if (transposed) { + // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } else { + // Default input tensor shape is [batch, seq_len, hidden_size] + head_stride = head_size; + seq_stride = num_heads * head_stride; + batch_stride = sequence_length * seq_stride; + } + + // Set rotary parameters + if (parameters != nullptr) { + RotaryParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; + output_parameters->hidden_size = hidden_size; + output_parameters->head_size = head_size; + output_parameters->num_heads = num_heads; + output_parameters->max_sequence_length = max_sequence_length; + output_parameters->head_stride = head_stride; + output_parameters->seq_stride = seq_stride; + output_parameters->batch_stride = batch_stride; + output_parameters->position_ids_format = position_ids_format; + output_parameters->transposed = transposed; + output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; + } + + return Status::OK(); +} + +} // namespace rotary_embedding_onnx_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index f2a2a52f8334f..f590ce2bbd9b4 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1386,6 +1386,299 @@ ONNX_MS_OPERATOR_SET_SCHEMA( propagateShapeFromInputToOutput(ctx, 0, 0); })); +static const char* RotaryEmbeddingONNX_ver1_doc = R"DOC( +RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864. +The key advantage of RoPE is that it allows the model to understand both the absolute position of a token and the relative distances +between tokens. This is achieved through a rotational mechanism where the extent of rotation is computed based on the token's absolute position (position_ids). + +The rotational mechanism is defined by sine and cosine functions that are used to represent the rotation angles. +For each token in the sequence, its positional embedding is computed by rotating its embedding vector. This is done by splitting the +embedding vector either into two halves or interleaving every alternate token and applying the rotation matrix to each half of the embedding vector. +The rotation matrix is parameterized by the token's position in the sequence. The rotated halves of the embedding vector are concatenated +to form the final positional embedding for each token. The rotated positional embeddings are used in the self-attention mechanism. +The rotation ensures that the model captures both absolute and relative positional information. + +Rotary embeddings are defined using the following algorithm: + + ```python + def compute_rotary_embedding(input, position_ids, sin_cache, cos_cache, interleaved=0, rotary_embedding_dim=0 ,num_heads=0) + + # First ensure input has shape [batch_size, num_heads, seq_len, head_size] + batch_size = input.shape[0] + sequence_length = input.shape[1] + if len(input.shape) == 3: + hidden_size = input.shape[2] + assert num_heads != 0 + head_size = int(hidden_size / num_heads) + new_shape = [batch_size, sequence_length, num_heads, head_size] + input = np.reshape(input, new_shape) + assert len(input.shape) == 4 + head_size = input.shape[3] + + # Fully or partially perform rotation on input based on rotary_embedding_dim attribute + if rotary_embedding_dim == 0: + # If rotary_embedding_dim not provided, perform full rotation by using head_size + rotary_embedding_dim = head_size + x_rotate = input[:, :, :, :rotary_embedding_dim] + x_not_rotate = input[:, :, :, rotary_embedding_dim:] + rotary_embedding_dim_half = int(rotary_embedding_dim / 2) + + # Retrieve sin and cos caches using position ids + if position_ids is not None: + cos = cos_cache[position_ids] # Shape: [batch_size, sequence_length, head_size/2] + sin = sin_cache[position_ids] # Shape: [batch_size, sequence_length, head_size/2] + else: + cos = cos_cache + sin = sin_cache + cos = cos[:, :, :rotary_embedding_dim_half] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] + sin = sin[:, :, :rotary_embedding_dim_half] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] + cos = np.expand_dims(cos, axis=2) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] + sin = np.expand_dims(sin, axis=2) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] + + # Either divide the input in halves or interleave (based on interleaved attribute) + if interleaved: + x1 = x_rotate[:, :, :, 0::2] + x2 = x_rotate[:, :, :, 1::2] + else: + x1, x2 = np.split(x_rotate, 2, axis=-1) + + # Calculate real and imaginary values + real = cos * x1 - sin * x2 + imag = sin * x1 + cos * x2 + + # Inserted rotated embeddings back to the original input + if interleaved: + # x_rotate[:, :, :, 0::2] = real + # x_rotate[:, :, :, 1::2] = imag + real = np.expand_dims(real, axis=-1) + imag = np.expand_dims(imag, axis=-1) + x_rotate_concat = np.concatenate((real, imag), axis=-1) + x_rotate = np.reshape(x_rotate_concat, x_rotate.shape) + else: + x_rotate = np.concatenate((real, imag), axis=-1) + output = np.concatenate((x_rotate, x_not_rotate), axis=-1) + if len(input.shape) == 3: + output = np.reshape(output, input.shape) + return output + ``` +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + RotaryEmbeddingONNX, + 1, + OpSchema() + .SetDoc(RotaryEmbeddingONNX_ver1_doc) + .Attr( + "interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + static_cast(0)) + .Attr( + "rotary_embedding_dim", + "Rotary embedding dimension used to apply partial rotary embeddings.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "num_heads", + "Number of attention heads. Must be provided when input is a 3D tensor. ", + AttributeProto::INT, + OPTIONAL_VALUE) + .Input( + 0, + "X", + "The input tensor representing the token embeddings. " + "4D tensor with shape (batch_size, sequence_length, num_heads, head_size) or 3D tensor with shape (batch_size, sequence_length, hidden_size). " + "For cases with a 4D input tensor, `head_size` has to be even. For cases with a 3D input tensor, `num_heads` attribute must be provided and " + "hidden_size must be an even multiple of num_heads where `hidden_size = num_heads * head_size`", + "T") + .Input( + 1, + "cos_cache", + "The cosine values for the rotation. " + "2D tensor with shape (max_position_id_plus_1, head_size / 2) for full rotation or (max_position_id_plus_1, rotary_embedding_dim / 2) " + "for partial rotation when position_ids are provided. 3D tensor with shape (batch_size, sequence_length, head_size / 2) " + "for full rotation or (batch_size, sequence_length, rotary_embedding_dim / 2) for partial rotation when position_ids are not provided. " + "`max_position_id_plus_1` is a parameter to the model.", + "T") + .Input( + 2, + "sin_cache", + "The sine values for the rotation. " + "2D tensor with shape (max_position_id_plus_1, head_size / 2) for full rotation or (max_position_id_plus_1, rotary_embedding_dim / 2) " + "for partial rotation when position_ids are provided. 3D tensor with shape (batch_size, sequence_length, head_size / 2) " + "for full rotation or (batch_size, sequence_length, rotary_embedding_dim / 2) for partial rotation when position_ids are not provided. " + "`max_position_id_plus_1` is a parameter to the model.", + "T") + .Input( + 3, + "position_ids", + "The position indices for the tokens. 2D tensor with shape (batch_size, sequence_length)", + "M", + OpSchema::Optional) + .Output(0, "Y", "Tensor with same shape as input.", "T") + .TypeConstraint( + "T", + {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateShapeFromInputToOutput(ctx, 0, 0); + + // we need at least one input to have a shape for this inference. + if (!hasNInputShapes(ctx, 1)) { + return; + } + + auto input_shape = ctx.getInputType(0)->tensor_type().shape(); + if ((input_shape.dim_size() < 3) || (input_shape.dim_size() > 4)) { + fail_shape_inference("Input tensor must have at least 3 and at most 4 dimensions"); + } + + auto* num_heads_attr = ctx.getAttribute("num_heads"); + if ((input_shape.dim_size() == 3) && (num_heads_attr == nullptr)) { + fail_shape_inference("Input shape is 3D, num_heads attribute must be provided"); + } + }) + .SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, + const OpSchema& schema, + FunctionProto& functionProto) { + // RotaryEmbedding (X, position_ids, cos_cache, + // sin_cache) => Y + + int64_t int_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; + auto* interleaved_attr = ctx.getAttribute("interleaved"); + int64_t interleaved = (interleaved_attr != nullptr) ? interleaved_attr->i() : 0; + auto* rotary_embedding_dim_attr = ctx.getAttribute("rotary_embedding_dim"); + int64_t rotary_embedding_dim = (rotary_embedding_dim_attr != nullptr) ? rotary_embedding_dim_attr->i() : 0; + auto* num_heads_attr = ctx.getAttribute("num_heads"); + int64_t num_heads = (num_heads_attr != nullptr) ? num_heads_attr->i() : 0; + + FunctionBuilder builder(functionProto); + // Set input tensor to the correct shape if input shape is 3D + // NewShape = [batch_size, sequence_length, num_heads, head_size] + builder.Const1D("Zero1D", (int64_t)0) + .Const1D("NumHeads", num_heads) // num_heads + .Const1D("NegOne", (int64_t)(-1)) // head_size, inferred from other dimensions + .Add("NewShape = Concat (Zero1D, Zero1D, NumHeads, NegOne)") + .Add("XReshaped = Reshape (X, NewShape)"); + + // Rotary embedding dimension is the value along which the input is to be split + // There are two cases for the rotary embedding dimension: + // 1. Complete rotation: rotary embedding dimension defaults to head_size, rotary_embedding_dim = cos.shape[3] + // * 2 or head_size + // 2. Partial rotation: rotary embedding dimension is provided, rotary_embedding_dim = rotary_embedding_dim + + builder.Add("HeadSize = Shape (XReshaped)"); + if (rotary_embedding_dim > 0) { + builder.Const1D("RotaryEmbedDim", rotary_embedding_dim); + } else { + builder.Add("RotaryEmbedDim = Identity(HeadSize)"); + } + builder.Const1D("Two1D", (int64_t)2) + .Add("NoRotateLength = Sub(HeadSize, RotaryEmbedDim)") + .Add("RotateSplitLengths = Concat (RotaryEmbedDim, NoRotateLength)"); + // shape of input to rotate = input[:,:,:,:rotary_embedding_dim] + // shape of input not to rotate = input[:,:,:,rotary_embedding_dim:] + builder.Add("XToRotate, XNoRotate = Split (XReshaped, RotateSplitLengths)"); + + // Gather the cos and sine matrices from the respective caches using position ids if provided. + // Otherwise Gather op functions as an Identity op. + // Unsqueeze applied to make cos and sin matrices have dimensions that are + // valid for multiplication with input when is split. For cases where rotary_embedding_dim is provided, + // slice the matrix values until that index only + if (ctx.hasInput(3)) { + builder + .Add("CosCacheGather = Gather(cos_cache, position_ids)") // shape of cos matrix: [batch_size, + // sequence_length, head_size / 2] + .Add("SinCacheGather = Gather(sin_cache, position_ids)"); // shape of cos matrix: [batch_size, + // sequence_length, head_size / 2] + } else { + builder + .Add("CosCacheGather = Identity(cos_cache)") // shape of cos matrix: [batch_size, sequence_length, + // head_size / 2] + .Add("SinCacheGather = Identity(sin_cache)"); // shape of cos matrix: [batch_size, sequence_length, + // head_size / 2] + } + + builder.Add("RotaryEmbedDimHalf = Div(RotaryEmbedDim, Two1D)") + .Add("RotaryEmbedDimHalfInt = Cast (RotaryEmbedDimHalf)", "to", int_type) + .Add( + "CosCacheSliced = Slice(CosCacheGather, Zero1D, RotaryEmbedDimHalfInt, Two1D)") // shape of cos + // matrix: + // [batch_size, + // sequence_length, + // rotary_embedding_dim + // / 2] + .Add( + "SinCacheSliced = Slice(SinCacheGather, Zero1D, RotaryEmbedDimHalfInt, Two1D)") // shape of sin + // matrix: + // [batch_size, + // sequence_length, + // rotary_embedding_dim + // / 2] + .Add("CosCacheUnsqueezed = Unsqueeze(CosCacheSliced, Two1D)") // shape of cos matrix: [batch_size, + // sequence_length, 1, rotary_embedding_dim + // / 2] + .Add("SinCacheUnsqueezed = Unsqueeze(SinCacheSliced, Two1D)"); // shape of sin matrix: [batch_size, + // sequence_length, 1, rotary_embedding_dim + // / 2] + + // Create slices of inputs to multiply with sin and cos matrices based on interleaved parameter + // Choose the correct slices based on interleaved parameter + // real = cos_x * x1 - sin_x * x2 + // imag = sin_x * x1 + cos_x * x2 + if (interleaved == 0) { + // For non-interleaved (basic) rotation, slices are created as follows, + builder.Add( + "X1, X2 = Split (XToRotate)"); // shape of X1 = + // input[:,:,:,:rotary_embedding_dim/2], + // X2 = + // input[:,:,:,rotary_embedding_dim/2:rotary_embedding_dim] + } else { + // For interleaved rotation, slices are created as follows, + builder.Const1D("One1D", (int64_t)1) + .Const1D("AxesRotaryDim", (int64_t)3) + .Add("RotaryEmbedDimInclusive = Add(RotaryEmbedDim, One1D)") + .Add( + "X1 = Slice(XToRotate, Zero1D, RotaryEmbedDim, AxesRotaryDim, Two1D)") // shape of X1 = + // input[:,:,:,0:rotary_embedding_dim:2] + .Add( + "X2 = Slice(XToRotate, One1D, RotaryEmbedDimInclusive, AxesRotaryDim, Two1D)"); // shape of + // X2 = + // input[:,:,:,1:rotary_embedding_dim:2] + } + + builder.Add("CosX1 = Mul(CosCacheUnsqueezed, X1)") + .Add("SinX2 = Mul(SinCacheUnsqueezed, X2)") + .Add("Real = Sub(CosX1, SinX2)") + .Add("SinX1 = Mul(SinCacheUnsqueezed, X1)") + .Add("CosX2 = Mul(CosCacheUnsqueezed, X2)") + .Add("Imaginary = Add(SinX1, CosX2)"); + + // Insert the real and imaginary values into the original input to be rotated based on interleaved parameter + if (interleaved == 0) { + builder.Add("XRotated = Concat (Real, Imaginary)"); + } else { + builder + .Add("RealInterleave = Unsqueeze(Real, NegOne)") // shape of indices = + // input[:,:,:,0:rotary_embedding_dim:2, 1] + .Add("ImaginaryInterleave = Unsqueeze(Imaginary, NegOne)") // shape of indices = + // input[:,:,:,1:rotary_embedding_dim+1:2, 1] + .Add("XRotatedInterleavedConcat = Concat (RealInterleave, ImaginaryInterleave)") + .Add("XRotatedShape = Shape(XToRotate)") + .Add("XRotated = Reshape(XRotatedInterleavedConcat, XRotatedShape)"); + } + + // Combine rotated parts with non-rotated parts + builder.Add("XConcat = Concat (XRotated, XNoRotate)"); + // Reshape back to 3D shape if input is a 3D tensor + builder.Add("XShape = Shape(X)").Add("Y = Reshape(XConcat, XShape)"); + + schema.BuildFunction(functionProto); + return true; + })); + constexpr const char* GemmaRotaryEmbedding_ver1_doc = R"DOC( GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py. diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index a9a89f756b071..f0bb294a532b6 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbeddingONNX); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmaRotaryEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); @@ -207,9 +208,10 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); - fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_onnx_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_onnx_op_test.cc new file mode 100644 index 0000000000000..5badf821def9d --- /dev/null +++ b/onnxruntime/test/contrib_ops/rotary_embedding_onnx_op_test.cc @@ -0,0 +1,633 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunTest( + const std::vector& input_data, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& position_ids, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size, + int num_heads, + int max_sequence_length, + int64_t interleaved, + bool use_float16, + bool disable_cpu, + bool disable_cuda) { + // input : (batch_size, sequence_length, hidden_size) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + // position ids : (1) or (batch_size, sequence_length) + // interleaved : 0 = false, 1 = true + + int hidden_size = num_heads * head_size; + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector pos_dims; + std::vector cache_dims = {max_sequence_length, head_size / 2}; + + assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); + assert(max_sequence_length >= sequence_length); + if (position_ids.size() == 1) { + pos_dims = {1}; + } else { + pos_dims = {batch_size, sequence_length}; + } + + std::string op_type = "RotaryEmbeddingONNX"; + std::vector> execution_providers; + + int min_cuda_architecture = use_float16 ? 530 : 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + if (enable_cuda && !disable_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (!use_float16 && !disable_cpu) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); + test.AddAttribute("interleaved", interleaved); + test.AddAttribute("num_heads", static_cast(num_heads)); + + if (!use_float16) { + test.AddInput("input", input_dims, input_data); + test.AddInput("cos_cache", cache_dims, cos_cache); + test.AddInput("sin_cache", cache_dims, sin_cache); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddOutput("output", input_dims, output_data); + } else { + test.AddInput("input", input_dims, ToFloat16(input_data)); + test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache)); + test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache)); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddOutput("output", input_dims, ToFloat16(output_data)); + } + test.SetOutputAbsErr("output", 0.002f); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +static void RunTests(const std::vector& input_data, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& position_ids, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size = 0, + int num_heads = 0, + int max_sequence_length = 0, + int64_t interleaved = 0, + bool use_float16 = true) { + // FP32 test for CPU + RunTest(input_data, + cos_cache, + sin_cache, + position_ids, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + false, /* use_fp16 */ + false, /* disable_cpu */ + true /* disable_cuda */); + + // FP32 test for CUDA +// RunTest(input_data, +// position_ids, +// cos_cache, +// sin_cache, +// output_data, +// batch_size, +// sequence_length, +// head_size, +// num_heads, +// max_sequence_length, +// interleaved, +// false, /* use_fp16 */ +// false, /* disable_cpu */ +// false /* disable_cuda */); + +// // FP16 test for CUDA +// if (use_float16) { +// RunTest(input_data, +// position_ids, +// cos_cache, +// sin_cache, +// output_data, +// batch_size, +// sequence_length, +// head_size, +// num_heads, +// max_sequence_length, +// interleaved, +// true, /* use_fp16 */ +// true, /* disable_cpu */ +// false /* disable_cuda*/); +// } +} + +// Interleaved = true, pos ids shape = (1) +TEST(RotaryEmbeddingONNXTest, RotaryEmbedding_ONNX_Interleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 3; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.2188f, 1.1676f, -1.0574f, -0.1188f, -0.7396f, -1.2425f, -0.1752f, 0.6990f, + -0.8110f, 0.6737f, -1.1233f, -0.0919f, -0.6861f, 0.7202f, 0.1963f, 0.6142f}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f, -0.9900f, 0.9996f, + -0.6536f, 0.9992f, 0.2837f, 0.9988f, 0.9602f, 0.9982f, 0.7539f, 0.9976f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f, 0.1411f, 0.0300f, + -0.7568f, 0.0400f, -0.9589f, 0.0500f, -0.2794f, 0.0600f, 0.6570f, 0.0699f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.6411f, -0.3948f, -1.0561f, -0.1294f, 0.6460f, -1.2937f, -0.1822f, 0.6972f, + -0.2751f, -1.0178f, -1.1212f, -0.1143f, -0.3694f, -0.9235f, 0.1840f, 0.6180f}; + + std::vector position_ids = {0}; + + RunTests(input_data, + cos_cache, + sin_cache, + position_ids, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = true, pos ids shape = (1) +TEST(RotaryEmbeddingONNXTest, RotaryEmbeddingONNX_Interleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int max_sequence_length = 16; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector position_ids = {0}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.4713f, + -0.9540f, -0.9229f, 0.3027f, -0.5708f, -0.2363f, + -1.2713f, 0.1137f, 0.8112f, -1.1659f, -0.5824f, + -0.4419f, -0.7649f, 0.7011f, -0.4569f, -0.5639f, + -0.5328f, -0.6424f, 1.0979f, 0.8773f, 0.5462f, + 0.0793f, 0.2582f, 0.8576f, 0.2653f, 1.2295f, + -0.1839f, -0.4517f, -1.5052f, -0.4651f, 0.1155f, + -2.1237f, -0.7586f, -0.2110f, 1.1441f, -0.6304f, + 0.4186f, 0.2303f, -0.1519f, 1.1903f, 0.5382f, + -0.1906f, -1.0080f, 2.3112f, -0.2220f, -0.9655f, + -0.0099f, 1.5198f, 0.7652f, -0.6410f, 0.0365f, + -0.0452f, 1.0593f, 0.8929f, 1.4856f, 0.0038f, + -1.0865f, 1.4794f, -0.2417f, 0.9428f, -0.6894f, + -0.6293f, 0.2904f, 1.5747f, -0.4956f, 0.9199f, + -0.2424f, 0.1801f, 0.7503f, -1.4576f, 0.6529f, + -1.1340f, -0.6807f, -0.0252f, -0.3834f, 2.7394f, + 0.1308f, 1.1203f, -2.1196f, -0.9618f, 0.1970f, + -0.0972f, -0.2764f, 0.3332f, -0.4522f, 1.1844f, + 0.3867f, -0.6626f, -0.9405f, 1.8656f, 0.5053f, + -1.2361f, 1.2072f, 0.1789f, -1.1002f, 1.0129f, + 1.7702f, 0.1949f, -1.1653f, 1.6049f, -0.2755f, + -0.2749f, 2.1087f, 0.4272f, 0.8076f, 0.2900f, + -0.0714f, 0.8261f, -1.1016f, -1.3814f, -0.1366f, + 0.2981f, 0.6060f, -1.4132f, 0.0893f, -0.1939f, + 0.2779f, 0.3910f, -0.8906f, -0.6489f, -1.2496f, + 0.3383f, -0.0315f, -0.7461f, 1.1510f, 0.4445f, + 0.3203f, -0.9031f, 0.2727f, 0.2609f, 2.0968f, + 1.0974f, 0.7120f, -0.5164f, 0.7415f, -0.0031f, + -0.1568f, 0.1533f, 0.5487f, -0.3357f, -0.9064f, + 1.0546f, 0.0542f, 1.1870f, -0.4045f, -1.3431f, + -0.6094f, -1.1105f, -0.9631f, -0.1137f, -0.7219f, + 0.8582f, -1.3443f, -0.6684f, -1.0227f, -1.5929f, + -0.2622f, 0.2264f, 0.0713f, 0.1843f, -1.3387f, + -1.6797f, 2.3165f, 0.1009f, 0.1081f, -0.9969f, + -1.4488f, 0.6291f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, 0.5985f, -1.0968f, 1.5662f, 1.4693f, + 0.8776f, 0.3408f, 0.4345f, 1.2549f, 0.6631f, + 1.4543f, 0.3374f, 0.0445f, 1.2320f, 1.4311f, + -2.0483f, -0.7272f, 0.4114f, -1.1449f, 1.6283f, + -0.9524f, -1.6435f, 0.5422f, 0.9907f, -0.0708f, + 0.3972f, 0.7376f, -1.5947f, 1.6138f, -0.9586f, + -0.4600f, 0.3993f, -1.5884f, 1.2934f, -1.4467f, + 1.2833f, -1.2459f, -0.7760f, 0.3108f, -3.3677f, + -0.0287f, 0.6942f, -0.7601f, -0.6993f, 2.3690f, + 1.3834f, -0.5234f, 0.3435f, 1.0053f, 0.1604f, + -0.9560f, -1.2641f, 0.2406f, 0.4973f, 0.9206f, + -1.9987f, -1.1733f, -0.4197f, -0.0366f, -0.6720f, + -1.3350f, -1.5960f, -0.1097f, 0.6386f, 0.5624f, + -0.6184f, 0.0778f, 0.1867f, 0.9643f, -1.3629f, + -0.0972f, -1.7907f, -0.3037f, 0.8245f, -0.0789f, + -0.2940f, -0.2833f, -0.2165f, 0.6264f, -1.1726f, + 0.7926f, 1.3621f, 1.3586f, -0.9007f, -0.8138f, + -2.7421f, 1.3155f, 2.4507f, 0.0507f, 0.6305f, + 1.6900f, 0.5210f, -0.3309f, 2.0630f, 1.8026f, + -0.7859f, -0.6802f, -1.1003f, -0.1990f, -0.5391f, + -0.9370f, 0.0857f, -2.3330f, -2.0112f, 0.7193f, + -0.1272f, -0.9981f, -0.1818f, 0.3973f, -0.9963f, + 1.4929f, -1.0109f, 0.4304f, 1.0160f, -1.4590f, + 0.2682f, 1.5658f, 0.1762f, 0.3038f, -0.7491f, + 0.3052f, -1.1534f, -0.0478f, 0.0021f, -0.0665f, + -0.8118f, 0.1310f, 0.2171f, 0.5485f, -0.1610f, + -1.5784f, -0.8660f, 0.7289f, -0.4678f, 0.1937f, + 1.1287f, -0.5772f, -0.0259f, -0.2212f, 0.2479f, + 0.6336f, 0.6407f, -0.6543f, 0.3838f, 0.9039f, + 0.4724f, 0.7117f, 1.0165f, 1.0270f, 1.1908f, + 1.3750f, -0.0850f, 0.5517f, -1.3842f, 0.3703f, + -0.8806f, 0.9336f, 0.8362f, 0.8105f, -1.1566f, + -0.6813f, 0.0294f, -0.1122f, 0.5620f, -0.2884f, + -2.0803f, 0.4684f, 0.6009f, -1.4160f}; + + RunTests(input_data, + cos_cache, + sin_cache, + position_ids, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (1) +TEST(RotaryEmbeddingONNXTest, RotaryEmbeddingONNX_NotInterleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int max_sequence_length = 16; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector position_ids = {0}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.8618f, + -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + 0.6923f, 1.1571f, 0.7572f, -1.1471f, -0.5302f, + -0.4391f, 0.5516f, 1.0461f, -0.4812f, -0.1443f, + -0.4862f, -0.6423f, 0.6740f, -0.4614f, 0.5475f, + 1.1495f, 0.2389f, 0.8582f, -0.0259f, -0.6099f, + -0.2230f, 1.0963f, -1.5704f, -0.4595f, 0.9507f, + 0.6696f, -0.7721f, -1.7415f, 1.2087f, -0.6387f, + -1.1052f, -0.5243f, -0.0400f, -0.4671f, 0.4909f, + -0.1931f, -0.1937f, -0.0447f, -0.3171f, 2.6839f, + -0.0076f, 1.5185f, 0.8465f, 0.3737f, 0.0242f, + -0.0703f, 1.1279f, 0.8862f, 1.2275f, -0.1786f, + -0.8767f, -1.8072f, -0.2630f, 0.9387f, -0.8021f, + 0.7813f, 0.5001f, -1.4202f, -0.3850f, 0.9263f, + -0.0443f, -0.2323f, 0.5480f, 1.5696f, 0.6193f, + -1.1346f, 1.7878f, -0.5160f, 0.1192f, -2.1572f, + 0.0460f, 1.1202f, -1.4812f, -0.9082f, 0.1728f, + -1.5132f, -0.4489f, 0.3370f, -0.1541f, -0.9266f, + 0.2416f, 0.9270f, -1.1146f, 1.8758f, -0.4312f, + 1.3714f, 1.2106f, -0.4272f, -0.8529f, 1.0328f, + 1.8441f, 1.7698f, -0.7620f, 0.2168f, 0.1322f, + -0.2802f, 0.1460f, 2.1002f, 0.8437f, -0.1534f, + 0.4321f, 0.8360f, 0.5955f, -1.5452f, -0.0491f, + -0.8794f, 0.2418f, -1.4203f, 0.3635f, 0.2362f, + 0.3672f, -0.1128f, -0.8664f, -0.6354f, -1.4409f, + -0.3413f, -0.2409f, -0.3188f, 1.1054f, 0.4265f, + 0.5867f, -1.3279f, 0.3201f, 0.0125f, 1.8157f, + 1.0745f, 0.7372f, -0.2429f, 0.7100f, -0.4299f, + -0.2304f, 0.1645f, 0.9489f, -0.1816f, -0.5968f, + 1.0394f, 0.0204f, 1.1786f, -0.3315f, -0.3997f, + -0.9304f, -1.4268f, -1.1526f, -0.1132f, 0.1490f, + 1.3967f, -1.4634f, -0.1412f, -0.6339f, -1.5995f, + -0.1366f, 0.7604f, 0.1514f, 0.0824f, -1.1830f, + -1.6572f, 2.0099f, -0.9108f, -0.2256f, 0.4527f, + -1.8254f, 0.6475f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -1.4979f, -1.1358f, 1.6320f, 0.2493f, + 0.8266f, 0.3424f, -0.4992f, 0.2964f, 0.7298f, + 1.8544f, 0.3516f, 0.0454f, 1.5415f, -0.2822f, + -2.0774f, 1.2323f, 0.3963f, -1.1503f, -0.4775f, + -1.9287f, -1.6164f, 0.3998f, 0.9020f, -0.0764f, + -1.8059f, -0.5762f, -1.4362f, -0.2706f, -1.0183f, + -0.4620f, 2.0891f, 0.1782f, 1.1591f, -0.8151f, + 1.3000f, -1.2464f, -0.5099f, 0.5098f, -3.3525f, + 0.4326f, 0.7414f, -0.7775f, -0.4271f, -0.3807f, + 1.3245f, 2.4936f, 0.3139f, 1.0095f, 0.2323f, + 0.8450f, -1.2244f, -0.4511f, 0.6266f, 0.9095f, + -1.7981f, 1.5241f, -0.4121f, 0.2341f, -0.4737f, + -1.3333f, -1.6150f, 0.4164f, 0.7100f, -0.2429f, + -0.5656f, 0.0863f, 0.0352f, -0.7227f, -1.3613f, + -0.0988f, -1.9114f, -0.3009f, 0.1435f, 0.7029f, + -0.3467f, 0.5092f, -0.0828f, 0.6253f, 0.7113f, + -1.2138f, 1.5964f, -0.8346f, -1.1515f, -0.7923f, + -0.8254f, -3.0038f, 2.4033f, -0.3398f, 0.0922f, + 1.7053f, 1.1114f, 0.7462f, 2.3660f, -0.8409f, + -0.6654f, -0.6530f, -0.7899f, -1.0957f, -0.7149f, + -0.1072f, -0.1967f, -2.3416f, -1.2609f, -1.6375f, + -0.3576f, 0.9413f, -0.5694f, 0.3954f, 0.1383f, + -0.7477f, -0.8689f, 1.8286f, 0.8510f, -1.4793f, + -0.1597f, 0.8541f, 0.2380f, 1.4392f, -0.5644f, + 0.3158f, -1.0686f, -0.1313f, -0.0181f, 0.2438f, + -0.8801f, 0.1413f, -0.3587f, 0.8002f, -0.5982f, + -1.4301f, -0.6620f, 0.7324f, -0.7250f, 0.0610f, + 0.9293f, -0.6902f, -0.0125f, -0.2089f, -0.1664f, + 0.5428f, 0.4245f, -0.7901f, 0.5665f, 0.9044f, + 0.1948f, -0.1723f, 1.2705f, 1.0303f, 1.2202f, + 1.3762f, -0.2959f, 0.7237f, -1.2077f, 0.7937f, + -0.6705f, 0.9287f, 1.0583f, 0.0496f, -1.3118f, + 0.5556f, 0.0459f, -0.1324f, -0.5513f, -0.7409f, + -1.8002f, 0.9892f, 0.3619f, -1.4522f}; + + RunTests(input_data, + cos_cache, + sin_cache, + position_ids, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingONNXTest, RotaryEmbeddingONNX_NotInterleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 3; + int head_size = 6; + int max_sequence_length = 4; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.8480f, 0.5266f, -1.2944f, -0.0243f, -0.2354f, -0.7087f, -0.9647f, -0.0991f, + -0.2994f, -0.0650f, -1.5720f, -1.3211f}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, 0.0043f, + 0.1411f, 0.1388f, 0.0065f}; + + std::vector position_ids = {0, 1}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -0.8618f, -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + -0.4377f, 0.5370f, -1.2929f, -0.7267f, -0.2107f, -0.7115f, -0.4666f, -0.0261f, + -0.2965f, -0.8469f, -1.5749f, -1.3217f}; + + RunTests(input_data, + cos_cache, + sin_cache, + position_ids, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +} // namespace test +} // namespace onnxruntime From 820dbbf23955431ddf524d974bcc7859c94f1812 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 29 Jan 2025 02:18:52 +0000 Subject: [PATCH 2/2] make position_ids optional --- .../cpu/onnx_std_exp/rotary_embedding_onnx.cc | 25 ++-- .../rotary_embedding_onnx_helper.h | 137 +++++++++++------- .../rotary_embedding_onnx_op_test.cc | 107 +++++++++++++- 3 files changed, 207 insertions(+), 62 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc index f37dcb6ea7179..a1eb093a72313 100644 --- a/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc +++ b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc @@ -69,13 +69,20 @@ Status RunRotaryEmbeddingONNX(concurrency::ThreadPool* tp, RotaryParameters para const T* input_data = input + block_offset; T* output_data = output + block_offset; - // Cache is (M, H/2) or (M, rotary_embedding_dim/2) - const int position_id = (position_ids_format == 0) - ? static_cast(position_ids[0]) + s - : static_cast(position_ids[b * sequence_length + s]); - const int cache_offset = position_id * half_rotary_emb_dim; - const T* cos_data = cos_cache + cache_offset; - const T* sin_data = sin_cache + cache_offset; + const T* cos_data; + const T* sin_data; + int cache_offset; + if (position_ids_format == -1) { + cache_offset = (b * sequence_length + s) * half_rotary_emb_dim; + } else { + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) + const int position_id = (position_ids_format == 0) + ? static_cast(position_ids[0]) + s + : static_cast(position_ids[b * sequence_length + s]); + cache_offset = position_id * half_rotary_emb_dim; + } + cos_data = cos_cache + cache_offset; + sin_data = sin_cache + cache_offset; MlasRotaryEmbedOneRow(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data); @@ -103,7 +110,7 @@ Status RotaryEmbeddingONNX::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* cos_cache = context->Input(1); const Tensor* sin_cache = context->Input(2); - const Tensor* position_ids = context->Input(3); + const Tensor* position_ids = context->Input(3); // position_ids are optional RotaryParameters parameters = {}; ORT_RETURN_IF_ERROR(rotary_embedding_onnx_helper::CheckInputs(input, @@ -124,7 +131,7 @@ Status RotaryEmbeddingONNX::Compute(OpKernelContext* context) const { const T* input_src = input->Data(); const T* cos_cache_data = cos_cache->Data(); const T* sin_cache_data = sin_cache->Data(); - const int64_t* pos_ids_data = position_ids->Data(); + const int64_t* pos_ids_data = (nullptr == position_ids) ? nullptr : position_ids->Data(); T* output_dest = output->MutableData(); AllocatorPtr allocator; diff --git a/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h index f826f78c5262d..784d3a9c931a5 100644 --- a/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h +++ b/onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx_helper.h @@ -44,34 +44,12 @@ Status CheckInputs(const T* input, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", input_dims.size()); } - // Check cos_cache and sin_cache - const auto& cos_cache_dims = cos_cache->Shape().GetDims(); - if (cos_cache_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", - cos_cache_dims.size()); - } - const auto& sin_cache_dims = sin_cache->Shape().GetDims(); - if (sin_cache_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", - sin_cache_dims.size()); - } - if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", - "the same shape"); - } - // Check position_ids - const auto& position_ids_dims = position_ids->Shape().GetDims(); - if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ", - "dimensions, got ", position_ids_dims.size()); - } // Check num_heads and rotary_embedding_dim if (rotary_embedding_dim > 0 && num_heads == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ", - "specified"); + "specified"); } - // Get attributes from inputs int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); @@ -83,41 +61,102 @@ Status CheckInputs(const T* input, hidden_size = static_cast(input_dims[2]) * static_cast(input_dims[3]); transposed = true; } - int max_sequence_length = static_cast(cos_cache_dims[0]); - int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2 - : static_cast(hidden_size / num_heads); - if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", - "head_size"); - } int position_ids_format = -1; + int max_sequence_length; + int head_size; - // Check position_ids input shapes - if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { - if (batch_size != static_cast(position_ids_dims[0])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", - "batch_size, got ", position_ids_dims[0]); + if (nullptr == position_ids) { + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 3 dimensions, got ", + cos_cache_dims.size()); } - if (sequence_length != static_cast(position_ids_dims[1])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", - "sequence_length, got ", position_ids_dims[1]); + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 3 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1] || cos_cache_dims[2] != sin_cache_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + + max_sequence_length = static_cast(cos_cache_dims[1]); + head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[2]) * 2 + : static_cast(hidden_size / num_heads); + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + // Check cos_cache input shapes + if (max_sequence_length != static_cast(cos_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", + "max_sequence_length, got ", cos_cache_dims[0]); + } + if ((head_size / 2) != static_cast(cos_cache_dims[2]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[2]))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 2 should be same as ", + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); } - position_ids_format = 1; } else { - position_ids_format = 0; - } + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", + cos_cache_dims.size()); + } + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + // Check position_ids + const auto& position_ids_dims = position_ids->Shape().GetDims(); + if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ", + "dimensions, got ", position_ids_dims.size()); + } - // Check cos_cache input shapes - if (max_sequence_length != static_cast(cos_cache_dims[0])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", - "max_sequence_length, got ", cos_cache_dims[0]); - } - if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", - "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); + max_sequence_length = static_cast(cos_cache_dims[0]); + head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2 + : static_cast(hidden_size / num_heads); + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + + // Check position_ids input shapes + if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { + if (batch_size != static_cast(position_ids_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", + "batch_size, got ", position_ids_dims[0]); + } + if (sequence_length != static_cast(position_ids_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", + "sequence_length, got ", position_ids_dims[1]); + } + position_ids_format = 1; + } else { + position_ids_format = 0; + } + + // Check cos_cache input shapes + if (max_sequence_length != static_cast(cos_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", + "max_sequence_length, got ", cos_cache_dims[0]); + } + if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); + } } + num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); // Calculate stride values int head_stride; diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_onnx_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_onnx_op_test.cc index 5badf821def9d..5cca419dd1f48 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_onnx_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_onnx_op_test.cc @@ -35,11 +35,18 @@ static void RunTest( int hidden_size = num_heads * head_size; std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector pos_dims; - std::vector cache_dims = {max_sequence_length, head_size / 2}; + std::vector cache_dims; + if (position_ids.size() != 0) { + cache_dims = {max_sequence_length, head_size / 2}; + } else { + cache_dims = {batch_size, sequence_length, head_size / 2}; + } assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); assert(max_sequence_length >= sequence_length); - if (position_ids.size() == 1) { + if (position_ids.size() == 0) { + pos_dims = {}; + } else if (position_ids.size() == 1) { pos_dims = {1}; } else { pos_dims = {batch_size, sequence_length}; @@ -69,13 +76,21 @@ static void RunTest( test.AddInput("input", input_dims, input_data); test.AddInput("cos_cache", cache_dims, cos_cache); test.AddInput("sin_cache", cache_dims, sin_cache); - test.AddInput("position_ids", pos_dims, position_ids); + if (position_ids.size()) { + test.AddInput("position_ids", pos_dims, position_ids); + } else { + test.AddOptionalInputEdge(); + } test.AddOutput("output", input_dims, output_data); } else { test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache)); test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache)); - test.AddInput("position_ids", pos_dims, position_ids); + if (position_ids.size()) { + test.AddInput("position_ids", pos_dims, position_ids); + } else { + test.AddOptionalInputEdge(); + } test.AddOutput("output", input_dims, ToFloat16(output_data)); } test.SetOutputAbsErr("output", 0.002f); @@ -629,5 +644,89 @@ TEST(RotaryEmbeddingONNXTest, RotaryEmbeddingONNX_NotInterleaved_SmallData_Llama interleaved); } +// Interleaved = false, pos ids = nullptr +TEST(RotaryEmbeddingONNXTest, RotaryEmbeddingONNX_NotInterleaved_NoPosIds_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 3; + int head_size = 6; + int max_sequence_length = 4; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.8480f, 0.5266f, -1.2944f, -0.0243f, -0.2354f, -0.7087f, -0.9647f, -0.0991f, + -0.2994f, -0.0650f, -1.5720f, -1.3211f}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f}; + + std::vector position_ids = {}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -0.8618f, -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + -0.4377f, 0.5370f, -1.2929f, -0.7267f, -0.2107f, -0.7115f, -0.4666f, -0.0261f, + -0.2965f, -0.8469f, -1.5749f, -1.3217f}; + + RunTests(input_data, + cos_cache, + sin_cache, + position_ids, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = true, pos ids = nullptr +TEST(RotaryEmbeddingONNXTest, RotaryEmbedding_ONNX_Interleaved_NoPosIds_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 3; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.2188f, 1.1676f, -1.0574f, -0.1188f, -0.7396f, -1.2425f, -0.1752f, 0.6990f, + -0.8110f, 0.6737f, -1.1233f, -0.0919f, -0.6861f, 0.7202f, 0.1963f, 0.6142f}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.6411f, -0.3948f, -1.0561f, -0.1294f, 0.6460f, -1.2937f, -0.1822f, 0.6972f, + -0.2751f, -1.0178f, -1.1212f, -0.1143f, -0.3694f, -0.9235f, 0.1840f, 0.6180f}; + + std::vector position_ids = {}; + + RunTests(input_data, + cos_cache, + sin_cache, + position_ids, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + } // namespace test } // namespace onnxruntime