Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator][Torch] Add causal fmha and torch sdpa mapping #238

Merged
merged 10 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,48 @@ def mish(x: Tensor, inplace: bool = False):
return ops.multiply(x, ops.tanh(ops.softplus(x, 1, 20)))


@register_function(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(
q: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor = None, dropout_p: float = 0.0, is_causal: bool = False
):
import math

if not math.isclose(dropout_p, 0.0):
warnings.warn_once('hidet: attention dropout is not supported. Treat as dropout_p=0.0')

k_rank = len(k.shape)
# transpose last 2 dimensions of k, and normalize by sqrt(head_dim)
k_transpose_scaled = ops.transpose(k, [i for i in range(k_rank - 2)] + [k_rank - 1, k_rank - 2]) / math.sqrt(
k.shape[-1]
)

from hidet import boolean, float16

type_match = (
q.dtype == k.dtype == v.dtype == float16
and same_list(q.shape, v.shape)
and len(q.shape) == len(k_transpose_scaled.shape)
and (q.shape[-2], q.shape[-1]) == (k_transpose_scaled.shape[-1], k_transpose_scaled.shape[-2])
)
fmha_requirements = q.shape[-1] <= 160 and (
attn_mask is None or attn_mask is not None and attn_mask.dtype == float16
)
if type_match and fmha_requirements:
return ops.attention(q, k_transpose_scaled, v, attn_mask, is_causal)

qk = ops.matmul(q, k_transpose_scaled)
if attn_mask is not None:
if attn_mask.dtype.is_float():
qk = qk + attn_mask
elif attn_mask.dtype == boolean:
neginfs = ops.full(qk.shape, value=qk.dtype.min_value, dtype=qk.dtype, device=qk.device)
qk = ops.where(attn_mask, qk, neginfs)
else:
raise NotImplementedError('hidet: attention mask must be bool or float')
out = ops.matmul(ops.softmax(qk, axis=-1), v)
return out


@register_function(torch.gather)
def gather(x: Tensor, dim: int, index: Tensor, *, sparse_grad=False, out=None):
if sparse_grad:
Expand Down
72 changes: 54 additions & 18 deletions python/hidet/graph/ops/definitions/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode, compute, input_like
from hidet.graph.ops.definitions.utils import broadcast_shape, broadcast_shapes, broadcast_indices
from hidet.graph.ops.definitions.utils import can_broadcast
from hidet.utils import same_list
from hidet.utils.py import cdiv, prod
from .attention_mask import AttnMaskAddOp


class AttnTask(Task):
def __init__(self, name: str, q: TensorNode, k: TensorNode, v: TensorNode):
def __init__(self, name: str, q: TensorNode, k: TensorNode, v: TensorNode, is_causal: bool):
q_shape = q.const_shape
k_shape = k.const_shape
v_shape = v.const_shape
Expand All @@ -40,6 +41,7 @@ def __init__(self, name: str, q: TensorNode, k: TensorNode, v: TensorNode):
o_head, q_head, k_head, v_head = o_shape[:-2], q_shape[:-2], k_shape[:-2], v_shape[:-2]
qk_head = broadcast_shape(q_head, k_head)

# ToDo: Add causal mask to compute definition (Will not affect results since schedule template will be used)
qk = compute(
name='qk',
shape=qk_head + [n_size, n_size],
Expand Down Expand Up @@ -95,7 +97,7 @@ def __init__(self, name: str, q: TensorNode, k: TensorNode, v: TensorNode):
reduce_type='sum',
),
)
super().__init__(name=name, inputs=[q, k, v], outputs=[o])
super().__init__(name=name, inputs=[q, k, v], outputs=[o], attributes={'is_causal': is_causal})

def allow_prologue(self) -> bool:
return False
Expand Down Expand Up @@ -140,6 +142,7 @@ def calc_swizzle_size(d):
return -1, -1

task = self
is_causal = task.attrs['is_causal']
node_q, node_k, node_v, node_o = task.inputs[0], task.inputs[1], task.inputs[2], task.outputs[0]
q_shape: List[int] = list(node_q.const_shape)
k_shape: List[int] = list(node_k.const_shape)
Expand All @@ -164,8 +167,8 @@ def calc_swizzle_size(d):
tune.check(n_size >= 64)
block_j = min(block_j, n_size)

acc_dtype = f16
sm_dtype = f32
acc_dtype = f16 # must be f16 for now. f32 will fail to compile
sm_dtype = f32 # currently changing to f16 will not boost performance
mma_m = mma_config.m
mma_n = mma_config.n
mma_k = mma_config.k
Expand Down Expand Up @@ -228,7 +231,6 @@ def calc_swizzle_size(d):
i_split = n_tiles
i_tiles_per_tb = 1
i_rows_per_tb = i_tiles_per_tb * block_i
j_tiles = cdiv(n_size, block_j)

smem_bytes_q = dtype_size * block_i * dpad_size
# k and v requires double memory for double buffering pipeline
Expand Down Expand Up @@ -587,8 +589,11 @@ def attn_kernel(
regs_acc_o[a, b, c] = acc_dtype.zero
regs_o[a, b, c] = acc_dtype.zero

j_tiles = cdiv(n_size, block_j)
if is_causal:
j_tiles = cdiv((blockIdx.x + 1) * block_i, block_j)
for j in range(j_tiles):
offset_j = block_j * j # 256j
offset_j = block_j * j

# ----------------------------
# Compute QK = Qi * Kj
Expand Down Expand Up @@ -626,6 +631,24 @@ def attn_kernel(
)
cp_async_wait_all()
syncthreads()

# Preload first tile of v into shared memory
copy_v_g2s(v, ~smem_v[0, 0, 0], offset_j)

# Apply Causal Masking
if is_causal:
for mma_i, mma_j in grid(mmas_per_warp_m, mmas_per_warp_n):
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
wi, wj, wk = spatial(warp_count_m, warp_count_n, warp_count_k).map(warp_id)
p = 0
for ti, tj in mma_config.c_store_map.on(lane_id):
delta_m = offset_i + wi * warp_elems_m + mma_i * mma_m + ti
delta_n = offset_j + wj * warp_elems_n + mma_j * mma_n + tj
if delta_n > delta_m:
regs_acc[mma_i, mma_j, p] = acc_dtype.min_value
p += 1

# Iterative softmax, and write result matrix into shared memory
qk_softmax_reduce(smem_qk, smem_mij, smem_lij, regs_acc)
# ----------------------------

Expand All @@ -634,11 +657,8 @@ def attn_kernel(
for a, b, c in grid(mmas_per_warp_m_o, mmas_per_warp_n_o, mma_config.c_elements):
regs_acc_o[a, b, c] = acc_dtype.zero

# Copy first tile of k into shared memory
copy_v_g2s(v, ~smem_v[0, 0, 0], offset_j)
cp_async_wait_all()
syncthreads()

for k1 in range(k_tiles_o):
# Load Vj into Smem
copy_v_g2s(v, ~smem_v[(k1 + 1) % 2, 0, 0], offset_j + (k1 + 1) * block_k_o)
Expand Down Expand Up @@ -719,24 +739,40 @@ def attn_kernel(


class AttnOp(Operator):
def __init__(self, q: Tensor, k: Tensor, v: Tensor):
def __init__(self, q: Tensor, k: Tensor, v: Tensor, is_causal: bool = False):
super().__init__(
inputs=[q, k, v],
task=AttnTask('attn', input_like(q, 'q'), input_like(k, 'k'), input_like(v, 'v')),
attributes={},
task=AttnTask('attn', input_like(q, 'q'), input_like(k, 'k'), input_like(v, 'v'), is_causal),
attributes={'is_causal': is_causal},
)


def attention(q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False) -> Tensor:
if mask is not None and is_causal is True:
raise ValueError("mask and is_causal cannot be set at the same time")

if not q.dtype == k.dtype == v.dtype == f16:
raise ValueError("Attention only supports float16 inputs")

if not (
same_list(q.shape, v.shape)
and len(q.shape) == len(k.shape)
and (q.shape[-2], q.shape[-1]) == (k.shape[-1], k.shape[-2])
):
raise ValueError(
'Attention expect tensor Q[..., S, D], K[..., D, S], V[..., S, D]'
+ ', got Q {}, K {}, V {}'.format(q.shape, k.shape, v.shape)
)

if q.shape[-1] > 160:
raise ValueError('Attention only supports head dim <= 160, got {}'.format(q.shape[-1]))

if mask is None:
return AttnOp(q, k, v).get_output(0)
return AttnOp(q, k, v, is_causal).get_output(0)

q_shape = q.shape
k_shape = k.shape
mask_shape = mask.shape
seq_len = q.shape[-2]

q_head, k_head = (q_shape[:-2], k_shape[:-2])
q_head, k_head = (q.shape[:-2], k.shape[:-2])
qk_head = broadcast_shape(q_head, k_head)
qk_shape = qk_head + [seq_len, seq_len]
if not can_broadcast(mask_shape, qk_shape):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,10 @@ def attn_kernel(
)
cp_async_wait_all()
syncthreads()

# Preload first tile of v into shared memory
copy_v_g2s(v, ~smem_v[0, 0, 0], offset_j)

# Apply Masking
qk_head_index = list(spatial(*qk_head).map(blockIdx.y))
for mma_i, mma_j in grid(mmas_per_warp_m, mmas_per_warp_n):
Expand All @@ -659,6 +663,8 @@ def attn_kernel(
)
]
p += 1

# Iterative softmax, and write result matrix into shared memory
qk_softmax_reduce(smem_qk, smem_mij, smem_lij, regs_acc)
# ----------------------------

Expand All @@ -667,11 +673,8 @@ def attn_kernel(
for a, b, c in grid(mmas_per_warp_m_o, mmas_per_warp_n_o, mma_config.c_elements):
regs_acc_o[a, b, c] = acc_dtype.zero

# Copy first tile of k into shared memory
copy_v_g2s(v, ~smem_v[0, 0, 0], offset_j)
cp_async_wait_all()
syncthreads()

for k1 in range(k_tiles_o):
# Load Vj into Smem
copy_v_g2s(v, ~smem_v[(k1 + 1) % 2, 0, 0], offset_j + (k1 + 1) * block_k_o)
Expand Down
46 changes: 46 additions & 0 deletions tests/frontends/torch/test_torch_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from hidet.testing.torch_utils import check_module, FunctionalModule


@pytest.mark.parametrize('shape', [[1, 16, 1024, 128], [4, 4, 4096, 64]])
@pytest.mark.parametrize('attn_mask_type', [None, 'bool', 'float16', 'causal'])
def test_sdpa(shape, attn_mask_type):
q = torch.randn(shape, dtype=torch.float16)
k = torch.randn(shape, dtype=torch.float16)
v = torch.randn(shape, dtype=torch.float16)
is_causal = False
attn_mask = None
mask_shape = q.shape[:-2] + (q.shape[-2], q.shape[-2])
if attn_mask_type == 'causal':
is_causal = True
elif attn_mask_type == 'bool':
attn_mask = torch.rand(mask_shape) > 0.5
elif attn_mask_type == 'float16':
attn_mask = torch.randn(mask_shape, dtype=torch.float16)

check_module(
FunctionalModule(
op=lambda _q, _k, _v, _attn_mask, _is_causal: torch.nn.functional.scaled_dot_product_attention(
_q, _k, _v, attn_mask=_attn_mask, is_causal=_is_causal
)
),
[q, k, v, attn_mask, is_causal],
atol=1e-2,
rtol=1e-2,
)


if __name__ == '__main__':
pytest.main([__file__])