diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 0bba23842..252e94cdf 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -386,7 +386,7 @@ def mma( print(mma(a, b, c).module_op) # CHECK: func.func @mma(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, - # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index @@ -405,60 +405,63 @@ def mma( # CHECK: %[[D1:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index # CHECK: %[[D2:.+]] = arith.muli %[[D1]], %[[C16]] : index # CHECK: %[[D3:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D2]] : index - # CHECK: %[[D5:.+]] = vector.load %[[D0]][%[[D4]], %[[C0]]] : memref<64x16xf16, strided<[16, 1], offset: ?>>, - # CHECK-SAME: vector<4xf16> + # CHECK: %[[D4:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[D3]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D2]] : index + # CHECK: %[[D7:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D8:.+]] = arith.divsi %[[D7]], %[[C16]] : index + # CHECK: %[[D9:.+]] = arith.muli %[[D8]], %[[C4]] : index + # CHECK: %[[D10:.+]] = vector.load %[[D0]][%[[D6]], %[[D9]]] : memref<64x16xf16, strided<[16, 1], offset: + # CHECK-SAME: ?>>, vector<4xf16> # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> - # CHECK: vector.store %[[D5]], %[[ALLOC]][%[[D2]], %[[C0]]] : memref<32x16xf16, + # CHECK: %[[D11:.+]] = arith.addi %[[D4]], %[[D2]] : index + # CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D6:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D7:.+]] = arith.addi %[[D6]], %[[D2]] : index - # CHECK: %[[D8:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D9:.+]] = arith.divsi %[[D8]], %[[C16]] : index - # CHECK: %[[D10:.+]] = arith.muli %[[D9]], %[[C4]] : index - # CHECK: %[[D11:.+]] = vector.load %[[ALLOC]][%[[D7]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D12:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, + # CHECK: %[[D13:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, # CHECK-SAME: strided<[16, 1], offset: ?>> - # CHECK: %[[D13:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D14:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D13]] : index - # CHECK: %[[D16:.+]] = vector.load %[[D12]][%[[D15]], %[[C0]]] : memref<128x16xf16, strided<[16, 1], offset: + # CHECK: %[[D14:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D15:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D16:.+]] = arith.addi %[[D4]], %[[D15]] : index + # CHECK: %[[D17:.+]] = arith.addi %[[D16]], %[[D14]] : index + # CHECK: %[[D18:.+]] = vector.load %[[D13]][%[[D17]], %[[D9]]] : memref<128x16xf16, strided<[16, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> - # CHECK: vector.store %[[D16]], %[[ALLOC_0]][%[[D13]], %[[C0]]] : memref<32x16xf16, + # CHECK: amdgpu.lds_barrier + # CHECK: %[[D19:.+]] = arith.addi %[[D4]], %[[D14]] : index + # CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D17:.+]] = arith.addi %[[D6]], %[[D13]] : index - # CHECK: %[[D18:.+]] = vector.load %[[ALLOC_0]][%[[D17]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D19:.+]] = amdgpu.mfma %[[D11]] * %[[D18]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK: %[[D21:.+]] = amdgpu.mfma %[[D12]] * %[[D20]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - # CHECK: %[[D20:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK: %[[D22:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [0], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D21:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, + # CHECK: %[[D23:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, # CHECK-SAME: strided<[128, 1], offset: ?>> - # CHECK: %[[D22:.+]] = arith.addi %[[D4]], %[[D10]] : index - # CHECK: %[[D23:.+]] = arith.addi %[[D6]], %[[D14]] : index - # CHECK: %[[D24:.+]] = arith.addi %[[D23]], %[[D13]] : index - # CHECK: vector.store %[[D20]], %[[D21]][%[[D22]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D24:.+]] = arith.addi %[[D3]], %[[D2]] : index + # CHECK: %[[D25:.+]] = arith.addi %[[D24]], %[[D9]] : index + # CHECK: vector.store %[[D22]], %[[D23]][%[[D25]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D25:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK: %[[D26:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [1], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D26:.+]] = arith.addi %[[D22]], %[[C1]] : index - # CHECK: vector.store %[[D25]], %[[D21]][%[[D26]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D27:.+]] = arith.addi %[[D25]], %[[C1]] : index + # CHECK: vector.store %[[D26]], %[[D23]][%[[D27]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D27:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK: %[[D28:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [2], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D28:.+]] = arith.addi %[[D22]], %[[C2]] : index - # CHECK: vector.store %[[D27]], %[[D21]][%[[D28]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D29:.+]] = arith.addi %[[D25]], %[[C2]] : index + # CHECK: vector.store %[[D28]], %[[D23]][%[[D29]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D29:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK: %[[D30:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [3], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D30:.+]] = arith.addi %[[D22]], %[[C3]] : index - # CHECK: vector.store %[[D29]], %[[D21]][%[[D30]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D31:.+]] = arith.addi %[[D25]], %[[C3]] : index + # CHECK: vector.store %[[D30]], %[[D23]][%[[D31]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: return @run_test @@ -515,7 +518,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(gemm(a, b, c).module_op) # CHECK: func.func @gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, - # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index @@ -531,77 +534,81 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> - # CHECK: %[[D22:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, - # CHECK-SAME: strided<[64, 1], offset: ?>> - # CHECK: %[[D23:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, - # CHECK-SAME: strided<[64, 1], offset: ?>> - # CHECK: %[[D24:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D25:.+]] = arith.muli %[[D24]], %[[C16]] : index - # CHECK: %[[D26:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D27:.+]] = arith.addi %[[D26]], %[[D25]] : index - # CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D25]] : index - # CHECK: %[[D32:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D33:.+]] = arith.divsi %[[D32]], %[[C16]] : index - # CHECK: %[[D34:.+]] = arith.muli %[[D33]], %[[C4]] : index - # CHECK: %[[D36:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D37:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D38:.+]] = arith.addi %[[D37]], %[[D36]] : index - # CHECK: %[[D40:.+]] = arith.addi %[[D30]], %[[D36]] : index - # CHECK: %[[D0:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C1]] + # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, + # CHECK-SAME: strided<[64, 1], offset: ?>> + # CHECK: %[[D1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, + # CHECK-SAME: strided<[64, 1], offset: ?>> + # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C16]] : index + # CHECK: %[[D4:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D5:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D4]] : index + # CHECK: %[[D7:.+]] = arith.addi %[[D6]], %[[D3]] : index + # CHECK: %[[D8:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D9:.+]] = arith.divsi %[[D8]], %[[C16]] : index + # CHECK: %[[D10:.+]] = arith.muli %[[D9]], %[[C4]] : index + # CHECK: %[[D11:.+]] = arith.addi %[[D5]], %[[D3]] : index + # CHECK: %[[D12:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D13:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D14:.+]] = arith.addi %[[D5]], %[[D13]] : index + # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D12]] : index + # CHECK: %[[D16:.+]] = arith.addi %[[D5]], %[[D12]] : index + # CHECK: %[[D17:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C1]] # CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[CST]]) -> (vector<4xf32>) { - # CHECK: %[[D28:.+]] = arith.muli %[[ARG3]], %[[C16]] : index - # CHECK: %[[D29:.+]] = vector.load %[[D22]][%[[D27]], %[[D28]]] : memref<64x64xf16, strided<[64, 1], - # CHECK-SAME: offset: ?>>, vector<4xf16> - # CHECK: vector.store %[[D29]], %[[ALLOC]][%[[D25]], %[[C0]]] : memref<32x16xf16, + # CHECK: %[[D39:.+]] = arith.muli %[[ARG3]], %[[C16]] : index + # CHECK: %[[D40:.+]] = arith.addi %[[D39]], %[[D10]] : index + # CHECK: %[[D41:.+]] = vector.load %[[D0]][%[[D7]], %[[D40]]] : memref<64x64xf16, strided<[64, 1], offset: + # CHECK-SAME: ?>>, vector<4xf16> + # CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D35:.+]] = vector.load %[[ALLOC]][%[[D31]], %[[D34]]] : memref<32x16xf16, + # CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D39:.+]] = vector.load %[[D23]][%[[D38]], %[[D28]]] : memref<128x64xf16, strided<[64, 1], + # CHECK: %[[D43:.+]] = vector.load %[[D1]][%[[D15]], %[[D40]]] : memref<128x64xf16, strided<[64, 1], # CHECK-SAME: offset: ?>>, vector<4xf16> - # CHECK: vector.store %[[D39]], %[[ALLOC_0]][%[[D36]], %[[C0]]] : memref<32x16xf16, + # CHECK: amdgpu.lds_barrier + # CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D41:.+]] = vector.load %[[ALLOC_0]][%[[D40]], %[[D34]]] : memref<32x16xf16, + # CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D42:.+]] = amdgpu.mfma %[[D35]] * %[[D41]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 + # CHECK: %[[D45:.+]] = amdgpu.mfma %[[D42]] * %[[D44]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 # CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - # CHECK: scf.yield %[[D42]] : vector<4xf32> + # CHECK: scf.yield %[[D45]] : vector<4xf32> # CHECK: } - # CHECK: %[[D1:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK: %[[D18:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [0], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, + # CHECK: %[[D19:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, # CHECK-SAME: strided<[128, 1], offset: ?>> - # CHECK: %[[D3:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D4:.+]] = arith.divsi %[[D3]], %[[C16]] : index - # CHECK: %[[D5:.+]] = arith.muli %[[D4]], %[[C4]] : index - # CHECK: %[[D6:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D7:.+]] = arith.muli %[[D6]], %[[C16]] : index - # CHECK: %[[D8:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D9:.+]] = arith.addi %[[D8]], %[[D7]] : index - # CHECK: %[[D10:.+]] = arith.addi %[[D9]], %[[D5]] : index - # CHECK: %[[D11:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D12:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D13:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D14:.+]] = arith.addi %[[D13]], %[[D12]] : index - # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D11]] : index - # CHECK: vector.store %[[D1]], %[[D2]][%[[D10]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D20:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D21:.+]] = arith.divsi %[[D20]], %[[C16]] : index + # CHECK: %[[D22:.+]] = arith.muli %[[D21]], %[[C4]] : index + # CHECK: %[[D23:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D24:.+]] = arith.muli %[[D23]], %[[C16]] : index + # CHECK: %[[D25:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D26:.+]] = arith.addi %[[D25]], %[[D24]] : index + # CHECK: %[[D27:.+]] = arith.addi %[[D26]], %[[D22]] : index + # CHECK: %[[D28:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D29:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D29]] : index + # CHECK: %[[D32:.+]] = arith.addi %[[D31]], %[[D28]] : index + # CHECK: vector.store %[[D18]], %[[D19]][%[[D27]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D16:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK: %[[D33:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [1], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D17:.+]] = arith.addi %[[D10]], %[[C1]] : index - # CHECK: vector.store %[[D16]], %[[D2]][%[[D17]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D34:.+]] = arith.addi %[[D27]], %[[C1]] : index + # CHECK: vector.store %[[D33]], %[[D19]][%[[D34]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D18:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK: %[[D35:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [2], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D19:.+]] = arith.addi %[[D10]], %[[C2]] : index - # CHECK: vector.store %[[D18]], %[[D2]][%[[D19]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D36:.+]] = arith.addi %[[D27]], %[[C2]] : index + # CHECK: vector.store %[[D35]], %[[D19]][%[[D36]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D20:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK: %[[D37:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [3], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D21:.+]] = arith.addi %[[D10]], %[[C3]] : index - # CHECK: vector.store %[[D20]], %[[D2]][%[[D21]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D38:.+]] = arith.addi %[[D27]], %[[C3]] : index + # CHECK: vector.store %[[D37]], %[[D19]][%[[D38]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> # CHECK: return diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index f1bf0edee..6f4e2f290 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -243,23 +243,23 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) # CHECK-NEXT: get_result(value=reduction, res_idx=3) # CHECK-NEXT: get_result(value=reduction, res_idx=2) # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_1_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: write(register_=getresult_1_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_0_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: output # Reduction subgraph: @@ -389,11 +389,11 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0] # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: output(return_vals=(None,)) # Reduction subgraph: diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 4cdb997b5..2bebc6908 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -182,13 +182,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index b085d8b17..dcf6b2258 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -131,13 +131,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( @@ -146,13 +146,13 @@ def test_gemm(): # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_1_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: write(register_=getresult_1_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_0_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # Reduction subgraph: # CHECK: %acc_0_0_0 diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 3a2d3d3b0..952f394ac 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -780,16 +780,6 @@ def custom_string(self, value_map: dict[str, str]) -> str: custom_str += f"acc={self.acc} (index = {self.acc_index}))" return custom_str - def post_expansion(self, constraints: list["Constraint"]) -> None: - """ - Once the arguments have been expanded, we set their indices, - ensuring that the LHS and RHS indices are consistent with their - corresponding address spaces. - """ - self.lhs.index = self.lhs_index - self.rhs.index = self.rhs_index - self.acc.index = self.acc_index - @define_op("read") @dataclass diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index 53796682d..e25cd7c09 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -19,7 +19,7 @@ from .._support.indexing import IndexingContext, IndexSequence from ...support.logging import get_logger from .._support.tracing import CapturedTrace -from .utils import get_mma_dimensional_mapping +from .utils import get_mma_dimensional_mapping, specialize_index_sequence from ..lang.global_symbols import * logger = get_logger("turbine.wave.expansion") @@ -141,6 +141,7 @@ def compute_stride( def set_node_index( constraints: Sequence[Constraint], mma_index: dict[IndexSymbol, int], + mma_slices: dict[IndexSymbol, list[fx.Node]], dim_tile_size: dict[IndexSymbol, int], custom: CustomOp, dim_scaling: dict[IndexSymbol, int], @@ -171,11 +172,7 @@ def set_node_index( for dim in custom.indexing_dims: index_seq = None for constraint in sorted_constraints: - mma_check = ( - isinstance(constraint, HardwareConstraint) - and dim in mma_index - and isinstance(custom, MMA) - ) + mma_check = isinstance(constraint, HardwareConstraint) and dim in mma_index vector_check = ( isinstance(constraint, HardwareConstraint) @@ -217,6 +214,8 @@ def set_node_index( index_seq = constraint.apply( constraint_index, dim, elements_per_thread, stride ) + if mma_index: + index_seq = specialize_index_sequence(index_seq, mma_slices, custom) else: if index_seq is None: @@ -246,10 +245,10 @@ def expand_graph( dim_scaling = constraints_or_scaling node_index_setter = lambda *args: None else: - mma_index = get_mma_dimensional_mapping(trace) + mma_index, mma_slices = get_mma_dimensional_mapping(trace) dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index) node_index_setter = partial( - set_node_index, constraints_or_scaling, mma_index, dim_tile_size + set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size ) # Start from the back and expand in the corresponding indexing dimensions of a node diff --git a/shark_turbine/kernel/wave/index_sequence_analysis.py b/shark_turbine/kernel/wave/index_sequence_analysis.py index cec8b60b1..b9212f014 100644 --- a/shark_turbine/kernel/wave/index_sequence_analysis.py +++ b/shark_turbine/kernel/wave/index_sequence_analysis.py @@ -24,7 +24,7 @@ def get_vector_shape( hardware_constraint: HardwareConstraint, symbolic_shape: list[IndexSymbol], ) -> list[int]: - mma_indices = get_mma_dimensional_mapping(trace) + mma_indices, _ = get_mma_dimensional_mapping(trace) return [ get_hardware_vector_size(dim, hardware_constraint, mma_indices) for dim in symbolic_shape diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index affd5fef5..42e5bca3f 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -1,5 +1,4 @@ # Copyright 2024 The IREE Authors -# # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -16,7 +15,16 @@ from .._support.tracing import CapturedTrace from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence from ..lang.global_symbols import * -from ..ops.wave_ops import get_custom, Output, Write, MMA +from ..ops.wave_ops import ( + get_custom, + Output, + Write, + MMA, + CustomOp, + Reduction, + GetResult, + IterArg, +) from .constraints import Constraint, HardwareConstraint, TilingConstraint import torch.fx as fx import shark_turbine.kernel.lang as tkl @@ -145,7 +153,9 @@ def simplify_index(index: IndexExpr) -> IndexExpr: return subs_idxc(index.subs(mapping)) -def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[IndexSymbol, int]: +def get_mma_dimensional_mapping( + trace: CapturedTrace, +) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]: """ Given a trace, determine the MMA dimensional mapping for all the MMA operations in the graph. For example, if we have @@ -159,7 +169,8 @@ def is_mma(node): return isinstance(get_custom(node), MMA) mapping: dict[IndexSymbol, int] = {} - for node in trace.walk(is_mma): + mma_nodes = trace.walk(is_mma) + for node in mma_nodes: custom: MMA = get_custom(node) m, n = custom.acc_type.symbolic_shape[-2:] lhs_shape = custom.lhs_type.symbolic_shape @@ -170,7 +181,7 @@ def is_mma(node): mapping[n] = 1 mapping[k] = 2 - return mapping + return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes]) def get_hardware_vector_size( @@ -378,3 +389,132 @@ def erase_graph(graph: fx.Graph): for user in node.users: graph.erase_node(user) graph.erase_node(node) + + +def get_users( + node: fx.Node, reduction: fx.Node = None +) -> tuple[list[fx.Node], fx.Node]: + """ + Return the users of a node, propagating through reductions. + """ + users = [] + for user in node.users: + custom = get_custom(user) + if isinstance(custom, Reduction): + # Map init arg to iter arg + reduction = custom + init_arg_idx = custom.init_args.index(node) + users.append(custom.iter_args[init_arg_idx]) + continue + if isinstance(custom, Output) and reduction: + # Map output to get result + return_vals = custom.return_vals[0] + get_results = sorted( + [x for x in reduction.users if isinstance(get_custom(x), GetResult)], + lambda x: get_custom(x).res_idx, + ) + if isinstance(return_vals, list): + output_idx = return_vals.index(node) + users.append(get_results[output_idx]) + else: + users.append(get_results[0]) + continue + users.append(user) + return users, reduction + + +def get_inputs( + node: fx.Node, reduction: fx.Node = None +) -> tuple[list[fx.Node], fx.Node]: + """ + Return the inputs of a node, propagating through reductions. + """ + inputs = [] + for input in node.all_input_nodes: + custom = get_custom(input) + if isinstance(custom, GetResult): + reduction = custom.value + assert isinstance( + reduction, Reduction + ), "GetResult must be used by a Reduction" + # Map get result to output + inputs.append(reduction.outputs[custom.res_idx]) + continue + if isinstance(custom, IterArg): + # Map iter args to init args + iter_arg_idx = reduction.iter_args.index(node) + inputs.append(reduction.init_args[iter_arg_idx]) + continue + inputs.append(input) + return inputs, reduction + + +def bfs( + node: fx.Node, + get_neighbors: Callable[[fx.Node, fx.Node], list[fx.Node]], +) -> set[fx.Node]: + """ + Run BFS on the graph to capture the forward slice of a node. + """ + visited: set[fx.Node] = set() + queue: list[fx.Node] = [] + visited.add(node) + queue.append(node) + reduction = None + while queue: + s = queue.pop(0) + neighbors, reduction = get_neighbors(s, reduction) + for neighbor in neighbors: + if neighbor not in visited: + visited.add(neighbor) + queue.append(neighbor) + return visited + + +def capture_forward_slice(node: fx.Node) -> set[fx.Node]: + """ + Run BFS on the graph to capture the forward slice of a node. + """ + return bfs(node, lambda x, y: get_users(x, y)) + + +def capture_backward_slice(node: fx.Node) -> set[fx.Node]: + """ + Capture backward slice from a node and return the tree. + Assumes graph is directed. + """ + return bfs(node, lambda x, y: get_inputs(x, y)) + + +def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]: + """ + Given an index sequence, specialize it to a LHS, RHS or ACC index sequence + based on whether the node is used as the LHS, RHS or ACC in the MMA node. + """ + mma_slices = {x: [] for x in [MMA_LHS, MMA_RHS, MMA_ACC]} + for mma in mma_nodes: + mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs) + mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs) + mma_slices[MMA_ACC] += capture_forward_slice(mma.acc) + return mma_slices + + +def specialize_index_sequence( + index_seq: IndexSequence, + mma_slices: dict[IndexSymbol, list[fx.Node]], + custom: CustomOp, +) -> IndexSequence: + """ + Given an index sequence, specialize it to a LHS, RHS or ACC index sequence + based on whether the node is used as the LHS, RHS or ACC in the MMA node. + If the node is not used as any of the operands, return the original index sequence + with all the MMA symbols zeroed out. + """ + if isinstance(custom, MMA): + return index_seq + operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0} + for key in mma_slices: + if custom.fx_node in mma_slices[key]: + operand_map[key] = 1 + return index_seq.subs(operand_map) + return index_seq.subs(operand_map)