|
| 1 | +from functools import reduce |
| 2 | +from typing import Any, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import Tensor |
| 6 | +from torch.cuda.amp import custom_bwd, custom_fwd |
| 7 | + |
| 8 | +try: |
| 9 | + import triton |
| 10 | + import triton.language as tl |
| 11 | + HAS_TRITON = True |
| 12 | +except ImportError: |
| 13 | + HAS_TRITON = False |
| 14 | + print("please install triton from https://github.com/openai/triton") |
| 15 | + |
| 16 | +if HAS_TRITON: |
| 17 | + PRECISION_MAP = { |
| 18 | + "fp32": (0, torch.float32), |
| 19 | + "fp16": (1, torch.float16), |
| 20 | + "bf16": (2, torch.bfloat16), |
| 21 | + } |
| 22 | + |
| 23 | + @triton.jit |
| 24 | + def _llama_act_combine_forward( |
| 25 | + X_GATE1, |
| 26 | + X_GATE2, |
| 27 | + X_UP, |
| 28 | + Y, |
| 29 | + stride, # how much to increase the pointer when moving by 1 row |
| 30 | + N, # number of columns in X |
| 31 | + BLOCK_SIZE: tl.constexpr, |
| 32 | + ): |
| 33 | + # Map the program id to the row of X and Y it should compute. |
| 34 | + row = tl.program_id(0) |
| 35 | + X_GATE1 += row * stride |
| 36 | + X_GATE2 += row * stride |
| 37 | + X_UP += row * stride |
| 38 | + Y += row * stride |
| 39 | + |
| 40 | + # do activation and combine, and store in y |
| 41 | + for off in range(0, N, BLOCK_SIZE): |
| 42 | + cols = off + tl.arange(0, BLOCK_SIZE) |
| 43 | + mask = cols < N |
| 44 | + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) |
| 45 | + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) |
| 46 | + x_up = tl.load(X_UP + cols, mask=mask, other=0.) |
| 47 | + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) |
| 48 | + y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up |
| 49 | + # Write output |
| 50 | + tl.store(Y + cols, y, mask=mask) |
| 51 | + |
| 52 | + @triton.jit |
| 53 | + def _llama_act_combine_backward( |
| 54 | + X_GATE1, |
| 55 | + X_GATE2, |
| 56 | + X_UP, |
| 57 | + X_GATE1_GRAD, |
| 58 | + X_GATE2_GRAD, |
| 59 | + X_UP_GRAD, |
| 60 | + Y_GRAD, |
| 61 | + stride, # how much to increase the pointer when moving by 1 row |
| 62 | + N, # number of columns in X |
| 63 | + BLOCK_SIZE: tl.constexpr, |
| 64 | + ): |
| 65 | + # Map the program id to the row of X and Y it should compute. |
| 66 | + row = tl.program_id(0) |
| 67 | + X_GATE1 += row * stride |
| 68 | + X_GATE2 += row * stride |
| 69 | + X_UP += row * stride |
| 70 | + X_GATE1_GRAD += row * stride |
| 71 | + X_GATE2_GRAD += row * stride |
| 72 | + X_UP_GRAD += row * stride |
| 73 | + Y_GRAD += row * stride |
| 74 | + |
| 75 | + # do activation and combine, and store in y |
| 76 | + for off in range(0, N, BLOCK_SIZE): |
| 77 | + cols = off + tl.arange(0, BLOCK_SIZE) |
| 78 | + mask = cols < N |
| 79 | + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) |
| 80 | + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) |
| 81 | + x_up = tl.load(X_UP + cols, mask=mask, other=0.) |
| 82 | + y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.) |
| 83 | + |
| 84 | + # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up |
| 85 | + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) |
| 86 | + x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid |
| 87 | + x_up_grad = x_gate2_act * x_gate1 |
| 88 | + x_gate1_grad = x_gate2_act * x_up |
| 89 | + # grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)] |
| 90 | + # = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]} |
| 91 | + x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid)) |
| 92 | + |
| 93 | + # Write output |
| 94 | + tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask) |
| 95 | + tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask) |
| 96 | + tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask) |
| 97 | + |
| 98 | + class LlamaActCombine(torch.autograd.Function): |
| 99 | + """ |
| 100 | + act(x_gate) * x_up |
| 101 | +
|
| 102 | + Args: |
| 103 | + x_gate (torch.Tensor): (b, l, 2d) x_gate |
| 104 | + x_up (torch.Tensor): (b, l, d) x_up |
| 105 | + activation (str): only support swiglu |
| 106 | + precision (str): fp32, fp16, bf16 |
| 107 | + """ |
| 108 | + |
| 109 | + @staticmethod |
| 110 | + @custom_fwd |
| 111 | + def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor: |
| 112 | + """ |
| 113 | + act(x_gate) * x_up |
| 114 | +
|
| 115 | + Args: |
| 116 | + x_gate (torch.Tensor): (b, l, 2d) x gate |
| 117 | + x_up (torch.Tensor): (b, l, d) x up |
| 118 | + activation (str): only support swiglu |
| 119 | + """ |
| 120 | + assert activation == "swiglu", "Only swiglu is supported" |
| 121 | + |
| 122 | + # split x gate |
| 123 | + assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2" |
| 124 | + x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1) |
| 125 | + x_gate1 = x_gate1.contiguous() |
| 126 | + x_gate2 = x_gate2.contiguous() |
| 127 | + if not x_up.is_contiguous(): |
| 128 | + x_up = x_up.contiguous() |
| 129 | + # assert shape |
| 130 | + assert x_gate1.shape == x_gate2.shape == x_up.shape |
| 131 | + |
| 132 | + # add ctx for backward |
| 133 | + if x_gate.requires_grad: |
| 134 | + ctx.save_for_backward(x_gate1, x_gate2, x_up) |
| 135 | + |
| 136 | + # allocate output |
| 137 | + y = torch.empty_like(x_up) |
| 138 | + M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1] |
| 139 | + |
| 140 | + # Less than 64KB per feature: enqueue fused kernel |
| 141 | + MAX_FUSED_SIZE = 65536 // x_gate.element_size() |
| 142 | + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) |
| 143 | + if N > BLOCK_SIZE: |
| 144 | + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") |
| 145 | + # heuristics for number of warps |
| 146 | + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) |
| 147 | + # restore setting |
| 148 | + ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps |
| 149 | + # enqueue kernel |
| 150 | + _llama_act_combine_forward[(M,)](x_gate1, |
| 151 | + x_gate2, |
| 152 | + x_up, |
| 153 | + y, |
| 154 | + x_up.stride(-2), |
| 155 | + N, |
| 156 | + BLOCK_SIZE=BLOCK_SIZE, |
| 157 | + num_warps=num_warps) |
| 158 | + return y |
| 159 | + |
| 160 | + @staticmethod |
| 161 | + @custom_bwd |
| 162 | + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]: |
| 163 | + # restore from ctx |
| 164 | + (x_gate1, x_gate2, x_up) = ctx.saved_tensors |
| 165 | + M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps |
| 166 | + |
| 167 | + # init grad |
| 168 | + y_grad = grad_outputs[0] |
| 169 | + x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like( |
| 170 | + x_gate2), torch.empty_like(x_up) |
| 171 | + |
| 172 | + # enqueue kernel |
| 173 | + _llama_act_combine_backward[(M,)](x_gate1, |
| 174 | + x_gate2, |
| 175 | + x_up, |
| 176 | + x_gate1_grad, |
| 177 | + x_gate2_grad, |
| 178 | + x_up_grad, |
| 179 | + y_grad, |
| 180 | + x_up.stride(-2), |
| 181 | + N, |
| 182 | + BLOCK_SIZE=BLOCK_SIZE, |
| 183 | + num_warps=num_warps) |
| 184 | + x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1) |
| 185 | + return x_gate_grad, x_up_grad, None, None |
0 commit comments