Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline pass hanging #1652

Closed
Jokeren opened this issue May 11, 2023 · 2 comments
Closed

Pipeline pass hanging #1652

Jokeren opened this issue May 11, 2023 · 2 comments
Labels

Comments

@Jokeren
Copy link
Contributor

Jokeren commented May 11, 2023

import torch

import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
    input_ptr,  # *Pointer* to indices tensor[N, ZIN]
    kernel_ptr,  # [K, ZIN, ZOUT]
    output_ptr, # [N, ZOUT]
    kernel_mask_ptr,
    n_indices,
    K: tl.constexpr,
    ZIN: tl.constexpr,
    ZOUT: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # assume power-of-2
    offset_zout = tl.arange(0, ZOUT)
    offset_zin = tl.arange(0, ZIN)

    # output features: [num_voxels, ZOUT]
    mask_indices = offsets < n_indices
    mask_zin = offset_zin < ZIN
    mask_zout = offset_zout < ZOUT

    input_mask = mask_indices[:, None] * mask_zin[None, :]
    output_mask = mask_indices[:, None] * mask_zout[None, :]

    input_a_offsets = offsets[:, None] * ZIN + offset_zin[None, :]
    input_a = tl.load(
        input_ptr + input_a_offsets,
        mask=input_mask,
        other=0,
    )
    output_features = tl.zeros((BLOCK_SIZE, ZOUT), dtype=tl.float32)
    for k in range(0,K):
        mask_val = tl.load(kernel_mask_ptr + k)
        if mask_val > 0:
            input_b_offsets = offset_zin[:, None] * ZOUT + offset_zout[None, :]
            input_b_offsets += k*ZIN*ZOUT
            input_b = tl.load(kernel_ptr + input_b_offsets)
            output_features += tl.dot(input_a, input_b)

    output_offsets = offsets[:, None] * ZOUT + offset_zout[None, :]
    tl.store(output_ptr + output_offsets, output_features, mask=output_mask)


def matmul_fn(input: torch.Tensor, kernel: torch.Tensor):
    (n_elements, ZIN) = input.shape
    (K, ZIN, ZOUT) = kernel.shape
    output = torch.zeros((n_elements, ZOUT), dtype=torch.float16, device="cuda")
    mask = torch.ones(K, dtype=torch.int32, device="cuda")
    mask[0] = 0
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    matmul_kernel[grid](input, kernel, output, mask, n_elements, K, ZIN, ZOUT, BLOCK_SIZE=32)
    return output


if __name__ == '__main__':
    ZIN=16
    ZOUT=16
    K=3
    n_elements=32
    input = torch.randn((n_elements, ZIN), dtype=torch.float16, device="cuda")
    kernel = torch.randn((K, ZIN, ZOUT), dtype=torch.float16, device="cuda")

    output = matmul_fn(input, kernel)
@Jokeren Jokeren added the bug label May 11, 2023
@ptillet
Copy link
Collaborator

ptillet commented May 11, 2023

Gah. I think this might be related to #1608

@Jokeren
Copy link
Contributor Author

Jokeren commented May 11, 2023

Maybe we should revisit the layout conversion part soon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants