Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong results where one of the args is assigned to constant inside the kernel #741

Closed
ngimel opened this issue Oct 5, 2022 · 7 comments
Closed

Comments

@ngimel
Copy link

ngimel commented Oct 5, 2022

This might be related to #714. Repro (comments inside, requires torchdynamo unfortunately), tl;dr if the kernel has xnumel=<const> where xnumel is also a kernel arg, and is equal to the value of xnumel that is passed to the kernel (so should be a no-op, or even if it's used for optimization, shouldn't change results) it produces wrong results. Note that this is using the new runtime, with the old runtime both versions of the kernel produce wrong results.
I'm happy to provide generated ptx if needed, or any additional info, given that repro requires dynamo, although to get wrong results minor changes can be made to disable pre-compilation and lose dynamo dependency.

from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torchinductor.codecache import AsyncCompile

aten = torch.ops.aten
async_compile = AsyncCompile()

import triton
import triton.language as tl
from torchinductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


kernel1 = async_compile.triton('''
import triton
import triton.language as tl
from torchinductor.ir import ReductionHint
from torchinductor.triton_ops.autotune import pointwise
from torchinductor.utils import instance_descriptor

@pointwise(size_hints=[2048], filename=__file__, meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 'constants': {}})
@triton.jit
def kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
    xmask = xindex < xnumel
    x2 = xindex
    x0 = xindex % 1000
    x1 = (xindex // 1000)
    tmp0 = tl.load(in_ptr0 + x2, xmask)
    tmp1 = tl.load(in_ptr1 + x2, xmask)
    tl.store(out_ptr0 + x0 + (1000*tmp0) + (196000*x1) + tl.zeros([XBLOCK], tl.int32), tmp1, xmask)
''')

kernel2 = async_compile.triton('''
import triton
import triton.language as tl
from torchinductor.ir import ReductionHint
from torchinductor.triton_ops.autotune import pointwise
from torchinductor.utils import instance_descriptor

@pointwise(size_hints=[2048], filename=__file__, meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 'constants': {}})
@triton.jit
def kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
    xmask = xindex < xnumel
    x2 = xindex
    x0 = xindex % 1000
    x1 = (xindex // 1000)
    tmp0 = tl.load(in_ptr0 + x2, xmask)
    tmp1 = tl.load(in_ptr1 + x2, xmask)
    tl.store(out_ptr0 + x0 + (1000*tmp0) + (196000*x1) + tl.zeros([XBLOCK], tl.int32), tmp1, xmask)
''')
async_compile.wait(globals())
del async_compile

def call(arg4_1, arg7_1):
    s1 = 2
    buf0 = empty_strided((2, 196, 1000), (196000, 1000, 1), device='cuda', dtype=torch.float32).fill_(0)
    stream0 = get_cuda_stream(0)
    kernel1_xnumel = 1000*s1
    kernel1.run(arg4_1, arg7_1, buf0, kernel1_xnumel, grid=grid(kernel1_xnumel), stream=stream0)
    #kernel1[grid(kernel1_xnumel)](arg4_1, arg7_1, buf0, kernel1_xnumel, 1024)
    print(buf0[0].amax(-1)[:10], buf0.sum()) #no xnumel=2000 in the kernel, correct answer
    buf0 = empty_strided((2, 196, 1000), (196000, 1000, 1), device='cuda', dtype=torch.float32).fill_(0)
    kernel2_xnumel = 1000*s1
    kernel2.run(arg4_1, arg7_1, buf0, kernel1_xnumel, grid=grid(kernel2_xnumel), stream=stream0)
    #kernel2[grid(kernel2_xnumel)](arg4_1, arg7_1, buf0, kernel2_xnumel, 1024)

    print(buf0[0].amax(-1)[:10], buf0.sum()) #xnumel=2000 in the kernel, wrong answer
    return (buf0, )


if __name__ == "__main__":
    from torchdynamo.testing import rand_strided
    from torchinductor.utils import print_performance
    torch.manual_seed(12345)
    arg = rand_strided((2, 1, 1000), (1000, 1000, 1), device='cuda', dtype=torch.int64)
    arg4_1 = torch.arange(2000, device="cuda", dtype=torch.int64).reshape(2,1,1000) % 196 #torch.randint(196, arg.size(), device=arg.device, dtype = arg.dtype)
    arg7_1 = torch.arange(1, 2001, device="cuda", dtype=torch.float32).reshape(2, 1000)
    buf = torch.zeros(2, 196, 1000, device="cuda")
    buf.scatter_(1, arg4_1, arg7_1.reshape(2, 1, 1000))
    print(buf[0].amax(-1)[:10]) #correct answer
    call(arg4_1, arg7_1)

Output:

tensor([981., 982., 983., 984., 985., 986., 987., 988., 989., 990.],
       device='cuda:0') #correct, output of torch.Tensor.scatter
tensor([981., 982., 983., 984., 985., 986., 987., 988., 989., 990.],
       device='cuda:0') tensor(2001000., device='cuda:0') #correct, w/o xnumel=2000 in the kernel
tensor([984.,   0.,   0.,   0., 988.,   0.,   0.,   0., 992.,   0.],
       device='cuda:0') tensor(2001000., device='cuda:0') #wrong, xnumel=2000 in the kernel
@ptillet
Copy link
Collaborator

ptillet commented Oct 5, 2022

I am tracking down a major bug that was introduced with the shared signature entry point. I suspect this is also related.

@ptillet
Copy link
Collaborator

ptillet commented Oct 5, 2022

Hmmm https://github.com/openai/triton/pull/742/files was a bug in the JIT so this shouldn't affect users of the compilation API, and this was also for constexprs that aren't clustered at the end. I'll dig deeper into this case

@ngimel
Copy link
Author

ngimel commented Oct 5, 2022

JIT produces wrong values in both versions of the kernel, with compilation API one kernel is correct and one is wrong, so it still might be related to #742, let me check.

@ptillet
Copy link
Collaborator

ptillet commented Oct 5, 2022

I did a little digging. And the big difference is that the falling kernel is vectorized, while the working one isn't. This is funny since 2000 is a multiple of 16

@ngimel
Copy link
Author

ngimel commented Oct 5, 2022

Yes, but the stores use indirect addressing:

tl.store(out_ptr0 + x0 + (1000*tmp0) + (196000*x1) + tl.zeros([XBLOCK], tl.int32), tmp1, xmask)

tmp0 was just read from the indices tensor.

@ptillet
Copy link
Collaborator

ptillet commented Oct 5, 2022

Setting divisible_by_16=(0, 1, 2, 3) in the first kernel also leads to failure, which confirms that this is an issue with vectorization rather than constexprs. Indeed vectorization shouldn't be allowed, since x0 + (1000*tmp0) does not correspond to a contiguous range.

Setting

tl.store(out_ptr0 + (x0 + (1000*tmp0) + (196000*x1) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)

(i.e., just adding parentheses) also seems to resolve the issues. So my money is that there's a bug deep down in the alignment analysis pass.

The good news is that in general it is preferable to add parentheses around offset math, since it promotes int32 math over int64 pointer arithmetics (though in this particular case tmp0 is int64 so it doesn't matter), so this is something that torchinductor should probably do anyway. The bad news is that I will be away until Sunday, so I won't have time to properly fix this issue until next week.

@ngimel
Copy link
Author

ngimel commented Oct 5, 2022

Thanks! Adding parentheses for indexing math is totally doable on our side!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants