diff --git a/compiler/include/byteir/Dialect/mhlo/Analysis/ShapeAnalysis.h b/compiler/include/byteir/Dialect/mhlo/Analysis/ShapeAnalysis.h index 613b1cd91..f156d690c 100644 --- a/compiler/include/byteir/Dialect/mhlo/Analysis/ShapeAnalysis.h +++ b/compiler/include/byteir/Dialect/mhlo/Analysis/ShapeAnalysis.h @@ -144,7 +144,8 @@ class MhloShapeAnalysisBase : public ShapeAnalysis { wrapperShapeValueKnowledges); return inferFunc(op->getContext(), op->getLoc(), range, - op->getAttrDictionary(), op->getPropertiesStorage(), op->getRegions(), results); + op->getAttrDictionary(), op->getPropertiesStorage(), + op->getRegions(), results); } }; diff --git a/compiler/test/Dialect/SCF/forallCollapsing.mlir b/compiler/test/Dialect/SCF/forallCollapsing.mlir index 5ce12e26d..d71152ea6 100644 --- a/compiler/test/Dialect/SCF/forallCollapsing.mlir +++ b/compiler/test/Dialect/SCF/forallCollapsing.mlir @@ -25,9 +25,9 @@ func.func @Elementwise(%arg0: memref<32x1024x?x30xf32>) -> memref<32768x?x30xf32 %dim = memref.dim %arg0, %c2 : memref<32x1024x?x30xf32> %alloc = memref.alloc(%dim) : memref<32768x?x30xf32> scf.forall (%arg1, %arg2, %arg3) in (32768, %dim, 30) { - %subview = memref.subview %collapse_shape[%arg1, %arg2, %arg3] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> - %subview_0 = memref.subview %alloc[%arg1, %arg2, %arg3] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%subview : memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>>) outs(%subview_0 : memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>>) attrs = {__byteir_gpu_tile_elementwise_0} { + %subview = memref.subview %collapse_shape[%arg1, %arg2, %arg3] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>> + %subview_0 = memref.subview %alloc[%arg1, %arg2, %arg3] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%subview : memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>>) outs(%subview_0 : memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>>) attrs = {__byteir_gpu_tile_elementwise_0} { ^bb0(%in: f32, %out: f32): %0 = arith.mulf %in, %in : f32 linalg.yield %0 : f32 @@ -49,6 +49,6 @@ func.func @Elementwise(%arg0: memref<32x1024x?x30xf32>) -> memref<32768x?x30xf32 // CHECK-NEXT: %[[V2:.*]] = arith.divsi %arg1, %[[C30]] : index // CHECK-NEXT: %[[V3:.*]] = arith.remsi %[[V2]], %[[DIM]] : index // CHECK-NEXT: %[[V4:.*]] = arith.divsi %[[V2]], %[[DIM]] : index -// CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[COLLAPSE]][%[[V4]], %[[V3]], %[[V1]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> -// CHECK-NEXT: %[[SUBVIEW_0:.*]] = memref.subview %[[ALLOC]][%[[V4]], %[[V3]], %[[V1]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> -// CHCCK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[SUBVIEW]] : memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>>) outs(%[[SUBVIEW_0]] : memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>>) attrs = {__byteir_gpu_tile_elementwise_0} { +// CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[COLLAPSE]][%[[V4]], %[[V3]], %[[V1]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>> +// CHECK-NEXT: %[[SUBVIEW_0:.*]] = memref.subview %[[ALLOC]][%[[V4]], %[[V3]], %[[V1]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>> +// CHCCK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[SUBVIEW]] : memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>>) outs(%[[SUBVIEW_0]] : memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>>) attrs = {__byteir_gpu_tile_elementwise_0} {