Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong values in broadcasted computation #1074

Closed
Edenhofer opened this issue Jan 19, 2023 · 2 comments
Closed

Wrong values in broadcasted computation #1074

Edenhofer opened this issue Jan 19, 2023 · 2 comments
Assignees

Comments

@Edenhofer
Copy link

Triton produces wrong values in a broadcasted subtraction:

#!/usr/bin/env python3

import torch

import triton
import triton.language as tl


@triton.jit
def toy_kernel(
    dist_ptr,
    coo_bounds_ptr,
    output_ptr,
    block_size: tl.constexpr,
    n_elements: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * block_size

    block_off = block_start + tl.arange(0, block_size)
    block_mask = block_off < n_elements
    dist = tl.load(dist_ptr + block_off, mask=block_mask)
    coo_bounds = tl.load(coo_bounds_ptr + block_off, mask=block_mask)

    coo_bounds = tl.reshape(coo_bounds, (1, n_elements))
    dist = tl.reshape(dist, (block_size, 1))
    nr = coo_bounds
    print(nr)
    nr -= dist
    print(nr)

    output = tl.sum(nr, axis=1)
    block_off = block_start + tl.arange(0, block_size)
    tl.store(output_ptr + block_off, output, mask=block_mask)


def toy(dist, coo_bounds):
    block_size = n_elements = 4
    grid = ((dist.shape[0] - 1) // block_size + 1, )
    output = torch.empty_like(dist)
    toy_kernel[grid](
        dist,
        coo_bounds,
        output,
        block_size=block_size,
        n_elements=n_elements,
        num_warps=2,
    )
    return output


dist = 1. + torch.arange(4, dtype=torch.float32, device='cuda')
coo_bounds = 2. + torch.arange(4, dtype=torch.float32, device='cuda')
print(toy(dist, coo_bounds))

# Reference implementation in torch
o = (coo_bounds.reshape(1, -1) - dist.reshape(-1, 1)).sum(axis=1)
print(o)

produces the following incompatible results

tensor([1., 0., 0., 0.], device='cuda:0')  # triton
tensor([10.,  6.,  2., -2.], device='cuda:0')  # reference implementation
@Jokeren Jokeren self-assigned this Jan 19, 2023
@Jokeren
Copy link
Contributor

Jokeren commented Jan 19, 2023

The problem should have been fixed using triton master. Please verify.

@Edenhofer
Copy link
Author

Yup, it works when building from master. Sorry for the noise.

ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this issue Aug 5, 2024
…tel SPIR-V Extension (triton-lang#1074)

Related to issue triton-lang#1001.

This pass is already lowering `arith::TruncFOp` and `arith::ExtFOp`, so
there was the original suggesting of lowering to arith operators didn't
make sense, but I have replace most of the bit operations with calls to
an Intel SPIR-V extension that translates to a MOV instruction in vISA.
I couldn't remove the round to zero mode of `convertFp32ToBf16`, since
the extension only supports round to closest even. The code that calls
`convertFp32ToBf16` uses round to closest even by default, so that's
fine.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants