We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I've noticed when using Pytorch's custom autograd functions, that sometimes the stride of dO can be (0, 0, 0, 0). Here's a very simple example: https://discuss.pytorch.org/t/getting-unusual-strides-when-using-pytorchs-autograd/208093.
dO
(0, 0, 0, 0)
In my custom wrapper for CudNN, I solve this my making dO contiguous if the stride is all zeros. Code (ctrl-f for "CHECK FOR WEIRD STRIDE"):
import cudnn import torch import math def convert_to_cudnn_type(torch_type): if torch_type == torch.float16: return cudnn.data_type.HALF elif torch_type == torch.bfloat16: return cudnn.data_type.BFLOAT16 elif torch_type == torch.float32: return cudnn.data_type.FLOAT elif torch_type == torch.int32: return cudnn.data_type.INT32 elif torch_type == torch.int64: return cudnn.data_type.INT64 else: raise ValueError("Unsupported tensor data type.") def make_cudnn_autograd(*, num_heads, head_dim, dtype): assert dtype in [torch.float16, torch.bfloat16], f"Invalid dtype {dtype}" dtype = convert_to_cudnn_type(dtype) # match CuDNN's docs H, D = num_heads, head_dim del num_heads, head_dim cache = {} def assert_cudnn_shape(tensor, expected_shape): assert tuple(tensor.get_dim()) == expected_shape, f"Expected shape {expected_shape} but got {tensor.get_dim()}" def init_or_check_tensor_attrs(tensor_name, tensor): nonlocal cache for attr in ['shape', 'stride', 'dtype', 'device']: key = f'{tensor_name}_{attr}' if key not in cache: cache[key] = getattr(tensor, attr) if callable(cache[key]): cache[key] = cache[key]() else: v = cache[key]() if callable(cache[key]) else cache[key] assert cache[key] == v, f"Expected {cache[key]} but got {v}" class CuDNNAttention(torch.autograd.Function): @staticmethod def forward(ctx, B, N, L, q, kv, seqlens_kv): assert q.shape == (B, N, H, D) assert kv.shape == (B, N + L, 2, H, D) assert seqlens_kv.shape == (B,) # CuDNN plans are compiled for a specific shape, stride, dtype # So we need to verify those attributes init_or_check_tensor_attrs('q', q) init_or_check_tensor_attrs('kv', kv) init_or_check_tensor_attrs('seqlens_kv', seqlens_kv) q = q.permute(0, 2, 1, 3) # B N H D -> B H N D kv_view = kv.permute(2, 0, 3, 1, 4) # B S KV H D -> KV B H S D k_view, v_view = torch.unbind(kv_view, dim=0) assert not k_view.is_contiguous() and not v_view.is_contiguous(), f"kv should not be contiguous (unnecessary copy)" assert k_view.shape == (B, H, (N + L), D), f"Got shape {k_view.shape} instead of {(B, num_heads, (N + L), D)}" assert v_view.shape == (B, H, (N + L), D) # TODO: Is this safe? if 'stats' not in cache: cache['stats'] = torch.empty(B, H, N, 1, dtype=torch.float32, device=q.device) cache['seqlens_q'] = torch.tensor([N] * B, device=q.device, dtype=torch.int32).view(B, 1, 1, 1) cache['o'] = torch.empty_like(q) stats = cache['stats'] seqlens_q = cache['seqlens_q'] o = cache['o'] seqlens_kv = seqlens_kv.view(B, 1, 1, 1) if 'compiled_graph_fwd' not in cache: print("Compiling CuDNN forward graph ...") g_fwd = cudnn.pygraph( io_data_type=dtype, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, ) cache['name_to_cu_tensor'] = { 'q_cu': g_fwd.tensor_like(q.detach()), 'k_cu': g_fwd.tensor_like(k_view.detach()), 'v_cu': g_fwd.tensor_like(v_view.detach()), 'seqlens_q_cu': g_fwd.tensor_like(seqlens_q.detach()), 'seqlens_kv_cu': g_fwd.tensor_like(seqlens_kv.detach()) } cu_tens = cache['name_to_cu_tensor'] o_forward, stats_forward = g_fwd.sdpa( name="sdpa", q=cu_tens['q_cu'], k=cu_tens['k_cu'], v=cu_tens['v_cu'], is_inference=False, attn_scale=1.0 / math.sqrt(D), use_causal_mask=False, use_padding_mask=True, seq_len_q=cu_tens['seqlens_q_cu'], seq_len_kv=cu_tens['seqlens_kv_cu'] ) o_forward.set_output(True).set_dim(o.shape).set_stride(o.stride()).set_data_type(dtype) stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_dim(stats.shape).set_stride(stats.stride()) cu_tens['o_forward_cu'] = o_forward cu_tens['stats_forward_cu'] = stats_forward assert_cudnn_shape(cu_tens['q_cu'], (B, H, N, D)) assert_cudnn_shape(cu_tens['k_cu'], (B, H, N + L, D)) assert_cudnn_shape(cu_tens['v_cu'], (B, H, N + L, D)) assert_cudnn_shape(cu_tens['o_forward_cu'], (B, H, N, D)) assert_cudnn_shape(cu_tens['stats_forward_cu'], (B, H, N, 1)) assert_cudnn_shape(cu_tens['seqlens_q_cu'], (B, 1, 1, 1)) assert_cudnn_shape(cu_tens['seqlens_kv_cu'], (B, 1, 1, 1)) g_fwd.validate() g_fwd.build_operation_graph() g_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) g_fwd.check_support() g_fwd.build_plans() cache['compiled_graph_fwd'] = g_fwd # TODO: Is this safe? cache['workspace'] = torch.empty( g_fwd.get_workspace_size(), device=q.device, dtype=torch.uint8 ) name_to_cu_tensor = cache['name_to_cu_tensor'] variant_pack_forward = { name_to_cu_tensor[name]: tensor for name, tensor in [ ('q_cu', q), ('k_cu', k_view), ('v_cu', v_view), ('o_forward_cu', o), ('stats_forward_cu', stats), ('seqlens_q_cu', seqlens_q), ('seqlens_kv_cu', seqlens_kv) ] } cache['compiled_graph_fwd'].execute(variant_pack_forward, cache['workspace']) ctx.save_for_backward(q, k_view, v_view, o, stats, seqlens_kv) ctx.B, ctx.N, ctx.L = B, N, L ctx.dtype = dtype return o @staticmethod def backward(ctx, dO): q, k_view, v_view, o, stats, seqlens_kv = ctx.saved_tensors B, N, L = ctx.B, ctx.N, ctx.L seqlens_q = cache['seqlens_q'] cu_tens = cache['name_to_cu_tensor'] init_or_check_tensor_attrs('dO', dO) # CHECK FOR WEIRD STRIDE # if dO's total stride is 0, copy it to a single element tensor if all(s == 0 for s in dO.stride()): dO = dO.contiguous() assert dO.shape == (B, H, N, D) # dO = dO.contiguous() if 'dQ' not in cache: cache['dQ'] = torch.empty_like(q) cache['dK'] = torch.empty_like(k_view) cache['dV'] = torch.empty_like(v_view) dQ, dK, dV = cache['dQ'], cache['dK'], cache['dV'] if 'compiled_graph_bwd' not in cache: print(f"Compiling CuDNN backward graph ...") g_bwd = cudnn.pygraph( io_data_type=dtype, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, ) cu_tens['q_cu_bwd'] = g_bwd.tensor_like(q.detach()) cu_tens['k_cu_bwd'] = g_bwd.tensor_like(k_view.detach()) cu_tens['v_cu_bwd'] = g_bwd.tensor_like(v_view.detach()) cu_tens['o_cu_bwd'] = g_bwd.tensor_like(o.detach()) cu_tens['dO_cu_bwd'] = g_bwd.tensor_like(dO.detach()) cu_tens['stats_cu_bwd'] = g_bwd.tensor_like(stats.detach()) cu_tens['seqlens_q_cu_bwd'] = g_bwd.tensor_like(seqlens_q.detach()) cu_tens['seqlens_kv_cu_bwd'] = g_bwd.tensor_like(seqlens_kv.detach()) dQ_bwd_cu, dK_bwd_cu, dV_bwd_cu = g_bwd.sdpa_backward( name="sdpa_backward", q=cu_tens['q_cu_bwd'], k=cu_tens['k_cu_bwd'], v=cu_tens['v_cu_bwd'], o=cu_tens['o_cu_bwd'], dO=cu_tens['dO_cu_bwd'], stats=cu_tens['stats_cu_bwd'], attn_scale=1.0 / math.sqrt(D), use_causal_mask=False, use_padding_mask=True, seq_len_q=cu_tens['seqlens_q_cu_bwd'], seq_len_kv=cu_tens['seqlens_kv_cu_bwd'] ) dQ_bwd_cu.set_output(True).set_dim(dQ.size()).set_stride(dQ.stride()) dK_bwd_cu.set_output(True).set_dim(dK.size()).set_stride(dK.stride()) dV_bwd_cu.set_output(True).set_dim(dV.size()).set_stride(dV.stride()) cu_tens['dQ_cu_bwd'] = dQ_bwd_cu cu_tens['dK_cu_bwd'] = dK_bwd_cu cu_tens['dV_cu_bwd'] = dV_bwd_cu assert_cudnn_shape(cu_tens['q_cu_bwd'], (B, H, N, D)) assert_cudnn_shape(cu_tens['k_cu_bwd'], (B, H, N + L, D)) assert_cudnn_shape(cu_tens['v_cu_bwd'], (B, H, N + L, D)) assert_cudnn_shape(cu_tens['dQ_cu_bwd'], (B, H, N, D)) assert_cudnn_shape(cu_tens['dK_cu_bwd'], (B, H, N + L, D)) assert_cudnn_shape(cu_tens['dV_cu_bwd'], (B, H, N + L, D)) assert_cudnn_shape(cu_tens['o_cu_bwd'], (B, H, N, D)) assert_cudnn_shape(cu_tens['dO_cu_bwd'], (B, H, N, D)) assert_cudnn_shape(cu_tens['stats_cu_bwd'], (B, H, N, 1)) assert_cudnn_shape(cu_tens['seqlens_q_cu_bwd'], (B, 1, 1, 1)) assert_cudnn_shape(cu_tens['seqlens_kv_cu_bwd'], (B, 1, 1, 1)) g_bwd.validate() g_bwd.build_operation_graph() g_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) g_bwd.check_support() g_bwd.build_plans() cache['compiled_graph_bwd'] = g_bwd cache['workspace'] = torch.empty( max(cache['compiled_graph_fwd'].get_workspace_size(), cache['compiled_graph_bwd'].get_workspace_size()), device=q.device, dtype=torch.uint8 ) variant_pack_backward = { cu_tens[name]: tensor for name, tensor in [ ('dQ_cu_bwd', cache['dQ']), ('dK_cu_bwd', cache['dK']), ('dV_cu_bwd', cache['dV']), ('q_cu_bwd', q), ('k_cu_bwd', k_view), ('v_cu_bwd', v_view), ('o_cu_bwd', o), ('dO_cu_bwd', dO), ('stats_cu_bwd', stats), ('seqlens_q_cu_bwd', seqlens_q), ('seqlens_kv_cu_bwd', seqlens_kv) ] } cache['compiled_graph_bwd'].execute(variant_pack_backward, cache['workspace']) assert cache['dQ'].shape == (B, H, N, D) dQ = cache['dQ'].permute(0, 2, 1, 3) # B H N D -> B N H D assert cache['dK'].shape == (B, H, N + L, D) assert cache['dV'].shape == (B, H, N + L, D) dKV = torch.stack([cache['dK'], cache['dV']], dim=2) assert dKV.shape == (B, H, 2, N + L, D) dKV = dKV.permute(0, 3, 2, 1, 4) # B H 2 N D -> B N 2 H D return None, None, None, dQ, dKV, None return CuDNNAttention
The problem is, when I do this, I get massive numerical error. Do you have thoughts on why making dO contiguous might cause issues?
The text was updated successfully, but these errors were encountered:
Hi @vedantroy, can you try if this persists in 1.5.2 release?
We have identified some numerical issues in 1.6.0 that we will be addressing in 1.6.1
1.6.0
1.6.1
Thanks Anerudhan
Sorry, something went wrong.
@vedantroy can you check in the latest 1.6.1 release?
No branches or pull requests
I've noticed when using Pytorch's custom autograd functions, that sometimes the stride of
dO
can be(0, 0, 0, 0)
.Here's a very simple example: https://discuss.pytorch.org/t/getting-unusual-strides-when-using-pytorchs-autograd/208093.
In my custom wrapper for CudNN, I solve this my making
dO
contiguous if the stride is all zeros. Code (ctrl-f for "CHECK FOR WEIRD STRIDE"):The problem is, when I do this, I get massive numerical error. Do you have thoughts on why making
dO
contiguous might cause issues?The text was updated successfully, but these errors were encountered: