diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8f938a0e..5146f144 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -60,7 +60,6 @@ jobs: if: "contains(matrix.os, 'mi300') && !cancelled()" run: | export WAVE_RUN_E2E_TESTS=1 - export TEST_PARAMS_PATH=./tests/kernel/wave/test_param.json pytest -n 4 ./tests/kernel/wave/ - name: Run LIT tests diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index 1b446dc0..8f8b4a6f 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -207,9 +207,9 @@ def test_gemm(): # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_0_0_0 # CHECK-NEXT: %read_0_0_1 diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 0bba2384..be4b04bb 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -231,51 +231,40 @@ def test( print(test(a, b).module_op) # CHECK: func.func @test(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding) - # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf16> - # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index - # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index - # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index - # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index - # CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index - # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index - # CHECK: %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index - # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index - # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x - # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y - # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>> - # CHECK: %[[D1:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C4]] : index - # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C4]] : index - # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D1]] : index - # CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[THREAD_ID_X]] : index - # CHECK: %[[D6:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C4]] : index - # CHECK: %[[D7:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C8]] : index - # CHECK: %[[D8:.+]] = arith.addi %[[D7]], %[[D6]] : index - # CHECK: %[[D9:.+]] = vector.constant_mask [4] : vector<4xi1> - # CHECK: %[[D10:.+]] = arith.cmpi slt, %[[D5]], %[[C1]] : index - # CHECK: %[[D11:.+]] = arith.cmpi slt, %[[D8]], %[[C3]] : index - # CHECK: %[[D12:.+]] = arith.andi %[[D10]], %[[D11]] : i1 - # CHECK: %[[D13:.+]] = vector.insertelement %[[D12]], %[[D9]][%[[C0]] : index] : vector<4xi1> - # CHECK: %[[D14:.+]] = arith.addi %[[D8]], %[[C1]] : index - # CHECK: %[[D15:.+]] = arith.cmpi slt, %[[D14]], %[[C3]] : index - # CHECK: %[[D16:.+]] = arith.andi %[[D10]], %[[D15]] : i1 - # CHECK: %[[D17:.+]] = vector.insertelement %[[D16]], %[[D13]][%[[C1]] : index] : vector<4xi1> - # CHECK: %[[D18:.+]] = arith.addi %[[D8]], %[[C2]] : index - # CHECK: %[[D19:.+]] = arith.cmpi slt, %[[D18]], %[[C3]] : index - # CHECK: %[[D20:.+]] = arith.andi %[[D10]], %[[D19]] : i1 - # CHECK: %[[D21:.+]] = vector.insertelement %[[D20]], %[[D17]][%[[C2]] : index] : vector<4xi1> - # CHECK: %[[D22:.+]] = arith.addi %[[D8]], %[[C3]] : index - # CHECK: %[[D23:.+]] = arith.cmpi slt, %[[D22]], %[[C3]] : index - # CHECK: %[[D24:.+]] = arith.andi %[[D10]], %[[D23]] : i1 - # CHECK: %[[D25:.+]] = vector.insertelement %[[D24]], %[[D21]][%[[C3]] : index] : vector<4xi1> - # CHECK: %[[D26:.+]] = vector.maskedload %[[D0]][%[[D5]], %[[D8]]], %[[D25]], %[[CST]] : - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> into vector<4xf16> - # CHECK: %[[D27:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>> - # CHECK: vector.maskedstore %[[D27]][%[[D5]], %[[D8]]], %[[D25]], %[[D26]] : - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> + # CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf16> + # CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<3> : vector<4xindex> + # CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> + # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + # CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + # CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + # CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + # CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index + # CHECK: %[[WORKGROUP_ID_0:.*]] = stream.dispatch.workgroup.id[0] : index + # CHECK: %[[WORKGROUP_ID_1:.*]] = stream.dispatch.workgroup.id[1] : index + # CHECK-DAG: %[[THREAD_ID_X:.*]] = gpu.thread_id x + # CHECK-DAG: %[[THREAD_ID_Y:.*]] = gpu.thread_id y + # CHECK: %[[D0:.*]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>> + # CHECK: %[[D1:.*]] = arith.muli %[[WORKGROUP_ID_0]], %[[C4]] : index + # CHECK: %[[D2:.*]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.*]] = arith.muli %[[D2]], %[[C4]] : index + # CHECK: %[[D4:.*]] = arith.addi %[[D3]], %[[D1]] : index + # CHECK: %[[D5:.*]] = arith.addi %[[D4]], %[[THREAD_ID_X]] : index + # CHECK: %[[D6:.*]] = arith.muli %[[WORKGROUP_ID_1]], %[[C4]] : index + # CHECK: %[[D7:.*]] = arith.muli %[[THREAD_ID_Y]], %[[C8]] : index + # CHECK: %[[D8:.*]] = arith.addi %[[D7]], %[[D6]] : index + # CHECK: %[[D9:.*]] = vector.splat %[[D8]] : vector<4xindex> + # CHECK: %[[D10:.*]] = arith.addi %[[D9]], %[[CST_1]] : vector<4xindex> + # CHECK: %[[D11:.*]] = arith.cmpi slt, %[[D10]], %[[CST_0]] : vector<4xindex> + # CHECK: %[[D12:.*]] = arith.cmpi slt, %[[D5]], %[[C1]] : index + # CHECK: %[[D13:.*]] = vector.splat %[[D12]] : vector<4xi1> + # CHECK: %[[D14:.*]] = arith.andi %[[D11]], %[[D13]] : vector<4xi1> + # CHECK: %[[D15:.*]] = vector.maskedload %[[D0]][%[[D5]], %[[D8]]], %[[D14]], %[[CST]] : memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> into vector<4xf16> + # CHECK: %[[D16:.*]] = stream.binding.subspan %arg1[%[[C0]]] : !stream.binding -> memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>> + # CHECK: vector.maskedstore %[[D16]][%[[D5]], %[[D8]]], %[[D14]], %[[D15]] : memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> @run_test @@ -386,7 +375,7 @@ def mma( print(mma(a, b, c).module_op) # CHECK: func.func @mma(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, - # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index @@ -405,60 +394,63 @@ def mma( # CHECK: %[[D1:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index # CHECK: %[[D2:.+]] = arith.muli %[[D1]], %[[C16]] : index # CHECK: %[[D3:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D2]] : index - # CHECK: %[[D5:.+]] = vector.load %[[D0]][%[[D4]], %[[C0]]] : memref<64x16xf16, strided<[16, 1], offset: ?>>, - # CHECK-SAME: vector<4xf16> + # CHECK: %[[D4:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[D3]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D2]] : index + # CHECK: %[[D7:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D8:.+]] = arith.divsi %[[D7]], %[[C16]] : index + # CHECK: %[[D9:.+]] = arith.muli %[[D8]], %[[C4]] : index + # CHECK: %[[D10:.+]] = vector.load %[[D0]][%[[D6]], %[[D9]]] : memref<64x16xf16, strided<[16, 1], offset: + # CHECK-SAME: ?>>, vector<4xf16> # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> - # CHECK: vector.store %[[D5]], %[[ALLOC]][%[[D2]], %[[C0]]] : memref<32x16xf16, + # CHECK: %[[D11:.+]] = arith.addi %[[D4]], %[[D2]] : index + # CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D6:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D7:.+]] = arith.addi %[[D6]], %[[D2]] : index - # CHECK: %[[D8:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D9:.+]] = arith.divsi %[[D8]], %[[C16]] : index - # CHECK: %[[D10:.+]] = arith.muli %[[D9]], %[[C4]] : index - # CHECK: %[[D11:.+]] = vector.load %[[ALLOC]][%[[D7]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D12:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, + # CHECK: %[[D13:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, # CHECK-SAME: strided<[16, 1], offset: ?>> - # CHECK: %[[D13:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D14:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D13]] : index - # CHECK: %[[D16:.+]] = vector.load %[[D12]][%[[D15]], %[[C0]]] : memref<128x16xf16, strided<[16, 1], offset: + # CHECK: %[[D14:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D15:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D16:.+]] = arith.addi %[[D4]], %[[D15]] : index + # CHECK: %[[D17:.+]] = arith.addi %[[D16]], %[[D14]] : index + # CHECK: %[[D18:.+]] = vector.load %[[D13]][%[[D17]], %[[D9]]] : memref<128x16xf16, strided<[16, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> - # CHECK: vector.store %[[D16]], %[[ALLOC_0]][%[[D13]], %[[C0]]] : memref<32x16xf16, + # CHECK: amdgpu.lds_barrier + # CHECK: %[[D19:.+]] = arith.addi %[[D4]], %[[D14]] : index + # CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D17:.+]] = arith.addi %[[D6]], %[[D13]] : index - # CHECK: %[[D18:.+]] = vector.load %[[ALLOC_0]][%[[D17]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D19:.+]] = amdgpu.mfma %[[D11]] * %[[D18]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK: %[[D21:.+]] = amdgpu.mfma %[[D12]] * %[[D20]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - # CHECK: %[[D20:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK: %[[D22:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [0], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D21:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, + # CHECK: %[[D23:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, # CHECK-SAME: strided<[128, 1], offset: ?>> - # CHECK: %[[D22:.+]] = arith.addi %[[D4]], %[[D10]] : index - # CHECK: %[[D23:.+]] = arith.addi %[[D6]], %[[D14]] : index - # CHECK: %[[D24:.+]] = arith.addi %[[D23]], %[[D13]] : index - # CHECK: vector.store %[[D20]], %[[D21]][%[[D22]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D24:.+]] = arith.addi %[[D3]], %[[D2]] : index + # CHECK: %[[D25:.+]] = arith.addi %[[D24]], %[[D9]] : index + # CHECK: vector.store %[[D22]], %[[D23]][%[[D25]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D25:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK: %[[D26:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [1], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D26:.+]] = arith.addi %[[D22]], %[[C1]] : index - # CHECK: vector.store %[[D25]], %[[D21]][%[[D26]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D27:.+]] = arith.addi %[[D25]], %[[C1]] : index + # CHECK: vector.store %[[D26]], %[[D23]][%[[D27]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D27:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK: %[[D28:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [2], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D28:.+]] = arith.addi %[[D22]], %[[C2]] : index - # CHECK: vector.store %[[D27]], %[[D21]][%[[D28]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D29:.+]] = arith.addi %[[D25]], %[[C2]] : index + # CHECK: vector.store %[[D28]], %[[D23]][%[[D29]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D29:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK: %[[D30:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [3], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D30:.+]] = arith.addi %[[D22]], %[[C3]] : index - # CHECK: vector.store %[[D29]], %[[D21]][%[[D30]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D31:.+]] = arith.addi %[[D25]], %[[C3]] : index + # CHECK: vector.store %[[D30]], %[[D23]][%[[D31]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: return @run_test @@ -515,7 +507,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(gemm(a, b, c).module_op) # CHECK: func.func @gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, - # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index @@ -531,77 +523,81 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> - # CHECK: %[[D22:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, - # CHECK-SAME: strided<[64, 1], offset: ?>> - # CHECK: %[[D23:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, - # CHECK-SAME: strided<[64, 1], offset: ?>> - # CHECK: %[[D24:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D25:.+]] = arith.muli %[[D24]], %[[C16]] : index - # CHECK: %[[D26:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D27:.+]] = arith.addi %[[D26]], %[[D25]] : index - # CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D25]] : index - # CHECK: %[[D32:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D33:.+]] = arith.divsi %[[D32]], %[[C16]] : index - # CHECK: %[[D34:.+]] = arith.muli %[[D33]], %[[C4]] : index - # CHECK: %[[D36:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D37:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D38:.+]] = arith.addi %[[D37]], %[[D36]] : index - # CHECK: %[[D40:.+]] = arith.addi %[[D30]], %[[D36]] : index - # CHECK: %[[D0:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C1]] + # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, + # CHECK-SAME: strided<[64, 1], offset: ?>> + # CHECK: %[[D1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, + # CHECK-SAME: strided<[64, 1], offset: ?>> + # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C16]] : index + # CHECK: %[[D4:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D5:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D4]] : index + # CHECK: %[[D7:.+]] = arith.addi %[[D6]], %[[D3]] : index + # CHECK: %[[D8:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D9:.+]] = arith.divsi %[[D8]], %[[C16]] : index + # CHECK: %[[D10:.+]] = arith.muli %[[D9]], %[[C4]] : index + # CHECK: %[[D11:.+]] = arith.addi %[[D5]], %[[D3]] : index + # CHECK: %[[D12:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D13:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D14:.+]] = arith.addi %[[D5]], %[[D13]] : index + # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D12]] : index + # CHECK: %[[D16:.+]] = arith.addi %[[D5]], %[[D12]] : index + # CHECK: %[[D17:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C1]] # CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[CST]]) -> (vector<4xf32>) { - # CHECK: %[[D28:.+]] = arith.muli %[[ARG3]], %[[C16]] : index - # CHECK: %[[D29:.+]] = vector.load %[[D22]][%[[D27]], %[[D28]]] : memref<64x64xf16, strided<[64, 1], - # CHECK-SAME: offset: ?>>, vector<4xf16> - # CHECK: vector.store %[[D29]], %[[ALLOC]][%[[D25]], %[[C0]]] : memref<32x16xf16, + # CHECK: %[[D39:.+]] = arith.muli %[[ARG3]], %[[C16]] : index + # CHECK: %[[D40:.+]] = arith.addi %[[D39]], %[[D10]] : index + # CHECK: %[[D41:.+]] = vector.load %[[D0]][%[[D7]], %[[D40]]] : memref<64x64xf16, strided<[64, 1], offset: + # CHECK-SAME: ?>>, vector<4xf16> + # CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D35:.+]] = vector.load %[[ALLOC]][%[[D31]], %[[D34]]] : memref<32x16xf16, + # CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D39:.+]] = vector.load %[[D23]][%[[D38]], %[[D28]]] : memref<128x64xf16, strided<[64, 1], + # CHECK: %[[D43:.+]] = vector.load %[[D1]][%[[D15]], %[[D40]]] : memref<128x64xf16, strided<[64, 1], # CHECK-SAME: offset: ?>>, vector<4xf16> - # CHECK: vector.store %[[D39]], %[[ALLOC_0]][%[[D36]], %[[C0]]] : memref<32x16xf16, + # CHECK: amdgpu.lds_barrier + # CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D41:.+]] = vector.load %[[ALLOC_0]][%[[D40]], %[[D34]]] : memref<32x16xf16, + # CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D42:.+]] = amdgpu.mfma %[[D35]] * %[[D41]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 + # CHECK: %[[D45:.+]] = amdgpu.mfma %[[D42]] * %[[D44]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 # CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - # CHECK: scf.yield %[[D42]] : vector<4xf32> + # CHECK: scf.yield %[[D45]] : vector<4xf32> # CHECK: } - # CHECK: %[[D1:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK: %[[D18:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [0], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, + # CHECK: %[[D19:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, # CHECK-SAME: strided<[128, 1], offset: ?>> - # CHECK: %[[D3:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D4:.+]] = arith.divsi %[[D3]], %[[C16]] : index - # CHECK: %[[D5:.+]] = arith.muli %[[D4]], %[[C4]] : index - # CHECK: %[[D6:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D7:.+]] = arith.muli %[[D6]], %[[C16]] : index - # CHECK: %[[D8:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D9:.+]] = arith.addi %[[D8]], %[[D7]] : index - # CHECK: %[[D10:.+]] = arith.addi %[[D9]], %[[D5]] : index - # CHECK: %[[D11:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D12:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D13:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D14:.+]] = arith.addi %[[D13]], %[[D12]] : index - # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D11]] : index - # CHECK: vector.store %[[D1]], %[[D2]][%[[D10]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D20:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D21:.+]] = arith.divsi %[[D20]], %[[C16]] : index + # CHECK: %[[D22:.+]] = arith.muli %[[D21]], %[[C4]] : index + # CHECK: %[[D23:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D24:.+]] = arith.muli %[[D23]], %[[C16]] : index + # CHECK: %[[D25:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D26:.+]] = arith.addi %[[D25]], %[[D24]] : index + # CHECK: %[[D27:.+]] = arith.addi %[[D26]], %[[D22]] : index + # CHECK: %[[D28:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D29:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D29]] : index + # CHECK: %[[D32:.+]] = arith.addi %[[D31]], %[[D28]] : index + # CHECK: vector.store %[[D18]], %[[D19]][%[[D27]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D16:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK: %[[D33:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [1], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D17:.+]] = arith.addi %[[D10]], %[[C1]] : index - # CHECK: vector.store %[[D16]], %[[D2]][%[[D17]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D34:.+]] = arith.addi %[[D27]], %[[C1]] : index + # CHECK: vector.store %[[D33]], %[[D19]][%[[D34]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D18:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK: %[[D35:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [2], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D19:.+]] = arith.addi %[[D10]], %[[C2]] : index - # CHECK: vector.store %[[D18]], %[[D2]][%[[D19]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D36:.+]] = arith.addi %[[D27]], %[[C2]] : index + # CHECK: vector.store %[[D35]], %[[D19]][%[[D36]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D20:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK: %[[D37:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [3], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D21:.+]] = arith.addi %[[D10]], %[[C3]] : index - # CHECK: vector.store %[[D20]], %[[D2]][%[[D21]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D38:.+]] = arith.addi %[[D27]], %[[C3]] : index + # CHECK: vector.store %[[D37]], %[[D19]][%[[D38]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> # CHECK: return @@ -756,6 +752,94 @@ def test( # CHECK: arith.addf {{.*}} : vector<1xf16> +# This test is to ensure that the propagation of indexing_dims between reduction and operations +# outside the reduction is working properly. +@run_test +def test_reduction_and_elemwise(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, N, 0)] + constraints += [tkw.TilingConstraint(N, BLOCK_N)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + init_max = tkl.Register[M, tkl.f16](-1e6) + + @tkw.reduction(N, init_args=[init_max]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], + ) -> tkl.Register[M, tkl.f16]: + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + partial_max = tkw.max(lhs, partial_max, dim=N) + return partial_max + + result = repeat + repeat + tkw.write(result, c, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 512) + a = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, c).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[INIT:.+]] = arith.constant dense<0xFC00> : vector<1xf16> + + # Tile Reduction Loop + # CHECK: %[[TILED:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT]], %[[ACC1:.+]] = %[[INIT]]) -> (vector<1xf16>, vector<1xf16>) { + # 1st Expanded Local Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 1st Expanded Global Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Reduction + # CHECK: %[[ACC_REDUCE_0:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + + # 2nd Expanded Local Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 2nd Expanded Global Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Reduction + # CHECK: %[[ACC_REDUCE_1:.+]] = arith.maximumf %[[ACC1]], %{{.*}} + + # CHECK: scf.yield %[[ACC_REDUCE_0]], %[[ACC_REDUCE_1]] : vector<1xf16>, vector<1xf16> + # CHECK: %[[POST_TILE_ELEMWISE_0:.+]] = arith.addf %[[TILED]]#0, %[[TILED]]#0 : vector<1xf16> + # CHECK: %[[POST_TILE_ELEMWISE_1:.+]] = arith.addf %[[TILED]]#1, %[[TILED]]#1 : vector<1xf16> + # CHECK: vector.store %[[POST_TILE_ELEMWISE_0:.+]], %{{.*}} + # CHECK: vector.store %[[POST_TILE_ELEMWISE_1:.+]], %{{.*}} + + @run_test def test_tiled_reduce_max(): M = tkl.sym.M @@ -851,6 +935,111 @@ def repeat( # CHECK: scf.yield %[[ACC_REDUCE]] : vector<1xf16> +# This test is to ensure that the we can handle multiple IV in reduction properly. +@run_test +def test_multiple_reduction_iv(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, N, 0)] + constraints += [tkw.TilingConstraint(N, BLOCK_N)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + d: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + init_max = tkl.Register[M, tkl.f16](-1e6) + init_sum = tkl.Register[M, tkl.f16](0) + + @tkw.reduction(N, init_args=[init_max, init_sum]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], + partial_sum: tkl.Register[M, tkl.f16], + ) -> tkl.Register[M, tkl.f16]: + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + partial_max = tkw.max(lhs, partial_max, dim=N) + partial_sum = tkw.sum(lhs, partial_sum, dim=N) + return partial_max, partial_sum + + res_max, res_sum = repeat + tkw.write(res_max, c, elements_per_thread=1) + tkw.write(res_sum, d, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 512) + a = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + d = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, c).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[INIT_MAX:.+]] = arith.constant dense<0xFC00> : vector<1xf16> + # CHECK-DAG: %[[INIT_SUM:.+]] = arith.constant dense<0.000000e+00> : vector<1xf16> + + # Tile Reduction Loop + # CHECK: %[[TILED:.+]]:4 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT_MAX]], %[[ACC1:.+]] = %[[INIT_SUM]], %[[ACC2:.+]] = %[[INIT_MAX]], %[[ACC3:.+]] = %[[INIT_SUM]]) + # CHECK-SAME: -> (vector<1xf16>, vector<1xf16>, vector<1xf16>, vector<1xf16>) { + # 1st Expanded Local Max Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 1st Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Max Reduction + # CHECK: %[[ACC_MAX_0:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + + # 2nd Expanded Local Max Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 2nd Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Max Reduction + # CHECK: %[[ACC_MAX_1:.+]] = arith.maximumf %[[ACC2]], %{{.*}} + + # 1st Expanded Local Sum Reduction + # CHECK: arith.addf {{.*}} : vector<1xf16> + # 1st Expanded Global Sum Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Sum Reduction + # CHECK: %[[ACC_SUM_0:.+]] = arith.addf %[[ACC1]], %{{.*}} + + # 2nd Expanded Local Sum Reduction + # CHECK: arith.addf {{.*}} : vector<1xf16> + # 2nd Expanded Global Sum Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Sum Reduction + # CHECK: %[[ACC_SUM_1:.+]] = arith.addf %[[ACC3]], %{{.*}} + + # CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]] + + @run_test def test_binary_lowerings(): constraints: list[tkw.Constraint] = [ diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index a20965f3..6f4e2f29 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -243,31 +243,31 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) # CHECK-NEXT: get_result(value=reduction, res_idx=3) # CHECK-NEXT: get_result(value=reduction, res_idx=2) # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_1_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: write(register_=getresult_1_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_0_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: output # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_0_0_0 @@ -305,13 +305,13 @@ def test_gemm(): # CHECK-SAME: (%read_0_0_0, %read_0_1_0, %acc_0_1_0) # CHECK-NEXT: %mma_0_1_1 # CHECK-SAME: (%read_0_0_1, %read_0_1_1, %mma_0_1_0) - # CHECK-NEXT: return [mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1] + # CHECK-NEXT: return [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1] # Custom format: # CHECK-NEXT: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) @@ -346,7 +346,7 @@ def test_gemm(): # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-SAME: rhs=read_0_1_1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-SAME: acc=mma_0_1_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})) - # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1],)) + # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1],)) # CHECK-NEXT: ----- @@ -389,11 +389,11 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0] # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: output(return_vals=(None,)) # Reduction subgraph: diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index d49ee3b2..2bebc690 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -182,13 +182,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( @@ -210,16 +210,16 @@ def test_gemm(): # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3, N: 64*$WG1 + Mod($T0, 16) + 32}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_4, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16), N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_5, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 1, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 17, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_6, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 2, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 18, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_7, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_0_0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_8, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16, N: 64*$WG1 + Mod($T0, 16) + 32}) @@ -234,22 +234,22 @@ def test_gemm(): # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19, N: 64*$WG1 + Mod($T0, 16) + 32}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_12, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16), N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_13, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 17, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 1, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_14, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 18, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 2, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_15, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3, N: 64*$WG1 + Mod($T0, 16) + 48}) # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_4 # CHECK-SAME: (%a, 8, None, None) @@ -303,9 +303,9 @@ def test_gemm(): # Reduction subgraph (custom format): # CHECK: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64), K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 310e9ef4..dcf6b225 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -131,13 +131,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( @@ -146,19 +146,19 @@ def test_gemm(): # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_1_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: write(register_=getresult_1_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_0_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_4 # CHECK-SAME: (%a, 8, None, None) @@ -215,9 +215,9 @@ def test_gemm(): # Reduction subgraph (custom format): # CHECK: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64), K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) diff --git a/shark_turbine/aot/compiled_module.py b/shark_turbine/aot/compiled_module.py index 3f44c8b9..5fffd6a0 100644 --- a/shark_turbine/aot/compiled_module.py +++ b/shark_turbine/aot/compiled_module.py @@ -41,6 +41,7 @@ from .support.ir_utils import ( ModuleBuilder, + ModuleBuilderOptions, ) @@ -162,11 +163,13 @@ class CompiledModuleClassInfo: __slots__ = [ "all_exports", "ir_module_name", + "options", ] - def __init__(self, *, ir_module_name: str): + def __init__(self, *, ir_module_name: str, options: ModuleBuilderOptions): self.ir_module_name = ir_module_name self.all_exports: Dict[str, Exportable] = dict() + self.options = options def add_export(self, key: str, value: Exportable): if key in self.all_exports: @@ -370,13 +373,23 @@ class CompiledModuleMeta(type): # It is passed the dictionary of declared attributes and any keyword # arguments from the class declaration: # class Foo(Bar, kwarg="you probably just learned this is possible"): - def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None): + def __new__( + mcls, + name: str, + bases, + dct, + *, + export_name: Optional[str] = None, + options: Optional[ModuleBuilderOptions] = None, + ): if not _metaclass_setup_complete: return type.__new__(mcls, name, bases, dct) ir_module_name = _derive_ir_module_name(name, export_name) logger.debug("Create new CompiledModule: %s", ir_module_name) - info = CompiledModuleClassInfo(ir_module_name=ir_module_name) + info = CompiledModuleClassInfo( + ir_module_name=ir_module_name, options=options or ModuleBuilderOptions() + ) # Process that attributes that were set as part of class definition. # Any attributes that we decide are part of the compiled module @@ -436,6 +449,7 @@ def create_from_dict( dct: dict, *, export_name: Optional[str] = None, + options: Optional[ModuleBuilderOptions] = None, ) -> CompiledModuleMeta: """Creates a CompiledModule subclass with an explicit dictionary of members. @@ -446,7 +460,9 @@ class Foo(CompiledModule, export_name="bar"): def member(): ... ``` """ - return CompiledModuleMeta(name, (cls,), dct, export_name=export_name) + return CompiledModuleMeta( + name, (cls,), dct, export_name=export_name, options=options + ) @staticmethod def get_class_info(cls: CompiledModuleMeta) -> CompiledModuleClassInfo: @@ -596,7 +612,7 @@ def __new__( module_op.attributes["sym_name"] = StringAttr.get( class_info.ir_module_name, context=context ) - module_builder = ModuleBuilder(module_op) + module_builder = ModuleBuilder(module_op, options=class_info.options) info = CompiledModuleInstanceInfo(class_info, module_builder=module_builder) _all_compiled_module_instance_infos[self] = info diff --git a/shark_turbine/aot/exporter.py b/shark_turbine/aot/exporter.py index 4c0e0160..c1adb527 100644 --- a/shark_turbine/aot/exporter.py +++ b/shark_turbine/aot/exporter.py @@ -26,6 +26,7 @@ from .builtins import * from .compiled_module import ( CompiledModule, + ModuleBuilderOptions, ImportPhase, ) from .fx_programs import FxPrograms @@ -175,6 +176,7 @@ def export( module_name: Optional[str] = None, function_name: Optional[str] = None, strict_export: bool = True, + import_symbolic_shape_expressions: bool = False, ) -> ExportOutput: """Exports a torch.nn.Module. @@ -223,6 +225,7 @@ def export( module_name: Optional[str] = None, function_name: Optional[str] = None, strict_export: bool = True, + import_symbolic_shape_expressions: bool = False, ) -> ExportOutput: """Generic export of supported entities. @@ -270,11 +273,19 @@ def export( "LambdaCompiledModule", {(function_name or "main"): mdl}, export_name=module_name or "module", + options=ModuleBuilderOptions( + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ), ) elif isinstance(mdl, FxPrograms): TransformedModule = CompiledModule.create_from_dict( - "LambdaCompiledModule", mdl.programs, export_name=module_name or "module" + "LambdaCompiledModule", + mdl.programs, + export_name=module_name or "module", + options=ModuleBuilderOptions( + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ), ) elif isinstance(mdl, torch.nn.Module): # Normalize arguments for torch.export. @@ -302,6 +313,9 @@ def export( "LambdaCompiledModule", {(function_name or "main"): exported_program}, export_name=module_name or "module", + options=ModuleBuilderOptions( + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ), ) elif issubclass(mdl, CompiledModule): TransformedModule = mdl diff --git a/shark_turbine/aot/support/ir_utils.py b/shark_turbine/aot/support/ir_utils.py index a662c15c..e1eb9d56 100644 --- a/shark_turbine/aot/support/ir_utils.py +++ b/shark_turbine/aot/support/ir_utils.py @@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Sequence, Tuple +from dataclasses import dataclass from pathlib import Path import tempfile @@ -148,6 +149,12 @@ def infer_external_from_tensor( ############################################################################### +@dataclass +class ModuleBuilderOptions: + # Whether to import torch symbolic shape expressions for ExportedPrograms. + import_symbolic_shape_expressions: bool = False + + class ModuleBuilder: """Wrapper around module and IR accounting for a module being built.""" @@ -159,14 +166,18 @@ class ModuleBuilder: "last_global_op", "ip", "module_op", + "options", "symbol_table", "global_ref_tracker", "native_type_converter", "_auto_symbol_counts", ] - def __init__(self, module_op: Operation): + def __init__( + self, module_op: Operation, *, options: Optional[ModuleBuilderOptions] = None + ): self.module_op = module_op + self.options = options or ModuleBuilderOptions() self.context = module_op.context self.body = module_op.regions[0].blocks[0] self.symbol_table = SymbolTable(module_op) diff --git a/shark_turbine/aot/support/procedural/exported_program.py b/shark_turbine/aot/support/procedural/exported_program.py index 331a7345..f6540bab 100644 --- a/shark_turbine/aot/support/procedural/exported_program.py +++ b/shark_turbine/aot/support/procedural/exported_program.py @@ -181,7 +181,10 @@ def import_exported_program( ) -> ExportedProgramIntrinsic: fx_importer = _create_fx_importer(module_builder) entry_func_op = fx_importer.import_program( - exported_program, func_name=symbol_name, func_visibility=symbol_visibility + exported_program, + func_name=symbol_name, + func_visibility=symbol_visibility, + import_symbolic_shape_expressions=module_builder.options.import_symbolic_shape_expressions, ) module_call_graph = exported_program.module_call_graph @@ -234,6 +237,8 @@ def store_produced_value( raise ValueError(f"Cannot store value to unmapped global for: {info}") logger.debug("Resolved global for store %r", mapping) materialized_global: MaterializedGlobal = mapping.value # type: ignore + assert isinstance(materialized_global.global_op, util_d.GlobalOp) + materialized_global.global_op.is_mutable = True converted_value = Operation.create( "torch_c.to_builtin_tensor", results=[materialized_global.ir_type], @@ -251,7 +256,7 @@ def resolve_literal( return None # See if we know about it. - materialized_global = self._lift_tensor_to_global(literal) + materialized_global = self._lift_tensor_to_global(literal, info) if not materialized_global: # If it is unknown, just let the default importer take it on. return None @@ -269,7 +274,7 @@ def resolve_literal( return converted_value def _lift_tensor_to_global( - self, literal: torch.Tensor + self, literal: torch.Tensor, info: InputInfo | None ) -> Optional[MaterializedGlobal]: module_builder = self.module_builder mapping = module_builder.global_ref_tracker.track(literal) @@ -282,7 +287,7 @@ def _lift_tensor_to_global( # Policy check: Should we auto-import? Generally, we keep "small" # tensors as inline as they can be optimized. external_trait = ExternalTensorTrait.get(literal) - if not self._should_lift_tensor_to_global(literal, external_trait): + if not self._should_lift_tensor_to_global(literal, external_trait, info): return None # If it is a tensor we haven't seen yet, materialize it @@ -304,8 +309,13 @@ def _lift_tensor_to_global( return materialized_global def _should_lift_tensor_to_global( - self, literal: torch.Tensor, external_trait: Optional[ExternalTensorTrait] + self, + literal: torch.Tensor, + external_trait: Optional[ExternalTensorTrait], + info: InputInfo | None, ) -> bool: + if info is not None and info.store_producer_node: + return True if external_trait is not None: return True volume = math.prod(literal.shape) diff --git a/shark_turbine/dynamo/type_conversion.py b/shark_turbine/dynamo/type_conversion.py index 8206e10f..e829bafc 100644 --- a/shark_turbine/dynamo/type_conversion.py +++ b/shark_turbine/dynamo/type_conversion.py @@ -32,7 +32,7 @@ # 1. Local name (int, float, vtensor) # 2. Parameter block ("<...>"), including the delimitters # 3. Inner parameter block (no delimitters) -DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch.([^<]+)(<([^>]*)>)?$") +DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch\.([^<]+)(<(.*)>)?$") # Decomposes a vtensor parameter block into a dimension list and dtype. Groups: # 1. Dimension list diff --git a/shark_turbine/kernel/_support/indexing.py b/shark_turbine/kernel/_support/indexing.py index 3f092278..b99d7b5b 100644 --- a/shark_turbine/kernel/_support/indexing.py +++ b/shark_turbine/kernel/_support/indexing.py @@ -99,6 +99,7 @@ class IndexingContext: __slots__ = [ "subs", + "special_subs", "shaped_bindings", "dyn_dims", "frozen_subs", @@ -109,6 +110,7 @@ class IndexingContext: def __init__(self): self.subs: dict[IndexSymbol, int] = {} + self.special_subs: dict[IndexSymbol, Any] = {} # Indexed by .instance self.shaped_bindings: dict[Any, _ShapedBinding] = {} self.dyn_dims: list[IndexSymbol] = [] @@ -245,6 +247,20 @@ def get_static_value(self, expr: IndexExpr | int) -> Optional[int]: except TypeError: return None + def iota(self, n: int) -> IndexExpr: + sym = index_symbol(f"$IOTA{n}") + if sym not in self.special_subs: + self.special_subs[sym] = tuple(range(n)) + + return sym + + def get_val(self, sym: IndexSymbol) -> Any: + res = self.subs.get(sym, None) + if res is None: + res = self.special_subs.get(sym, None) + + return res + ##### Context management. @staticmethod def current() -> "IndexingContext": diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 0298a065..2c38c9c2 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -365,12 +365,15 @@ def copy( new_name: Optional[str] = None, new_graph: Optional[fx.Graph] = None, arg_transform: Optional[Callable[[Any], Any]] = lambda x: x, + anchor: Optional[fx.Node] = None, ) -> Self: """Returns a duplicate of this node.""" graph = new_graph if new_graph is None: graph = self.graph - graph.inserting_after(self.fx_node) + if anchor is None: + anchor = self.fx_node + graph.inserting_after(anchor) new_node = graph.node_copy(self.fx_node, arg_transform=arg_transform) new_node.tkw_op = self new_node.tkw_op_name = self.tkw_op_name @@ -483,7 +486,6 @@ def post_expansion(self, constraints: list["Constraint"]) -> None: pass -@define_py_op(operator.getitem) @define_py_op(operator.add) @define_py_op(operator.sub) @define_py_op(operator.mul) @@ -777,16 +779,6 @@ def custom_string(self, value_map: dict[str, str]) -> str: custom_str += f"acc={self.acc} (index = {self.acc_index}))" return custom_str - def post_expansion(self, constraints: list["Constraint"]) -> None: - """ - Once the arguments have been expanded, we set their indices, - ensuring that the LHS and RHS indices are consistent with their - corresponding address spaces. - """ - self.lhs.index = self.lhs_index - self.rhs.index = self.rhs_index - self.acc.index = self.acc_index - @define_op("read") @dataclass @@ -858,12 +850,23 @@ def wrapper(f): return wrapper @property - def indexing_dims(self) -> list[IndexSymbol]: + def indexing_dims(self) -> list[IndexSymbol] | list[list[IndexSymbol]]: expand_dims: list[IndexSymbol] = [] - for user in self.users: - for indexing_dim in user.indexing_dims: - if indexing_dim not in expand_dims: - expand_dims.append(indexing_dim) + return_node = [ + nested_node + for nested_node in self.graph.subgraphs[self.subgraph_name].nodes + if isinstance(get_custom(nested_node), Output) + ] + assert len(return_node) == 1 + return_vals = get_custom(return_node[0]).return_vals[0] + if not isinstance(return_vals, Sequence): + return_vals = [return_vals] + for return_val in return_vals: + return_dims = get_custom(return_val).indexing_dims + reduced_dims = [dims for dims in return_dims if dims != self.axis] + expand_dims.append(reduced_dims) + if len(expand_dims) == 1: + expand_dims = expand_dims[0] return expand_dims def iter_args(self, graph: fx.Graph) -> list[fx.Node]: @@ -941,6 +944,7 @@ def register_index(self) -> dict[IndexSymbol, IndexSequence]: return custom.index +@define_py_op(operator.getitem) @define_op("get_result") @dataclass class GetResult(CustomOp): @@ -949,16 +953,24 @@ class GetResult(CustomOp): @property def type(self) -> "Memory": - return get_custom(self.value).type[self.res_idx] + src_type = get_custom(self.value).type + if isinstance(src_type, list): + return src_type[self.res_idx] + else: + return src_type @property - def indexing_dims(self) -> list[IndexSymbol]: - expand_dims: list[IndexSymbol] = [] - for user in self.users: - for indexing_dim in user.indexing_dims: - if indexing_dim not in expand_dims: - expand_dims.append(indexing_dim) - return expand_dims + def indexing_dims(self) -> list[IndexExpr]: + has_multiple_value = lambda x: all(isinstance(el, list) for el in x) + is_valid_indexing_dim = lambda x: isinstance(src_indexing, list) and all( + isinstance(el, IndexExpr) for el in x + ) + src_indexing = get_custom(self.value).indexing_dims + if has_multiple_value(src_indexing): + assert self.res_idx <= len(src_indexing) - 1 + src_indexing = src_indexing[self.res_idx] + assert is_valid_indexing_dim(src_indexing) + return src_indexing @property def index(self) -> dict[IndexSymbol, IndexSequence]: diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index e4a8cf72..aff72cf3 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools import operator import sympy import math @@ -77,7 +78,7 @@ WorkgroupConstraint, TilingConstraint, ) -from .utils import subs_idxc +from .utils import subs_idxc, find_index_bounds # Indexing imports. from .._support.indexing import IndexingContext, IndexExpr, IndexSequence @@ -171,6 +172,32 @@ def get_type_or_element_type(operand_type: IrType): def gen_sympy_index(emitter: WaveEmitter, expr: sympy.Expr) -> OpResult: stack: list[OpResult] = [] + def _broadcast(a, b): + if not isinstance(a, (Value, OpResult)): + a = a.result + + if not isinstance(b, (Value, OpResult)): + b = b.result + + if a.type == b.type: + return a, b + + if isinstance(a.type, VectorType) and isinstance( + b.type, (IndexType, IntegerType) + ): + assert a.type.element_type == b.type + b = vector_d.splat(a.type, b) + return a, b + + if isinstance(a.type, (IndexType, IntegerType)) and isinstance( + b.type, VectorType + ): + assert b.type.element_type == a.type + a = vector_d.splat(b.type, a) + return a, b + + raise CodegenError(f"Cannot broadcast {a.type} and {b.type}") + def _process_mul_add_ops(term, is_mul): args = [] callables = [] @@ -187,9 +214,9 @@ def _process_mul_add_ops(term, is_mul): continue if is_mul: - operation = arith_d.MulIOp(operation, arg) + operation = arith_d.MulIOp(*_broadcast(operation, arg)) else: - operation = arith_d.AddIOp(operation, arg) + operation = arith_d.AddIOp(*_broadcast(operation, arg)) for arg in callables: operation = arg(operation, is_mul) @@ -197,16 +224,29 @@ def _process_mul_add_ops(term, is_mul): stack.append(operation) def _get_mul(numerator): - return lambda x: arith_d.MulIOp(x, numerator) + return lambda x: arith_d.MulIOp(*_broadcast(x, numerator)) def _get_add(numerator, denominator): - return lambda x: arith_d.AddIOp(arith_d.MulIOp(x, denominator), numerator) + return lambda x: arith_d.AddIOp( + *_broadcast(arith_d.MulIOp(*_broadcast(x, denominator)), numerator) + ) def _get_div(mul, add, denominator): return lambda x, is_mul: arith_d.DivSIOp( - mul(x) if is_mul else add(x), denominator + *_broadcast(mul(x) if is_mul else add(x), denominator) ) + def _get_const(val): + if isinstance(val, int): + return arith_d.constant(IndexType.get(), res) + + if isinstance(val, (tuple, list)): + vec_type = VectorType.get([len(val)], IndexType.get()) + vals = [IntegerAttr.get(IndexType.get(), v) for v in val] + return arith_d.constant(vec_type, DenseElementsAttr.get(vals, vec_type)) + + raise CodegenError(f"Unsupported const val {val} : {type(val)}") + induction_var_syms = [] induction_vars = [] for constraint in emitter.constraints: @@ -237,9 +277,9 @@ def _get_div(mul, add, denominator): for term in sympy.postorder_traversal(expr): match term: case sympy.Symbol(): - if term in idxc.subs.keys(): - cst = arith_d.constant(IndexType.get(), idxc.subs[term]) - stack.append(cst) + res = idxc.get_val(term) + if res is not None: + stack.append(_get_const(res)) elif term in dynamics.keys(): stack.append(dynamics[term]) else: @@ -253,7 +293,7 @@ def _get_div(mul, add, denominator): case sympy.Mod(): rhs = stack.pop() lhs = stack.pop() - mod = arith_d.RemSIOp(lhs, rhs) + mod = arith_d.RemSIOp(*_broadcast(lhs, rhs)) stack.append(mod) case sympy.floor(): # TODO: Since divsi rounds to zero, this seems to work. @@ -267,17 +307,27 @@ def _get_div(mul, add, denominator): # Assumes that the negative term is always carried on the numerator if abs(term.p) > term.p: zero = arith_d.constant(IndexType.get(), int(0)) - numerator = arith_d.SubIOp(zero, numerator) + numerator = arith_d.SubIOp(*_broadcast(zero, numerator)) mul = lambda x: x if abs(term.p) != 1: mul = _get_mul(numerator) add = _get_add(numerator, denominator) operation = _get_div(mul, add, denominator) stack.append(operation) + case sympy.StrictLessThan(): + rhs = stack.pop() + lhs = stack.pop() + res = arith_d.cmpi(arith_d.CmpIPredicate.slt, *_broadcast(lhs, rhs)) + stack.append(res) + case sympy.And(): + rhs = stack.pop() + lhs = stack.pop() + res = arith_d.andi(*_broadcast(lhs, rhs)) + stack.append(res) case sympy.UnevaluatedExpr(): continue case _: - raise CodegenError(f"Can not handle {term} yet") + raise CodegenError(f"Can not handle {type(term)} : {term}") if len(stack) != 1: raise CodegenError(f"Expected single result, got {len(stack)}") return stack[0] @@ -392,44 +442,24 @@ def _is_identity_mapping( def _build_mask( emitter: WaveEmitter, index: Dict[IndexExpr, IndexExpr], elements_per_thread: int ) -> Optional[OpResult]: - bounds = [] - for constraint in emitter.constraints: - if not isinstance(constraint, (WorkgroupConstraint, TilingConstraint)): - continue - - dim = constraint.dim - if dim not in index: - continue - - work_size = constraint.count * constraint.tile_size - if subs_idxc(work_size) == subs_idxc(dim): - continue - - bounds.append((dim, gen_sympy_index(emitter, dim))) - - if len(bounds) == 0: + bounds = find_index_bounds(emitter.constraints, index) + if bounds is None: return None - mask_vec_type = VectorType.get([elements_per_thread], IntegerType.get_signless(1)) - mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread]) - + idxc = IndexingContext.current() last_dim = tuple(index.keys())[-1] new_index = {k: _get_start_index(v) for k, v in index.items()} - for i in range(elements_per_thread): - cond = None - for dim, bound in bounds: - idx = gen_sympy_index(emitter, new_index[dim]) - lt = arith_d.cmpi(arith_d.CmpIPredicate.slt, idx, bound) - if cond is None: - cond = lt - else: - cond = arith_d.andi(cond, lt) + new_index[last_dim] = new_index[last_dim] + idxc.iota(elements_per_thread) - pos = arith_d.ConstantOp(IndexType.get(), i) - mask = vector_d.insertelement(cond, mask, position=pos) + mask_expr = functools.reduce( + lambda a, b: sympy.And(a, b), (new_index[dim] < dim for dim in bounds) + ) + mask = gen_sympy_index(emitter, mask_expr) - new_index[last_dim] = new_index[last_dim] + 1 + mask_vec_type = VectorType.get([elements_per_thread], IntegerType.get_signless(1)) + if mask.type != mask_vec_type: + mask = vector_d.splat(mask_vec_type, mask) return mask diff --git a/shark_turbine/kernel/wave/docs/mlsys/.gitignore b/shark_turbine/kernel/wave/docs/mlsys/.gitignore index f2e31fe2..b4c7d64b 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/.gitignore +++ b/shark_turbine/kernel/wave/docs/mlsys/.gitignore @@ -3,3 +3,4 @@ *.out *.pdf *.synctex.gz +*.blg diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl b/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl index 1295a6d4..5ca46234 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl +++ b/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl @@ -1,54 +1,208 @@ -\begin{thebibliography}{8} +\begin{thebibliography}{6} \providecommand{\natexlab}[1]{#1} \providecommand{\url}[1]{\texttt{#1}} \expandafter\ifx\csname urlstyle\endcsname\relax \providecommand{\doi}[1]{doi: #1}\else \providecommand{\doi}{doi: \begingroup \urlstyle{rm}\Url}\fi -\bibitem[Author(2018)]{anonymous} -Author, N.~N. -\newblock Suppressed for anonymity, 2018. +\bibitem[Chetlur et~al.(2014)Chetlur, Woolley, Vandermersch, Cohen, Tran, + Catanzaro, and Shelhamer]{chetlur_cudnn_2014} +Chetlur, S., Woolley, C., Vandermersch, P., Cohen, J., Tran, J., Catanzaro, B., + and Shelhamer, E. +\newblock {cuDNN}: {Efficient} {Primitives} for {Deep} {Learning}, December + 2014. +\newblock URL \url{http://arxiv.org/abs/1410.0759}. +\newblock arXiv:1410.0759 [cs]. -\bibitem[Duda et~al.(2000)Duda, Hart, and Stork]{DudaHart2nd} -Duda, R.~O., Hart, P.~E., and Stork, D.~G. -\newblock \emph{Pattern Classification}. -\newblock John Wiley and Sons, 2nd edition, 2000. +\bibitem[Dubey et~al.(2024)Dubey, Jauhri, Pandey, Kadian, Al-Dahle, Letman, + Mathur, Schelten, Yang, Fan, Goyal, Hartshorn, Yang, Mitra, Sravankumar, + Korenev, Hinsvark, Rao, Zhang, Rodriguez, Gregerson, Spataru, Roziere, Biron, + Tang, Chern, Caucheteux, Nayak, Bi, Marra, McConnell, Keller, Touret, Wu, + Wong, Ferrer, Nikolaidis, Allonsius, Song, Pintz, Livshits, Esiobu, + Choudhary, Mahajan, Garcia-Olano, Perino, Hupkes, Lakomkin, AlBadawy, + Lobanova, Dinan, Smith, Radenovic, Zhang, Synnaeve, Lee, Anderson, Nail, + Mialon, Pang, Cucurell, Nguyen, Korevaar, Xu, Touvron, Zarov, Ibarra, + Kloumann, Misra, Evtimov, Copet, Lee, Geffert, Vranes, Park, Mahadeokar, + Shah, van~der Linde, Billock, Hong, Lee, Fu, Chi, Huang, Liu, Wang, Yu, + Bitton, Spisak, Park, Rocca, Johnstun, Saxe, Jia, Alwala, Upasani, Plawiak, + Li, Heafield, Stone, El-Arini, Iyer, Malik, Chiu, Bhalla, Rantala-Yeary, + van~der Maaten, Chen, Tan, Jenkins, Martin, Madaan, Malo, Blecher, Landzaat, + de~Oliveira, Muzzi, Pasupuleti, Singh, Paluri, Kardas, Oldham, Rita, Pavlova, + Kambadur, Lewis, Si, Singh, Hassan, Goyal, Torabi, Bashlykov, Bogoychev, + Chatterji, Duchenne, Çelebi, Alrassy, Zhang, Li, Vasic, Weng, Bhargava, + Dubal, Krishnan, Koura, Xu, He, Dong, Srinivasan, Ganapathy, Calderer, + Cabral, Stojnic, Raileanu, Girdhar, Patel, Sauvestre, Polidoro, Sumbaly, + Taylor, Silva, Hou, Wang, Hosseini, Chennabasappa, Singh, Bell, Kim, Edunov, + Nie, Narang, Raparthy, Shen, Wan, Bhosale, Zhang, Vandenhende, Batra, + Whitman, Sootla, Collot, Gururangan, Borodinsky, Herman, Fowler, Sheasha, + Georgiou, Scialom, Speckbacher, Mihaylov, Xiao, Karn, Goswami, Gupta, + Ramanathan, Kerkez, Gonguet, Do, Vogeti, Petrovic, Chu, Xiong, Fu, Meers, + Martinet, Wang, Tan, Xie, Jia, Wang, Goldschlag, Gaur, Babaei, Wen, Song, + Zhang, Li, Mao, Coudert, Yan, Chen, Papakipos, Singh, Grattafiori, Jain, + Kelsey, Shajnfeld, Gangidi, Victoria, Goldstand, Menon, Sharma, Boesenberg, + Vaughan, Baevski, Feinstein, Kallet, Sangani, Yunus, Lupu, Alvarado, Caples, + Gu, Ho, Poulton, Ryan, Ramchandani, Franco, Saraf, Chowdhury, Gabriel, + Bharambe, Eisenman, Yazdan, James, Maurer, Leonhardi, Huang, Loyd, De~Paola, + Paranjape, Liu, Wu, Ni, Hancock, Wasti, Spence, Stojkovic, Gamido, Montalvo, + Parker, Burton, Mejia, Wang, Kim, Zhou, Hu, Chu, Cai, Tindal, Feichtenhofer, + Civin, Beaty, Kreymer, Li, Wyatt, Adkins, Xu, Testuggine, David, Parikh, + Liskovich, Foss, Wang, Le, Holland, Dowling, Jamil, Montgomery, Presani, + Hahn, Wood, Brinkman, Arcaute, Dunbar, Smothers, Sun, Kreuk, Tian, Ozgenel, + Caggioni, Guzmán, Kanayet, Seide, Florez, Schwarz, Badeer, Swee, Halpern, + Thattai, Herman, Sizov, Guangyi, Zhang, Lakshminarayanan, Shojanazeri, Zou, + Wang, Zha, Habeeb, Rudolph, Suk, Aspegren, Goldman, Damlaj, Molybog, Tufanov, + Veliche, Gat, Weissman, Geboski, Kohli, Asher, Gaya, Marcus, Tang, Chan, + Zhen, Reizenstein, Teboul, Zhong, Jin, Yang, Cummings, Carvill, Shepard, + McPhie, Torres, Ginsburg, Wang, Wu, U, Saxena, Prasad, Khandelwal, Zand, + Matosich, Veeraraghavan, Michelena, Li, Huang, Chawla, Lakhotia, Huang, Chen, + Garg, A, Silva, Bell, Zhang, Guo, Yu, Moshkovich, Wehrstedt, Khabsa, Avalani, + Bhatt, Tsimpoukelli, Mankus, Hasson, Lennie, Reso, Groshev, Naumov, Lathi, + Keneally, Seltzer, Valko, Restrepo, Patel, Vyatskov, Samvelyan, Clark, Macey, + Wang, Hermoso, Metanat, Rastegari, Bansal, Santhanam, Parks, White, Bawa, + Singhal, Egebo, Usunier, Laptev, Dong, Zhang, Cheng, Chernoguz, Hart, + Salpekar, Kalinli, Kent, Parekh, Saab, Balaji, Rittner, Bontrager, Roux, + Dollar, Zvyagina, Ratanchandani, Yuvraj, Liang, Alao, Rodriguez, Ayub, + Murthy, Nayani, Mitra, Li, Hogan, Battey, Wang, Maheswari, Howes, Rinott, + Bondu, Datta, Chugh, Hunt, Dhillon, Sidorov, Pan, Verma, Yamamoto, Ramaswamy, + Lindsay, Lindsay, Feng, Lin, Zha, Shankar, Zhang, Zhang, Wang, Agarwal, + Sajuyigbe, Chintala, Max, Chen, Kehoe, Satterfield, Govindaprasad, Gupta, + Cho, Virk, Subramanian, Choudhury, Goldman, Remez, Glaser, Best, Kohler, + Robinson, Li, Zhang, Matthews, Chou, Shaked, Vontimitta, Ajayi, Montanez, + Mohan, Kumar, Mangla, Albiero, Ionescu, Poenaru, Mihailescu, Ivanov, Li, + Wang, Jiang, Bouaziz, Constable, Tang, Wang, Wu, Wang, Xia, Wu, Gao, Chen, + Hu, Jia, Qi, Li, Zhang, Zhang, Adi, Nam, Yu, Wang, Hao, Qian, He, Rait, + DeVito, Rosnbrick, Wen, Yang, and Zhao]{dubey_llama_2024} +Dubey, A., Jauhri, A., Pandey, A., Kadian, A., Al-Dahle, A., Letman, A., + Mathur, A., Schelten, A., Yang, A., Fan, A., Goyal, A., Hartshorn, A., Yang, + A., Mitra, A., Sravankumar, A., Korenev, A., Hinsvark, A., Rao, A., Zhang, + A., Rodriguez, A., Gregerson, A., Spataru, A., Roziere, B., Biron, B., Tang, + B., Chern, B., Caucheteux, C., Nayak, C., Bi, C., Marra, C., McConnell, C., + Keller, C., Touret, C., Wu, C., Wong, C., Ferrer, C.~C., Nikolaidis, C., + Allonsius, D., Song, D., Pintz, D., Livshits, D., Esiobu, D., Choudhary, D., + Mahajan, D., Garcia-Olano, D., Perino, D., Hupkes, D., Lakomkin, E., + AlBadawy, E., Lobanova, E., Dinan, E., Smith, E.~M., Radenovic, F., Zhang, + F., Synnaeve, G., Lee, G., Anderson, G.~L., Nail, G., Mialon, G., Pang, G., + Cucurell, G., Nguyen, H., Korevaar, H., Xu, H., Touvron, H., Zarov, I., + Ibarra, I.~A., Kloumann, I., Misra, I., Evtimov, I., Copet, J., Lee, J., + Geffert, J., Vranes, J., Park, J., Mahadeokar, J., Shah, J., van~der Linde, + J., Billock, J., Hong, J., Lee, J., Fu, J., Chi, J., Huang, J., Liu, J., + Wang, J., Yu, J., Bitton, J., Spisak, J., Park, J., Rocca, J., Johnstun, J., + Saxe, J., Jia, J., Alwala, K.~V., Upasani, K., Plawiak, K., Li, K., Heafield, + K., Stone, K., El-Arini, K., Iyer, K., Malik, K., Chiu, K., Bhalla, K., + Rantala-Yeary, L., van~der Maaten, L., Chen, L., Tan, L., Jenkins, L., + Martin, L., Madaan, L., Malo, L., Blecher, L., Landzaat, L., de~Oliveira, L., + Muzzi, M., Pasupuleti, M., Singh, M., Paluri, M., Kardas, M., Oldham, M., + Rita, M., Pavlova, M., Kambadur, M., Lewis, M., Si, M., Singh, M.~K., Hassan, + M., Goyal, N., Torabi, N., Bashlykov, N., Bogoychev, N., Chatterji, N., + Duchenne, O., Çelebi, O., Alrassy, P., Zhang, P., Li, P., Vasic, P., Weng, + P., Bhargava, P., Dubal, P., Krishnan, P., Koura, P.~S., Xu, P., He, Q., + Dong, Q., Srinivasan, R., Ganapathy, R., Calderer, R., Cabral, R.~S., + Stojnic, R., Raileanu, R., Girdhar, R., Patel, R., Sauvestre, R., Polidoro, + R., Sumbaly, R., Taylor, R., Silva, R., Hou, R., Wang, R., Hosseini, S., + Chennabasappa, S., Singh, S., Bell, S., Kim, S.~S., Edunov, S., Nie, S., + Narang, S., Raparthy, S., Shen, S., Wan, S., Bhosale, S., Zhang, S., + Vandenhende, S., Batra, S., Whitman, S., Sootla, S., Collot, S., Gururangan, + S., Borodinsky, S., Herman, T., Fowler, T., Sheasha, T., Georgiou, T., + Scialom, T., Speckbacher, T., Mihaylov, T., Xiao, T., Karn, U., Goswami, V., + Gupta, V., Ramanathan, V., Kerkez, V., Gonguet, V., Do, V., Vogeti, V., + Petrovic, V., Chu, W., Xiong, W., Fu, W., Meers, W., Martinet, X., Wang, X., + Tan, X.~E., Xie, X., Jia, X., Wang, X., Goldschlag, Y., Gaur, Y., Babaei, Y., + Wen, Y., Song, Y., Zhang, Y., Li, Y., Mao, Y., Coudert, Z.~D., Yan, Z., Chen, + Z., Papakipos, Z., Singh, A., Grattafiori, A., Jain, A., Kelsey, A., + Shajnfeld, A., Gangidi, A., Victoria, A., Goldstand, A., Menon, A., Sharma, + A., Boesenberg, A., Vaughan, A., Baevski, A., Feinstein, A., Kallet, A., + Sangani, A., Yunus, A., Lupu, A., Alvarado, A., Caples, A., Gu, A., Ho, A., + Poulton, A., Ryan, A., Ramchandani, A., Franco, A., Saraf, A., Chowdhury, A., + Gabriel, A., Bharambe, A., Eisenman, A., Yazdan, A., James, B., Maurer, B., + Leonhardi, B., Huang, B., Loyd, B., De~Paola, B., Paranjape, B., Liu, B., Wu, + B., Ni, B., Hancock, B., Wasti, B., Spence, B., Stojkovic, B., Gamido, B., + Montalvo, B., Parker, C., Burton, C., Mejia, C., Wang, C., Kim, C., Zhou, C., + Hu, C., Chu, C.-H., Cai, C., Tindal, C., Feichtenhofer, C., Civin, D., Beaty, + D., Kreymer, D., Li, D., Wyatt, D., Adkins, D., Xu, D., Testuggine, D., + David, D., Parikh, D., Liskovich, D., Foss, D., Wang, D., Le, D., Holland, + D., Dowling, E., Jamil, E., Montgomery, E., Presani, E., Hahn, E., Wood, E., + Brinkman, E., Arcaute, E., Dunbar, E., Smothers, E., Sun, F., Kreuk, F., + Tian, F., Ozgenel, F., Caggioni, F., Guzmán, F., Kanayet, F., Seide, F., + Florez, G.~M., Schwarz, G., Badeer, G., Swee, G., Halpern, G., Thattai, G., + Herman, G., Sizov, G., Guangyi, Zhang, Lakshminarayanan, G., Shojanazeri, H., + Zou, H., Wang, H., Zha, H., Habeeb, H., Rudolph, H., Suk, H., Aspegren, H., + Goldman, H., Damlaj, I., Molybog, I., Tufanov, I., Veliche, I.-E., Gat, I., + Weissman, J., Geboski, J., Kohli, J., Asher, J., Gaya, J.-B., Marcus, J., + Tang, J., Chan, J., Zhen, J., Reizenstein, J., Teboul, J., Zhong, J., Jin, + J., Yang, J., Cummings, J., Carvill, J., Shepard, J., McPhie, J., Torres, J., + Ginsburg, J., Wang, J., Wu, K., U, K.~H., Saxena, K., Prasad, K., Khandelwal, + K., Zand, K., Matosich, K., Veeraraghavan, K., Michelena, K., Li, K., Huang, + K., Chawla, K., Lakhotia, K., Huang, K., Chen, L., Garg, L., A, L., Silva, + L., Bell, L., Zhang, L., Guo, L., Yu, L., Moshkovich, L., Wehrstedt, L., + Khabsa, M., Avalani, M., Bhatt, M., Tsimpoukelli, M., Mankus, M., Hasson, M., + Lennie, M., Reso, M., Groshev, M., Naumov, M., Lathi, M., Keneally, M., + Seltzer, M.~L., Valko, M., Restrepo, M., Patel, M., Vyatskov, M., Samvelyan, + M., Clark, M., Macey, M., Wang, M., Hermoso, M.~J., Metanat, M., Rastegari, + M., Bansal, M., Santhanam, N., Parks, N., White, N., Bawa, N., Singhal, N., + Egebo, N., Usunier, N., Laptev, N.~P., Dong, N., Zhang, N., Cheng, N., + Chernoguz, O., Hart, O., Salpekar, O., Kalinli, O., Kent, P., Parekh, P., + Saab, P., Balaji, P., Rittner, P., Bontrager, P., Roux, P., Dollar, P., + Zvyagina, P., Ratanchandani, P., Yuvraj, P., Liang, Q., Alao, R., Rodriguez, + R., Ayub, R., Murthy, R., Nayani, R., Mitra, R., Li, R., Hogan, R., Battey, + R., Wang, R., Maheswari, R., Howes, R., Rinott, R., Bondu, S.~J., Datta, S., + Chugh, S., Hunt, S., Dhillon, S., Sidorov, S., Pan, S., Verma, S., Yamamoto, + S., Ramaswamy, S., Lindsay, S., Lindsay, S., Feng, S., Lin, S., Zha, S.~C., + Shankar, S., Zhang, S., Zhang, S., Wang, S., Agarwal, S., Sajuyigbe, S., + Chintala, S., Max, S., Chen, S., Kehoe, S., Satterfield, S., Govindaprasad, + S., Gupta, S., Cho, S., Virk, S., Subramanian, S., Choudhury, S., Goldman, + S., Remez, T., Glaser, T., Best, T., Kohler, T., Robinson, T., Li, T., Zhang, + T., Matthews, T., Chou, T., Shaked, T., Vontimitta, V., Ajayi, V., Montanez, + V., Mohan, V., Kumar, V.~S., Mangla, V., Albiero, V., Ionescu, V., Poenaru, + V., Mihailescu, V.~T., Ivanov, V., Li, W., Wang, W., Jiang, W., Bouaziz, W., + Constable, W., Tang, X., Wang, X., Wu, X., Wang, X., Xia, X., Wu, X., Gao, + X., Chen, Y., Hu, Y., Jia, Y., Qi, Y., Li, Y., Zhang, Y., Zhang, Y., Adi, Y., + Nam, Y., Yu, Wang, Hao, Y., Qian, Y., He, Y., Rait, Z., DeVito, Z., + Rosnbrick, Z., Wen, Z., Yang, Z., and Zhao, Z. +\newblock The {Llama} 3 {Herd} of {Models}, August 2024. +\newblock URL \url{http://arxiv.org/abs/2407.21783}. +\newblock arXiv:2407.21783 [cs]. -\bibitem[Kearns(1989)]{kearns89} -Kearns, M.~J. -\newblock \emph{Computational Complexity of Machine Learning}. -\newblock PhD thesis, Department of Computer Science, Harvard University, 1989. +\bibitem[Paszke et~al.(2019)Paszke, Gross, Massa, Lerer, Bradbury, Chanan, + Killeen, Lin, Gimelshein, Antiga, Desmaison, Köpf, Yang, DeVito, Raison, + Tejani, Chilamkurthy, Steiner, Fang, Bai, and Chintala]{paszke_pytorch_2019} +Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, + T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Köpf, A., Yang, E., + DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., + Bai, J., and Chintala, S. +\newblock {PyTorch}: {An} {Imperative} {Style}, {High}-{Performance} {Deep} + {Learning} {Library}, December 2019. +\newblock URL \url{http://arxiv.org/abs/1912.01703}. +\newblock arXiv:1912.01703 [cs, stat]. -\bibitem[Langley(2000)]{langley00} -Langley, P. -\newblock Crafting papers on machine learning. -\newblock In Langley, P. (ed.), \emph{Proceedings of the 17th International - Conference on Machine Learning (ICML 2000)}, pp.\ 1207--1216, Stanford, CA, - 2000. Morgan Kaufmann. +\bibitem[Podell et~al.(2023)Podell, English, Lacey, Blattmann, Dockhorn, + Müller, Penna, and Rombach]{podell_sdxl_2023} +Podell, D., English, Z., Lacey, K., Blattmann, A., Dockhorn, T., Müller, J., + Penna, J., and Rombach, R. +\newblock {SDXL}: {Improving} {Latent} {Diffusion} {Models} for + {High}-{Resolution} {Image} {Synthesis}, July 2023. +\newblock URL \url{http://arxiv.org/abs/2307.01952}. +\newblock arXiv:2307.01952 [cs]. -\bibitem[Michalski et~al.(1983)Michalski, Carbonell, and - Mitchell]{MachineLearningI} -Michalski, R.~S., Carbonell, J.~G., and Mitchell, T.~M. (eds.). -\newblock \emph{Machine Learning: An Artificial Intelligence Approach, Vol. I}. -\newblock Tioga, Palo Alto, CA, 1983. +\bibitem[Sun et~al.(2023)Sun, Li, Geng, Stuijk, and + Corporaal]{sun_dissecting_2023} +Sun, W., Li, A., Geng, T., Stuijk, S., and Corporaal, H. +\newblock Dissecting {Tensor} {Cores} via {Microbenchmarks}: {Latency}, + {Throughput} and {Numeric} {Behaviors}. +\newblock \emph{IEEE Transactions on Parallel and Distributed Systems}, + 34\penalty0 (1):\penalty0 246--261, January 2023. +\newblock ISSN 1045-9219, 1558-2183, 2161-9883. +\newblock \doi{10.1109/TPDS.2022.3217824}. +\newblock URL \url{https://ieeexplore.ieee.org/document/9931992/}. -\bibitem[Mitchell(1980)]{mitchell80} -Mitchell, T.~M. -\newblock The need for biases in learning generalizations. -\newblock Technical report, Computer Science Department, Rutgers University, - New Brunswick, MA, 1980. - -\bibitem[Newell \& Rosenbloom(1981)Newell and Rosenbloom]{Newell81} -Newell, A. and Rosenbloom, P.~S. -\newblock Mechanisms of skill acquisition and the law of practice. -\newblock In Anderson, J.~R. (ed.), \emph{Cognitive Skills and Their - Acquisition}, chapter~1, pp.\ 1--51. Lawrence Erlbaum Associates, Inc., - Hillsdale, NJ, 1981. - -\bibitem[Samuel(1959)]{Samuel59} -Samuel, A.~L. -\newblock Some studies in machine learning using the game of checkers. -\newblock \emph{IBM Journal of Research and Development}, 3\penalty0 - (3):\penalty0 211--229, 1959. +\bibitem[Tillet et~al.(2019)Tillet, Kung, and Cox]{tillet_triton_2019} +Tillet, P., Kung, H.~T., and Cox, D. +\newblock Triton: an intermediate language and compiler for tiled neural + network computations. +\newblock In \emph{Proceedings of the 3rd {ACM} {SIGPLAN} {International} + {Workshop} on {Machine} {Learning} and {Programming} {Languages}}, pp.\ + 10--19, Phoenix AZ USA, June 2019. ACM. +\newblock ISBN 978-1-4503-6719-6. +\newblock \doi{10.1145/3315508.3329973}. +\newblock URL \url{https://dl.acm.org/doi/10.1145/3315508.3329973}. \end{thebibliography} diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.bib b/shark_turbine/kernel/wave/docs/mlsys/tkw.bib index 6bd0e3ee..61f02789 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.bib +++ b/shark_turbine/kernel/wave/docs/mlsys/tkw.bib @@ -1,75 +1,102 @@ -@inproceedings{langley00, - author = {P. Langley}, - title = {Crafting Papers on Machine Learning}, - year = {2000}, - pages = {1207--1216}, - editor = {Pat Langley}, - booktitle = {Proceedings of the 17th International Conference - on Machine Learning (ICML 2000)}, - address = {Stanford, CA}, - publisher = {Morgan Kaufmann} -} - -@TechReport{mitchell80, - author = "T. M. Mitchell", - title = "The Need for Biases in Learning Generalizations", - institution = "Computer Science Department, Rutgers University", - year = "1980", - address = "New Brunswick, MA", -} -@phdthesis{kearns89, - author = {M. J. Kearns}, - title = {Computational Complexity of Machine Learning}, - school = {Department of Computer Science, Harvard University}, - year = {1989} +@inproceedings{tillet_triton_2019, + address = {Phoenix AZ USA}, + title = {Triton: an intermediate language and compiler for tiled neural network computations}, + isbn = {978-1-4503-6719-6}, + shorttitle = {Triton}, + url = {https://dl.acm.org/doi/10.1145/3315508.3329973}, + doi = {10.1145/3315508.3329973}, + abstract = {The validation and deployment of novel research ideas in the field of Deep Learning is often limited by the availability of efficient compute kernels for certain basic primitives. In particular, operations that cannot leverage existing vendor libraries (e.g., cuBLAS, cuDNN) are at risk of facing poor device utilization unless custom implementations are written by experts – usually at the expense of portability. For this reason, the development of new programming abstractions for specifying custom Deep Learning workloads at a minimal performance cost has become crucial.}, + language = {en}, + urldate = {2024-09-25}, + booktitle = {Proceedings of the 3rd {ACM} {SIGPLAN} {International} {Workshop} on {Machine} {Learning} and {Programming} {Languages}}, + publisher = {ACM}, + author = {Tillet, Philippe and Kung, H. T. and Cox, David}, + month = jun, + year = {2019}, + pages = {10--19}, + file = {PDF:/Users/harsh/Zotero/storage/FMLLYK4M/Tillet et al. - 2019 - Triton an intermediate language and compiler for tiled neural network computations.pdf:application/pdf}, } -@Book{MachineLearningI, - editor = "R. S. Michalski and J. G. Carbonell and T. - M. Mitchell", - title = "Machine Learning: An Artificial Intelligence - Approach, Vol. I", - publisher = "Tioga", - year = "1983", - address = "Palo Alto, CA" +@misc{podell_sdxl_2023, + title = {{SDXL}: {Improving} {Latent} {Diffusion} {Models} for {High}-{Resolution} {Image} {Synthesis}}, + shorttitle = {{SDXL}}, + url = {http://arxiv.org/abs/2307.01952}, + abstract = {We present SDXL, a latent diffusion model for text-to-image synthesis. Compared to previous versions of Stable Diffusion, SDXL leverages a three times larger UNet backbone: The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder. We design multiple novel conditioning schemes and train SDXL on multiple aspect ratios. We also introduce a refinement model which is used to improve the visual fidelity of samples generated by SDXL using a post-hoc image-to-image technique. We demonstrate that SDXL shows drastically improved performance compared to previous versions of Stable Diffusion and achieves results competitive with those of black-box state-of-the-art image generators. In the spirit of promoting open research and fostering transparency in large model training and evaluation, we provide access to code and model weights.}, + language = {en}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Podell, Dustin and English, Zion and Lacey, Kyle and Blattmann, Andreas and Dockhorn, Tim and Müller, Jonas and Penna, Joe and Rombach, Robin}, + month = jul, + year = {2023}, + note = {arXiv:2307.01952 [cs]}, + keywords = {Computer Science - Artificial Intelligence, Computer Science - Computer Vision and Pattern Recognition}, + file = {PDF:/Users/harsh/Zotero/storage/ARJZQZ42/Podell et al. - 2023 - SDXL Improving Latent Diffusion Models for High-Resolution Image Synthesis.pdf:application/pdf}, } -@Book{DudaHart2nd, - author = "R. O. Duda and P. E. Hart and D. G. Stork", - title = "Pattern Classification", - publisher = "John Wiley and Sons", - edition = "2nd", - year = "2000" +@misc{dubey_llama_2024, + title = {The {Llama} 3 {Herd} of {Models}}, + url = {http://arxiv.org/abs/2407.21783}, + abstract = {Modern artificial intelligence (AI) systems are powered by foundation models. This paper presents a new set of foundation models, called Llama 3. It is a herd of language models that natively support multilinguality, coding, reasoning, and tool usage. Our largest model is a dense Transformer with 405B parameters and a context window of up to 128K tokens. This paper presents an extensive empirical evaluation of Llama 3. We find that Llama 3 delivers comparable quality to leading language models such as GPT-4 on a plethora of tasks. We publicly release Llama 3, including pre-trained and post-trained versions of the 405B parameter language model and our Llama Guard 3 model for input and output safety. The paper also presents the results of experiments in which we integrate image, video, and speech capabilities into Llama 3 via a compositional approach. We observe this approach performs competitively with the state-of-the-art on image, video, and speech recognition tasks. The resulting models are not yet being broadly released as they are still under development.}, + language = {en}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Dubey, Abhimanyu and Jauhri, Abhinav and Pandey, Abhinav and Kadian, Abhishek and Al-Dahle, Ahmad and Letman, Aiesha and Mathur, Akhil and Schelten, Alan and Yang, Amy and Fan, Angela and Goyal, Anirudh and Hartshorn, Anthony and Yang, Aobo and Mitra, Archi and Sravankumar, Archie and Korenev, Artem and Hinsvark, Arthur and Rao, Arun and Zhang, Aston and Rodriguez, Aurelien and Gregerson, Austen and Spataru, Ava and Roziere, Baptiste and Biron, Bethany and Tang, Binh and Chern, Bobbie and Caucheteux, Charlotte and Nayak, Chaya and Bi, Chloe and Marra, Chris and McConnell, Chris and Keller, Christian and Touret, Christophe and Wu, Chunyang and Wong, Corinne and Ferrer, Cristian Canton and Nikolaidis, Cyrus and Allonsius, Damien and Song, Daniel and Pintz, Danielle and Livshits, Danny and Esiobu, David and Choudhary, Dhruv and Mahajan, Dhruv and Garcia-Olano, Diego and Perino, Diego and Hupkes, Dieuwke and Lakomkin, Egor and AlBadawy, Ehab and Lobanova, Elina and Dinan, Emily and Smith, Eric Michael and Radenovic, Filip and Zhang, Frank and Synnaeve, Gabriel and Lee, Gabrielle and Anderson, Georgia Lewis and Nail, Graeme and Mialon, Gregoire and Pang, Guan and Cucurell, Guillem and Nguyen, Hailey and Korevaar, Hannah and Xu, Hu and Touvron, Hugo and Zarov, Iliyan and Ibarra, Imanol Arrieta and Kloumann, Isabel and Misra, Ishan and Evtimov, Ivan and Copet, Jade and Lee, Jaewon and Geffert, Jan and Vranes, Jana and Park, Jason and Mahadeokar, Jay and Shah, Jeet and van der Linde, Jelmer and Billock, Jennifer and Hong, Jenny and Lee, Jenya and Fu, Jeremy and Chi, Jianfeng and Huang, Jianyu and Liu, Jiawen and Wang, Jie and Yu, Jiecao and Bitton, Joanna and Spisak, Joe and Park, Jongsoo and Rocca, Joseph and Johnstun, Joshua and Saxe, Joshua and Jia, Junteng and Alwala, Kalyan Vasuden and Upasani, Kartikeya and Plawiak, Kate and Li, Ke and Heafield, Kenneth and Stone, Kevin and El-Arini, Khalid and Iyer, Krithika and Malik, Kshitiz and Chiu, Kuenley and Bhalla, Kunal and Rantala-Yeary, Lauren and van der Maaten, Laurens and Chen, Lawrence and Tan, Liang and Jenkins, Liz and Martin, Louis and Madaan, Lovish and Malo, Lubo and Blecher, Lukas and Landzaat, Lukas and de Oliveira, Luke and Muzzi, Madeline and Pasupuleti, Mahesh and Singh, Mannat and Paluri, Manohar and Kardas, Marcin and Oldham, Mathew and Rita, Mathieu and Pavlova, Maya and Kambadur, Melanie and Lewis, Mike and Si, Min and Singh, Mitesh Kumar and Hassan, Mona and Goyal, Naman and Torabi, Narjes and Bashlykov, Nikolay and Bogoychev, Nikolay and Chatterji, Niladri and Duchenne, Olivier and Çelebi, Onur and Alrassy, Patrick and Zhang, Pengchuan and Li, Pengwei and Vasic, Petar and Weng, Peter and Bhargava, Prajjwal and Dubal, Pratik and Krishnan, Praveen and Koura, Punit Singh and Xu, Puxin and He, Qing and Dong, Qingxiao and Srinivasan, Ragavan and Ganapathy, Raj and Calderer, Ramon and Cabral, Ricardo Silveira and Stojnic, Robert and Raileanu, Roberta and Girdhar, Rohit and Patel, Rohit and Sauvestre, Romain and Polidoro, Ronnie and Sumbaly, Roshan and Taylor, Ross and Silva, Ruan and Hou, Rui and Wang, Rui and Hosseini, Saghar and Chennabasappa, Sahana and Singh, Sanjay and Bell, Sean and Kim, Seohyun Sonia and Edunov, Sergey and Nie, Shaoliang and Narang, Sharan and Raparthy, Sharath and Shen, Sheng and Wan, Shengye and Bhosale, Shruti and Zhang, Shun and Vandenhende, Simon and Batra, Soumya and Whitman, Spencer and Sootla, Sten and Collot, Stephane and Gururangan, Suchin and Borodinsky, Sydney and Herman, Tamar and Fowler, Tara and Sheasha, Tarek and Georgiou, Thomas and Scialom, Thomas and Speckbacher, Tobias and Mihaylov, Todor and Xiao, Tong and Karn, Ujjwal and Goswami, Vedanuj and Gupta, Vibhor and Ramanathan, Vignesh and Kerkez, Viktor and Gonguet, Vincent and Do, Virginie and Vogeti, Vish and Petrovic, Vladan and Chu, Weiwei and Xiong, Wenhan and Fu, Wenyin and Meers, Whitney and Martinet, Xavier and Wang, Xiaodong and Tan, Xiaoqing Ellen and Xie, Xinfeng and Jia, Xuchao and Wang, Xuewei and Goldschlag, Yaelle and Gaur, Yashesh and Babaei, Yasmine and Wen, Yi and Song, Yiwen and Zhang, Yuchen and Li, Yue and Mao, Yuning and Coudert, Zacharie Delpierre and Yan, Zheng and Chen, Zhengxing and Papakipos, Zoe and Singh, Aaditya and Grattafiori, Aaron and Jain, Abha and Kelsey, Adam and Shajnfeld, Adam and Gangidi, Adithya and Victoria, Adolfo and Goldstand, Ahuva and Menon, Ajay and Sharma, Ajay and Boesenberg, Alex and Vaughan, Alex and Baevski, Alexei and Feinstein, Allie and Kallet, Amanda and Sangani, Amit and Yunus, Anam and Lupu, Andrei and Alvarado, Andres and Caples, Andrew and Gu, Andrew and Ho, Andrew and Poulton, Andrew and Ryan, Andrew and Ramchandani, Ankit and Franco, Annie and Saraf, Aparajita and Chowdhury, Arkabandhu and Gabriel, Ashley and Bharambe, Ashwin and Eisenman, Assaf and Yazdan, Azadeh and James, Beau and Maurer, Ben and Leonhardi, Benjamin and Huang, Bernie and Loyd, Beth and De Paola, Beto and Paranjape, Bhargavi and Liu, Bing and Wu, Bo and Ni, Boyu and Hancock, Braden and Wasti, Bram and Spence, Brandon and Stojkovic, Brani and Gamido, Brian and Montalvo, Britt and Parker, Carl and Burton, Carly and Mejia, Catalina and Wang, Changhan and Kim, Changkyu and Zhou, Chao and Hu, Chester and Chu, Ching-Hsiang and Cai, Chris and Tindal, Chris and Feichtenhofer, Christoph and Civin, Damon and Beaty, Dana and Kreymer, Daniel and Li, Daniel and Wyatt, Danny and Adkins, David and Xu, David and Testuggine, Davide and David, Delia and Parikh, Devi and Liskovich, Diana and Foss, Didem and Wang, Dingkang and Le, Duc and Holland, Dustin and Dowling, Edward and Jamil, Eissa and Montgomery, Elaine and Presani, Eleonora and Hahn, Emily and Wood, Emily and Brinkman, Erik and Arcaute, Esteban and Dunbar, Evan and Smothers, Evan and Sun, Fei and Kreuk, Felix and Tian, Feng and Ozgenel, Firat and Caggioni, Francesco and Guzmán, Francisco and Kanayet, Frank and Seide, Frank and Florez, Gabriela Medina and Schwarz, Gabriella and Badeer, Gada and Swee, Georgia and Halpern, Gil and Thattai, Govind and Herman, Grant and Sizov, Grigory and Guangyi and Zhang and Lakshminarayanan, Guna and Shojanazeri, Hamid and Zou, Han and Wang, Hannah and Zha, Hanwen and Habeeb, Haroun and Rudolph, Harrison and Suk, Helen and Aspegren, Henry and Goldman, Hunter and Damlaj, Ibrahim and Molybog, Igor and Tufanov, Igor and Veliche, Irina-Elena and Gat, Itai and Weissman, Jake and Geboski, James and Kohli, James and Asher, Japhet and Gaya, Jean-Baptiste and Marcus, Jeff and Tang, Jeff and Chan, Jennifer and Zhen, Jenny and Reizenstein, Jeremy and Teboul, Jeremy and Zhong, Jessica and Jin, Jian and Yang, Jingyi and Cummings, Joe and Carvill, Jon and Shepard, Jon and McPhie, Jonathan and Torres, Jonathan and Ginsburg, Josh and Wang, Junjie and Wu, Kai and U, Kam Hou and Saxena, Karan and Prasad, Karthik and Khandelwal, Kartikay and Zand, Katayoun and Matosich, Kathy and Veeraraghavan, Kaushik and Michelena, Kelly and Li, Keqian and Huang, Kun and Chawla, Kunal and Lakhotia, Kushal and Huang, Kyle and Chen, Lailin and Garg, Lakshya and A, Lavender and Silva, Leandro and Bell, Lee and Zhang, Lei and Guo, Liangpeng and Yu, Licheng and Moshkovich, Liron and Wehrstedt, Luca and Khabsa, Madian and Avalani, Manav and Bhatt, Manish and Tsimpoukelli, Maria and Mankus, Martynas and Hasson, Matan and Lennie, Matthew and Reso, Matthias and Groshev, Maxim and Naumov, Maxim and Lathi, Maya and Keneally, Meghan and Seltzer, Michael L. and Valko, Michal and Restrepo, Michelle and Patel, Mihir and Vyatskov, Mik and Samvelyan, Mikayel and Clark, Mike and Macey, Mike and Wang, Mike and Hermoso, Miquel Jubert and Metanat, Mo and Rastegari, Mohammad and Bansal, Munish and Santhanam, Nandhini and Parks, Natascha and White, Natasha and Bawa, Navyata and Singhal, Nayan and Egebo, Nick and Usunier, Nicolas and Laptev, Nikolay Pavlovich and Dong, Ning and Zhang, Ning and Cheng, Norman and Chernoguz, Oleg and Hart, Olivia and Salpekar, Omkar and Kalinli, Ozlem and Kent, Parkin and Parekh, Parth and Saab, Paul and Balaji, Pavan and Rittner, Pedro and Bontrager, Philip and Roux, Pierre and Dollar, Piotr and Zvyagina, Polina and Ratanchandani, Prashant and Yuvraj, Pritish and Liang, Qian and Alao, Rachad and Rodriguez, Rachel and Ayub, Rafi and Murthy, Raghotham and Nayani, Raghu and Mitra, Rahul and Li, Raymond and Hogan, Rebekkah and Battey, Robin and Wang, Rocky and Maheswari, Rohan and Howes, Russ and Rinott, Ruty and Bondu, Sai Jayesh and Datta, Samyak and Chugh, Sara and Hunt, Sara and Dhillon, Sargun and Sidorov, Sasha and Pan, Satadru and Verma, Saurabh and Yamamoto, Seiji and Ramaswamy, Sharadh and Lindsay, Shaun and Lindsay, Shaun and Feng, Sheng and Lin, Shenghao and Zha, Shengxin Cindy and Shankar, Shiva and Zhang, Shuqiang and Zhang, Shuqiang and Wang, Sinong and Agarwal, Sneha and Sajuyigbe, Soji and Chintala, Soumith and Max, Stephanie and Chen, Stephen and Kehoe, Steve and Satterfield, Steve and Govindaprasad, Sudarshan and Gupta, Sumit and Cho, Sungmin and Virk, Sunny and Subramanian, Suraj and Choudhury, Sy and Goldman, Sydney and Remez, Tal and Glaser, Tamar and Best, Tamara and Kohler, Thilo and Robinson, Thomas and Li, Tianhe and Zhang, Tianjun and Matthews, Tim and Chou, Timothy and Shaked, Tzook and Vontimitta, Varun and Ajayi, Victoria and Montanez, Victoria and Mohan, Vijai and Kumar, Vinay Satish and Mangla, Vishal and Albiero, Vítor and Ionescu, Vlad and Poenaru, Vlad and Mihailescu, Vlad Tiberiu and Ivanov, Vladimir and Li, Wei and Wang, Wenchen and Jiang, Wenwen and Bouaziz, Wes and Constable, Will and Tang, Xiaocheng and Wang, Xiaofang and Wu, Xiaojian and Wang, Xiaolan and Xia, Xide and Wu, Xilun and Gao, Xinbo and Chen, Yanjun and Hu, Ye and Jia, Ye and Qi, Ye and Li, Yenda and Zhang, Yilin and Zhang, Ying and Adi, Yossi and Nam, Youngjin and Yu and Wang and Hao, Yuchen and Qian, Yundi and He, Yuzi and Rait, Zach and DeVito, Zachary and Rosnbrick, Zef and Wen, Zhaoduo and Yang, Zhenyu and Zhao, Zhiwei}, + month = aug, + year = {2024}, + note = {arXiv:2407.21783 [cs]}, + keywords = {Computer Science - Artificial Intelligence, Computer Science - Computer Vision and Pattern Recognition, Computer Science - Computation and Language}, + file = {PDF:/Users/harsh/Zotero/storage/BQKY8VZZ/Dubey et al. - 2024 - The Llama 3 Herd of Models.pdf:application/pdf}, } -@misc{anonymous, - title= {Suppressed for Anonymity}, - author= {Author, N. N.}, - year= {2018} +@article{sun_dissecting_2023, + title = {Dissecting {Tensor} {Cores} via {Microbenchmarks}: {Latency}, {Throughput} and {Numeric} {Behaviors}}, + volume = {34}, + copyright = {https://ieeexplore.ieee.org/Xplorehelp/downloads/license-information/IEEE.html}, + issn = {1045-9219, 1558-2183, 2161-9883}, + shorttitle = {Dissecting {Tensor} {Cores} via {Microbenchmarks}}, + url = {https://ieeexplore.ieee.org/document/9931992/}, + doi = {10.1109/TPDS.2022.3217824}, + abstract = {Tensor Cores have been an important unit to accelerate Fused Matrix Multiplication Accumulation (MMA) in all NVIDIA GPUs since Volta Architecture. To program Tensor Cores, users have to use either legacy wmma APIs or current mma APIs. Legacy wmma APIs are more easy-to-use but can only exploit limited features and power of Tensor Cores. Specifically, wmma APIs support fewer operand shapes and can not leverage the new sparse matrix multiplication feature of the newest Ampere Tensor Cores. However, the performance of current programming interface has not been well explored. Furthermore, the computation numeric behaviors of lowprecision floating points (TF32, BF16, and FP16) supported by the newest Ampere Tensor Cores are also mysterious. In this paper, we explore the throughput and latency of current programming APIs. We also intuitively study the numeric behaviors of Tensor Cores MMA and profile the intermediate operations including multiplication, addition of inner product, and accumulation. All codes used in this work can be found in https://github.com/sunlex0717/DissectingTensorCores.}, + language = {en}, + number = {1}, + urldate = {2024-09-25}, + journal = {IEEE Transactions on Parallel and Distributed Systems}, + author = {Sun, Wei and Li, Ang and Geng, Tong and Stuijk, Sander and Corporaal, Henk}, + month = jan, + year = {2023}, + pages = {246--261}, + file = {PDF:/Users/harsh/Zotero/storage/NZD3FJUB/Sun et al. - 2023 - Dissecting Tensor Cores via Microbenchmarks Latency, Throughput and Numeric Behaviors.pdf:application/pdf}, } -@InCollection{Newell81, - author = "A. Newell and P. S. Rosenbloom", - title = "Mechanisms of Skill Acquisition and the Law of - Practice", - booktitle = "Cognitive Skills and Their Acquisition", - pages = "1--51", - publisher = "Lawrence Erlbaum Associates, Inc.", - year = "1981", - editor = "J. R. Anderson", - chapter = "1", - address = "Hillsdale, NJ" +@misc{paszke_pytorch_2019, + title = {{PyTorch}: {An} {Imperative} {Style}, {High}-{Performance} {Deep} {Learning} {Library}}, + shorttitle = {{PyTorch}}, + url = {http://arxiv.org/abs/1912.01703}, + abstract = {Deep learning frameworks have often focused on either usability or speed, but not both. PyTorch is a machine learning library that shows that these two goals are in fact compatible: it provides an imperative and Pythonic programming style that supports code as a model, makes debugging easy and is consistent with other popular scientific computing libraries, while remaining efficient and supporting hardware accelerators such as GPUs.}, + language = {en}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Paszke, Adam and Gross, Sam and Massa, Francisco and Lerer, Adam and Bradbury, James and Chanan, Gregory and Killeen, Trevor and Lin, Zeming and Gimelshein, Natalia and Antiga, Luca and Desmaison, Alban and Köpf, Andreas and Yang, Edward and DeVito, Zach and Raison, Martin and Tejani, Alykhan and Chilamkurthy, Sasank and Steiner, Benoit and Fang, Lu and Bai, Junjie and Chintala, Soumith}, + month = dec, + year = {2019}, + note = {arXiv:1912.01703 [cs, stat]}, + keywords = {Computer Science - Machine Learning, Computer Science - Mathematical Software, Statistics - Machine Learning}, + annote = {Comment: 12 pages, 3 figures, NeurIPS 2019}, + file = {PDF:/Users/harsh/Zotero/storage/D72HUVME/Paszke et al. - 2019 - PyTorch An Imperative Style, High-Performance Deep Learning Library.pdf:application/pdf}, } - -@Article{Samuel59, - author = "A. L. Samuel", - title = "Some Studies in Machine Learning Using the Game of - Checkers", - journal = "IBM Journal of Research and Development", - year = "1959", - volume = "3", - number = "3", - pages = "211--229" +@misc{chetlur_cudnn_2014, + title = {{cuDNN}: {Efficient} {Primitives} for {Deep} {Learning}}, + shorttitle = {{cuDNN}}, + url = {http://arxiv.org/abs/1410.0759}, + doi = {10.48550/arXiv.1410.0759}, + abstract = {We present a library of efficient implementations of deep learning primitives. Deep learning workloads are computationally intensive, and optimizing their kernels is difficult and time-consuming. As parallel architectures evolve, kernels must be reoptimized, which makes maintaining codebases difficult over time. Similar issues have long been addressed in the HPC community by libraries such as the Basic Linear Algebra Subroutines (BLAS). However, there is no analogous library for deep learning. Without such a library, researchers implementing deep learning workloads on parallel processors must create and optimize their own implementations of the main computational kernels, and this work must be repeated as new parallel processors emerge. To address this problem, we have created a library similar in intent to BLAS, with optimized routines for deep learning workloads. Our implementation contains routines for GPUs, although similarly to the BLAS library, these routines could be implemented for other platforms. The library is easy to integrate into existing frameworks, and provides optimized performance and memory usage. For example, integrating cuDNN into Caffe, a popular framework for convolutional networks, improves performance by 36\% on a standard model while also reducing memory consumption.}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Chetlur, Sharan and Woolley, Cliff and Vandermersch, Philippe and Cohen, Jonathan and Tran, John and Catanzaro, Bryan and Shelhamer, Evan}, + month = dec, + year = {2014}, + note = {arXiv:1410.0759 [cs]}, + keywords = {Computer Science - Machine Learning, Computer Science - Mathematical Software, Computer Science - Neural and Evolutionary Computing}, } diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.blg b/shark_turbine/kernel/wave/docs/mlsys/tkw.blg deleted file mode 100644 index ef864a1b..00000000 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.blg +++ /dev/null @@ -1,46 +0,0 @@ -This is BibTeX, Version 0.99d (TeX Live 2020) -Capacity: max_strings=200000, hash_size=200000, hash_prime=170003 -The top-level auxiliary file: example_paper.aux -The style file: mlsys2024.bst -Database file #1: example_paper.bib -You've used 8 entries, - 2773 wiz_defined-function locations, - 645 strings with 5916 characters, -and the built_in function-call counts, 3248 in all, are: -= -- 293 -> -- 140 -< -- 9 -+ -- 49 -- -- 41 -* -- 223 -:= -- 507 -add.period$ -- 25 -call.type$ -- 8 -change.case$ -- 36 -chr.to.int$ -- 8 -cite$ -- 16 -duplicate$ -- 174 -empty$ -- 295 -format.name$ -- 51 -if$ -- 691 -int.to.chr$ -- 1 -int.to.str$ -- 1 -missing$ -- 6 -newline$ -- 47 -num.names$ -- 37 -pop$ -- 81 -preamble$ -- 1 -purify$ -- 29 -quote$ -- 0 -skip$ -- 127 -stack$ -- 0 -substring$ -- 100 -swap$ -- 24 -text.length$ -- 3 -text.prefix$ -- 0 -top$ -- 0 -type$ -- 78 -warning$ -- 0 -while$ -- 34 -width$ -- 0 -write$ -- 113 diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.tex b/shark_turbine/kernel/wave/docs/mlsys/tkw.tex index cb56cab1..33282648 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.tex +++ b/shark_turbine/kernel/wave/docs/mlsys/tkw.tex @@ -30,7 +30,7 @@ \begin{document} \twocolumn[ -\mlsystitle{Submission and Formatting Instructions for MLSys 2024} +\mlsystitle{Wave : A Python DSL for High Performance Machine Learning} % It is OKAY to include author information, even for blind % submissions: the style file will automatically remove it for you @@ -94,6 +94,81 @@ %\printAffiliationsAndNotice{} % leave blank if no need to mention equal contribution \printAffiliationsAndNotice{\mlsysEqualContribution} % otherwise use the standard text. +\section{Introduction} +Generative models have seen tremendous success in a wide variety of +domains ranging from image generation to natural language processing and beyond. +\cite{podell_sdxl_2023,dubey_llama_2024}. Much of this success is being +driven by graphics processing units (GPUs) which while originally +designed for graphics, are now being optimized for machine learning. +Both datacenter and consumer grade GPUs feature powerful matrix multiplication hardware units +and specialized instructions to enable high performance inference and training \cite{sun_dissecting_2023}. +\\ \\ +Given the importance of GPUs in machine learning, significant +effort has been put into developing frameworks that allow developers to +write high performance machine learning models with a low barrier to entry. Frameworks such +as Pytorch \cite{paszke_pytorch_2019} have become extremely popular +because they expose a Python based approach to programming GPUs. Prior +to the advent of these frameworks, developers had to write CUDA or OpenCL +kernels by hand which required significant expertise to achieve +good performance and did not scale well to new operators. +\\ \\ +Under the hood, these machine learning frameworks rely heavily +on vendor-specific libraries such as cuDNN \cite{chetlur_cudnn_2014} to achieve high performance. +These libraries are performant but are black boxes consisting of +hand-written kernels and often do not support the full set of +fused operators encountered in machine learning models. +To address these limitations, recent work has focused on developing +Python domain specific languages (DSL) that allow developers to get high performance +while reducing the kernel complexity. Triton \cite{tillet_triton_2019}. +is a popular Python DSL that exposes a workgroup level programming +model and allows developers to author high performance kernels. +However, Triton kernels often get quite complex and start to +resemble hand-written kernels as the kernel complexity grows. +Furthermore, fusion of Triton kernels is limited to a few operators +and remains an open problem. + +In this paper, we introduce Wave, a Python DSL for high performance machine learning. +Wave exposes a subgroup (wave or warp) level programming model that allows +for much simpler kernels compared to Triton. Through the use of constraints, Wave forces developers to +come up with the distribution strategy for their kernel - +which dimensions are parallel and which are sequential and how to distribute those +dimensions across the memory and compute hierarchy of the GPU. This allows for a separation +between the kernel and the distribution strategy and makes the kernel simpler. +Wave also embraces symbolic data types using sympy to represent the shapes and +memory access patterns of tensors in the kernel. +It has a Python based compiler that uses torch.fx tracing to define +and trace operators written in the language. The torch.fx graphs are then run through a series of optimization passes +on the computation graph and are finally lowered to MLIR and subsequently LLVM. This code generation flow allows compiler writers +to blend high productivity in Python with high performance from the MLIR and LLVM +code generation flow. +\\ \\ +In summary, the contributions of this paper are as follows: +\begin{itemize} + \item A novel subgroup programming model for GPU with a Python DSL that separates distribution strategies from the core kernel allowing for simpler kernels, + \item A symbolic data type system that allows for reasoning about tensor shapes and memory access patterns in the kernel, + \item A Python compiler that leverages torch.fx for tracing and maps torch.fx graphs to MLIR and LLVM for high performance code generation. +\end{itemize} + + +\section{Memory Access Patterns} +We represent memory access patterns in the language using the standard +triplet notation consisting of an offset, number of elements, and absolute stride and associate +a triplet with each tensor dimension. The memory access pattern for a given operation +is determined by the access patterns of the operands of the operation as well as +the user-specified constraints. For example, the memory access pattern for the output +of an elementwise operation is determined from the access patterns of the inputs, +whereas for a matrix-multiply accumulate operation, the memory access patterns of the operands are specified by +the hardware constraint. +\\ \\ +One of the advantages of the dimension based specification is that it obviates +the need for any propagation of memory access patterns through the computation graph, +as is commonly done in other frameworks. When setting the access pattern for a specific +dimension of a tensor, the access pattern is taken to be the union of all possible +access patterns with the determination of which access pattern to use based on +the minimization of an appropriate metric across the entire graph (see Section 3). + + +\iffalse \section{Electronic Submission} \label{submission} @@ -527,8 +602,9 @@ \section*{Acknowledgements} % In the unusual situation where you want a paper to appear in the % references without citing it in the main text, use \nocite \nocite{langley00} +\fi -\bibliography{example_paper} +\bibliography{tkw} \bibliographystyle{mlsys2024} diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index b96f778c..69785031 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -19,15 +19,15 @@ from .._support.indexing import IndexingContext, IndexSequence from ...support.logging import get_logger from .._support.tracing import CapturedTrace -from .utils import get_mma_dimensional_mapping +from .utils import get_mma_dimensional_mapping, specialize_index_sequence from ..lang.global_symbols import * logger = get_logger("turbine.wave.expansion") -# This represents a mapping of a node + indexing into the dimensions to the -# corresponding expanded node in these specific dimensions. An example for a -# record in this map is (read_0_0_0, ((M,0),(N,0),(K,1)) -> read_0_0_1 +# This represents a mapping of a node + indexing + res_idx(output index for op with multiple results) +# of node into the dimensions to the corresponding expanded node in these specific dimensions. +# An example for a record in this map is (read_0_0_0, ((M,0),(N,0),(K,1), 0) -> read_0_0_1. ExpandedNodeMap: TypeAlias = dict[ - tuple[CustomOp, tuple[tuple[IndexSymbol, int], ...]], CustomOp + tuple[CustomOp, tuple[tuple[IndexSymbol, int], int, ...]], CustomOp ] @@ -81,6 +81,11 @@ def get_indexed_dims( """ if isinstance(nodeOrDims, CustomOp): nodeOrDims = nodeOrDims.indexing_dims + # Flatten dims for node with multiple values or expanded Reduction. + if all(isinstance(el, Sequence) for el in nodeOrDims): + flattened_dims = list(itertools.chain.from_iterable(nodeOrDims)) + flatten_dims_set = dict.fromkeys(flattened_dims) + nodeOrDims = list(flatten_dims_set) return tuple((key, all_dims[key]) for key in nodeOrDims if key in all_dims) @@ -141,6 +146,7 @@ def compute_stride( def set_node_index( constraints: Sequence[Constraint], mma_index: dict[IndexSymbol, int], + mma_slices: dict[IndexSymbol, list[fx.Node]], dim_tile_size: dict[IndexSymbol, int], custom: CustomOp, dim_scaling: dict[IndexSymbol, int], @@ -171,11 +177,7 @@ def set_node_index( for dim in custom.indexing_dims: index_seq = None for constraint in sorted_constraints: - mma_check = ( - isinstance(constraint, HardwareConstraint) - and dim in mma_index - and isinstance(custom, MMA) - ) + mma_check = isinstance(constraint, HardwareConstraint) and dim in mma_index vector_check = ( isinstance(constraint, HardwareConstraint) @@ -217,6 +219,8 @@ def set_node_index( index_seq = constraint.apply( constraint_index, dim, elements_per_thread, stride ) + if mma_index: + index_seq = specialize_index_sequence(index_seq, mma_slices, custom) else: if index_seq is None: @@ -246,10 +250,10 @@ def expand_graph( dim_scaling = constraints_or_scaling node_index_setter = lambda *args: None else: - mma_index = get_mma_dimensional_mapping(trace) + mma_index, mma_slices = get_mma_dimensional_mapping(trace) dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index) node_index_setter = partial( - set_node_index, constraints_or_scaling, mma_index, dim_tile_size + set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size ) # Start from the back and expand in the corresponding indexing dimensions of a node @@ -298,6 +302,7 @@ def _expand_node( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int = 0, ) -> CustomOp: """Expand a single node or list of nodes in specific dimensions and recursively proceed to its inputs.""" if isinstance(node, list): @@ -305,23 +310,31 @@ def _expand_node( for elem in node: expanded_nodes.append( _expand_node( - elem, trace, dim_query, dim_scaling, node_index_setter, context + elem, + trace, + dim_query, + dim_scaling, + node_index_setter, + context, + res_idx, ).fx_node ) return expanded_nodes # If we expanded a node in the same dimensions before, we can reuse it - if (node, get_indexed_dims(dim_query, node)) in context: + if (node, get_indexed_dims(dim_query, node), res_idx) in context: logger.debug(f"Already expanded node: {node} in {dim_query}") - return context[(node, get_indexed_dims(dim_query, node))] + return context[(node, get_indexed_dims(dim_query, node), res_idx)] elif isinstance(node, Reduction): return _expand_reduction( node, trace, dim_query, dim_scaling, node_index_setter, context ) - elif isinstance(node, GetResult): + elif isinstance(node, Getitem): + res_idx = node.res_idx + elif isinstance(node, GetResult) and not isinstance(node, Getitem): # The presence of a GetResult node indicates that the reduction has already # been expanded. Simply return the corresponding node. reduction = get_custom(node.value) - return context[(reduction, get_indexed_dims(dim_query, reduction))] + return context[(reduction, get_indexed_dims(dim_query, reduction), res_idx)] elif isinstance(node, Allocate): # Allocate nodes are not expanded. return node @@ -329,14 +342,28 @@ def _expand_node( # Filter out the dimensions that are not indexed by the node restricted_dims = filter_and_zero_unselected_dims(dim_query, node.indexing_dims) logger.debug(f"Expanding node: {node} in {restricted_dims}") + + # For iter args, we want to insert + if not hasattr(_expand_node, "last_expanded_iter_arg"): + _expand_node.last_expanded_iter_arg = None + # Clone the node for the new expansion. The original node is reused for the # case of all dimensions being zero. if expansion_needed(restricted_dims, node.indexing_dims): - new_node = node.copy() + new_node = node.copy( + anchor=( + _expand_node.last_expanded_iter_arg + if isinstance(node, IterArg) + else None + ) + ) else: new_node = node logger.debug(f"did not clone node: {node} in {restricted_dims}") + if isinstance(node, IterArg): + _expand_node.last_expanded_iter_arg = new_node.fx_node + new_node.fx_node.expanded_dims = restricted_dims new_node.fx_node.name = get_expanded_name(node, restricted_dims) node_index_setter(new_node, restricted_dims) @@ -353,12 +380,13 @@ def _expand_node( dim_scaling, node_index_setter, context, + res_idx, ) new_node.update_arg(i, new_arg) new_node.post_expansion(constraints) - context[(node, get_indexed_dims(restricted_dims, node))] = new_node + context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node return new_node @@ -369,6 +397,7 @@ def _expand_reduction( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int = 0, ) -> CustomOp: """Expand a reduction in a specific dimension and recursively proceed to its inputs.""" # Determine the dimensions to expand the reduction from the indexing of its users @@ -391,32 +420,41 @@ def _expand_reduction( new_output_args = [] new_init_args = [] for dim_vals in get_dim_combinations(dim_scaling, expand_dims): - for arg_idx, arg in output.node_args.items(): - dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} + return_vals = output.return_vals[0] + dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} + if not isinstance(return_vals, Sequence): + return_vals = [return_vals] + for arg_idx, arg in enumerate(return_vals): + arg = get_custom(arg) # Add GetResult nodes for the corresponding dimensions reduction.graph.inserting_after(reduction.fx_node) new_node = GetResult(reduction.fx_node, len(new_output_args)) new_node.add_to_graph(reduction.graph) new_node.fx_node.name = get_expanded_name(new_node, dims) - context[(reduction, get_indexed_dims(dims, expand_dims))] = new_node + context[ + (reduction, get_indexed_dims(dims, expand_dims), arg_idx) + ] = new_node # Proceed with expansion inside the reduction new_output_args.append( - _expand_node(arg, trace, dims, dim_scaling, node_index_setter, context) + _expand_node( + arg, trace, dims, dim_scaling, node_index_setter, context, res_idx + ) ) - # Proceed with expansion outside the reduction - for init_arg in reduction.init_args: - new_init_args.append( - _expand_node( - get_custom(init_arg), - trace, - dims, - dim_scaling, - node_index_setter, - context, - ) + # Proceed with expansion outside the reduction + for init_arg in reduction.init_args: + new_init_args.append( + _expand_node( + get_custom(init_arg), + trace, + dims, + dim_scaling, + node_index_setter, + context, + res_idx, ) + ) # Update init_args and return values reduction.update_arg( @@ -424,11 +462,17 @@ def _expand_reduction( ) output.update_arg("return_vals", [node.fx_node for node in new_output_args]) _handle_reduction_dim( - reduction, output, trace, dim_scaling, node_index_setter, context + reduction, + output, + trace, + dim_scaling, + node_index_setter, + context, + res_idx, ) # Even though we expanded the reduction in multiple dimensions, we only return # the node corresponding to the original query - return context[(reduction, get_indexed_dims(dim_query, expand_dims))] + return context[(reduction, get_indexed_dims(dim_query, expand_dims), res_idx)] def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: @@ -518,6 +562,7 @@ def _handle_reduction_dim( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int, ): # Rediscover iter args # TODO: Register iter args with the reduction initially so accessing them is easier @@ -554,7 +599,13 @@ def _handle_reduction_dim( saved_arg = user.node_args[index] user.update_arg(index, dummy) new_node = _expand_node( - user, trace, dims, dim_scaling, node_index_setter, context + user, + trace, + dims, + dim_scaling, + node_index_setter, + context, + res_idx, ) # This expansion always happens, user should never be reused diff --git a/shark_turbine/kernel/wave/index_sequence_analysis.py b/shark_turbine/kernel/wave/index_sequence_analysis.py index cec8b60b..b9212f01 100644 --- a/shark_turbine/kernel/wave/index_sequence_analysis.py +++ b/shark_turbine/kernel/wave/index_sequence_analysis.py @@ -24,7 +24,7 @@ def get_vector_shape( hardware_constraint: HardwareConstraint, symbolic_shape: list[IndexSymbol], ) -> list[int]: - mma_indices = get_mma_dimensional_mapping(trace) + mma_indices, _ = get_mma_dimensional_mapping(trace) return [ get_hardware_vector_size(dim, hardware_constraint, mma_indices) for dim in symbolic_shape diff --git a/shark_turbine/kernel/wave/minimize_global_loads.py b/shark_turbine/kernel/wave/minimize_global_loads.py index 3ea1a3d0..17971354 100644 --- a/shark_turbine/kernel/wave/minimize_global_loads.py +++ b/shark_turbine/kernel/wave/minimize_global_loads.py @@ -63,12 +63,11 @@ def materialize_shape( constraint_tile_size: dict[IndexSymbol, int], symbolic_shape: list[IndexSymbol] ) -> list[int]: materialized_shape = [] - idxc = IndexingContext.current() for dim in symbolic_shape: if dim in constraint_tile_size: - materialized_shape.append(constraint_tile_size[dim].subs(idxc.subs)) + materialized_shape.append(subs_idxc(constraint_tile_size[dim])) else: - materialized_shape.append(dim.subs(idxc.subs)) + materialized_shape.append(subs_idxc(dim)) return materialized_shape diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index affd5fef..c2d0a582 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -1,5 +1,4 @@ # Copyright 2024 The IREE Authors -# # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -12,12 +11,26 @@ transform_d, UnitAttr, ) -from typing import Callable, Any, List, Tuple +from typing import Optional, Callable, Any, List, Tuple from .._support.tracing import CapturedTrace from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence from ..lang.global_symbols import * -from ..ops.wave_ops import get_custom, Output, Write, MMA -from .constraints import Constraint, HardwareConstraint, TilingConstraint +from ..ops.wave_ops import ( + get_custom, + Output, + Write, + MMA, + CustomOp, + Reduction, + GetResult, + IterArg, +) +from .constraints import ( + Constraint, + WorkgroupConstraint, + HardwareConstraint, + TilingConstraint, +) import torch.fx as fx import shark_turbine.kernel.lang as tkl @@ -115,6 +128,19 @@ def is_removable_operator(node: fx.Node) -> bool: get_custom(node).graph.erase_node(node) +def remove_chained_getresult(trace: CapturedTrace): + def is_chained_getresult(node: fx.Node) -> bool: + custom = get_custom(node) + return isinstance(custom, GetResult) and isinstance( + get_custom(custom.value), GetResult + ) + + while removable_nodes := trace.walk(is_chained_getresult): + for node in removable_nodes: + get_custom(node).replace_all_uses_with(get_custom(node).value) + get_custom(node).graph.erase_node(node) + + def delinearize_index(index: IndexExpr, shape: list[int]) -> list[IndexExpr]: """ Delinearizes a 1D index into a multi-dimensional index @@ -145,7 +171,9 @@ def simplify_index(index: IndexExpr) -> IndexExpr: return subs_idxc(index.subs(mapping)) -def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[IndexSymbol, int]: +def get_mma_dimensional_mapping( + trace: CapturedTrace, +) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]: """ Given a trace, determine the MMA dimensional mapping for all the MMA operations in the graph. For example, if we have @@ -159,7 +187,8 @@ def is_mma(node): return isinstance(get_custom(node), MMA) mapping: dict[IndexSymbol, int] = {} - for node in trace.walk(is_mma): + mma_nodes = trace.walk(is_mma) + for node in mma_nodes: custom: MMA = get_custom(node) m, n = custom.acc_type.symbolic_shape[-2:] lhs_shape = custom.lhs_type.symbolic_shape @@ -170,7 +199,7 @@ def is_mma(node): mapping[n] = 1 mapping[k] = 2 - return mapping + return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes]) def get_hardware_vector_size( @@ -378,3 +407,156 @@ def erase_graph(graph: fx.Graph): for user in node.users: graph.erase_node(user) graph.erase_node(node) + + +def get_users( + node: fx.Node, reduction: fx.Node = None +) -> tuple[list[fx.Node], fx.Node]: + """ + Return the users of a node, propagating through reductions. + """ + users = [] + for user in node.users: + custom = get_custom(user) + if isinstance(custom, Reduction): + # Map init arg to iter arg + reduction = custom + init_arg_idx = custom.init_args.index(node) + users.append(custom.iter_args[init_arg_idx]) + continue + if isinstance(custom, Output) and reduction: + # Map output to get result + return_vals = custom.return_vals[0] + get_results = sorted( + [x for x in reduction.users if isinstance(get_custom(x), GetResult)], + lambda x: get_custom(x).res_idx, + ) + if isinstance(return_vals, list): + output_idx = return_vals.index(node) + users.append(get_results[output_idx]) + else: + users.append(get_results[0]) + continue + users.append(user) + return users, reduction + + +def get_inputs( + node: fx.Node, reduction: fx.Node = None +) -> tuple[list[fx.Node], fx.Node]: + """ + Return the inputs of a node, propagating through reductions. + """ + inputs = [] + for input in node.all_input_nodes: + custom = get_custom(input) + if isinstance(custom, GetResult): + reduction = custom.value + assert isinstance( + reduction, Reduction + ), "GetResult must be used by a Reduction" + # Map get result to output + inputs.append(reduction.outputs[custom.res_idx]) + continue + if isinstance(custom, IterArg): + # Map iter args to init args + iter_arg_idx = reduction.iter_args.index(node) + inputs.append(reduction.init_args[iter_arg_idx]) + continue + inputs.append(input) + return inputs, reduction + + +def bfs( + node: fx.Node, + get_neighbors: Callable[[fx.Node, fx.Node], list[fx.Node]], +) -> set[fx.Node]: + """ + Run BFS on the graph to capture the forward slice of a node. + """ + visited: set[fx.Node] = set() + queue: list[fx.Node] = [] + visited.add(node) + queue.append(node) + reduction = None + while queue: + s = queue.pop(0) + neighbors, reduction = get_neighbors(s, reduction) + for neighbor in neighbors: + if neighbor not in visited: + visited.add(neighbor) + queue.append(neighbor) + return visited + + +def capture_forward_slice(node: fx.Node) -> set[fx.Node]: + """ + Run BFS on the graph to capture the forward slice of a node. + """ + return bfs(node, lambda x, y: get_users(x, y)) + + +def capture_backward_slice(node: fx.Node) -> set[fx.Node]: + """ + Capture backward slice from a node and return the tree. + Assumes graph is directed. + """ + return bfs(node, lambda x, y: get_inputs(x, y)) + + +def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]: + """ + Given an index sequence, specialize it to a LHS, RHS or ACC index sequence + based on whether the node is used as the LHS, RHS or ACC in the MMA node. + """ + mma_slices = {x: [] for x in [MMA_LHS, MMA_RHS, MMA_ACC]} + for mma in mma_nodes: + mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs) + mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs) + mma_slices[MMA_ACC] += capture_forward_slice(mma.acc) + return mma_slices + + +def specialize_index_sequence( + index_seq: IndexSequence, + mma_slices: dict[IndexSymbol, list[fx.Node]], + custom: CustomOp, +) -> IndexSequence: + """ + Given an index sequence, specialize it to a LHS, RHS or ACC index sequence + based on whether the node is used as the LHS, RHS or ACC in the MMA node. + If the node is not used as any of the operands, return the original index sequence + with all the MMA symbols zeroed out. + """ + if isinstance(custom, MMA): + return index_seq + operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0} + for key in mma_slices: + if custom.fx_node in mma_slices[key]: + operand_map[key] = 1 + return index_seq.subs(operand_map) + return index_seq.subs(operand_map) + + +def find_index_bounds( + constraints: list[Constraint], index: dict[IndexExpr, IndexExpr] +) -> Optional[list[IndexExpr]]: + bounds = [] + for constraint in constraints: + if not isinstance(constraint, (WorkgroupConstraint, TilingConstraint)): + continue + + dim = constraint.dim + if dim not in index: + continue + + work_size = constraint.count * constraint.tile_size + if subs_idxc(work_size) == subs_idxc(dim): + continue + + bounds.append(dim) + + if len(bounds) == 0: + return None + + return bounds diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index eb6003de..4d19d99f 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -23,7 +23,12 @@ from .expansion import expand_graph from .promotion import promote_placeholders from .hoisting import hoist_allocs -from .utils import canonicalize_module, compile_and_invoke, safe_subs +from .utils import ( + canonicalize_module, + compile_and_invoke, + safe_subs, + remove_chained_getresult, +) from .minimize_global_loads import minimize_global_loads from .decompose_reduce_ops import decompose_reduce_ops from .barriers import add_shared_memory_barriers @@ -205,6 +210,9 @@ def _trace_and_get_kernel_signature( # Expansion expand_graph(graph, self.constraints) + # Clean up chains of GetResults + remove_chained_getresult(graph) + # Register analysis to determine register shapes. determine_register_shape(graph, self.constraints) diff --git a/tests/aot/dynamic_shape_export_test.py b/tests/aot/dynamic_shape_export_test.py new file mode 100644 index 00000000..da8c11b7 --- /dev/null +++ b/tests/aot/dynamic_shape_export_test.py @@ -0,0 +1,50 @@ +import torch + +import pytest + +from shark_turbine.aot import * + + +@pytest.mark.parametrize( + "import_symbolic_shape_expressions", + [ + True, + False, + ], +) +def test_exported_program_dynamic_shapes(import_symbolic_shape_expressions): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU()) + self.branch2 = torch.nn.Sequential( + torch.nn.Linear(128, 64), torch.nn.ReLU() + ) + self.buffer = torch.ones(32) + + def forward(self, x1, x2): + out1 = self.branch1(x1) + out2 = self.branch2(x2) + return (out1 + self.buffer, out2) + + example_args = (torch.randn(32, 64), torch.randn(32, 128)) + + # Create a dynamic batch size + batch = torch.export.Dim("batch") + # Specify that the first dimension of each input is that batch size + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + + output = export( + M(), + args=example_args, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) + output.print_readable() + asm = str(output.mlir_module) + + if import_symbolic_shape_expressions: + assert "bind_symbolic_shape" in asm + else: + assert "bind_symbolic_shape" not in asm diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index 26bab1a6..607382fd 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -425,6 +425,68 @@ def testUnsupportedCombinations(self): export_global(AbstractF32, external=True, uninitialized=True) +class SimpleCache(torch.nn.Module): + def __init__(self, max_size, dtype=torch.float32): + super().__init__() + self.register_buffer("cache", torch.zeros(max_size, dtype=dtype)) + + def forward(self, input_pos, values): + # input_pos: [S], values: [S] + assert input_pos.shape[0] == values.shape[0] + + # Writing the values to the buffer at the specified positions + cache = torch.ops.aten.index_put_(self.cache, [input_pos], values) + + return cache + + +class ReadWriteReadCache(torch.nn.Module): + def __init__(self, max_size, dtype=torch.float32): + super().__init__() + self.register_buffer("cache", torch.zeros(max_size, dtype=dtype)) + + def forward(self, input_pos, values): + # input_pos: [S], values: [S] + assert input_pos.shape[0] == values.shape[0] + cache_value_0 = self.cache[2].clone() + # Writing the values to the buffer at the specified positions + cache = torch.ops.aten.index_put_(self.cache, [input_pos], values) + cache_value_1 = cache[2].clone() + return cache, cache_value_0, cache_value_1 + + +class BufferTest(unittest.TestCase): + def testMutableBuffer(self): + max_size = 10 + simple_cache = SimpleCache(max_size) + + input_pos = torch.tensor([2, 5, 7]) + values = torch.tensor([1.0, 2.0, 3.0]) + simple_cache(input_pos, values) + exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values)) + exported_programm = export(exported_fx_graph) + module_str = str(exported_programm.mlir_module) + self.assertIn( + "util.global private mutable @__auto.constant_10_torch.float32", + module_str, + ) + + def testReadWriteReadMutableBuffer(self): + max_size = 10 + simple_cache = ReadWriteReadCache(max_size) + + input_pos = torch.tensor([2, 5, 7]) + values = torch.tensor([1.0, 2.0, 3.0]) + simple_cache(input_pos, values) + exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values)) + exported_programm = export(exported_fx_graph) + module_str = str(exported_programm.mlir_module) + self.assertIn( + "util.global private mutable @__auto.constant_10_torch.float32", + module_str, + ) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/tests/dynamo/type_conversion_test.py b/tests/dynamo/type_conversion_test.py index dfc3de25..617c5d05 100644 --- a/tests/dynamo/type_conversion_test.py +++ b/tests/dynamo/type_conversion_test.py @@ -32,6 +32,7 @@ def testValueTensors(self): self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>") self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor") self._compareNative("!torch.vtensor<[],f32>", "tensor") + self._compareNative("!torch.vtensor<[],complex>", "tensor>") def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True): with self.conv._context: diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 4c2f04db..dbe88424 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -44,6 +44,15 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]: return default_test_shapes +def xfail_unaligned(func): + def wrapper(shape): + if shape[-1] % 2 != 0: + pytest.xfail("Unaligned shape is not expected to work on this test yet.") + func(shape) + + return wrapper + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) def test_copy(shape): @@ -269,13 +278,14 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_tiled_reduce_max")) -def test_tiled_reduce_max(shape): +@xfail_unaligned +def test_toy_online_softmax(shape): M = tkl.sym.M N = tkl.sym.N wave_size = 64 BLOCK_M = 1 BLOCK_N = tkl.sym.BLOCK_N - ELEMS_PER_THREAD = BLOCK_N / wave_size + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE constraints: list[tkw.Constraint] = [ @@ -293,35 +303,44 @@ def test_tiled_reduce_max(shape): @tkw.wave(constraints) def test( - a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f32], ): - init_max = tkl.Register[M, tkl.f16](-1e6) + init_max = tkl.Register[M, tkl.f32](-1e6) + init_sum = tkl.Register[M, tkl.f32](0) - @tkw.reduction(N, init_args=[init_max]) + @tkw.reduction(N, init_args=[init_max, init_sum]) def repeat( - partial_max: tkl.Register[M, tkl.f16], - ) -> tkl.Register[M, tkl.f16]: + partial_max: tkl.Register[M, tkl.f32], + partial_sum: tkl.Register[M, tkl.f32], + ) -> tkl.Register[M, tkl.f32]: lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) rhs = tkw.read(b, elements_per_thread=ELEMS_PER_THREAD) res = lhs * rhs partial_max = tkw.max(res, partial_max, dim=N) - return partial_max + partial_sum = tkw.sum(res, partial_sum, dim=N) + return partial_max, partial_sum - tkw.write(repeat, c, elements_per_thread=1) + res_max, res_sum = repeat + result = res_max / res_sum + tkw.write(result, c, elements_per_thread=1) config = {"backend": "rocm", "device": "hip", "target": "gfx942"} - a = torch.randn(shape, dtype=torch.float16) - b = torch.randn(shape, dtype=torch.float16) - c = torch.zeros((shape[0],), dtype=torch.float16) - ref = torch.max((a * b), dim=-1) + torch.manual_seed(1) + a = torch.randn(shape, dtype=torch.float32) + b = torch.randn(shape, dtype=torch.float32) + c = torch.zeros((shape[0],), dtype=torch.float32) + ref_max = torch.max((a * b), dim=-1).values + ref_sum = torch.sum((a * b), dim=-1) + ref = ref_max / ref_sum with tk.gen.TestLaunchContext( { M: shape[0], N: shape[1], BLOCK_N: min(128, shape[1]), + ELEMS_PER_THREAD: min(128, shape[1]) // wave_size, ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, }, canonicalize=True, @@ -332,7 +351,7 @@ def repeat( # Assert equal does cast to boolean on torch.Tensor # which causes issues, hence we cast to numpy before # checking. - assert_equal(c, ref.values.numpy()) + assert_allclose(ref, c, atol=0.015) @require_e2e @@ -571,7 +590,8 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: @pytest.mark.parametrize("c", [1, 3, 4, 10]) @pytest.mark.parametrize("nf", [1, 2, 16]) @pytest.mark.parametrize("stride", [1, 2, 3]) -def test_igemm_conv(n, c, nf, stride): +@pytest.mark.parametrize("mem_space", [GLOBAL_ADDRESS_SPACE, SHARED_ADDRESS_SPACE]) +def test_igemm_conv(n, c, nf, stride, mem_space): h, w = 5, 5 # Image. cf, hf, wf = c, 2, 2 # Filters. padding = 0 # TODO: only pad=0 is supported for now @@ -691,7 +711,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: BLOCK_M: 16, BLOCK_N: 16, ELEMS_PER_THREAD: 4, - ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE: mem_space, }, canonicalize=True, run=True, diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 344032a4..566cbc6a 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -126,6 +126,119 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: assert torch.equal(c, iree_ref) +# Format: (M, K, N, B) +intermediate_size = 28672 +tensor_parallel_shape = 8 +hidden_size = 8192 +# Batch size can be 1, 2, 3, 4 +batch_size = 1 +gemm_silu_shapes = [ + ( + intermediate_size / tensor_parallel_shape, + hidden_size, + hidden_size, + 1, + ) +] + + +@require_e2e +@pytest.mark.parametrize("shape", gemm_silu_shapes) +def testGemmSilu(shape: tuple[int]): + + # FC1 and FC2 GEMM Sizes + # Weights matrices are of size (M0, K0). + # Input matrix is of size (BS, K0). + M = tkl.sym.M # Reduction + K = tkl.sym.K # Reduction + N = tkl.sym.N # Parallel + B = tkl.sym.B # Parallel + + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_K = tkl.sym.BLOCK_K + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_B = tkl.sym.BLOCK_B + + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(B, BLOCK_B, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.TilingConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 2, 1)) + ] + + @tkw.wave(constraints) + def gemm_silu( + x: tkl.Memory[B, K, ADDRESS_SPACE, tkl.f16], + w0: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + w1: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + w2: tkl.Memory[N, M, ADDRESS_SPACE, tkl.f16], + output: tkl.Memory[B, N, ADDRESS_SPACE, tkl.f32], + ): + + c_reg = tkl.Register[B, N, tkl.f32](0.0) + + @tkw.reduction(M, init_args=[c_reg]) + def outer_loop(acc: tkl.Register[B, N, tkl.f32]) -> tkl.Register[B, N, tkl.f32]: + + c_reg0 = tkl.Register[B, M, tkl.f32](0.0) + c_reg1 = tkl.Register[B, M, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg0, c_reg1]) + def inner_loop( + acc0: tkl.Register[B, M, tkl.f32], acc1: tkl.Register[B, M, tkl.f32] + ) -> tuple[tkl.Register[B, M, tkl.f32], tkl.Register[B, M, tkl.f32]]: + x_reg = tkw.read(x, elements_per_thread=LOAD_ELEMS_PER_THREAD) + w0_reg = tkw.read(w0, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc0 = tkw.mma(x_reg, w0_reg, acc0) + w1_reg = tkw.read(w1, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc1 = tkw.mma(x_reg, w1_reg, acc1) + return acc0, acc1 + + w2_reg = tkw.read(w2, elements_per_thread=LOAD_ELEMS_PER_THREAD) + mm0, mm1 = inner_loop + silu = 1.0 / (1.0 + tkw.exp(-mm0)) + y = silu * mm1 + acc = tkw.mma(y, w2_reg, acc) + return acc + + # repeat represents the results of the loop + tkw.write(outer_loop, output, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: shape[0], + K: shape[1], + N: shape[2], + B: shape[3], + } + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + with tk.gen.TestLaunchContext( + hyperparams, canonicalize=True, run=True, run_config=config + ): + x = torch.randn(shape[3], shape[1], dtype=torch.float16) + w0 = torch.randn(shape[0], shape[1], dtype=torch.float16) + w1 = torch.zeros(shape[0], shape[1], dtype=torch.float16) + w2 = torch.zeros(shape[2], shape[0], dtype=torch.float16) + output = torch.zeros(shape[3], shape[2], dtype=torch.float32) + gemm_silu(x, w0, w1, w2, output) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main()