Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update after 'loop'->'scf' Op namespace renaming in MLIR #1900

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions iree/compiler/Translation/SPIRV/LinalgToSPIRV/ConvertToGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace iree_compiler {
// Loop utilities
//===----------------------------------------------------------------------===//

/// Builds an empty loop.for operation. The default builder adds an entry basic
/// Builds an empty scf.for operation. The default builder adds an entry basic
/// block which needs to be avoided here.
static scf::ForOp buildEmptyForOp(Location loc, OpBuilder &builder, Value lb,
Value ub, Value step) {
Expand All @@ -58,10 +58,10 @@ struct LoopBounds {
};
} // namespace

/// Replaces a loop.parallelOp with an optional loop.parallel op and nested
/// loop.for operations. To create the loop.parallel op as the outermost loop,
/// Replaces a scf.parallelOp with an optional scf.parallel op and nested
/// scf.for operations. To create the scf.parallel op as the outermost loop,
/// pass the lower bound, upper bound and steps in `newPLoopLbs`, `newPLoopUbs`,
/// and `newPLoopStep` respectively. The bounds of the inner loop.for operations
/// and `newPLoopStep` respectively. The bounds of the inner scf.for operations
/// to be created are passed in `forLbs`, `forUbs`, and `forStep`. The
/// `permutation` vector contains a mapping from the original loop order, to the
/// loop order to be generated.
Expand All @@ -70,21 +70,21 @@ static Operation *replacePLoopOp(ConversionPatternRewriter &rewriter,
ArrayRef<LoopBounds> newPLoopBounds,
ArrayRef<LoopBounds> forBounds,
ArrayRef<unsigned> permutation) {
assert(!forBounds.empty() && "unhandled case of no loop.for created");
assert(!forBounds.empty() && "unhandled case of no scf.for created");
unsigned numLoops = pLoopOp.getNumLoops();
Location loc = pLoopOp.getLoc();
assert(forBounds.size() + newPLoopBounds.size() == numLoops &&
"cannot drop loops when splitting loop.parallel operation");
"cannot drop loops when splitting scf.parallel operation");
assert(permutation.size() == numLoops);
OpBuilder::InsertionGuard guard(rewriter);

// Need a signature conversion for the body of the loop.parallel operation,
// Need a signature conversion for the body of the scf.parallel operation,
// before can it can be used as the body of the innermost loop created here.
TypeConverter::SignatureConversion signatureConverter(numLoops);
Operation *outermostLoop = nullptr;
auto permuteIt = permutation.begin();

// Create the loop.parallel operation as the outermost loop, if specified.
// Create the scf.parallel operation as the outermost loop, if specified.
if (!newPLoopBounds.empty()) {
auto lbs = llvm::to_vector<2>(llvm::map_range(
newPLoopBounds, [](LoopBounds bounds) -> Value { return bounds.lb; }));
Expand All @@ -101,7 +101,7 @@ static Operation *replacePLoopOp(ConversionPatternRewriter &rewriter,
outermostLoop = newPLoop.getOperation();
}

// Generate the nested loop.for operations with the bounds passed.
// Generate the nested scf.for operations with the bounds passed.
for (auto it : enumerate(forBounds)) {
Value lb = it.value().lb, ub = it.value().ub, step = it.value().step;
if (it.index() != forBounds.size() - 1) {
Expand All @@ -110,7 +110,7 @@ static Operation *replacePLoopOp(ConversionPatternRewriter &rewriter,
signatureConverter.remapInput(*permuteIt, forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
} else {
// For the last loop, move the body of the loop.parallel op as the body of
// For the last loop, move the body of the scf.parallel op as the body of
// the loop after signature conversion.
auto forOp = buildEmptyForOp(loc, rewriter, lb, ub, step);
if (!outermostLoop) outermostLoop = forOp.getOperation();
Expand All @@ -127,8 +127,8 @@ static Operation *replacePLoopOp(ConversionPatternRewriter &rewriter,
return outermostLoop;
}

/// Serializes the dimensions of the loop.parallel specified in
/// `serializedDimensions`, by creating an nested loop.for operation for each
/// Serializes the dimensions of the scf.parallel specified in
/// `serializedDimensions`, by creating an nested scf.for operation for each
/// dimension.
// TODO(ravishankarm): Move this into LoopUtils.h in MLIR.
static Operation *serializeDimensions(ConversionPatternRewriter &rewriter,
Expand All @@ -141,7 +141,7 @@ static Operation *serializeDimensions(ConversionPatternRewriter &rewriter,
serializedDimSet.insert(serializedDimensions.begin(),
serializedDimensions.end());
assert(serializedDimSet.size() == serializedDimensions.size() &&
"cannot repeat dimensions during serialization of loop.parallel");
"cannot repeat dimensions during serialization of scf.parallel");
SmallVector<LoopBounds, 2> newPLoopBounds, forBounds;
SmallVector<unsigned, 2> permutation;
auto lbs = pLoopOp.lowerBound();
Expand Down Expand Up @@ -178,9 +178,9 @@ static Operation *serializeDimensionsFrom(ConversionPatternRewriter &rewriter,
// GPU processor ID mapping utilities
//===----------------------------------------------------------------------===//

/// Distribute loop.parallel to processors with the processors logically
/// Distribute scf.parallel to processors with the processors logically
/// arranged with same dimensionality as the number of loops, i.e. a
/// loop.parallel with 2 loops to a 2D grid of processors. `processorIDs` and
/// scf.parallel with 2 loops to a 2D grid of processors. `processorIDs` and
/// `numProcessors` must be of same size as the number of loops and are the
/// values to use for process ID and number of processors along each dimension
/// in the distributed code.
Expand Down Expand Up @@ -251,7 +251,7 @@ ProcessorIdAndCount getGPUProcessorIdAndCount<GPUGlobalId, GPUGlobalCount>(
rewriter.create<MulIOp>(loc, blockDim, gridDim)};
}

/// Distribute loop.parallel to processors where `IdOp` is used to get the
/// Distribute scf.parallel to processors where `IdOp` is used to get the
/// processor ID and `DimOp` is used to get the number of processors along a
/// dimension.
template <typename GPUIdOp, typename GPUCountOp>
Expand All @@ -277,19 +277,19 @@ static LogicalResult mapToProcessor(ConversionPatternRewriter &rewriter,
return mapToProcessors(rewriter, pLoopOp, id, count);
}

/// Distribute the loop.parallel to workgroups.
/// Distribute the scf.parallel to workgroups.
static LogicalResult mapToWorkgroups(ConversionPatternRewriter &rewriter,
scf::ParallelOp pLoopOp) {
return mapToProcessor<gpu::BlockIdOp, gpu::GridDimOp>(rewriter, pLoopOp);
}

/// Distribute loop.parallel to workitems using local invocation ID.
/// Distribute scf.parallel to workitems using local invocation ID.
static LogicalResult mapToLocalInvocationId(ConversionPatternRewriter &rewriter,
scf::ParallelOp pLoopOp) {
return mapToProcessor<gpu::ThreadIdOp, gpu::BlockDimOp>(rewriter, pLoopOp);
}

/// Distribute loop.parallel to workitems using global invocation ID. The GPU
/// Distribute scf.parallel to workitems using global invocation ID. The GPU
/// dialect doesn't have a direct operation to do this. This could be done using
/// id = blockIdx * blockDim + gridIdx. count = blockDim * gridDim.
static LogicalResult mapToGlobalInvocationId(
Expand All @@ -307,7 +307,7 @@ struct ConvertToGPUPass : public PassWrapper<ConvertToGPUPass, FunctionPass> {
void runOnFunction() override;
};

/// Pattern to map loop.parallel to workgroups.
/// Pattern to map scf.parallel to workgroups.
struct PartitionPLoopToWorkgroups
: public OpConversionPattern<scf::ParallelOp> {
using OpConversionPattern<scf::ParallelOp>::OpConversionPattern;
Expand All @@ -318,7 +318,7 @@ struct PartitionPLoopToWorkgroups
}
};

/// Map tiled linalg op to workitems by lowering it to loop.parallel and
/// Map tiled linalg op to workitems by lowering it to scf.parallel and
/// partitioning it to workitems.
template <typename LinalgOpTy>
struct MapLinalgOpToLocalInvocationId : public OpConversionPattern<LinalgOpTy> {
Expand Down Expand Up @@ -394,7 +394,7 @@ void ConvertToGPUPass::runOnFunction() {

MLIRContext *context = &getContext();
ConversionTarget target(*context);
// Ater this pass Linalg and loop.parallel ops should be gone.
// Ater this pass Linalg and scf.parallel ops should be gone.
target.addIllegalOp<scf::ParallelOp>();
target.addIllegalDialect<linalg::LinalgDialect>();
// Reshape ops are treated legal since they just change the way the underlying
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ void addLinalgToSPIRVPasses(OpPassManager &pm,
// - All Linalg ops have buffer semantics.
//
// Post-conditions:
// - loop.parallel ops are generated for mapping to workgroups.
// - Linalg ops are nested inside loop.parallel ops and ready for mapping
// - scf.parallel ops are generated for mapping to workgroups.
// - Linalg ops are nested inside scf.parallel ops and ready for mapping
// to workitems.
//===--------------------------------------------------------------------===//
pm.addPass(createLinalgTileAndFusePass(workGroupSize));
Expand All @@ -88,9 +88,9 @@ void addLinalgToSPIRVPasses(OpPassManager &pm,
// Map to GPU processor IDs.
//
// Post-conditions:
// - loop.parallel ops are converted to loop.for ops and mapped to
// - scf.parallel ops are converted to scf.for ops and mapped to
// workgroups.
// - Linalg ops are converted to loop.for ops and mapped to workitems.
// - Linalg ops are converted to scf.for ops and mapped to workitems.
//===--------------------------------------------------------------------===//
pm.addPass(createConvertToGPUPass());
pm.addPass(createLowerAffinePass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
// CHECK: loop.parallel
// CHECK: scf.parallel
// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]]
// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
Expand Down Expand Up @@ -41,7 +41,7 @@ module {
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
// CHECK: loop.parallel
// CHECK: scf.parallel
// CHECK-DAG: %[[VIEW0:.+]] = subview %[[ARG0]]
// CHECK-DAG: %[[VIEW1:.+]] = subview %[[ARG1]]
// CHECK-DAG: %[[VIEW2READ:.+]] = subview %[[ARG2]]
Expand Down Expand Up @@ -88,7 +88,7 @@ module {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
// CHECK: loop.parallel (%{{.+}})
// CHECK: scf.parallel (%{{.+}})
// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
// CHECK: linalg.conv
Expand All @@ -113,7 +113,7 @@ module {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
// CHECK: loop.parallel (%{{.+}}, %{{.+}}, %{{.+}})
// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
// CHECK: linalg.conv
Expand All @@ -134,7 +134,7 @@ module {
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
// CHECK-LABEL: func @parallel_4D
// CHECK: loop.parallel (%{{.+}}, %{{.+}}, %{{.+}})
// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
func @parallel_4D(%arg0: memref<?x?x?x?xf32>,
%arg1 : memref<?x?x?x?xf32>,
%arg2 : memref<?x?x?x?xf32>)
Expand Down
38 changes: 19 additions & 19 deletions iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/loop_to_gpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module {
%c4 = constant 4 : index
%c8 = constant 8 : index
%c1 = constant 1 : index
loop.parallel (%arg3, %arg4) = (%c0, %c0) to (%c4, %c8) step (%c4, %c32) {
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c4, %c8) step (%c4, %c32) {
%0 = affine.min #map0(%c4, %c4, %arg3)
%1 = affine.min #map0(%c32, %c8, %arg4)
%2 = subview %arg0[%arg3, %arg4] [%0, %1] [%c1, %c1]
Expand All @@ -35,7 +35,7 @@ module {
%9 = addi %arg5, %arg6 : i32
linalg.yield %9 : i32
} : memref<?x?xi32, #map1>, memref<?x?xi32, #map1>, memref<?x?xi32, #map1>
loop.yield
scf.yield
}
return
}
Expand All @@ -50,14 +50,14 @@ module {
// CHECK: %[[NEWSTEPY:.+]] = muli %[[NBLOCKSY]], %[[STEPY]]
// CHECK: %[[NEWLBX:.+]] = muli %[[BIDX]], %[[STEPX]]
// CHECK: %[[NEWSTEPX:.+]] = muli %[[NBLOCKSX]], %[[STEPX]]
// CHECK: loop.for %{{.+}} = %[[NEWLBY]] to %{{.+}} step %[[NEWSTEPY]]
// CHECK: loop.for %{{.+}} = %[[NEWLBX]] to %{{.+}} step %[[NEWSTEPX]]
// CHECK: scf.for %{{.+}} = %[[NEWLBY]] to %{{.+}} step %[[NEWSTEPY]]
// CHECK: scf.for %{{.+}} = %[[NEWLBX]] to %{{.+}} step %[[NEWSTEPX]]
// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
// CHECK-DAG: %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"}
// CHECK: loop.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
// CHECK: loop.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]
// CHECK: scf.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
// CHECK: scf.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]

// -----

Expand All @@ -84,7 +84,7 @@ module {
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK: loop.for %{{.+}} = %[[C0]] to %[[C4]] step %[[C1]]
// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C4]] step %[[C1]]
// CHECK-NOT: loop

// -----
Expand All @@ -104,7 +104,7 @@ module {
%1 = dim %arg0, 1 : memref<?x?x?x?xf32>
%2 = dim %arg0, 2 : memref<?x?x?x?xf32>
%3 = dim %arg0, 3 : memref<?x?x?x?xf32>
loop.parallel (%arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0) to (%0, %1, %2, %3) step (%c2, %c2, %c2, %c32) {
scf.parallel (%arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0) to (%0, %1, %2, %3) step (%c2, %c2, %c2, %c32) {
%12 = affine.min #map0(%arg3)[%0]
%13 = affine.min #map0(%arg4)[%1]
%14 = affine.min #map0(%arg5)[%2]
Expand All @@ -122,7 +122,7 @@ module {
%19 = addf %arg7, %arg8 : f32
linalg.yield %19 : f32
} : memref<?x?x?x?xf32, #map2>, memref<?x?x?x?xf32, #map2>, memref<?x?x?x?xf32, #map2>
loop.yield
scf.yield
}
return
}
Expand All @@ -145,20 +145,20 @@ module {
// CHECK-DAG: %[[STEP1:.+]] = muli %[[NBLOCKSY]], %[[C2]]
// CHECK-DAG: %[[LB2:.+]] = muli %[[BIDX]], %[[C2]]
// CHECK-DAG: %[[STEP2:.+]] = muli %[[NBLOCKSX]], %[[C2]]
// CHECK: loop.for %{{.+}} = %[[LB0]] to %{{.+}} step %[[STEP0]]
// CHECK: loop.for %{{.+}} = %[[LB1]] to %{{.+}} step %[[STEP1]]
// CHECK: loop.for %{{.+}} = %[[LB2]] to %{{.+}} step %[[STEP2]]
// CHECK: loop.for %{{.+}} = %[[C0]] to %[[SERIALDIMOUTER]] step %[[C32]]
// CHECK: scf.for %{{.+}} = %[[LB0]] to %{{.+}} step %[[STEP0]]
// CHECK: scf.for %{{.+}} = %[[LB1]] to %{{.+}} step %[[STEP1]]
// CHECK: scf.for %{{.+}} = %[[LB2]] to %{{.+}} step %[[STEP2]]
// CHECK: scf.for %{{.+}} = %[[C0]] to %[[SERIALDIMOUTER]] step %[[C32]]
// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"} : () -> index
// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"} : () -> index
// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"} : () -> index
// CHECK-DAG: %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"} : () -> index
// CHECK-DAG: %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"} : () -> index
// CHECK-DAG: %[[NTHREADSZ:.+]] = "gpu.block_dim"() {dimension = "z"} : () -> index
// CHECK: loop.for %{{.+}} = %[[TIDZ]] to %{{.+}} step %[[NTHREADSZ]]
// CHECK: loop.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
// CHECK: loop.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]
// CHECK: loop.for %{{.+}} = %[[C0]] to %{{.+}} step %[[C1]]
// CHECK: scf.for %{{.+}} = %[[TIDZ]] to %{{.+}} step %[[NTHREADSZ]]
// CHECK: scf.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
// CHECK: scf.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]
// CHECK: scf.for %{{.+}} = %[[C0]] to %{{.+}} step %[[C1]]

// -----

Expand Down Expand Up @@ -197,5 +197,5 @@ module {
// CHECK: %[[T6:.+]] = muli %[[BIDY]], %[[BLOCKSIZEY]]
// CHECK: %[[GIDY:.+]] = addi %[[T6]], %[[TIDY]]
// CHECK: %[[NPROCSY:.+]] = muli %[[BLOCKSIZEY]], %[[NBLOCKSY]]
// CHECK: loop.for %{{.+}} = %[[GIDY]] to %[[UBY]] step %[[NPROCSY]]
// CHECK: loop.for %{{.+}} = %[[GIDX]] to %[[UBX]] step %[[NPROCSX]]
// CHECK: scf.for %{{.+}} = %[[GIDY]] to %[[UBY]] step %[[NPROCSY]]
// CHECK: scf.for %{{.+}} = %[[GIDX]] to %[[UBX]] step %[[NPROCSX]]