diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index 80647b934cd9..164d900a1c71 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -305,9 +305,9 @@ struct DistributeBroadcast final : OpDistributionPattern { auto vectorType = VectorType::get(distShape, elementType); VectorValue srcVector = dyn_cast(broadcastOp.getSource()); - // If the srcVector is a scalar (like f32) or a rank-0 vector (like - // vector), we proceed with the scalar distribution branch. - if (!srcVector || !isNonZeroRank(srcVector)) { + // If the srcVector is a scalar (like f32) we proceed with the scalar + // distribution branch. + if (!srcVector) { // The way distribution currently works, there is no partial thread // distribution, so a scalar is available to all threads. Scalar // distribution is simply a broadcast from scalar to the distributed diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp index 7e927b499077..a8831809e25b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp @@ -132,16 +132,14 @@ void DistributionPattern::replaceOpWithDistributedValues( for (auto [opResult, replacement] : llvm::zip_equal(op->getOpResults(), values)) { // If this value is a vector type, it must be converted back to simd. - if (auto replacementType = dyn_cast(replacement.getType())) { - if (replacementType.getRank() != 0) { - auto oldResult = cast(opResult); - // Create a toSIMD op to convert the value back to the simd. - rewriter.setInsertionPointAfterValue(oldResult); - Value toSIMD = rewriter.create( - oldResult.getLoc(), oldResult.getType(), replacement); - // Add to replacements. - replacement = toSIMD; - } + if (isa(replacement.getType())) { + auto oldResult = cast(opResult); + // Create a toSIMD op to convert the value back to the simd. + rewriter.setInsertionPointAfterValue(oldResult); + Value toSIMD = rewriter.create( + oldResult.getLoc(), oldResult.getType(), replacement); + // Add to replacements. + replacement = toSIMD; } replacements.push_back(replacement); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir index 71448ef84066..98455c93f3e0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir @@ -783,6 +783,47 @@ builtin.module attributes { transform.with_named_sequence } { // ----- +#layout = #iree_vector_ext.nested_layout< + subgroup_tile = [2, 2, 2], + batch_tile = [2, 2, 1], + outer_tile = [2, 1, 1], + thread_tile = [4, 16, 8], + element_tile = [1, 4, 4], + subgroup_strides = [4, 2, 1], + thread_strides = [128, 8, 1] +> + +func.func @zero_rank_broadcast(%src: vector) -> (vector<32x256x64xf16>) { + %bcast = vector.broadcast %src : vector to vector<32x256x64xf16> + %bcastl = iree_vector_ext.to_layout %bcast to layout(#layout) : vector<32x256x64xf16> + return %bcastl : vector<32x256x64xf16> +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @zero_rank_broadcast +// CHECK-SAME: (%[[SRC:.*]]: vector) +// CHECK: %[[SRC_SIMT:.*]] = iree_vector_ext.to_simt %[[SRC]] : vector +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC_SIMT]] +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f16 to vector<1x4x4xf16> +// CHECK: vector.insert %[[BCAST]], %{{.*}} +// CHECK: vector.insert %[[BCAST]], %{{.*}} +// CHECK: vector.insert %[[BCAST]], %{{.*}} +// CHECK: vector.insert %[[BCAST]], %{{.*}} +// CHECK: vector.insert %[[BCAST]], %{{.*}} +// CHECK: vector.insert %[[BCAST]], %{{.*}} +// CHECK: vector.insert %[[BCAST]], %{{.*}} +// CHECK: %[[OUT:.*]] = vector.insert %[[BCAST]], %{{.*}} +// CHECK: iree_vector_ext.to_simd %[[OUT]] : vector<2x2x1x2x1x1x1x4x4xf16> -> vector<32x256x64xf16> + +// ----- + #layout = #iree_vector_ext.nested_layout< subgroup_tile = [2, 2, 2], batch_tile = [2, 2, 1], diff --git a/third_party/llvm-project b/third_party/llvm-project index ac39504813f8..889525fa99b2 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit ac39504813f8c52f10c0e364485569bff5a5f7a1 +Subproject commit 889525fa99b251dc962edb516e0108088ba7e44d