Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 8ffc2d9

Browse files
Mahesh Ravishankarjpienaar
authored andcommitted
Enable (de)serialization support for spirv::AccessChainOp
Automatic generation of spirv::AccessChainOp (de)serialization needs the (de)serialization emitters to handle argument specified as Variadic<...>. To handle this correctly, this argument can only be the last entry in the arguments list. Add a test to (de)serialize spirv::AccessChainOp PiperOrigin-RevId: 260532598
1 parent a8d42cc commit 8ffc2d9

File tree

3 files changed

+45
-13
lines changed

3 files changed

+45
-13
lines changed

include/mlir/Dialect/SPIRV/SPIRVOps.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
9595
let results = (outs
9696
SPV_AnyPtr:$component_ptr
9797
);
98-
99-
let autogenSerialization = 0;
10098
}
10199

102100
// -----
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
2+
3+
func @foo() {
4+
spv.module "Logical" "VulkanKHR" {
5+
func @access_chain(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>,
6+
%arg1 : i32, %arg2 : i32) {
7+
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
8+
// CHECK-NEXT: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
9+
%1 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
10+
%2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
11+
spv.Return
12+
}
13+
}
14+
return
15+
}

tools/mlir-tblgen/SPIRVUtilsGen.cpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,13 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
126126
auto argument = op.getArg(i);
127127
os << " {\n";
128128
if (argument.is<NamedTypeConstraint *>()) {
129-
os << " if (" << operandNum
130-
<< " < op.getOperation()->getNumOperands()) {\n";
131-
os << " auto arg = findValueID(op.getOperation()->getOperand("
132-
<< operandNum << "));\n";
133-
os << " if (!arg) {\n";
129+
os << " for (auto arg : op.getODSOperands(" << i << ")) {\n";
130+
os << " auto argID = findValueID(arg);\n";
131+
os << " if (!argID) {\n";
134132
os << " emitError(op.getLoc(), \"operand " << operandNum
135133
<< " has a use before def\");\n";
136134
os << " }\n";
137-
os << " operands.push_back(arg);\n";
135+
os << " operands.push_back(argID);\n";
138136
os << " }\n";
139137
operandNum++;
140138
} else {
@@ -243,32 +241,53 @@ static void emitDeserializationFunction(const Record *record,
243241
"SPIR-V ops can have only zero or one result");
244242
}
245243

246-
// Process arguments/attributes
244+
// Process operands/attributes
247245
os << " SmallVector<Value *, 4> operands;\n";
248246
os << " SmallVector<NamedAttribute, 4> attributes;\n";
249247
unsigned operandNum = 0;
250248
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
251249
auto argument = op.getArg(i);
252-
os << " if (wordIndex < words.size()) {\n";
253-
if (argument.is<NamedTypeConstraint *>()) {
250+
if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
251+
if (valueArg->isVariadic()) {
252+
if (i != e - 1) {
253+
PrintFatalError(record->getLoc(),
254+
"SPIR-V ops can have Variadic<..> argument only if "
255+
"it's the last argument");
256+
}
257+
os << " for (; wordIndex < words.size(); ++wordIndex)";
258+
} else {
259+
os << " if (wordIndex < words.size())";
260+
}
261+
os << " {\n";
254262
os << " auto arg = getValue(words[wordIndex]);\n";
255263
os << " if (!arg) {\n";
256264
os << " return emitError(unknownLoc, \"unknown result <id> : \") << "
257265
"words[wordIndex];\n";
258266
os << " }\n";
259267
os << " operands.push_back(arg);\n";
260-
os << " wordIndex++;\n";
268+
if (!valueArg->isVariadic()) {
269+
os << " wordIndex++;\n";
270+
}
261271
operandNum++;
272+
os << " }\n";
262273
} else {
274+
os << " if (wordIndex < words.size()) {\n";
263275
auto attr = argument.get<NamedAttribute *>();
264276
emitAttributeDeserialization(
265277
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
266278
record->getLoc(), "attributes", attr->name, "words", "wordIndex",
267279
"words.size()", os);
280+
os << " }\n";
268281
}
269-
os << " }\n";
270282
}
271283

284+
os << " if (wordIndex != words.size()) {\n";
285+
os << " return emitError(unknownLoc, \"found more operands than expected "
286+
"when deserializing "
287+
<< op.getQualCppClassName()
288+
<< ", only \") << wordIndex << \" of \" << words.size() << \" "
289+
"processed\";\n";
290+
os << " }\n";
272291
os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
273292
"operands, attributes); (void)op;\n",
274293
op.getQualCppClassName());

0 commit comments

Comments
 (0)