Skip to content

Commit 141a5ad

Browse files
Merge branch 'master' into new_inputs_rank_align
2 parents a281e79 + ebf4301 commit 141a5ad

26 files changed

+866
-115
lines changed

src/core/include/openvino/op/scaled_dot_product_attention.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class OPENVINO_API ScaledDotProductAttention : public Op {
5050
return m_causal;
5151
}
5252

53+
void set_causal(bool causal) {
54+
m_causal = causal;
55+
}
56+
5357
private:
5458
bool m_causal = false;
5559
};

src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp

+28-4
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,7 @@ struct ConvertPrecision<std::tuple<src_t, ov::float16>> {
382382
src_t lbound, ubound;
383383
std::tie(lbound, ubound) = ctx.range<src_t>();
384384

385-
if (std::is_integral<src_t>::value
386-
|| ctx.interimPrc.is_real()) {
385+
if (std::is_integral<src_t>::value) {
387386
parallel_for(iterations, [&](size_t i) {
388387
batch_type tmp;
389388
const size_t offset = i * batch;
@@ -392,6 +391,19 @@ struct ConvertPrecision<std::tuple<src_t, ov::float16>> {
392391
tmp[j] = static_cast<float>(std::max(std::min(src[offset + j], ubound), lbound));
393392
jit_convert(tmp, dst + offset, current_batch_size); // fp32 -> fp16
394393
});
394+
} else if (ctx.interimPrc.is_real()) {
395+
parallel_for(iterations, [&](size_t i) {
396+
const size_t offset = i * batch;
397+
const size_t current_batch_size = std::min(ctx.size - offset, batch);
398+
if (std::is_same<typename std::remove_cv<src_t>::type, float>::value) { // fp32 -> fp16
399+
jit_convert(reinterpret_cast<const float *>(src) + offset, dst + offset, current_batch_size);
400+
} else {
401+
batch_type tmp;
402+
for (size_t j = 0; j < current_batch_size; ++j) // src_t -> fp32
403+
tmp[j] = static_cast<float>(src[offset + j]);
404+
jit_convert(tmp, dst + offset, current_batch_size); // fp32 -> fp16
405+
}
406+
});
395407
} else {
396408
parallel_for(iterations, [&](size_t i) {
397409
batch_type tmp;
@@ -420,8 +432,7 @@ struct ConvertPrecision<std::tuple<ov::float16, dst_t>> {
420432
float lbound, ubound;
421433
std::tie(lbound, ubound) = ctx.range<ov::float16>();
422434

423-
if (ctx.interimPrc.is_real()
424-
|| std::is_integral<dst_t>::value) {
435+
if (std::is_integral<dst_t>::value) {
425436
parallel_for(iterations, [&](size_t i) {
426437
batch_type tmp;
427438
const size_t offset = i * batch;
@@ -430,6 +441,19 @@ struct ConvertPrecision<std::tuple<ov::float16, dst_t>> {
430441
for (size_t j = 0; j < current_batch_size; ++j) // fp32 -> dst_t
431442
dst[offset + j] = static_cast<dst_t>(std::max(std::min(tmp[j], ubound), lbound));
432443
});
444+
} else if (ctx.interimPrc.is_real()) {
445+
parallel_for(iterations, [&](size_t i) {
446+
const size_t offset = i * batch;
447+
const size_t current_batch_size = std::min(ctx.size - offset, batch);
448+
if (std::is_same<typename std::remove_cv<dst_t>::type, float>::value) { // fp16 -> fp32
449+
jit_convert(src + offset, reinterpret_cast<float *>(dst) + offset, current_batch_size);
450+
} else {
451+
batch_type tmp;
452+
jit_convert(src + offset, tmp, current_batch_size); // fp16 -> fp32
453+
for (size_t j = 0; j < current_batch_size; ++j) // fp32 -> dst_t
454+
dst[offset + j] = static_cast<dst_t>(tmp[j]);
455+
}
456+
});
433457
} else {
434458
parallel_for(iterations, [&](size_t i) {
435459
batch_type tmp;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "intel_gpu/op/sdpa.hpp"
8+
#include "openvino/core/node.hpp"
9+
#include "openvino/core/partial_shape.hpp"
10+
#include "openvino/op/op.hpp"
11+
12+
namespace ov {
13+
namespace intel_gpu {
14+
namespace op {
15+
16+
class IndirectSDPA : public ov::intel_gpu::op::SDPA {
17+
public:
18+
OPENVINO_OP("IndirectSDPA", "gpu_opset");
19+
20+
IndirectSDPA() = default;
21+
22+
IndirectSDPA(const ov::Output<Node>& Q,
23+
const ov::Output<Node>& K,
24+
const ov::Output<Node>& V,
25+
const ov::Output<Node>& beam_table,
26+
const bool is_causal,
27+
const int64_t indirect_axis,
28+
const std::vector<int64_t>& order_q,
29+
const std::vector<int64_t>& order_k,
30+
const std::vector<int64_t>& order_v,
31+
const std::vector<int64_t>& order_out,
32+
const ov::element::Type output_type = ov::element::undefined);
33+
34+
IndirectSDPA(const ov::Output<Node>& Q,
35+
const ov::Output<Node>& K,
36+
const ov::Output<Node>& V,
37+
const ov::Output<Node>& attn_mask,
38+
const ov::Output<Node>& beam_table,
39+
const bool is_causal,
40+
const int64_t indirect_axis,
41+
const std::vector<int64_t>& order_q,
42+
const std::vector<int64_t>& order_k,
43+
const std::vector<int64_t>& order_v,
44+
const std::vector<int64_t>& order_out,
45+
const ov::element::Type output_type = ov::element::undefined);
46+
47+
IndirectSDPA(const ov::Output<Node>& Q,
48+
const ov::Output<Node>& K,
49+
const ov::Output<Node>& V,
50+
const ov::Output<Node>& attn_mask,
51+
const ov::Output<Node>& scale,
52+
const ov::Output<Node>& beam_table,
53+
const bool is_causal,
54+
const int64_t indirect_axis,
55+
const std::vector<int64_t>& order_q,
56+
const std::vector<int64_t>& order_k,
57+
const std::vector<int64_t>& order_v,
58+
const std::vector<int64_t>& order_out,
59+
const ov::element::Type output_type = ov::element::undefined);
60+
61+
bool visit_attributes(ov::AttributeVisitor &visitor) override;
62+
void validate_and_infer_types() override;
63+
64+
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
65+
66+
ov::element::Type get_output_type() const { return m_output_type; }
67+
68+
int64_t get_indirect_axis() const { return m_indirect_axis; }
69+
70+
using ov::intel_gpu::op::SDPA::default_order;
71+
72+
protected:
73+
int64_t m_indirect_axis = -1;
74+
};
75+
76+
} // namespace op
77+
} // namespace intel_gpu
78+
} // namespace ov

src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,4 @@ REGISTER_FACTORY(internal, IndirectGemm);
285285
REGISTER_FACTORY(internal, Convolution);
286286
REGISTER_FACTORY(internal, Placeholder);
287287
REGISTER_FACTORY(internal, SDPA);
288+
REGISTER_FACTORY(internal, IndirectSDPA);

src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,31 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
1919
scaled_dot_product_attention(const primitive_id& id,
2020
const std::vector<cldnn::input_info> inputs,
2121
bool is_causal,
22+
int64_t indirect_axis = -1,
2223
const std::vector<int64_t>& input_q_transpose_order = {},
2324
const std::vector<int64_t>& input_k_transpose_order = {},
2425
const std::vector<int64_t>& input_v_transpose_order = {},
2526
const std::vector<int64_t>& output_transpose_order = {},
2627
const padding& output_padding = padding())
2728
: primitive_base(id, inputs, {output_padding})
2829
, is_causal(is_causal)
29-
, has_attn_mask_input(inputs.size() > 3)
30-
, has_scale_input(inputs.size() > 4)
30+
, indirect_axis(indirect_axis)
3131
, input_q_transpose_order(input_q_transpose_order)
3232
, input_k_transpose_order(input_k_transpose_order)
3333
, input_v_transpose_order(input_v_transpose_order)
34-
, output_transpose_order(output_transpose_order) {}
34+
, output_transpose_order(output_transpose_order) {
35+
auto data_inputs_num = inputs.size();
36+
if (indirect_axis != -1)
37+
data_inputs_num--;
3538

39+
has_attn_mask_input = data_inputs_num > 3;
40+
has_scale_input = data_inputs_num > 4;
41+
}
3642

3743
bool is_causal = false;
3844
bool has_attn_mask_input = false;
3945
bool has_scale_input = false;
46+
int64_t indirect_axis = -1;
4047

4148
std::vector<int64_t> input_q_transpose_order;
4249
std::vector<int64_t> input_k_transpose_order;
@@ -48,6 +55,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
4855
seed = hash_combine(seed, is_causal);
4956
seed = hash_combine(seed, has_attn_mask_input);
5057
seed = hash_combine(seed, has_scale_input);
58+
seed = hash_combine(seed, indirect_axis);
5159
seed = hash_range(seed, input_q_transpose_order.begin(), input_q_transpose_order.end());
5260
seed = hash_range(seed, input_k_transpose_order.begin(), input_k_transpose_order.end());
5361
seed = hash_range(seed, input_v_transpose_order.begin(), input_v_transpose_order.end());
@@ -64,6 +72,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
6472
return is_causal == rhs_casted.is_causal &&
6573
has_attn_mask_input == rhs_casted.has_attn_mask_input &&
6674
has_scale_input == rhs_casted.has_scale_input &&
75+
indirect_axis == rhs_casted.indirect_axis &&
6776
input_q_transpose_order == rhs_casted.input_q_transpose_order &&
6877
input_k_transpose_order == rhs_casted.input_k_transpose_order &&
6978
input_v_transpose_order == rhs_casted.input_v_transpose_order &&
@@ -75,6 +84,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
7584
ob << is_causal;
7685
ob << has_attn_mask_input;
7786
ob << has_scale_input;
87+
ob << indirect_axis;
7888
ob << input_q_transpose_order;
7989
ob << input_k_transpose_order;
8090
ob << input_v_transpose_order;
@@ -86,6 +96,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
8696
ib >> is_causal;
8797
ib >> has_attn_mask_input;
8898
ib >> has_scale_input;
99+
ib >> indirect_axis;
89100
ib >> input_q_transpose_order;
90101
ib >> input_k_transpose_order;
91102
ib >> input_v_transpose_order;

src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class debug_configuration {
129129
std::vector<std::string> forced_impl_types; // Force implementation type either ocl or onednn
130130
int max_kernels_per_batch; // Maximum number of kernels in a batch during compiling kernels
131131
int impls_cache_capacity; // The maximum number of entries in the kernel impl cache
132+
int enable_sdpa; // Allows to control SDPA decomposition
132133
int disable_async_compilation; // Disable async compilation
133134
int disable_winograd_conv; // Disable Winograd conv
134135
int disable_dynamic_impl; // Disable dynamic implementation

0 commit comments

Comments
 (0)