Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Making dO contiguous affects output? #100

Open
vedantroy opened this issue Aug 14, 2024 · 2 comments
Open

[Question] Making dO contiguous affects output? #100

vedantroy opened this issue Aug 14, 2024 · 2 comments

Comments

@vedantroy
Copy link

I've noticed when using Pytorch's custom autograd functions, that sometimes the stride of dO can be (0, 0, 0, 0).
Here's a very simple example: https://discuss.pytorch.org/t/getting-unusual-strides-when-using-pytorchs-autograd/208093.

In my custom wrapper for CudNN, I solve this my making dO contiguous if the stride is all zeros. Code (ctrl-f for "CHECK FOR WEIRD STRIDE"):

import cudnn
import torch
import math

def convert_to_cudnn_type(torch_type):
    if torch_type == torch.float16:
        return cudnn.data_type.HALF
    elif torch_type == torch.bfloat16:
        return cudnn.data_type.BFLOAT16
    elif torch_type == torch.float32:
        return cudnn.data_type.FLOAT
    elif torch_type == torch.int32:
        return cudnn.data_type.INT32
    elif torch_type == torch.int64:
        return cudnn.data_type.INT64
    else:
        raise ValueError("Unsupported tensor data type.")

def make_cudnn_autograd(*, num_heads, head_dim, dtype):
    assert dtype in [torch.float16, torch.bfloat16], f"Invalid dtype {dtype}"
    dtype = convert_to_cudnn_type(dtype)
    # match CuDNN's docs
    H, D = num_heads, head_dim
    del num_heads, head_dim

    cache = {}

    def assert_cudnn_shape(tensor, expected_shape):
        assert tuple(tensor.get_dim()) == expected_shape, f"Expected shape {expected_shape} but got {tensor.get_dim()}"

    def init_or_check_tensor_attrs(tensor_name, tensor):
        nonlocal cache
        for attr in ['shape', 'stride', 'dtype', 'device']:
            key = f'{tensor_name}_{attr}'
            if key not in cache:
                cache[key] = getattr(tensor, attr)
                if callable(cache[key]):
                    cache[key] = cache[key]()
            else:
                v = cache[key]() if callable(cache[key]) else cache[key]
                assert cache[key] == v, f"Expected {cache[key]} but got {v}"


    class CuDNNAttention(torch.autograd.Function):
        @staticmethod
        def forward(ctx, B, N, L, q, kv, seqlens_kv):
            assert q.shape == (B, N, H, D)
            assert kv.shape == (B, N + L, 2, H, D)
            assert seqlens_kv.shape == (B,)

            # CuDNN plans are compiled for a specific shape, stride, dtype
            # So we need to verify those attributes
            init_or_check_tensor_attrs('q', q)
            init_or_check_tensor_attrs('kv', kv)
            init_or_check_tensor_attrs('seqlens_kv', seqlens_kv)

            q = q.permute(0, 2, 1, 3)  # B N H D -> B H N D
            kv_view = kv.permute(2, 0, 3, 1, 4) # B S KV H D -> KV B H S D
            k_view, v_view = torch.unbind(kv_view, dim=0)

            assert not k_view.is_contiguous() and not v_view.is_contiguous(), f"kv should not be contiguous (unnecessary copy)"
            assert k_view.shape == (B, H, (N + L),  D), f"Got shape {k_view.shape} instead of {(B, num_heads, (N + L),  D)}"
            assert v_view.shape == (B, H, (N + L), D)

            # TODO: Is this safe?
            if 'stats' not in cache:
                cache['stats'] = torch.empty(B, H, N, 1, dtype=torch.float32, device=q.device)
                cache['seqlens_q'] = torch.tensor([N] * B, device=q.device, dtype=torch.int32).view(B, 1, 1, 1)
                cache['o'] = torch.empty_like(q)

            stats = cache['stats']
            seqlens_q = cache['seqlens_q']
            o = cache['o']

            seqlens_kv = seqlens_kv.view(B, 1, 1, 1)

            if 'compiled_graph_fwd' not in cache:
                print("Compiling CuDNN forward graph ...")
                g_fwd = cudnn.pygraph(
                    io_data_type=dtype,
                    intermediate_data_type=cudnn.data_type.FLOAT,
                    compute_data_type=cudnn.data_type.FLOAT,
                )
                cache['name_to_cu_tensor'] = {
                    'q_cu': g_fwd.tensor_like(q.detach()),
                    'k_cu': g_fwd.tensor_like(k_view.detach()),
                    'v_cu': g_fwd.tensor_like(v_view.detach()),
                    'seqlens_q_cu': g_fwd.tensor_like(seqlens_q.detach()),
                    'seqlens_kv_cu': g_fwd.tensor_like(seqlens_kv.detach())
                }
                cu_tens = cache['name_to_cu_tensor']

                o_forward, stats_forward = g_fwd.sdpa(
                    name="sdpa",
                    q=cu_tens['q_cu'],
                    k=cu_tens['k_cu'],
                    v=cu_tens['v_cu'],
                    is_inference=False,
                    attn_scale=1.0 / math.sqrt(D),
                    use_causal_mask=False,
                    use_padding_mask=True,
                    seq_len_q=cu_tens['seqlens_q_cu'],
                    seq_len_kv=cu_tens['seqlens_kv_cu']
                )

                o_forward.set_output(True).set_dim(o.shape).set_stride(o.stride()).set_data_type(dtype)
                stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_dim(stats.shape).set_stride(stats.stride())

                cu_tens['o_forward_cu'] = o_forward
                cu_tens['stats_forward_cu'] = stats_forward

                assert_cudnn_shape(cu_tens['q_cu'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['k_cu'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['v_cu'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['o_forward_cu'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['stats_forward_cu'], (B, H, N, 1))
                assert_cudnn_shape(cu_tens['seqlens_q_cu'], (B, 1, 1, 1))
                assert_cudnn_shape(cu_tens['seqlens_kv_cu'], (B, 1, 1, 1))

                g_fwd.validate()
                g_fwd.build_operation_graph()
                g_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
                g_fwd.check_support()
                g_fwd.build_plans()

                cache['compiled_graph_fwd'] = g_fwd

                # TODO: Is this safe?
                cache['workspace'] = torch.empty(
                    g_fwd.get_workspace_size(),
                    device=q.device, dtype=torch.uint8
                )
            
            name_to_cu_tensor = cache['name_to_cu_tensor']
            variant_pack_forward = {
                name_to_cu_tensor[name]: tensor for name, tensor in [
                    ('q_cu', q),
                    ('k_cu', k_view),
                    ('v_cu', v_view),
                    ('o_forward_cu', o),
                    ('stats_forward_cu', stats),
                    ('seqlens_q_cu', seqlens_q),
                    ('seqlens_kv_cu', seqlens_kv)
                ]
            }
            cache['compiled_graph_fwd'].execute(variant_pack_forward, cache['workspace'])
            ctx.save_for_backward(q, k_view, v_view, o, stats, seqlens_kv)
            ctx.B, ctx.N, ctx.L = B, N, L
            ctx.dtype = dtype
            return o

        @staticmethod
        def backward(ctx, dO):
            q, k_view, v_view, o, stats, seqlens_kv = ctx.saved_tensors
            B, N, L = ctx.B, ctx.N, ctx.L
            seqlens_q = cache['seqlens_q']
            cu_tens = cache['name_to_cu_tensor']

            init_or_check_tensor_attrs('dO', dO)

            # CHECK FOR WEIRD STRIDE
            # if dO's total stride is 0, copy it to a single element tensor
            if all(s == 0 for s in dO.stride()):
                dO = dO.contiguous()
            assert dO.shape == (B, H, N, D)
            # dO = dO.contiguous()

            if 'dQ' not in cache:
                cache['dQ'] = torch.empty_like(q)
                cache['dK'] = torch.empty_like(k_view)
                cache['dV'] = torch.empty_like(v_view)

            dQ, dK, dV = cache['dQ'], cache['dK'], cache['dV']

            if 'compiled_graph_bwd' not in cache:
                print(f"Compiling CuDNN backward graph ...")
                g_bwd = cudnn.pygraph(
                     io_data_type=dtype,
                     intermediate_data_type=cudnn.data_type.FLOAT,
                     compute_data_type=cudnn.data_type.FLOAT,
                )

                cu_tens['q_cu_bwd'] = g_bwd.tensor_like(q.detach())
                cu_tens['k_cu_bwd'] = g_bwd.tensor_like(k_view.detach())
                cu_tens['v_cu_bwd'] = g_bwd.tensor_like(v_view.detach())
                cu_tens['o_cu_bwd'] = g_bwd.tensor_like(o.detach())
                cu_tens['dO_cu_bwd'] = g_bwd.tensor_like(dO.detach())
                cu_tens['stats_cu_bwd'] = g_bwd.tensor_like(stats.detach())
                cu_tens['seqlens_q_cu_bwd'] = g_bwd.tensor_like(seqlens_q.detach())
                cu_tens['seqlens_kv_cu_bwd'] = g_bwd.tensor_like(seqlens_kv.detach())

                dQ_bwd_cu, dK_bwd_cu, dV_bwd_cu = g_bwd.sdpa_backward(
                    name="sdpa_backward",
                    q=cu_tens['q_cu_bwd'],
                    k=cu_tens['k_cu_bwd'],
                    v=cu_tens['v_cu_bwd'],
                    o=cu_tens['o_cu_bwd'],
                    dO=cu_tens['dO_cu_bwd'],
                    stats=cu_tens['stats_cu_bwd'],
                    attn_scale=1.0 / math.sqrt(D),
                    use_causal_mask=False,
                    use_padding_mask=True,
                    seq_len_q=cu_tens['seqlens_q_cu_bwd'],
                    seq_len_kv=cu_tens['seqlens_kv_cu_bwd']
                )

                dQ_bwd_cu.set_output(True).set_dim(dQ.size()).set_stride(dQ.stride())
                dK_bwd_cu.set_output(True).set_dim(dK.size()).set_stride(dK.stride())
                dV_bwd_cu.set_output(True).set_dim(dV.size()).set_stride(dV.stride())

                cu_tens['dQ_cu_bwd'] = dQ_bwd_cu
                cu_tens['dK_cu_bwd'] = dK_bwd_cu
                cu_tens['dV_cu_bwd'] = dV_bwd_cu

                assert_cudnn_shape(cu_tens['q_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['k_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['v_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['dQ_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['dK_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['dV_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['o_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['dO_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['stats_cu_bwd'], (B, H, N, 1))
                assert_cudnn_shape(cu_tens['seqlens_q_cu_bwd'], (B, 1, 1, 1))
                assert_cudnn_shape(cu_tens['seqlens_kv_cu_bwd'], (B, 1, 1, 1))

                g_bwd.validate()
                g_bwd.build_operation_graph()
                g_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
                g_bwd.check_support()
                g_bwd.build_plans()

                cache['compiled_graph_bwd'] = g_bwd
                cache['workspace'] = torch.empty(
                    max(cache['compiled_graph_fwd'].get_workspace_size(), cache['compiled_graph_bwd'].get_workspace_size()),
                    device=q.device, dtype=torch.uint8
                )

            variant_pack_backward = {
                cu_tens[name]: tensor for name, tensor in [
                    ('dQ_cu_bwd', cache['dQ']),
                    ('dK_cu_bwd', cache['dK']),
                    ('dV_cu_bwd', cache['dV']),
                    ('q_cu_bwd', q),
                    ('k_cu_bwd', k_view),
                    ('v_cu_bwd', v_view),
                    ('o_cu_bwd', o),
                    ('dO_cu_bwd', dO),
                    ('stats_cu_bwd', stats),
                    ('seqlens_q_cu_bwd', seqlens_q),
                    ('seqlens_kv_cu_bwd', seqlens_kv)
                ]
            }

            cache['compiled_graph_bwd'].execute(variant_pack_backward, cache['workspace'])
            assert cache['dQ'].shape == (B, H, N, D)
            dQ = cache['dQ'].permute(0, 2, 1, 3) # B H N D -> B N H D

            assert cache['dK'].shape == (B, H, N + L, D)
            assert cache['dV'].shape == (B, H, N + L, D)

            dKV = torch.stack([cache['dK'], cache['dV']], dim=2)
            assert dKV.shape == (B, H, 2, N + L, D)

            dKV = dKV.permute(0, 3, 2, 1, 4) # B H 2 N D -> B N 2 H D

            return None, None, None, dQ, dKV, None

    return CuDNNAttention

The problem is, when I do this, I get massive numerical error. Do you have thoughts on why making dO contiguous might cause issues?

@Anerudhan
Copy link
Collaborator

Hi @vedantroy, can you try if this persists in 1.5.2 release?

We have identified some numerical issues in 1.6.0 that we will be addressing in 1.6.1

Thanks
Anerudhan

@mnicely
Copy link
Collaborator

mnicely commented Aug 22, 2024

@vedantroy can you check in the latest 1.6.1 release?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants