Question about the structured-to-memref
pass
#126
-
Hello, I have a few questions about the choices made to lower
If we take a closer look to the conversion pipeline, the
This suggests that this operation is useful for creating a sub-view from the pointer created by If this is the case, I'm not sure why we need a Thank you. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
@mfrancepillois Hey thank you for your question, and sorry for the late reply. I don't get notifications from the discussion so I missed this. Will get back to you with some thoughts as soon as possible. |
Beta Was this translation helpful? Give feedback.
-
This question touches on a few design key points that I will try to address. Our triton-to-linalg pass was originally created to help with compiling triton on accelerators. Because accelerators potentially have multiple levels of memory, tensors are first created from the main memory and distributed to many triton program instances, each operating on a slice of the main tensor at another memory level. As a result, each triton program has to copy a slice of the data to its local memory. The lowering to memref as you can see now is therefore a design choice that we made for accelerators. As for CPU, the this doesn't apply and unfortunately generates sub-optimal code in certain cases. However, by always generating a pair of alloc and copy, it helps us maintains the value semantics for cases like: import torch
import triton
import triton.language as tl
@triton.jit
def test_mem(
A,
B,
):
# [1, 1]
t1 = tl.load(A + tl.arange(0, 2))
# [1, 1]
t2 = tl.load(B + tl.arange(0, 2))
# should be [2, 2]
t3 = t1 + t2
# writing back to A! t1 should still have [1, 1] as its value
# if t1 is simply a subview, this would change t1's value to [2, 2]
tl.store(A + tl.arange(0, 2), t3)
# should be [3, 3]
# but if we simply let t1 be a subview of A, t1 now aliases
# with t3 which is already [2, 2].
t4 = t1 + t3
tl.store(B + tl.arange(0, 2), t4)
def test():
n_cols = 2
x = torch.full([n_cols], 1, device="cuda", dtype=torch.float32)
y = torch.full([n_cols], 1, device="cuda", dtype=torch.float32)
grid = lambda meta: (n_cols,)
test_mem[grid](x, y)
print('x: ', x)
print('y: ', y)
src = triton.compiler.ASTSource(
test_mem,
signature="*fp32,*fp32"
)
ret = triton.compile(
src,
)
print(ret.asm["ttir"])
test() Running this we would expect to see |
Beta Was this translation helpful? Give feedback.
@mfrancepillois
This question touches on a few design key points that I will try to address.
Our triton-to-linalg pass was originally created to help with compiling triton on accelerators. Because accelerators potentially have multiple levels of memory, tensors are first created from the main memory and distributed to many triton program instances, each operating on a slice of the main tensor at another memory level. As a result, each triton prog…