Skip to content

Commit

Permalink
add support for SPV_INTEL_subgroup_matrix_multiply_accumulate (#5928)
Browse files Browse the repository at this point in the history
* add support for SPV_INTEL_subgroup_matrix_multiply_accumulate

* Update DEPS

---------

Co-authored-by: Alan Baker <alanbaker@google.com>
  • Loading branch information
bashbaug and alan-baker authored Jan 16, 2025
1 parent a6107ed commit f942f65
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 4 deletions.
2 changes: 1 addition & 1 deletion DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ vars = {

're2_revision': '6dcd83d60f7944926bfd308cc13979fc53dd69ca',

'spirv_headers_revision': '0659679d9648a4dfdb5513efe25c495a3712dbf4',
'spirv_headers_revision': '2b2e05e088841c63c0b6fd4c9fb380d8688738d3',
}

deps = {
Expand Down
3 changes: 3 additions & 0 deletions include/spirv-tools/libspirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ typedef enum spv_operand_type_t {
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE,
// Enum type from SPV_NV_cooperative_matrix2
SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS,
// Optional types from SPV_INTEL_subgroup_matrix_multiply_accumulate
SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS,
SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS,

// This is a sentinel value, and does not represent an operand type.
// It should come last.
Expand Down
7 changes: 6 additions & 1 deletion source/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,9 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS: {
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
case SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS: {
// This operand is a mask.

// Map an optional operand type to its corresponding concrete type.
Expand All @@ -738,6 +740,9 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
parsed_operand.type = SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS;
if (type == SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS)
parsed_operand.type = SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS;
if (type == SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS)
parsed_operand.type =
SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS;

// Check validity of set mask bits. Also prepare for operands for those
// masks if they have any. To get operand order correct, scan from
Expand Down
5 changes: 5 additions & 0 deletions source/operand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
return "cooperative matrix reduce";
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
return "tensor addressing operands";
case SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
return "matrix multiply accumulate operands";
case SPV_OPERAND_TYPE_INITIALIZATION_MODE_QUALIFIER:
return "initialization mode qualifier";
case SPV_OPERAND_TYPE_HOST_ACCESS_QUALIFIER:
Expand Down Expand Up @@ -415,6 +418,7 @@ bool spvOperandIsConcreteMask(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
case SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
Expand All @@ -437,6 +441,7 @@ bool spvOperandIsOptional(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_CIV:
case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_FPENCODING:
Expand Down
3 changes: 2 additions & 1 deletion source/text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ spv_result_t spvTextEncodeOperand(const spvtools::AssemblyGrammar& grammar,
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE: {
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS: {
uint32_t value;
if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
return context->diagnostic(error)
Expand Down
27 changes: 27 additions & 0 deletions test/binary_to_text_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,33 @@ INSTANTIATE_TEST_SUITE_P(
"OpDecorate %1 HostAccessINTEL ReadWriteINTEL \"readwrite\"\n",
})));

// clang-format off
INSTANTIATE_TEST_SUITE_P(
MatrixMultiplyAccumulateOperands, RoundTripInstructionsTest,
Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0),
::testing::ValuesIn(std::vector<std::string>{
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 None\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixASignedComponentsINTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBSignedComponentsINTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixCBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixResultBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixAPackedInt8INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBPackedInt8INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixAPackedInt4INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBPackedInt4INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixATF32INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBTF32INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixCBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixAPackedFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBPackedFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixAPackedBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBPackedBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 "
"MatrixASignedComponentsINTEL|MatrixBSignedComponentsINTEL|MatrixAPackedInt8INTEL|MatrixBPackedInt8INTEL\n",
})));
// clang-format on

using MaskSorting = TextToBinaryTest;

TEST_F(MaskSorting, MasksAreSortedFromLSBToMSB) {
Expand Down
2 changes: 1 addition & 1 deletion utils/generate_grammar_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def generate_operand_kind_table(enums):

# We have a few operand kinds that require their optional counterpart to
# exist in the operand info table.
optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat', 'CooperativeMatrixOperands', 'RawAccessChainOperands', 'FPEncoding']
optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat', 'CooperativeMatrixOperands', 'MatrixMultiplyAccumulateOperands', 'RawAccessChainOperands', 'FPEncoding']
optional_enums = [e for e in enums if e[0] in optional_enums]
enums.extend(optional_enums)

Expand Down

0 comments on commit f942f65

Please sign in to comment.