|
| 1 | +"""1D GPT-2 model compatible with HuggingFace weights.""" |
| 2 | +from typing import Dict, List, Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import nn |
| 6 | +from transformers import GPT2Config |
| 7 | + |
| 8 | +from cacheflow.models import InputMetadata |
| 9 | +from cacheflow.models.attention import GPTCacheFlowAttention |
| 10 | +from cacheflow.models.sample import Sampler |
| 11 | +from cacheflow.models.utils import (hf_model_weights_iterator, |
| 12 | + load_tensor_parallel_weights) |
| 13 | +from cacheflow.parallel_utils.parallel_state import ( |
| 14 | + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) |
| 15 | +from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, |
| 16 | + ColumnParallelLinear, |
| 17 | + RowParallelLinear) |
| 18 | +from cacheflow.sequence import SequenceOutputs |
| 19 | + |
| 20 | +KVCache = Tuple[torch.Tensor, torch.Tensor] |
| 21 | + |
| 22 | + |
| 23 | +class GPT2Attention(nn.Module): |
| 24 | + |
| 25 | + def __init__(self, config: GPT2Config): |
| 26 | + super().__init__() |
| 27 | + self.hidden_size = config.hidden_size |
| 28 | + total_num_heads = config.num_attention_heads |
| 29 | + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() |
| 30 | + assert total_num_heads % tensor_model_parallel_world_size == 0 |
| 31 | + self.num_heads = total_num_heads // tensor_model_parallel_world_size |
| 32 | + self.head_dim = self.hidden_size // total_num_heads |
| 33 | + self.scale = self.head_dim ** -0.5 |
| 34 | + |
| 35 | + self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, bias=True, |
| 36 | + gather_output=False, |
| 37 | + perform_initialization=False) |
| 38 | + self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True, |
| 39 | + input_is_parallel=True, |
| 40 | + perform_initialization=False) |
| 41 | + self.attn = GPTCacheFlowAttention(scale=self.scale) |
| 42 | + |
| 43 | + def forward( |
| 44 | + self, |
| 45 | + hidden_states: torch.Tensor, |
| 46 | + kv_cache: KVCache, |
| 47 | + input_metadata: InputMetadata, |
| 48 | + cache_event: Optional[torch.cuda.Event], |
| 49 | + ) -> torch.Tensor: |
| 50 | + qkv, _ = self.c_attn(hidden_states) |
| 51 | + q, k, v = qkv.chunk(chunks=3, dim=-1) |
| 52 | + key_cache, value_cache = kv_cache |
| 53 | + attn_output = self.attn( |
| 54 | + q, k, v, key_cache, value_cache, input_metadata, cache_event) |
| 55 | + attn_output, _ = self.c_proj(attn_output) |
| 56 | + return attn_output |
| 57 | + |
| 58 | + |
| 59 | +class GPT2MLP(nn.Module): |
| 60 | + |
| 61 | + def __init__( |
| 62 | + self, |
| 63 | + intermediate_size: int, |
| 64 | + config: GPT2Config, |
| 65 | + ): |
| 66 | + super().__init__() |
| 67 | + hidden_size = config.hidden_size |
| 68 | + self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, |
| 69 | + bias=True, gather_output=False, |
| 70 | + perform_initialization=False) |
| 71 | + self.c_proj = RowParallelLinear(intermediate_size, hidden_size, |
| 72 | + bias=True, input_is_parallel=True, |
| 73 | + perform_initialization=False) |
| 74 | + |
| 75 | + act_fn = config.activation_function |
| 76 | + if act_fn != "gelu_new": |
| 77 | + raise ValueError(f"Unsupported activation: {act_fn}. " |
| 78 | + "GPT-2 only supports gelu_new for now.") |
| 79 | + self.act = torch.nn.GELU(approximate="tanh") |
| 80 | + |
| 81 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 82 | + hidden_states, _ = self.c_fc(hidden_states) |
| 83 | + hidden_states = self.act(hidden_states) |
| 84 | + hidden_states, _ = self.c_proj(hidden_states) |
| 85 | + return hidden_states |
| 86 | + |
| 87 | + |
| 88 | +class GPT2Block(nn.Module): |
| 89 | + |
| 90 | + def __init__(self, config: GPT2Config): |
| 91 | + super().__init__() |
| 92 | + hidden_size = config.hidden_size |
| 93 | + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size |
| 94 | + |
| 95 | + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| 96 | + self.attn = GPT2Attention(config) |
| 97 | + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| 98 | + self.mlp = GPT2MLP(inner_dim, config) |
| 99 | + |
| 100 | + def forward( |
| 101 | + self, |
| 102 | + hidden_states: torch.Tensor, |
| 103 | + kv_cache: KVCache, |
| 104 | + input_metadata: InputMetadata, |
| 105 | + cache_event: Optional[torch.cuda.Event], |
| 106 | + ) -> torch.Tensor: |
| 107 | + residual = hidden_states |
| 108 | + hidden_states = self.ln_1(hidden_states) |
| 109 | + attn_output = self.attn( |
| 110 | + hidden_states=hidden_states, |
| 111 | + kv_cache=kv_cache, |
| 112 | + input_metadata=input_metadata, |
| 113 | + cache_event=cache_event, |
| 114 | + ) |
| 115 | + # residual connection |
| 116 | + hidden_states = attn_output + residual |
| 117 | + |
| 118 | + residual = hidden_states |
| 119 | + hidden_states = self.ln_2(hidden_states) |
| 120 | + feed_forward_hidden_states = self.mlp(hidden_states) |
| 121 | + # residual connection |
| 122 | + hidden_states = residual + feed_forward_hidden_states |
| 123 | + return hidden_states |
| 124 | + |
| 125 | + |
| 126 | +class GPT2Model(nn.Module): |
| 127 | + |
| 128 | + def __init__(self, config: GPT2Config): |
| 129 | + super().__init__() |
| 130 | + self.config = config |
| 131 | + assert config.add_cross_attention == False |
| 132 | + assert config.scale_attn_by_inverse_layer_idx == False |
| 133 | + assert config.reorder_and_upcast_attn == False |
| 134 | + self.embed_dim = config.hidden_size |
| 135 | + |
| 136 | + # Optimization: While the vocab size of GPT-2 is 50257, we extend it |
| 137 | + # to 50304 in order to make it divisible by 64. |
| 138 | + # This improves performance since GPUs are faster if the dimension |
| 139 | + # is divisible by 64. In addition, it allows us to shard the embedding |
| 140 | + # layer across 2, 4, 8, or more GPUs. |
| 141 | + vocab_size = ((config.vocab_size + 63) // 64) * 64 |
| 142 | + self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim) |
| 143 | + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) |
| 144 | + self.h = nn.ModuleList( |
| 145 | + [GPT2Block(config) for _ in range(config.num_hidden_layers)]) |
| 146 | + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) |
| 147 | + |
| 148 | + def forward( |
| 149 | + self, |
| 150 | + input_ids: torch.LongTensor, |
| 151 | + position_ids: torch.LongTensor, |
| 152 | + kv_caches: List[KVCache], |
| 153 | + input_metadata: InputMetadata, |
| 154 | + cache_events: Optional[List[torch.cuda.Event]], |
| 155 | + ) -> torch.Tensor: |
| 156 | + inputs_embeds = self.wte(input_ids) |
| 157 | + position_embeds = self.wpe(position_ids) |
| 158 | + hidden_states = inputs_embeds + position_embeds |
| 159 | + |
| 160 | + for i in range(len(self.h)): |
| 161 | + if cache_events is None: |
| 162 | + cache_event = None |
| 163 | + else: |
| 164 | + cache_event = cache_events[i] |
| 165 | + layer = self.h[i] |
| 166 | + hidden_states = layer( |
| 167 | + hidden_states, kv_caches[i], input_metadata, cache_event) |
| 168 | + |
| 169 | + hidden_states = self.ln_f(hidden_states) |
| 170 | + return hidden_states |
| 171 | + |
| 172 | + |
| 173 | +class GPT2LMHeadModel(nn.Module): |
| 174 | + |
| 175 | + def __init__(self, config: GPT2Config): |
| 176 | + super().__init__() |
| 177 | + self.config = config |
| 178 | + self.transformer = GPT2Model(config) |
| 179 | + # TODO(zhuohan): create a new weight after implementing pipeline |
| 180 | + # parallelism |
| 181 | + self.lm_head_weight = self.transformer.wte.weight |
| 182 | + self.sampler = Sampler(config.vocab_size) |
| 183 | + |
| 184 | + def forward( |
| 185 | + self, |
| 186 | + input_ids: torch.LongTensor, |
| 187 | + positions: torch.LongTensor, |
| 188 | + kv_caches: List[KVCache], |
| 189 | + input_metadata: InputMetadata, |
| 190 | + cache_events: Optional[List[torch.cuda.Event]], |
| 191 | + ) -> Dict[int, SequenceOutputs]: |
| 192 | + hidden_states = self.transformer( |
| 193 | + input_ids, positions, kv_caches, input_metadata, cache_events) |
| 194 | + next_tokens = self.sampler( |
| 195 | + self.lm_head_weight, hidden_states, input_metadata) |
| 196 | + return next_tokens |
| 197 | + |
| 198 | + _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] |
| 199 | + _row_parallel_weights = ["c_proj.weight"] |
| 200 | + |
| 201 | + def load_weights(self, model_name_or_path: str, |
| 202 | + cache_dir: Optional[str] = None, |
| 203 | + use_np_cache: bool = False): |
| 204 | + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() |
| 205 | + tensor_model_parallel_rank = get_tensor_model_parallel_rank() |
| 206 | + state_dict = self.state_dict() |
| 207 | + |
| 208 | + for name, loaded_weight in hf_model_weights_iterator( |
| 209 | + model_name_or_path, cache_dir, use_np_cache): |
| 210 | + if "lm_head.weight" in name: |
| 211 | + # GPT-2 ties the weights of the embedding layer and the final |
| 212 | + # linear layer. |
| 213 | + continue |
| 214 | + if ".attn.bias" in name: |
| 215 | + # Skip attention mask. |
| 216 | + # NOTE: "c_attn.bias" should not be skipped. |
| 217 | + continue |
| 218 | + name = "transformer." + name |
| 219 | + |
| 220 | + # The HF's GPT-2 implementation uses Conv1D instead of Linear. |
| 221 | + # Because of this, we need to transpose the weights. |
| 222 | + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: |
| 223 | + if conv1d_weight_name not in name: |
| 224 | + continue |
| 225 | + if not name.endswith(".weight"): |
| 226 | + continue |
| 227 | + loaded_weight = loaded_weight.t() |
| 228 | + param = state_dict[name] |
| 229 | + |
| 230 | + if name == "transformer.wte.weight": |
| 231 | + # Consider padding in the vocab size. |
| 232 | + padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size |
| 233 | + num_extra_rows = padded_vocab_size - self.config.vocab_size |
| 234 | + extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) |
| 235 | + extra_rows = extra_rows.to(loaded_weight) |
| 236 | + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) |
| 237 | + |
| 238 | + # For the fused QKV linear layer, manually shard the weights. |
| 239 | + if "c_attn" in name: |
| 240 | + # GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. |
| 241 | + # When tensor parallelism is used, we shard the weights along the head dimension. |
| 242 | + total_num_heads = self.config.num_attention_heads |
| 243 | + hidden_size = self.config.hidden_size |
| 244 | + head_size = hidden_size // total_num_heads |
| 245 | + num_heads = total_num_heads // tensor_model_parallel_world_size |
| 246 | + head_start = tensor_model_parallel_rank * num_heads |
| 247 | + head_end = (tensor_model_parallel_rank + 1) * num_heads |
| 248 | + |
| 249 | + if name.endswith(".weight"): |
| 250 | + loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) |
| 251 | + loaded_weight = loaded_weight[:, head_start:head_end, :, :] |
| 252 | + loaded_weight = loaded_weight.reshape(-1, hidden_size) |
| 253 | + elif name.endswith(".bias"): |
| 254 | + loaded_weight = loaded_weight.view(3, total_num_heads, head_size) |
| 255 | + loaded_weight = loaded_weight[:, head_start:head_end, :] |
| 256 | + loaded_weight = loaded_weight.reshape(-1) |
| 257 | + else: |
| 258 | + raise ValueError(f"Unexpected parameter name {name}") |
| 259 | + load_tensor_parallel_weights(param, loaded_weight, name, |
| 260 | + self._column_parallel_weights, |
| 261 | + self._row_parallel_weights) |
| 262 | + |
| 263 | + def initialize_dummy_weights(self) -> None: |
| 264 | + for param in self.state_dict().values(): |
| 265 | + param.data.uniform_(-1e-3, 1e-3) |
0 commit comments