Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch-frontend] use new register method to register byteir.flash_att… #473

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
from typing import List
import torch
import math
from torch.library import Library

OPERATORS = []


def op(schema):
def inner(f):
# TODO: Refactor the Library API so this is less rage inducing
# TODO: Perhaps the namespace should be directly based on Python
# module
if "::" in schema:
ns = schema.split("::", 2)[0]
else:
ns = "contrib"
# TODO: Library doesn't allow FRAGMENT, need to allow it
lib = Library(ns, "FRAGMENT")
name = lib.define(schema)
if "::" in name:
name = name.split("::", 2)[1]
lib.impl(name, f, "CompositeExplicitAutograd")
OPERATORS.append(lib)
return getattr(getattr(torch.ops, ns), name)

return inner


@op(
"byteir::flash_attn_fwd(Tensor q, Tensor k, Tensor v, float dropout_p, float softmax_scale, bool causal, bool return_softmax) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
)

@torch.library.custom_op("byteir::flash_attn_fwd", mutates_args=())
def byteir_flash_attn_fwd(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, return_softmax: bool
) -> List[torch.Tensor]:
sizes = q.shape
batch_size = sizes[0]
seqlen_q = sizes[1]
num_heads = sizes[2]
seqlen_k = k.shape[1]

rng = torch.empty((2), dtype=torch.int64, device="meta")
softmax_lse = torch.empty(
(batch_size, num_heads, seqlen_q), dtype=torch.float, device="meta"
)
p = None
if return_softmax:
p = torch.empty(
(batch_size, num_heads, seqlen_q, seqlen_k),
dtype=torch.float,
device="meta",
)
q_padded = q
k_padded = k
v_padded = v
out = torch.empty_like(q_padded)
out_padded = torch.empty_like(out)
return out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng

@torch.library.register_fake("byteir::flash_attn_fwd")
def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
sizes = q.shape
batch_size = sizes[0]
Expand All @@ -55,9 +57,32 @@ def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_soft
return out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng


@op(
"byteir::flash_attn_bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, float dropout_p, float softmax_scale, bool causal, Tensor rng) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"
)
@torch.library.custom_op("byteir::flash_attn_bwd", mutates_args=())
def byteir_flash_attn_bwd(
dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, rng_state: torch.Tensor
) -> List[torch.Tensor]:
sizes = q.shape
batch_size = sizes[0]
seqlen_q = sizes[1]
num_heads = sizes[2]
seqlen_q_rounded = ((seqlen_q + 127) // 128) * 128
head_size = sizes[3]
head_size_rounded = ((head_size + 31) // 32) * 32
dq_accum = torch.empty(
(batch_size, num_heads, seqlen_q_rounded, head_size_rounded),
dtype=torch.float,
device="meta",
)
softmax_d = torch.empty(
(batch_size, num_heads, seqlen_q_rounded), dtype=torch.float, device="meta"
)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
return dq, dk, dv, softmax_d, dq_accum


@torch.library.register_fake("byteir::flash_attn_bwd")
def byteir_flash_attn_bwd(
dout, q, k, v, out, softmax_lse, dropout_p, softmax_scale, causal, rng_state
):
Expand All @@ -82,9 +107,23 @@ def byteir_flash_attn_bwd(
return dq, dk, dv, softmax_d, dq_accum


@op(
"byteir::flash_attn_kvcache(Tensor q, Tensor k, Tensor v, Tensor kcache, Tensor vcache, Tensor seqlen_k, float softmax_scale, bool causal) -> (Tensor, Tensor)"
)
@torch.library.custom_op("byteir::flash_attn_kvcache", mutates_args=())
def byteir_flash_attn_kvcache(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kcache: torch.Tensor, vcache: torch.Tensor, seqlen_k: torch.Tensor, softmax_scale: float, causal: bool
) -> List[torch.Tensor]:
sizes = q.shape
batch_size = sizes[0]
seqlen_q = sizes[1]
num_heads = sizes[2]

softmax_lse = torch.empty(
(batch_size, num_heads, seqlen_q), dtype=torch.float, device="meta"
)
out = torch.empty_like(q)
return out, softmax_lse


@torch.library.register_fake("byteir::flash_attn_kvcache")
def byteir_flash_attn_kvcache(q, k, v, kcache, vcache, seqlen_k, softmax_scale, causal):
sizes = q.shape
batch_size = sizes[0]
Expand Down
Loading