Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Loop carried dependences should be SSA values not memory operations #123

Open
matth2k opened this issue Dec 15, 2023 · 4 comments
Open
Labels
enhancement New feature or request

Comments

@matth2k
Copy link

matth2k commented Dec 15, 2023

Describe the bug
Neither writing kernels with primitives like matmul() or using allo.grid() make use of affine.for's ability to contain iteration arguments. For us, this is important for pipelining. Here is an example of the MLIR produced by test_reduce() (shown further down).

module {
  func.func @kernel(%arg0: memref<20xi32>) -> i32 attributes {itypes = "s", otypes = "s", top} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "sum"} : memref<1xi32>
    affine.store %c0_i32, %alloc[0] {to = "sum"} : memref<1xi32>
    affine.for %arg1 = 0 to 20 {
      %1 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
      // We have to reload the value
     // ... when it should be forwarded from last iteration's store
      %2 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
      %3 = arith.addi %2, %1 : i32
      affine.store %3, %alloc[0] {to = "sum"} : memref<1xi32>
    } {loop_name = "i", op_name = "S_i_0", reduction}
    %0 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
    return %0 : i32
  }
}

To Reproduce
The linalg dialect compounds the issue, because it lowers linalg to affine loops without an accumulator:

def test_linalg_matmul():
    N = 16
    from allo import matmul

    def kernel(A: int32[N, N], B: int32[N, N]) -> int32[N, N]:
        return matmul(A, B)

    s = allo.customize(kernel)
    print(s.module)

But even with an explicit accumulator in a single memref cell, I can't get it to be raised to SSA values:

def test_reduce():
    N = 20

    def kernel(A: int32[N]) -> int32:
        sum: int32 = 0
        for i in allo.reduction(N):
            sum += A[i]
        return sum

    s = allo.customize(kernel)
    print(s.module)

Buggy output
I was not hopeful that the existing MLIR passes would help with this issue, but I tried anyways by running mlir-opt --convert-linalg-to-affine-loops --affine-scalrep --lower-affine --convert-scf-to-cf --mem2reg

It is only expected to work on unstructured control flow, but I could not get it to work for that.

Expected behavior
Here is an example of how we do matmul in affine that uses iteration arguments to assist the pipelining pass:

  affine.for %arg3 = 0 to 16 {
      affine.for %arg4 = 0 to 16 {
        %sum = affine.for %arg5 = 0 to 16 
                iter_args(%sum_iter = %c0_i32) -> (i32) {
          %2 = affine.load %A[%arg3, %arg5] : memref<16x16xi32>
          %3 = affine.load %B[%arg5, %arg4] : memref<16x16xi32>
          %4 = arith.muli %2, %3 : i32
          %sum_next = arith.addi %4, %sum_iter : i32
          affine.yield %sum_next : i32
        }
        affine.store %sum, %C[%arg3, %arg4] : memref<16x16xi32>
      }
    }

Perhaps there are the right patterns/passes in MLIR to accomplish what we want, but I haven't found them yet. Maybe we will have to write our own pass for this or lower the AST differently.

@matth2k matth2k added the bug Something isn't working label Dec 15, 2023
@chhzh123
Copy link
Member

I agree that generating iteration variables may be helpful for some compiler passes. However, it is somehow not an easy job to determine whether a variable is a reduction variable from the frontend, so we currently do not support this feature. The allo.reduction function is just an annotation, and it does not generate the loop with iteration variables.

I haven't figured out a good way to resolve this issue. Probably some sophisticated frontend analysis pass may help generate this kind of reduction loops.

@chhzh123 chhzh123 added enhancement New feature or request and removed bug Something isn't working labels Dec 15, 2023
@andrewb1999
Copy link

I think I found a solution to this problem: https://github.com/cornell-zhang/amc-dialect/pull/64

@chhzh123
Copy link
Member

This looks cool! Could you provide an example of the original MLIR code and the code after this pass? @andrewb1999

@andrewb1999
Copy link

andrewb1999 commented Dec 15, 2023

Yeah so this is the code before the pass:

module {
  func.func @kernel(%arg0: memref<20xi32>) -> i32 attributes {itypes = "s", otypes = "s", top} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "sum"} : memref<1xi32>
    affine.store %c0_i32, %alloc[0] {to = "sum"} : memref<1xi32>
    affine.for %arg1 = 0 to 20 {
      %1 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
      %2 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
      %3 = arith.addi %2, %1 : i32
      affine.store %3, %alloc[0] {to = "sum"} : memref<1xi32>
    } {loop_name = "i", op_name = "S_i_0", reduction}
    %0 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
    return %0 : i32
  }
}

and this is the code after the pass:

module {
  func.func @kernel(%arg0: memref<20xi32>) -> i32 attributes {itypes = "s", otypes = "s", top} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "sum"} : memref<1xi32>
    affine.store %c0_i32, %alloc[0] {to = "sum"} : memref<1xi32>
    %0 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
    %1 = affine.for %arg1 = 0 to 20 iter_args(%arg2 = %0) -> (i32) {
      %3 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
      %4 = arith.addi %arg2, %3 : i32
      affine.yield %4 : i32
    }
    affine.store %1, %alloc[0] : memref<1xi32>
    %2 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
    return %2 : i32
  }
}

you can see the load and store on sum have been removed and replaced with iter_args and an affine.yield. The sum memref should then be able to be removed entirely using store-load forwarding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants