diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index f955cf5ebcad..85506a1532cc 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -369,7 +369,9 @@ def is_activation_to_checkpoint(item): Is an activation to be checkpointed """ global mp_size - return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size + extra_flag = (not hasattr(item, 'no_checkpointing')) or (hasattr(item, 'no_checkpointing') + and item.no_checkpointing == False) + return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size and extra_flag def partition_activations(args, cpu_checkpoint, contiguous_checkpoint): diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py new file mode 100644 index 000000000000..611d992d809d --- /dev/null +++ b/deepspeed/sequence/fpdt_layer.py @@ -0,0 +1,1133 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from typing import Optional, Any, Tuple +from torch import Tensor + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + +from packaging import version +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from einops import rearrange +from .layer import single_all_to_all, apply_rotary_pos_emb + + +def _rotate_half_backward(x): + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((x2, -x1), dim=-1) + + +def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): + rot_dim = freqs_cos.shape[-1] + grad, grad_pass = grad_output[..., :rot_dim], grad_output[..., rot_dim:] + grad_t = (grad * freqs_cos) + (_rotate_half_backward(grad * freqs_sin)) + grad = grad_t if grad_pass.shape[-1] == 0 else torch.cat((grad_t, grad_pass), dim=-1) + return grad + + +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + + out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + + lse = new_lse + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.permute(0, 2, 1).contiguous().unsqueeze(dim=-1).contiguous() + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +class FPDT_InputConstruct(torch.nn.Module): + + def __init__(self, tokens, labels, loss_mask, attention_mask, position_ids, args, sp_size, sp_rank) -> None: + + super(FPDT_InputConstruct, self).__init__() + self.tokens = tokens + self.labels = labels + self.loss_mask = loss_mask + self.attention_mask = attention_mask + self.position_ids = position_ids + global_seq_len = tokens.shape[1] + batch_size = tokens.shape[0] + assert global_seq_len % sp_size == 0 + assert global_seq_len % args.ds_sequence_parallel_fpdt_chunk_size == 0 + num_chunk_per_gpu = global_seq_len // args.ds_sequence_parallel_fpdt_chunk_size + local_seq_len = global_seq_len // sp_size + assert local_seq_len % num_chunk_per_gpu == 0 + + self.num_chunk_per_gpu = num_chunk_per_gpu + self.chunk_size = local_seq_len // num_chunk_per_gpu + self.sp_size = sp_size + self.sp_rank = sp_rank + self.global_seq_len = global_seq_len + self.local_seq_len = local_seq_len + self.batch_size = batch_size + self.device = tokens.device + + def generate(self): + device = self.device + totalChunks = self.global_seq_len // self.chunk_size + token_chunk_idx = torch.arange(self.global_seq_len, device=device, dtype=torch.int) // self.chunk_size + chunk_to_gpu = torch.arange(totalChunks, device=device, dtype=torch.int) + chunk_to_gpu = chunk_to_gpu.reshape(self.num_chunk_per_gpu, -1).t().contiguous() + + gather_chunk = chunk_to_gpu.flatten().unsqueeze(1).contiguous() + mask = gather_chunk == token_chunk_idx + + indices = mask.nonzero(as_tuple=False) + gather_indices = indices[:, 0] + token_chunk_indices = indices[:, 1] + indices = torch.cat([token_chunk_indices[gather_indices == i] for i in range(gather_chunk.shape[0])]) + load_balanced_loss_mask = self.loss_mask[:, indices] + + indices = indices.reshape(-1, self.chunk_size)[self.num_chunk_per_gpu * self.sp_rank:self.num_chunk_per_gpu * + (self.sp_rank + 1)].flatten().contiguous() + load_balanced_tokens = self.tokens[:, indices] + load_balanced_labels = self.labels[:, indices] + + load_balanced_attention_mask = self.attention_mask if self.attention_mask is not None else None + load_balanced_position_ids = self.position_ids[:, indices] + + return load_balanced_tokens, load_balanced_labels, load_balanced_loss_mask, load_balanced_attention_mask, load_balanced_position_ids + + +class _FPDTGPUAttentionImpl_(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + spg, + scatter_idx, + gather_idx, + hidden_size, + projection_size, + hidden_size_per_attention_head, + kv_projection_size, + qkv_linear_weight, + qkv_linear_bias, + dropout, + num_chunks=8, + cpu_offloading=True): + + do_save = layernorm_output.requires_grad + + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + ctx.pos_emb_cos = pos_emb_cos + ctx.pos_emb_sin = pos_emb_sin + + with torch.no_grad(): + per_gpu_seq_len = layernorm_output.shape[0] + chunk_size = per_gpu_seq_len // num_chunks + assert chunk_size * num_chunks == per_gpu_seq_len + assert attention_mask is None + ctx.num_chunks = num_chunks + ctx.cpu_offloading = cpu_offloading + ctx.spg = spg + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + device = get_accelerator().current_device_name() + ctx.device = device + ctx.dtype = layernorm_output.dtype + ctx.projection_size = projection_size + ctx.kv_projection_size = kv_projection_size + + global_q = [] + global_k = [] + global_v = [] + + ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) + + ctx.dropout_p = dropout + ctx.window_size = (-1, -1) + ctx.alibi_slopes = None + + batch_size = layernorm_output.shape[1] + + global_o = [None for _ in range(num_chunks)] + global_lse = [None for _ in range(num_chunks)] + + for i in range(num_chunks): + + st = chunk_size * i + ed = st + chunk_size + + qkv_chunk = torch.matmul(layernorm_output[st:ed], qkv_linear_weight.t()) + qkv_linear_bias + + q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) + global_q_chunk_len = q_chunk.shape[1] + q_chunk = apply_rotary_pos_emb(q_chunk, + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + global_q.append(q_chunk) + + k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) + k_chunk = apply_rotary_pos_emb(k_chunk, + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + global_k.append(k_chunk) + + v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) + global_v.append(v_chunk) + + for k_i in range(len(global_k)): + causal_chunk = i == k_i + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], + global_k[k_i], + global_v[k_i], + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + global_o[i], global_lse[i] = update_out_and_lse(global_o[i], global_lse[i], block_out, block_lse) + + global_o[i] = global_o[i].to(q_chunk.dtype) + + output = [None for i in range(num_chunks)] + + for i in range(num_chunks): + global_lse[i] = global_lse[i][:, :, :, 0].permute(0, 2, 1).contiguous() + output[i] = single_all_to_all(global_o[i].to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) + output = torch.cat(output, dim=1) + + head_dim = output.shape[-1] + + if do_save: + ctx.save_for_backward(layernorm_output) + ctx.global_q = global_q + ctx.global_k = global_k + ctx.global_v = global_v + ctx.attn_output = global_o + ctx.attn_lse = global_lse + ctx.head_dim = head_dim + ctx.batch_size = batch_size + + ctx.qkv_linear_weight = qkv_linear_weight + ctx.qkv_linear_bias = qkv_linear_bias + + return output + + @staticmethod + def backward(ctx, grad_output): + + num_chunks = ctx.num_chunks + device = ctx.device + dtype = ctx.dtype + spg = ctx.spg + scatter_idx = ctx.scatter_idx + gather_idx = ctx.gather_idx + softmax_scale = ctx.softmax_scale + dropout_p = ctx.dropout_p + window_size = ctx.window_size + alibi_slopes = ctx.alibi_slopes + + projection_size = ctx.projection_size + kv_projection_size = ctx.kv_projection_size + + layernorm_output = ctx.saved_tensors[0] + + global_q = ctx.global_q + global_k = ctx.global_k + global_v = ctx.global_v + attn_output = ctx.attn_output + lse = ctx.attn_lse + + qkv_linear_weight = ctx.qkv_linear_weight + qkv_linear_bias = ctx.qkv_linear_bias + + input_chunk_size = layernorm_output.shape[0] // num_chunks + grad_layernorm_output = [ + torch.zeros((input_chunk_size, layernorm_output.shape[1], layernorm_output.shape[2]), + device=device, + dtype=dtype) for _ in range(num_chunks) + ] + + grad_global_attn_output = [] + chunk_size = grad_output.shape[1] // num_chunks + + for i in range(num_chunks): + st = chunk_size * i + ed = st + chunk_size + grad_global_attn_output.append( + single_all_to_all(grad_output[:, st:ed].contiguous(), scatter_idx, gather_idx, 0, spg)) + + del grad_output + + dq = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] + dk = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] + dv = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] + + grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, + device=qkv_linear_weight.device, + dtype=torch.float) + grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) + + for i in range(num_chunks): + k_chunk = global_k[i] + v_chunk = global_v[i] + + for q_i in range(num_chunks): + no_computation = q_i < i + if no_computation: + continue + + causal_chunk = q_i == i + + q_chunk = global_q[q_i] + attn_output_chunk = attn_output[q_i] + lse_chunk = lse[q_i] + dout = grad_global_attn_output[q_i] + + dq_this = torch.zeros(global_q[0].shape, dtype=dtype, device=device) + dk_this = torch.zeros(global_k[0].shape, dtype=dtype, device=device) + dv_this = torch.zeros(global_v[0].shape, dtype=dtype, device=device) + + _flash_attn_backward(dout, + q_chunk, + k_chunk, + v_chunk, + attn_output_chunk, + lse_chunk, + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + + dq[q_i].add_(dq_this.to(torch.float)) + dk[i].add_(dk_this.to(torch.float)) + dv[i].add_(dv_this.to(torch.float)) + + dk_seq_len = dk[i].shape[1] + dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dv[i] = dv[i].to(dtype) + dk[i] = single_all_to_all(dk[i].contiguous(), gather_idx, scatter_idx, 0, spg) + dv[i] = single_all_to_all(dv[i].contiguous(), gather_idx, scatter_idx, 0, spg) + + input_st = i * input_chunk_size + input_ed = input_st + input_chunk_size + + input_chunk = layernorm_output[input_st:input_ed].reshape(-1, layernorm_output.shape[-1]) + + dk[i] = dk[i].flatten(2).permute(1, 0, 2) + dv[i] = dv[i].flatten(2).permute(1, 0, 2) + l, b = dk[i].shape[0], dk[i].shape[1] + grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( + torch.matmul(dk[i].reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( + torch.matmul(dv[i].reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk[i].sum(0).sum(0)) + grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv[i].sum(0).sum(0)) + + grad_layernorm_output[i].add_( + torch.matmul(dk[i], qkv_linear_weight[projection_size:projection_size + kv_projection_size])) + grad_layernorm_output[i].add_(torch.matmul(dv[i], + qkv_linear_weight[projection_size + kv_projection_size:])) + + dk[i] = None + dv[i] = None + + for i in range(num_chunks): + dq_seq_len = dq[i].shape[1] + dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), + ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) + + dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, 0, spg) + + input_chunk = layernorm_output[:input_chunk_size].reshape(-1, layernorm_output.shape[-1]) + layernorm_output = layernorm_output[input_chunk_size:] + + dq[i] = dq[i].flatten(2).permute(1, 0, 2) + l, b = dq[i].shape[0], dq[i].shape[1] + grad_qkv_linear_weight[:projection_size].add_(torch.matmul(dq[i].reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_bias[:projection_size].add_(dq[i].sum(0).sum(0)) + + grad_layernorm_output[i].add_(torch.matmul(dq[i], qkv_linear_weight[:projection_size])) + + dq[i] = None + + return torch.cat( + grad_layernorm_output, + dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( + dtype), grad_qkv_linear_bias.to(dtype), None, None, None + + +class SequenceChunk: + + def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): + + self.chunk_shape = chunk.shape + self.chunk_dtype = chunk.dtype + self.device = chunk.device if device is None else device + + cpu_chunk = torch.empty(chunk.shape, dtype=chunk.dtype, device='cpu', pin_memory=True) + if chunk.is_cuda: + cpu_chunk.copy_(chunk, non_blocking=True) + else: + cpu_chunk = chunk + + self.cpu_chunk = cpu_chunk + + self.gpu_chunk = chunk if is_in_use else None + + def load_to_gpu(self): + assert self.gpu_chunk is None + if self.gpu_chunk is not None: + pass + else: + gpu_chunk = torch.empty(self.chunk_shape, device=self.device, dtype=self.chunk_dtype) + gpu_chunk.copy_(self.cpu_chunk, non_blocking=True) + self.gpu_chunk = gpu_chunk + + def get_gpu_chunk(self): + assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device + return self.gpu_chunk + + def check_gpu_chunk(self, ): + assert (self.gpu_chunk is not None) and ( + self.gpu_chunk.device == self.device + ), f"gpu_chunk {self.gpu_chunk is not None} shound be on {self.device}, but it is now on {self.gpu_chunk.device}" + return True + + def offload(self): + assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device + del self.gpu_chunk + self.gpu_chunk = None + + def overwrite_to_cpu(self): + assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device + self.cpu_chunk.copy_(self.gpu_chunk, non_blocking=True) + + +class _FPDTGPUOffloadingAttentionImpl_(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + spg, + scatter_idx, + gather_idx, + hidden_size, + projection_size, + hidden_size_per_attention_head, + kv_projection_size, + qkv_linear_weight, + qkv_linear_bias, + dropout, + num_chunks=8, + cpu_offloading=True): + + do_save = layernorm_output.requires_grad + + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + ctx.pos_emb_cos = pos_emb_cos + ctx.pos_emb_sin = pos_emb_sin + with torch.no_grad(): + per_gpu_seq_len = layernorm_output.shape[0] + chunk_size = per_gpu_seq_len // num_chunks + assert chunk_size * num_chunks == per_gpu_seq_len + assert attention_mask is None + ctx.num_chunks = num_chunks + ctx.cpu_offloading = cpu_offloading + ctx.spg = spg + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + ctx.chunk_size = chunk_size + device = get_accelerator().current_device_name() + ctx.device = device + ctx.dtype = layernorm_output.dtype + ctx.projection_size = projection_size + ctx.kv_projection_size = kv_projection_size + + global_q = [] + global_k = [] + global_v = [] + + ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) + + ctx.dropout_p = dropout + ctx.window_size = (-1, -1) + ctx.alibi_slopes = None + + batch_size = layernorm_output.shape[1] + + global_o = [] + global_lse = [] + + layernorm_output_cpu = [] + final_output = [] + + offload_stream = get_accelerator().Stream() + general_offload_stream = get_accelerator().Stream() + compute_stream = get_accelerator().default_stream() + + q_compute_chunk_idx = 0 + kv_compute_chunk_idx = 0 + for i in range(num_chunks): + + qkv_chunk = torch.matmul(layernorm_output[:chunk_size], + qkv_linear_weight.t()) + qkv_linear_bias # torch.Size([18126, 1, 12288]) + + with get_accelerator().stream(general_offload_stream): + layernorm_output_cpu.append(SequenceChunk(layernorm_output[:chunk_size])) + + layernorm_output = layernorm_output[chunk_size:] + + q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) + global_q_chunk_len = q_chunk.shape[1] + + k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) + + v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) + + torch.distributed.barrier() + + pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] + pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] + + q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) + k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + with get_accelerator().stream(offload_stream): + global_q.append(SequenceChunk(q_chunk, is_in_use=True)) + global_k.append(SequenceChunk(k_chunk, is_in_use=True)) + global_v.append(SequenceChunk(v_chunk, is_in_use=True)) + + del qkv_chunk + + cur_attn_output = None + cur_attn_lse = None + for k_i in range(len(global_k)): + causal_chunk = i == k_i + with get_accelerator().stream(compute_stream): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + cur_attn_output, cur_attn_lse = update_out_and_lse(cur_attn_output, cur_attn_lse, block_out, + block_lse) + + can_offload_kv = True + if k_i != (len(global_k) - 1) or i != (num_chunks - 1): + if k_i != (len(global_k) - 1): + next_kv_compute_chunk_idx = k_i + 1 + else: + next_kv_compute_chunk_idx = 0 + + if next_kv_compute_chunk_idx == kv_compute_chunk_idx: + can_offload_kv = False + else: + if next_kv_compute_chunk_idx != (len(global_k) - 1): + with get_accelerator().stream(offload_stream): + global_k[next_kv_compute_chunk_idx].load_to_gpu() + global_v[next_kv_compute_chunk_idx].load_to_gpu() + + if i == num_chunks - 1 and k_i == num_chunks - 1: + with get_accelerator().stream(offload_stream): + global_q[0].load_to_gpu() + global_k[0].load_to_gpu() + global_v[0].load_to_gpu() + global_o[0].load_to_gpu() + global_lse[0].load_to_gpu() + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + if can_offload_kv: + global_k[kv_compute_chunk_idx].offload() + global_v[kv_compute_chunk_idx].offload() + kv_compute_chunk_idx = next_kv_compute_chunk_idx + + global_q[q_compute_chunk_idx].offload() + q_compute_chunk_idx += 1 + + all2all_output = single_all_to_all( + cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) + final_output.append(all2all_output) + with get_accelerator().stream(general_offload_stream): + global_o.append(SequenceChunk(cur_attn_output.to(ctx.dtype))) + global_lse.append(SequenceChunk(cur_attn_lse[:, :, :, 0].permute(0, 2, 1).contiguous())) + + compute_stream.wait_stream(general_offload_stream) + compute_stream.synchronize() + + final_output = torch.cat(final_output, dim=1) + + head_dim = final_output.shape[-1] + + if do_save: + ctx.layernorm_output = layernorm_output_cpu + ctx.global_q = global_q + ctx.global_k = global_k + ctx.global_v = global_v + ctx.attn_output = global_o + ctx.attn_lse = global_lse + ctx.head_dim = head_dim + ctx.batch_size = batch_size + + ctx.qkv_linear_weight = qkv_linear_weight + ctx.qkv_linear_bias = qkv_linear_bias + + return final_output + + @staticmethod + def backward(ctx, grad_output): + num_chunks = ctx.num_chunks + device = grad_output.device + dtype = ctx.dtype + spg = ctx.spg + scatter_idx = ctx.scatter_idx + gather_idx = ctx.gather_idx + softmax_scale = ctx.softmax_scale + dropout_p = ctx.dropout_p + window_size = ctx.window_size + alibi_slopes = ctx.alibi_slopes + + projection_size = ctx.projection_size + kv_projection_size = ctx.kv_projection_size + + layernorm_output = ctx.layernorm_output + + global_q = ctx.global_q + global_k = ctx.global_k + global_v = ctx.global_v + attn_output = ctx.attn_output + lse = ctx.attn_lse + + qkv_linear_weight = ctx.qkv_linear_weight + qkv_linear_bias = ctx.qkv_linear_bias + + offload_stream = get_accelerator().Stream() + general_offload_stream = get_accelerator().Stream() + compute_stream = get_accelerator().default_stream() + + chunk_size = grad_output.shape[1] // num_chunks + assert chunk_size == layernorm_output[0].cpu_chunk.shape[0] + + grad_layernorm_output = [ + torch.zeros(layernorm_output[0].chunk_shape, device=device, dtype=dtype) for _ in range(num_chunks) + ] + + grad_global_attn_output = [None for _ in range(num_chunks)] + + q_compute_chunk_idx = 0 + kv_compute_chunk_idx = 0 + last_q_accum_idx = 0 + + with get_accelerator().stream(general_offload_stream): + layernorm_output[0].load_to_gpu() + grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, + device=qkv_linear_weight.device, + dtype=torch.float) + grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, + device=qkv_linear_weight.device, + dtype=torch.float) + + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, + gather_idx, 0, spg) + get_accelerator().synchronize() + grad_output = grad_output[:, chunk_size:] + + with get_accelerator().stream(offload_stream): + grad_global_attn_output[0] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) + dq = [ + SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device=device), is_in_use=True) + ] + [ + SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device='cpu', pin_memory=True), + device) for _ in range(num_chunks - 1) + ] + dk_accum = torch.zeros(global_k[0].chunk_shape, dtype=torch.float, device=device) + dv_accum = torch.zeros(global_v[0].chunk_shape, dtype=torch.float, device=device) + + for i in range(num_chunks): + for q_i in range(num_chunks): + no_computation = q_i < i + if no_computation: + continue + + causal_chunk = q_i == i + + dq_this = torch.zeros(global_q[0].chunk_shape, dtype=dtype, device=device) + dk_this = torch.zeros(global_k[0].chunk_shape, dtype=dtype, device=device) + dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) + + with get_accelerator().stream(compute_stream): + _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + attn_output[q_compute_chunk_idx].get_gpu_chunk(), + lse[q_compute_chunk_idx].get_gpu_chunk(), + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + + if i != (len(global_k) - 1): + if q_i != (len(global_q) - 1): + next_q_compute_chunk_idx = q_i + 1 + else: + next_q_compute_chunk_idx = i + 1 + + can_offload_q = True + + if next_q_compute_chunk_idx == q_compute_chunk_idx: + can_offload_q = False + else: + with get_accelerator().stream(offload_stream): + if i > 0 or q_i > 0: + if can_offload_q and last_q_accum_idx != i: # the first q chunk calculate in the loop will be sent out, therefore we do not offload it + dq[last_q_accum_idx].offload() + dq[next_q_compute_chunk_idx].load_to_gpu() + global_q[next_q_compute_chunk_idx].load_to_gpu() + attn_output[next_q_compute_chunk_idx].load_to_gpu() + lse[next_q_compute_chunk_idx].load_to_gpu() + if grad_global_attn_output[next_q_compute_chunk_idx] is not None: + grad_global_attn_output[next_q_compute_chunk_idx].load_to_gpu() + + if grad_global_attn_output[next_q_compute_chunk_idx] is None: + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), + scatter_idx, gather_idx, 0, spg) + torch.distributed.barrier() + grad_output = grad_output[:, chunk_size:] + grad_global_attn_output[next_q_compute_chunk_idx] = SequenceChunk( + grad_global_attn_output_chunk, is_in_use=True) + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + with get_accelerator().stream(compute_stream): + dq[q_compute_chunk_idx].check_gpu_chunk() + dq[q_compute_chunk_idx].gpu_chunk.add_(dq_this) + dk_accum.add_(dk_this) + dv_accum.add_(dv_this) + + offload_stream.wait_stream(compute_stream) + with get_accelerator().stream(offload_stream): + dq[q_compute_chunk_idx].overwrite_to_cpu() + + if can_offload_q: + global_q[q_compute_chunk_idx].offload() + attn_output[q_compute_chunk_idx].offload() + lse[q_compute_chunk_idx].offload() + grad_global_attn_output[q_compute_chunk_idx].offload() + + last_q_accum_idx = q_compute_chunk_idx + q_compute_chunk_idx = next_q_compute_chunk_idx + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + dk_seq_len = dk_accum.shape[1] + dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dv_accum = dv_accum.to(dtype) + + dq_accum = single_all_to_all(dq_accum.contiguous(), gather_idx, scatter_idx, 0, spg) + dk_accum = single_all_to_all(dk_accum.contiguous(), gather_idx, scatter_idx, 0, spg) + dv_accum = single_all_to_all(dv_accum.contiguous(), gather_idx, scatter_idx, 0, spg) + + general_offload_stream.synchronize() + compute_stream.wait_stream(general_offload_stream) + torch.distributed.barrier() + + with get_accelerator().stream(compute_stream): + input_chunk = layernorm_output[i].get_gpu_chunk().reshape(-1, layernorm_output[i].chunk_shape[-1]) + + dq_accum = dq_accum.flatten(2).permute(1, 0, 2) + dk_accum = dk_accum.flatten(2).permute(1, 0, 2) + dv_accum = dv_accum.flatten(2).permute(1, 0, 2) + + l, b = dk_accum.shape[0], dk_accum.shape[1] + + grad_qkv_linear_weight[:projection_size].add_( + torch.matmul(dq_accum.reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( + torch.matmul(dk_accum.reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( + torch.matmul(dv_accum.reshape(l * b, -1).t(), input_chunk)) + + grad_qkv_linear_bias[:projection_size].add_(dq_accum.sum(0).sum(0)) + grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk_accum.sum(0).sum(0)) + grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv_accum.sum(0).sum(0)) + + grad_layernorm_output[i].add_(torch.matmul(dq_accum, qkv_linear_weight[:projection_size])) + grad_layernorm_output[i].add_( + torch.matmul(dk_accum, qkv_linear_weight[projection_size:projection_size + kv_projection_size])) + grad_layernorm_output[i].add_( + torch.matmul(dv_accum, qkv_linear_weight[projection_size + kv_projection_size:])) + + del dq_accum, dk_accum, dv_accum + dk_accum = torch.zeros(global_k[i].chunk_shape, dtype=torch.float, device=device) + dv_accum = torch.zeros(global_v[i].chunk_shape, dtype=torch.float, device=device) + dq[kv_compute_chunk_idx].offload() + dq[kv_compute_chunk_idx] = None + + if i != (len(global_k) - 1): + next_kv_compute_chunk_idx = kv_compute_chunk_idx + 1 + with get_accelerator().stream(offload_stream): + global_k[next_kv_compute_chunk_idx].load_to_gpu() + global_v[next_kv_compute_chunk_idx].load_to_gpu() + + with get_accelerator().stream(general_offload_stream): + layernorm_output[next_kv_compute_chunk_idx].load_to_gpu() + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + layernorm_output[kv_compute_chunk_idx].offload() + global_k[kv_compute_chunk_idx].offload() + global_v[kv_compute_chunk_idx].offload() + kv_compute_chunk_idx = next_kv_compute_chunk_idx + + return torch.cat( + grad_layernorm_output, + dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( + dtype), grad_qkv_linear_bias.to(dtype), None, None, None + + +class FPDT_Attention(torch.nn.Module): + + def __init__(self, + config, + first_weight, + first_bias, + second_weight, + second_bias, + sequence_process_group, + gather_idx: int = 0, + scatter_idx: int = 2, + return_bias=True, + chunk_size=65536, + enable_offloading=True) -> None: + + super(FPDT_Attention, self).__init__() + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.config = config + + self.projection_size = config.kv_channels * config.num_attention_heads + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.kv_projection_size = config.kv_channels * config.num_key_value_heads + self.hidden_size = config.hidden_size + + self.qkv_linear_weight = first_weight + self.qkv_linear_bias = first_bias + self.qkv_dense_weight = second_weight + self.qkv_dense_bias = second_bias + + self.reture_bias = return_bias + self.dropout = config.attention_dropout + + self.chunk_size = chunk_size + self.double_buffer = enable_offloading + + def forward(self, + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + cpu_offloading=True) -> Tensor: + self.num_chunks_attn = layernorm_output.shape[0] * dist.get_world_size(self.spg) // self.chunk_size + + if not cpu_offloading: + output = _FPDTGPUAttentionImpl_.apply(layernorm_output, attention_mask, inference_params, rotary_pos_emb, + self.spg, self.scatter_idx, self.gather_idx, self.hidden_size, + self.projection_size, self.hidden_size_per_attention_head, + self.kv_projection_size, self.qkv_linear_weight, + self.qkv_linear_bias, self.dropout, self.num_chunks_attn, + cpu_offloading) + else: + output = _FPDTGPUOffloadingAttentionImpl_.apply( + layernorm_output, attention_mask, inference_params, rotary_pos_emb, self.spg, self.scatter_idx, + self.gather_idx, self.hidden_size, self.projection_size, self.hidden_size_per_attention_head, + self.kv_projection_size, self.qkv_linear_weight, self.qkv_linear_bias, self.dropout, + self.num_chunks_attn, cpu_offloading) + + output = output.flatten(2).permute(1, 0, 2).contiguous() + + output = torch.matmul(output, self.qkv_dense_weight.t()) + if not self.reture_bias: + output += self.qkv_dense_bias + return output, self.qkv_dense_bias if self.reture_bias else None + + +@torch.jit.script +def bias_gelu(x): + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + +@torch.jit.script +def bias_gelu_back(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * g + + +class FPDT_FFN(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, x, w1, b1, w2, b2, add_bias, chunk_size): + do_save = x.requires_grad + ctx.add_bias = add_bias + device = x.device + + with torch.no_grad(): + num_chunk = x.shape[0] // chunk_size + ctx.num_chunk = num_chunk + result = torch.empty(x.shape, device=device, dtype=x.dtype) + assert chunk_size * num_chunk == x.shape[0] + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + x_ = torch.matmul(x[st:ed], w1.t()) + b1 + x_ = bias_gelu(x_) + if add_bias: + result[st:ed] = torch.matmul(x_, w2.t()) + b2 + else: + result[st:ed] = torch.matmul(x_, w2.t()) + + del x_ + + if do_save: + ctx.device = device + ctx.dtype = x.dtype + ctx.save_for_backward(x, w1, b1, w2, b2) + ctx.grad_x_shape = x.shape + return result.to(x.dtype), b2 if not add_bias else None + + @staticmethod + def backward(ctx, grad_output, grad_bias): + x, w1, b1, w2, b2 = ctx.saved_tensors + device = ctx.device + dtype = ctx.dtype + add_bias = ctx.add_bias + + num_chunk = ctx.num_chunk + chunk_size = x.shape[0] // num_chunk + assert chunk_size * num_chunk == grad_output.shape[0] + + grad_w2 = torch.zeros(w2.shape, device=device, dtype=torch.float) + grad_b2 = torch.zeros(b2.shape, device=device, dtype=torch.float) + grad_w1 = torch.zeros(w1.shape, device=device, dtype=torch.float) + grad_b1 = torch.zeros(b1.shape, device=device, dtype=torch.float) + + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + x_chunk = x[st:ed] + + before_act = (torch.matmul(x_chunk, w1.t()) + b1) + before_act_2 = before_act**2 + tanh_out = torch.tanh(0.79788456 * before_act * (1 + 0.044715 * before_act_2)) + ff = 0.5 * before_act * ((1 - tanh_out * tanh_out) * + (0.79788456 + 0.1070322243 * before_act_2)) + 0.5 * (1 + tanh_out) + grad_w2.add_( + torch.matmul(grad_output[st:ed].reshape(-1, grad_output.shape[2]).t(), + (before_act * 0.5 * (1 + tanh_out)).reshape(-1, before_act.shape[2]))) + del before_act, before_act_2, tanh_out + + grad_inter = torch.matmul(grad_output[st:ed], w2) * ff + del ff + + grad_w1.add_(torch.matmul( + grad_inter.reshape(-1, grad_inter.shape[2]).t(), x_chunk.reshape(-1, x.shape[2]))) + grad_b1.add_(grad_inter.sum(0).sum(0)) + + x[st:ed].copy_(torch.matmul(grad_inter, w1)) + + del grad_inter + + if add_bias: + grad_b2.add_(grad_output[st:ed].sum(0).sum(0)) + + return x, grad_w1.to(dtype), grad_b1.to(dtype), grad_w2.to(dtype), grad_b2.to(dtype), None, None + + +class FPDT_LogitsLoss(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, lm_output, labels, logit_weights, rank, spg_size, spg, num_chunk): + labels = labels.t() + chunk_size = lm_output.shape[0] // num_chunk + assert chunk_size * num_chunk == lm_output.shape[0] + batch_size, local_seq_len = lm_output.shape[1], lm_output.shape[0] + loss = torch.empty((batch_size, local_seq_len), dtype=torch.float, device=lm_output.device) + + ctx.num_chunk = num_chunk + ctx.chunk_size = chunk_size + ctx.device = lm_output.device + ctx.dtype = lm_output.dtype + + ctx.rank = rank + ctx.local_seq_len = local_seq_len + with torch.no_grad(): + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + logits_chunk = torch.matmul(lm_output[st:ed], logit_weights.t()).float() + + vocab_size = logits_chunk.size(2) + # nll + softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) + loss_chunk = torch.nn.functional.nll_loss(softmax.log().reshape(-1, vocab_size).contiguous(), + labels[st:ed, :].reshape(-1).contiguous(), + reduction='none') + loss[:, st:ed] = loss_chunk.reshape(chunk_size, batch_size).t() + + del logits_chunk + ctx.save_for_backward(lm_output.to('cpu'), labels) + ctx.logit_weights = logit_weights + + seqlen = local_seq_len * spg_size + batch_size = loss.size(0) + loss = loss.t().contiguous() + loss_all = torch.empty(seqlen, batch_size, dtype=loss.dtype, device=loss.device).contiguous() + + if version.parse(torch.__version__) >= version.parse('1.13'): + torch.distributed.all_gather_into_tensor(loss_all, loss, group=spg) + else: + torch.distributed._all_gather_base(loss_all, loss, group=spg) + + return loss_all + + @staticmethod + def backward(ctx, grad_output): + lm_output, labels = ctx.saved_tensors + logit_weights = ctx.logit_weights + device = ctx.device + dtype = ctx.dtype + num_chunk = ctx.num_chunk + chunk_size = ctx.chunk_size + + rank = ctx.rank + local_seq_len = ctx.local_seq_len + + grad_output = grad_output[rank * local_seq_len:(rank + 1) * local_seq_len] + grad_lm_output = [None for _ in range(num_chunk)] + grad_logit_weights = torch.zeros(logit_weights.shape, device=grad_output.device, dtype=torch.float) + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + lm_output_chunk = lm_output[st:ed].to(device) + logits_chunk = torch.matmul(lm_output_chunk, logit_weights.t()).float() + + # nll + softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) + vocab_size = logits_chunk.size(2) + + grad_input = softmax + grad_2d = grad_input.reshape(-1, vocab_size).contiguous() + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=device) + + grad_2d[arange_1d, labels[st:ed, :].reshape(-1).contiguous()] -= 1 + grad_input.mul_(grad_output[:chunk_size, :].unsqueeze(dim=-1)) + grad_input = grad_input.to(dtype) + + grad_output = grad_output[chunk_size:].contiguous() + + grad_lm_output_chunk = torch.matmul(grad_input, logit_weights) + grad_lm_output[i] = grad_lm_output_chunk + + grad_logit_weights.add_( + torch.matmul( + grad_input.reshape(-1, grad_input.shape[2]).t(), + lm_output_chunk.reshape(-1, lm_output_chunk.shape[2]))) + + return torch.cat(grad_lm_output, dim=0).to(dtype), None, grad_logit_weights.to(dtype), None, None, None, None diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index e809fe1118b5..7e86582df595 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -10,6 +10,34 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator +from einops import rearrange + + +def _rotate_half(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): + """ + input tensor t is of shape [seq_length, ..., dim] + rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] + check https://kexue.fm/archives/8265 for detailed formulas + """ + rot_dim = freqs_cos.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin) + + res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1) + return res def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim): @@ -178,7 +206,14 @@ def layer_sync(self, layer): if self.sp_overlap_comm and hasattr(layer, 'done_event'): self.dafult_stream.wait_event(layer.done_event) - def forward(self, query: Tensor, key: Tensor, value: Tensor, batch_dim_idx: int, *args: Any, **kwargs) -> Tensor: + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + batch_dim_idx: int, + rotary_pos_emb=None, + *args: Any, + **kwargs) -> Tensor: """ forward Arguments: @@ -233,6 +268,10 @@ def pre_hook_fun(grad): grad_fn_k.register_prehook(bwd_hook(layer_type='k')) #out shape : e.g., [s:h/p:] + if rotary_pos_emb is not None: + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + query_layer = apply_rotary_pos_emb(query_layer, pos_emb_cos, pos_emb_sin) + key_layer = apply_rotary_pos_emb(key_layer, pos_emb_cos, pos_emb_sin) context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)