Skip to content

Commit

Permalink
apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-k-park committed Mar 20, 2024
1 parent a748572 commit 0de05f9
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 162 deletions.
14 changes: 7 additions & 7 deletions src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ class Gemm : public ov::op::v0::MatMul {

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::vector<int32_t> get_input0_target_shape() const { return m_target_shape_a; }
std::vector<int32_t> get_input1_target_shape() const { return m_target_shape_b; }
std::vector<int64_t> get_input0_output_pattern() const { return m_output_pattern_a; }
std::vector<int64_t> get_input1_output_pattern() const { return m_output_pattern_b; }
std::vector<int64_t> get_input0_order() const { return m_order_a; }
std::vector<int64_t> get_input1_order() const { return m_order_b; }
std::vector<int64_t> get_output_order() const { return m_order_c; }
std::vector<int32_t> get_input0_broadcast_target_shape() const { return m_target_shape_a; }
std::vector<int32_t> get_input1_broadcast_target_shape() const { return m_target_shape_b; }
std::vector<int64_t> get_input0_reshape_pattern() const { return m_output_pattern_a; }
std::vector<int64_t> get_input1_reshape_pattern() const { return m_output_pattern_b; }
std::vector<int64_t> get_input0_transpose_order() const { return m_order_a; }
std::vector<int64_t> get_input1_transpose_order() const { return m_order_b; }
std::vector<int64_t> get_output_transpose_order() const { return m_order_c; }
ov::element::Type get_output_type() const { return m_output_type; }

static std::vector<int64_t> default_order(size_t rank) {
Expand Down
134 changes: 67 additions & 67 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ struct gemm : public primitive_base<gemm> {
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
transpose_input0(transpose_input0 ? 1 : 0),
transpose_input1(transpose_input1 ? 1 : 0),
input0_target_shape({}),
input1_target_shape({}),
input0_output_pattern({}),
input1_output_pattern({}),
input0_broadcast_target_shape({}),
input1_broadcast_target_shape({}),
input0_reshape_pattern({}),
input1_reshape_pattern({}),
alpha(alpha),
beta(beta),
input_rank(input_rank),
Expand All @@ -74,9 +74,9 @@ struct gemm : public primitive_base<gemm> {
return order;
};

input0_order = get_transposed_order(input_rank, transpose_input0);
input1_order = get_transposed_order(weight_rank, transpose_input1);
output_order = {};
input0_tranpose_order = get_transposed_order(input_rank, transpose_input0);
input1_transpose_order = get_transposed_order(weight_rank, transpose_input1);
output_transpose_order = {};
}

/// @brief Constructs gemm layer.
Expand All @@ -90,89 +90,89 @@ struct gemm : public primitive_base<gemm> {
gemm(const primitive_id& id,
const std::vector<input_info>& inputs,
const data_types data_type,
const std::vector<int32_t>& input0_target_shape = {},
const std::vector<int32_t>& input1_target_shape = {},
const std::vector<int64_t>& input0_output_pattern = {},
const std::vector<int64_t>& input1_output_pattern = {},
const std::vector<int64_t>& input0_order = {0, 1, 2, 3},
const std::vector<int64_t>& input1_order = {0, 1, 2, 3},
const std::vector<int64_t>& output_order = {},
const std::vector<int32_t>& input0_broadcast_target_shape = {},
const std::vector<int32_t>& input1_broadcast_target_shape = {},
const std::vector<int64_t>& input0_reshape_pattern = {},
const std::vector<int64_t>& input1_reshape_pattern = {},
const std::vector<int64_t>& input0_tranpose_order = {0, 1, 2, 3},
const std::vector<int64_t>& input1_transpose_order = {0, 1, 2, 3},
const std::vector<int64_t>& output_transpose_order = {},
const float alpha = 1.0f,
const float beta = 0.0f,
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
input0_target_shape(input0_target_shape),
input1_target_shape(input1_target_shape),
input0_output_pattern(input0_output_pattern),
input1_output_pattern(input1_output_pattern),
input0_order(input0_order),
input1_order(input1_order),
output_order(output_order),
input0_broadcast_target_shape(input0_broadcast_target_shape),
input1_broadcast_target_shape(input1_broadcast_target_shape),
input0_reshape_pattern(input0_reshape_pattern),
input1_reshape_pattern(input1_reshape_pattern),
input0_tranpose_order(input0_tranpose_order),
input1_transpose_order(input1_transpose_order),
output_transpose_order(output_transpose_order),
alpha(alpha),
beta(beta),
input_rank(input0_order.size()),
weight_rank(input1_order.size()) {
input_rank(input0_tranpose_order.size()),
weight_rank(input1_transpose_order.size()) {
if (inputs.size() != 2 && inputs.size() != 3) {
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
}

transpose_input0 = get_transpose_mode(input0_order);
transpose_input1 = get_transpose_mode(input1_order);
transpose_input0 = get_transpose_mode(input0_tranpose_order);
transpose_input1 = get_transpose_mode(input1_transpose_order);
}

gemm(const primitive_id& id,
const std::vector<input_info>& inputs,
const input_info& beam_table,
const data_types data_type,
const std::vector<int64_t>& input0_order,
const std::vector<int64_t>& input1_order,
const std::vector<int64_t>& output_order,
const std::vector<int64_t>& input0_tranpose_order,
const std::vector<int64_t>& input1_transpose_order,
const std::vector<int64_t>& output_transpose_order,
bool indirect_a,
bool indirect_b,
const float alpha = 1.0f,
const float beta = 0.0f,
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
input0_target_shape({}),
input1_target_shape({}),
input0_output_pattern({}),
input1_output_pattern({}),
input0_order(input0_order),
input1_order(input1_order),
output_order(output_order),
input0_broadcast_target_shape({}),
input1_broadcast_target_shape({}),
input0_reshape_pattern({}),
input1_reshape_pattern({}),
input0_tranpose_order(input0_tranpose_order),
input1_transpose_order(input1_transpose_order),
output_transpose_order(output_transpose_order),
alpha(alpha),
beta(beta),
input_rank(input0_order.size()),
weight_rank(input1_order.size()),
input_rank(input0_tranpose_order.size()),
weight_rank(input1_transpose_order.size()),
beam_table(beam_table),
indirect_a(indirect_a),
indirect_b(indirect_b) {
if (inputs.size() != 2 && inputs.size() != 3) {
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
}

transpose_input0 = get_transpose_mode(input0_order);
transpose_input1 = get_transpose_mode(input1_order);
transpose_input0 = get_transpose_mode(input0_tranpose_order);
transpose_input1 = get_transpose_mode(input1_transpose_order);
}

/// @brief Flag for transposing first input matrix
uint32_t transpose_input0 = 0;
/// @brief Flag for transposing second input matrix
uint32_t transpose_input1 = 0;
/// @brief broadcasted target shape of input 0
std::vector<int32_t> input0_target_shape;
std::vector<int32_t> input0_broadcast_target_shape;
/// @brief broadcasted target shape of input 1
std::vector<int32_t> input1_target_shape;
std::vector<int32_t> input1_broadcast_target_shape;
/// @brief reshaped output pattern of input 0
std::vector<int64_t> input0_output_pattern;
std::vector<int64_t> input0_reshape_pattern;
/// @brief reshaped output pattern of input 1
std::vector<int64_t> input1_output_pattern;
std::vector<int64_t> input1_reshape_pattern;
/// @brief order of input 0
std::vector<int64_t> input0_order;
std::vector<int64_t> input0_tranpose_order;
/// @brief order of input 1
std::vector<int64_t> input1_order;
std::vector<int64_t> input1_transpose_order;
/// @brief order of output
std::vector<int64_t> output_order;
std::vector<int64_t> output_transpose_order;
/// @brief Variable containing ALPHA parameter
float alpha = 1.0f;
/// @brief Variable containing BETA parameter
Expand All @@ -193,13 +193,13 @@ struct gemm : public primitive_base<gemm> {
seed = hash_combine(seed, transpose_input1);
seed = hash_combine(seed, indirect_a);
seed = hash_combine(seed, indirect_b);
seed = hash_range(seed, input0_target_shape.begin(), input0_target_shape.end());
seed = hash_range(seed, input1_target_shape.begin(), input1_target_shape.end());
seed = hash_range(seed, input0_output_pattern.begin(), input0_output_pattern.end());
seed = hash_range(seed, input1_output_pattern.begin(), input1_output_pattern.end());
seed = hash_range(seed, input0_order.begin(), input0_order.end());
seed = hash_range(seed, input1_order.begin(), input1_order.end());
seed = hash_range(seed, output_order.begin(), output_order.end());
seed = hash_range(seed, input0_broadcast_target_shape.begin(), input0_broadcast_target_shape.end());
seed = hash_range(seed, input1_broadcast_target_shape.begin(), input1_broadcast_target_shape.end());
seed = hash_range(seed, input0_reshape_pattern.begin(), input0_reshape_pattern.end());
seed = hash_range(seed, input1_reshape_pattern.begin(), input1_reshape_pattern.end());
seed = hash_range(seed, input0_tranpose_order.begin(), input0_tranpose_order.end());
seed = hash_range(seed, input1_transpose_order.begin(), input1_transpose_order.end());
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end());
seed = hash_combine(seed, alpha);
seed = hash_combine(seed, beta);
return seed;
Expand All @@ -225,13 +225,13 @@ struct gemm : public primitive_base<gemm> {
primitive_base<gemm>::save(ob);
ob << transpose_input0;
ob << transpose_input1;
ob << input0_target_shape;
ob << input1_target_shape;
ob << input0_output_pattern;
ob << input1_output_pattern;
ob << input0_order;
ob << input1_order;
ob << output_order;
ob << input0_broadcast_target_shape;
ob << input1_broadcast_target_shape;
ob << input0_reshape_pattern;
ob << input1_reshape_pattern;
ob << input0_tranpose_order;
ob << input1_transpose_order;
ob << output_transpose_order;
ob << alpha;
ob << beta;
ob << input_rank;
Expand All @@ -246,13 +246,13 @@ struct gemm : public primitive_base<gemm> {
primitive_base<gemm>::load(ib);
ib >> transpose_input0;
ib >> transpose_input1;
ib >> input0_target_shape;
ib >> input1_target_shape;
ib >> input0_output_pattern;
ib >> input1_output_pattern;
ib >> input0_order;
ib >> input1_order;
ib >> output_order;
ib >> input0_broadcast_target_shape;
ib >> input1_broadcast_target_shape;
ib >> input0_reshape_pattern;
ib >> input1_reshape_pattern;
ib >> input0_tranpose_order;
ib >> input1_transpose_order;
ib >> output_transpose_order;
ib >> alpha;
ib >> beta;
ib >> input_rank;
Expand Down
Loading

0 comments on commit 0de05f9

Please sign in to comment.