Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plotfi smem evolution july #4

Open
wants to merge 11 commits into
base: plotfi-smem-evolution-july-base
Choose a base branch
from

Conversation

plotfi
Copy link
Owner

@plotfi plotfi commented Jul 18, 2024

The following are a redo of Triton SMEM based on discussions around ttg.local_gather.

The RFC for local_gather is at https://docs.google.com/document/d/1rmYVe8tTRrPcVHkS5GcmcVl_tVjgMbcfiFWc_ihkTuU/edit?usp=sharing

The target for how to use Shared Memory with Triton using these patches is trending towards something like this:

@triton.jit
def triton_local_gather(data_ ptr, index_ptr, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xbase = tl.arange(0, XBLOCK)
    rbase = tl.arange(0, RBLOCK)

    x = xoffset + xbase
    t = tl.load(data_ ptr + x)
    s = tl.local_copy(t)

    for offset in range(0, RBLOCK):
        r = roffset + rbase
        indices = tl.load(index_ptr + r)
        g = tl.gather(s, indices)
        # accumulate based on g

  # tl.store based on g accumulation

Local gathers take a pointer to shared memory and a tensor of indices to
load from. This patch implementes the Op.

(cherry picked from commit 3dac0e3)
@plotfi plotfi force-pushed the plotfi-smem-evolution-july branch from 08112f0 to 6420b05 Compare July 19, 2024 17:25
Removed cast, set default order based on tensor rank.
local_copy -> local_alloc

gather -> local_gather

Eventually I plan to make tl.gather generate a tt::LoadOp if the ptr
type is one to global memory. tt::GatherOp ideally should either go back
into tt:LoadOp or automatically generate to a tt::LoadOp or
ttg::LocalGatherOp.
This is required for passing MemDesc values via CallOps

TODO: Replace this with a generic opaque Python type wrapper for MLIR Types
python/src/ir.cc Outdated Show resolved Hide resolved
"triton_gpu.threads-per-warp" = 32 : i32
}
{
tt.func public @triton_global_gather(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change func name to triton_local_gather

plotfi added a commit to plotfi/generative-recommenders that referenced this pull request Aug 22, 2024
…totuning

The following change requires a private patchset that is not yet
available outside of plotfi/triton#4

This patch adds usage of shared memory using the tl.local_copy and
tl.gather operations for the TW (time bias) and PW (position bias)
tensors for the forward pass kernel.

Autotuning is also hooked up to the usage of these shared memory operators
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant