Skip to content

Commit

Permalink
[GPU] Extend gemm to fuse unsqueeze layer (#23734)
Browse files Browse the repository at this point in the history
### Details:
- Follow up some comments from
#23513
 - Fuse `unsqueeze` layer into `gemm` layer for indirect gemm
    - before : [`kv_cache`] --> [`unsqueeze`] --> `gemm`
    - after : [`kv_cache`] --> `gemm`
 - Simplify fusion pass and logic as `unsqueeze` is fused together

### Tickets:
 - 136567

---------

Signed-off-by: Andrew Park <andrew.park@intel.com>
  • Loading branch information
andrew-k-park authored Apr 17, 2024
1 parent 6e961cd commit 4b9f92a
Show file tree
Hide file tree
Showing 14 changed files with 182 additions and 430 deletions.
23 changes: 0 additions & 23 deletions src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,12 @@ class Gemm : public ov::op::v0::MatMul {
const std::vector<int64_t>& order_c,
const ov::element::Type output_type = ov::element::undefined);

Gemm(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const std::vector<int32_t>& target_shape_a,
const std::vector<int32_t>& target_shape_b,
const std::vector<int64_t>& output_pattern_a,
const std::vector<int64_t>& output_pattern_b,
const std::vector<int64_t>& order_a,
const std::vector<int64_t>& order_b,
const std::vector<int64_t>& order_c,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor &visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::vector<int32_t> get_input0_broadcast_target_shape() const { return m_target_shape_a; }
std::vector<int32_t> get_input1_broadcast_target_shape() const { return m_target_shape_b; }
std::vector<int64_t> get_input0_reshape_pattern() const { return m_output_pattern_a; }
std::vector<int64_t> get_input1_reshape_pattern() const { return m_output_pattern_b; }
std::vector<int64_t> get_input0_transpose_order() const { return m_order_a; }
std::vector<int64_t> get_input1_transpose_order() const { return m_order_b; }
std::vector<int64_t> get_output_transpose_order() const { return m_order_c; }
Expand All @@ -59,10 +44,6 @@ class Gemm : public ov::op::v0::MatMul {
}

protected:
std::vector<int32_t> m_target_shape_a;
std::vector<int32_t> m_target_shape_b;
std::vector<int64_t> m_output_pattern_a;
std::vector<int64_t> m_output_pattern_b;
std::vector<int64_t> m_order_a;
std::vector<int64_t> m_order_b;
std::vector<int64_t> m_order_c;
Expand All @@ -71,10 +52,6 @@ class Gemm : public ov::op::v0::MatMul {

std::vector<ov::PartialShape> shape_infer(const Gemm* op,
std::vector<ov::PartialShape> input_shapes,
const std::vector<int32_t>& target_shape_a,
const std::vector<int32_t>& target_shape_b,
const std::vector<int64_t>& output_pattern_a,
const std::vector<int64_t>& output_pattern_b,
const std::vector<int64_t>& order_a,
const std::vector<int64_t>& order_b,
const std::vector<int64_t>& order_c);
Expand Down
36 changes: 0 additions & 36 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ struct gemm : public primitive_base<gemm> {
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
transpose_input0(transpose_input0 ? 1 : 0),
transpose_input1(transpose_input1 ? 1 : 0),
input0_broadcast_target_shape({}),
input1_broadcast_target_shape({}),
input0_reshape_pattern({}),
input1_reshape_pattern({}),
alpha(alpha),
beta(beta),
input_rank(input_rank),
Expand Down Expand Up @@ -90,21 +86,13 @@ struct gemm : public primitive_base<gemm> {
gemm(const primitive_id& id,
const std::vector<input_info>& inputs,
const data_types data_type,
const std::vector<int32_t>& input0_broadcast_target_shape = {},
const std::vector<int32_t>& input1_broadcast_target_shape = {},
const std::vector<int64_t>& input0_reshape_pattern = {},
const std::vector<int64_t>& input1_reshape_pattern = {},
const std::vector<int64_t>& input0_transpose_order = {0, 1, 2, 3},
const std::vector<int64_t>& input1_transpose_order = {0, 1, 2, 3},
const std::vector<int64_t>& output_transpose_order = {},
const float alpha = 1.0f,
const float beta = 0.0f,
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
input0_broadcast_target_shape(input0_broadcast_target_shape),
input1_broadcast_target_shape(input1_broadcast_target_shape),
input0_reshape_pattern(input0_reshape_pattern),
input1_reshape_pattern(input1_reshape_pattern),
input0_transpose_order(input0_transpose_order),
input1_transpose_order(input1_transpose_order),
output_transpose_order(output_transpose_order),
Expand Down Expand Up @@ -133,10 +121,6 @@ struct gemm : public primitive_base<gemm> {
const float beta = 0.0f,
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
input0_broadcast_target_shape({}),
input1_broadcast_target_shape({}),
input0_reshape_pattern({}),
input1_reshape_pattern({}),
input0_transpose_order(input0_transpose_order),
input1_transpose_order(input1_transpose_order),
output_transpose_order(output_transpose_order),
Expand All @@ -159,14 +143,6 @@ struct gemm : public primitive_base<gemm> {
uint32_t transpose_input0 = 0;
/// @brief Flag for transposing second input matrix
uint32_t transpose_input1 = 0;
/// @brief broadcasted target shape of input 0
std::vector<int32_t> input0_broadcast_target_shape;
/// @brief broadcasted target shape of input 1
std::vector<int32_t> input1_broadcast_target_shape;
/// @brief reshaped output pattern of input 0
std::vector<int64_t> input0_reshape_pattern;
/// @brief reshaped output pattern of input 1
std::vector<int64_t> input1_reshape_pattern;
/// @brief order of input 0
std::vector<int64_t> input0_transpose_order;
/// @brief order of input 1
Expand All @@ -193,10 +169,6 @@ struct gemm : public primitive_base<gemm> {
seed = hash_combine(seed, transpose_input1);
seed = hash_combine(seed, indirect_a);
seed = hash_combine(seed, indirect_b);
seed = hash_range(seed, input0_broadcast_target_shape.begin(), input0_broadcast_target_shape.end());
seed = hash_range(seed, input1_broadcast_target_shape.begin(), input1_broadcast_target_shape.end());
seed = hash_range(seed, input0_reshape_pattern.begin(), input0_reshape_pattern.end());
seed = hash_range(seed, input1_reshape_pattern.begin(), input1_reshape_pattern.end());
seed = hash_range(seed, input0_transpose_order.begin(), input0_transpose_order.end());
seed = hash_range(seed, input1_transpose_order.begin(), input1_transpose_order.end());
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end());
Expand Down Expand Up @@ -225,10 +197,6 @@ struct gemm : public primitive_base<gemm> {
primitive_base<gemm>::save(ob);
ob << transpose_input0;
ob << transpose_input1;
ob << input0_broadcast_target_shape;
ob << input1_broadcast_target_shape;
ob << input0_reshape_pattern;
ob << input1_reshape_pattern;
ob << input0_transpose_order;
ob << input1_transpose_order;
ob << output_transpose_order;
Expand All @@ -246,10 +214,6 @@ struct gemm : public primitive_base<gemm> {
primitive_base<gemm>::load(ib);
ib >> transpose_input0;
ib >> transpose_input1;
ib >> input0_broadcast_target_shape;
ib >> input1_broadcast_target_shape;
ib >> input0_reshape_pattern;
ib >> input1_reshape_pattern;
ib >> input0_transpose_order;
ib >> input1_transpose_order;
ib >> output_transpose_order;
Expand Down
56 changes: 4 additions & 52 deletions src/plugins/intel_gpu/src/graph/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,6 @@

#include "intel_gpu/op/gemm.hpp"

namespace {
template <typename T, typename DT, typename = typename std::enable_if<std::is_convertible<DT, T>::value>::type>
int find_index_from_vec(const std::vector<T>& vec, const DT value) {
int idx = 0;
for (auto v : vec) {
if (v != static_cast<T>(value))
break;
idx += 1;
}
return idx;
}
} // namespace
namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(gemm)

Expand Down Expand Up @@ -139,10 +127,6 @@ std::vector<layout> gemm_inst::calc_output_layouts(gemm_node const& node, const

std::vector<ShapeType> output_shapes = ov::intel_gpu::op::shape_infer(&op,
input_shapes,
prim->input0_broadcast_target_shape,
prim->input1_broadcast_target_shape,
prim->input0_reshape_pattern,
prim->input1_reshape_pattern,
prim->input0_transpose_order,
prim->input1_transpose_order,
prim->output_transpose_order);
Expand All @@ -158,28 +142,6 @@ template std::vector<layout> gemm_inst::calc_output_layouts<ov::PartialShape>(ge

std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<const gemm> primitive,
const std::vector<layout>& input_layouts) {
auto get_reshaped_input_shape = [&](const ov::PartialShape& input_pshape,
const std::vector<int32_t>& broadcast_target_shape,
const std::vector<int64_t>& reshape_pattern) {
ov::PartialShape reshaped_input_pshape;

if (broadcast_target_shape.size() > 0 && reshape_pattern.size() > 0) {
std::vector<ov::Dimension> dims(input_pshape);
int idx_recalc = find_index_from_vec(broadcast_target_shape, 1);
int idx_target = find_index_from_vec(reshape_pattern, 0);
if (dims[idx_recalc].is_static() && dims[idx_target].is_static()) {
dims[idx_recalc] *= dims[idx_target];
} else {
dims[idx_recalc] = ov::Dimension::dynamic();
}
dims.erase(dims.begin() + idx_target);
reshaped_input_pshape = ov::PartialShape(dims);
} else {
reshaped_input_pshape = input_pshape;
}
return reshaped_input_pshape;
};

auto get_transposed_input_shape = [&](const ov::PartialShape& input_pshape, size_t input_rank, size_t output_rank, bool transpose, bool first_input) {
ov::PartialShape transposed_input_pshape;

Expand Down Expand Up @@ -214,30 +176,20 @@ std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<con
return transposed_input_pshape;
};

auto reshaped_input0_pshape = get_reshaped_input_shape(input_layouts[0].get_partial_shape(),
primitive->input0_broadcast_target_shape,
primitive->input0_reshape_pattern);
auto reshaped_input1_pshape = get_reshaped_input_shape(input_layouts[1].get_partial_shape(),
primitive->input1_broadcast_target_shape,
primitive->input1_reshape_pattern);
auto input0_pshape = input_layouts[0].get_partial_shape();
auto input1_pshape = input_layouts[1].get_partial_shape();

bool reordered = primitive->input_rank > 4 || primitive->weight_rank > 4;
size_t output_rank = std::max(primitive->input_rank, primitive->weight_rank);
size_t input_rank = reordered ? output_rank : primitive->input_rank;
size_t weight_rank = reordered ? output_rank : primitive->weight_rank;

auto transposed_input0_pshape = get_transposed_input_shape(reshaped_input0_pshape, input_rank, output_rank, primitive->transpose_input0, true);
auto transposed_input1_pshape = get_transposed_input_shape(reshaped_input1_pshape, weight_rank, output_rank, primitive->transpose_input1, false);
auto transposed_input0_pshape = get_transposed_input_shape(input0_pshape, input_rank, output_rank, primitive->transpose_input0, true);
auto transposed_input1_pshape = get_transposed_input_shape(input1_pshape, weight_rank, output_rank, primitive->transpose_input1, false);

std::vector<layout> layouts = input_layouts;
layouts[0].set_partial_shape(transposed_input0_pshape);
if (primitive->input0_broadcast_target_shape.size() > input_rank) {
layouts[0].format = format::adjust_to_rank(layouts[0].format, input_rank);
}
layouts[1].set_partial_shape(transposed_input1_pshape);
if (primitive->input1_broadcast_target_shape.size() > weight_rank) {
layouts[1].format = format::adjust_to_rank(layouts[1].format, weight_rank);
}

if (primitive->input_size() == 3) {
auto bias_pshape = input_layouts[2].get_partial_shape();
Expand Down
42 changes: 38 additions & 4 deletions src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "intel_gpu/op/gemm.hpp"
#include "intel_gpu/plugin/common_utils.hpp"
#include "intel_gpu/graph/kernel_impl_params.hpp"
#include "multi_stage_primitive.hpp"

Expand Down Expand Up @@ -173,14 +175,46 @@ struct gemm_impl : multi_stage_primitive<gemm> {
params.beta = primitive->beta;
params.transpose_input0 = primitive->transpose_input0;
params.transpose_input1 = primitive->transpose_input1;
params.input0_target_shape = primitive->input0_broadcast_target_shape;
params.input1_target_shape = primitive->input1_broadcast_target_shape;
params.input0_output_pattern = primitive->input0_reshape_pattern;
params.input1_output_pattern = primitive->input0_reshape_pattern;
params.input0_order = primitive->input0_transpose_order;
params.input1_order = primitive->input1_transpose_order;
params.output_order = primitive->output_transpose_order;

auto input0_pshape = impl_param.input_layouts[0].get_partial_shape();
auto input1_pshape = impl_param.input_layouts[1].get_partial_shape();
const auto is_broadcastable = input0_pshape.rank().is_static() &&
input1_pshape.rank().is_static() &&
input0_pshape.size() > 1 &&
input1_pshape.size() > 1 &&
(primitive->input_rank == primitive->weight_rank);
if (is_broadcastable) {
auto transpose_pshape = [](const ov::PartialShape pshape, const std::vector<int64_t>& order) {
auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank());
for (size_t i = 0; i < order.size(); i++) {
transposed_pshape[i] = pshape[order[i]];
}
return transposed_pshape;
};
size_t max_rank = input0_pshape.size();
auto default_order = ov::intel_gpu::op::Gemm::default_order(max_rank);
auto input0_trans_pshape = (primitive->input0_transpose_order != default_order) ?
transpose_pshape(input0_pshape, primitive->input0_transpose_order) :
input0_pshape;
auto input1_trans_pshape = (primitive->input1_transpose_order != default_order) ?
transpose_pshape(input1_pshape, primitive->input1_transpose_order) :
input1_pshape;
for (size_t i = 0; i < max_rank - 2; ++i) {
if (input0_trans_pshape[i].is_static() && input1_trans_pshape[i].is_static()) {
if (input1_trans_pshape[i].get_length() > input0_trans_pshape[i].get_length()) {
params.input0_reshape_axes = primitive->input0_transpose_order[i];
params.input0_broadcast_val = input1_trans_pshape[i].get_length() / input0_trans_pshape[i].get_length();
} else if (input0_trans_pshape[i].get_length() > input1_trans_pshape[i].get_length()) {
params.input1_reshape_axes = primitive->input1_transpose_order[i];
params.input1_broadcast_val = input0_trans_pshape[i].get_length() / input1_trans_pshape[i].get_length();
}
}
}
}

params.indirect_input0 = primitive->indirect_a && indirect;
params.indirect_input1 = primitive->indirect_b && indirect;
if (indirect && (primitive->indirect_a || primitive->indirect_b)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,41 +215,37 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const {
jit.AddConstant(MakeJitConstant("BIAS_TERM", 1));
}

auto get_broadcast_input_str = [](const std::vector<int32_t>& target_shape) {
const size_t target_rank = target_shape.size();
auto get_broadcast_input_str = [](const size_t input_rank, const int64_t axes, const int64_t val) {
std::vector<std::string> dims;
if (target_rank == 1) {
if (input_rank == 1) {
dims = {"x"};
} else if (target_rank == 2) {
} else if (input_rank == 2) {
dims = {"y", "x"};
} else if (target_rank == 3) {
} else if (input_rank == 3) {
dims = {"f", "y", "x"};
} else if (target_rank == 4) {
} else if (input_rank == 4) {
dims = {"b", "f", "y", "x"};
} else if (target_rank == 5) {
} else if (input_rank == 5) {
dims = {"b", "f", "z", "y", "x"};
} else if (target_rank == 6) {
} else if (input_rank == 6) {
dims = {"b", "f", "w", "z", "y", "x"};
}
int pos = 0;
for (auto ts : target_shape) {
if (ts != 1)
break;
pos += 1;
}
std::string str = dims[pos] + " /= " + std::to_string(target_shape[pos]) + ";";
return str;
return dims[axes] + " /= " + std::to_string(val) + ";";
};
if (params.input0_target_shape.size() > 1) {
if (params.input0_broadcast_val != 0) {
jit.AddConstants({
MakeJitConstant("BROADCAST_INPUT0", true),
MakeJitConstant("DO_BROADCAST_INPUT0", get_broadcast_input_str(params.input0_target_shape)),
MakeJitConstant("DO_BROADCAST_INPUT0", get_broadcast_input_str(params.inputs[0].GetDims().size(),
params.input0_reshape_axes,
params.input0_broadcast_val)),
});
}
if (params.input1_target_shape.size() > 1) {
if (params.input1_broadcast_val != 0) {
jit.AddConstants({
MakeJitConstant("BROADCAST_INPUT1", true),
MakeJitConstant("DO_BROADCAST_INPUT1", get_broadcast_input_str(params.input1_target_shape)),
MakeJitConstant("DO_BROADCAST_INPUT1", get_broadcast_input_str(params.inputs[1].GetDims().size(),
params.input1_reshape_axes,
params.input1_broadcast_val)),
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ struct gemm_params : public base_params {
float beta;
uint32_t transpose_input0;
uint32_t transpose_input1;
std::vector<int32_t> input0_target_shape;
std::vector<int32_t> input1_target_shape;
std::vector<int64_t> input0_output_pattern;
std::vector<int64_t> input1_output_pattern;
std::vector<int64_t> input0_order;
std::vector<int64_t> input1_order;
std::vector<int64_t> output_order;
int64_t input0_reshape_axes = 0;
int64_t input1_reshape_axes = 0;
int64_t input0_broadcast_val = 0;
int64_t input1_broadcast_val = 0;
DataTensor beam_table;
bool indirect_input0 = false;
bool indirect_input1 = false;
Expand Down
Loading

0 comments on commit 4b9f92a

Please sign in to comment.