@@ -19,24 +19,31 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
19
19
scaled_dot_product_attention (const primitive_id& id,
20
20
const std::vector<cldnn::input_info> inputs,
21
21
bool is_causal,
22
+ int64_t indirect_axis = -1 ,
22
23
const std::vector<int64_t >& input_q_transpose_order = {},
23
24
const std::vector<int64_t >& input_k_transpose_order = {},
24
25
const std::vector<int64_t >& input_v_transpose_order = {},
25
26
const std::vector<int64_t >& output_transpose_order = {},
26
27
const padding& output_padding = padding())
27
28
: primitive_base(id, inputs, {output_padding})
28
29
, is_causal(is_causal)
29
- , has_attn_mask_input(inputs.size() > 3 )
30
- , has_scale_input(inputs.size() > 4 )
30
+ , indirect_axis(indirect_axis)
31
31
, input_q_transpose_order(input_q_transpose_order)
32
32
, input_k_transpose_order(input_k_transpose_order)
33
33
, input_v_transpose_order(input_v_transpose_order)
34
- , output_transpose_order(output_transpose_order) {}
34
+ , output_transpose_order(output_transpose_order) {
35
+ auto data_inputs_num = inputs.size ();
36
+ if (indirect_axis != -1 )
37
+ data_inputs_num--;
35
38
39
+ has_attn_mask_input = data_inputs_num > 3 ;
40
+ has_scale_input = data_inputs_num > 4 ;
41
+ }
36
42
37
43
bool is_causal = false ;
38
44
bool has_attn_mask_input = false ;
39
45
bool has_scale_input = false ;
46
+ int64_t indirect_axis = -1 ;
40
47
41
48
std::vector<int64_t > input_q_transpose_order;
42
49
std::vector<int64_t > input_k_transpose_order;
@@ -48,6 +55,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
48
55
seed = hash_combine (seed, is_causal);
49
56
seed = hash_combine (seed, has_attn_mask_input);
50
57
seed = hash_combine (seed, has_scale_input);
58
+ seed = hash_combine (seed, indirect_axis);
51
59
seed = hash_range (seed, input_q_transpose_order.begin (), input_q_transpose_order.end ());
52
60
seed = hash_range (seed, input_k_transpose_order.begin (), input_k_transpose_order.end ());
53
61
seed = hash_range (seed, input_v_transpose_order.begin (), input_v_transpose_order.end ());
@@ -64,6 +72,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
64
72
return is_causal == rhs_casted.is_causal &&
65
73
has_attn_mask_input == rhs_casted.has_attn_mask_input &&
66
74
has_scale_input == rhs_casted.has_scale_input &&
75
+ indirect_axis == rhs_casted.indirect_axis &&
67
76
input_q_transpose_order == rhs_casted.input_q_transpose_order &&
68
77
input_k_transpose_order == rhs_casted.input_k_transpose_order &&
69
78
input_v_transpose_order == rhs_casted.input_v_transpose_order &&
@@ -75,6 +84,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
75
84
ob << is_causal;
76
85
ob << has_attn_mask_input;
77
86
ob << has_scale_input;
87
+ ob << indirect_axis;
78
88
ob << input_q_transpose_order;
79
89
ob << input_k_transpose_order;
80
90
ob << input_v_transpose_order;
@@ -86,6 +96,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
86
96
ib >> is_causal;
87
97
ib >> has_attn_mask_input;
88
98
ib >> has_scale_input;
99
+ ib >> indirect_axis;
89
100
ib >> input_q_transpose_order;
90
101
ib >> input_k_transpose_order;
91
102
ib >> input_v_transpose_order;
0 commit comments