Skip to content

Commit e548c14

Browse files
authored
Add support for GPT-2 (#60)
1 parent 130d5fd commit e548c14

File tree

7 files changed

+350
-8
lines changed

7 files changed

+350
-8
lines changed

cacheflow/models/gpt2.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
"""1D GPT-2 model compatible with HuggingFace weights."""
2+
from typing import Dict, List, Optional, Tuple
3+
4+
import torch
5+
from torch import nn
6+
from transformers import GPT2Config
7+
8+
from cacheflow.models import InputMetadata
9+
from cacheflow.models.attention import GPTCacheFlowAttention
10+
from cacheflow.models.sample import Sampler
11+
from cacheflow.models.utils import (hf_model_weights_iterator,
12+
load_tensor_parallel_weights)
13+
from cacheflow.parallel_utils.parallel_state import (
14+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
15+
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
16+
ColumnParallelLinear,
17+
RowParallelLinear)
18+
from cacheflow.sequence import SequenceOutputs
19+
20+
KVCache = Tuple[torch.Tensor, torch.Tensor]
21+
22+
23+
class GPT2Attention(nn.Module):
24+
25+
def __init__(self, config: GPT2Config):
26+
super().__init__()
27+
self.hidden_size = config.hidden_size
28+
total_num_heads = config.num_attention_heads
29+
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
30+
assert total_num_heads % tensor_model_parallel_world_size == 0
31+
self.num_heads = total_num_heads // tensor_model_parallel_world_size
32+
self.head_dim = self.hidden_size // total_num_heads
33+
self.scale = self.head_dim ** -0.5
34+
35+
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, bias=True,
36+
gather_output=False,
37+
perform_initialization=False)
38+
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
39+
input_is_parallel=True,
40+
perform_initialization=False)
41+
self.attn = GPTCacheFlowAttention(scale=self.scale)
42+
43+
def forward(
44+
self,
45+
hidden_states: torch.Tensor,
46+
kv_cache: KVCache,
47+
input_metadata: InputMetadata,
48+
cache_event: Optional[torch.cuda.Event],
49+
) -> torch.Tensor:
50+
qkv, _ = self.c_attn(hidden_states)
51+
q, k, v = qkv.chunk(chunks=3, dim=-1)
52+
key_cache, value_cache = kv_cache
53+
attn_output = self.attn(
54+
q, k, v, key_cache, value_cache, input_metadata, cache_event)
55+
attn_output, _ = self.c_proj(attn_output)
56+
return attn_output
57+
58+
59+
class GPT2MLP(nn.Module):
60+
61+
def __init__(
62+
self,
63+
intermediate_size: int,
64+
config: GPT2Config,
65+
):
66+
super().__init__()
67+
hidden_size = config.hidden_size
68+
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
69+
bias=True, gather_output=False,
70+
perform_initialization=False)
71+
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
72+
bias=True, input_is_parallel=True,
73+
perform_initialization=False)
74+
75+
act_fn = config.activation_function
76+
if act_fn != "gelu_new":
77+
raise ValueError(f"Unsupported activation: {act_fn}. "
78+
"GPT-2 only supports gelu_new for now.")
79+
self.act = torch.nn.GELU(approximate="tanh")
80+
81+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
82+
hidden_states, _ = self.c_fc(hidden_states)
83+
hidden_states = self.act(hidden_states)
84+
hidden_states, _ = self.c_proj(hidden_states)
85+
return hidden_states
86+
87+
88+
class GPT2Block(nn.Module):
89+
90+
def __init__(self, config: GPT2Config):
91+
super().__init__()
92+
hidden_size = config.hidden_size
93+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
94+
95+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
96+
self.attn = GPT2Attention(config)
97+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
98+
self.mlp = GPT2MLP(inner_dim, config)
99+
100+
def forward(
101+
self,
102+
hidden_states: torch.Tensor,
103+
kv_cache: KVCache,
104+
input_metadata: InputMetadata,
105+
cache_event: Optional[torch.cuda.Event],
106+
) -> torch.Tensor:
107+
residual = hidden_states
108+
hidden_states = self.ln_1(hidden_states)
109+
attn_output = self.attn(
110+
hidden_states=hidden_states,
111+
kv_cache=kv_cache,
112+
input_metadata=input_metadata,
113+
cache_event=cache_event,
114+
)
115+
# residual connection
116+
hidden_states = attn_output + residual
117+
118+
residual = hidden_states
119+
hidden_states = self.ln_2(hidden_states)
120+
feed_forward_hidden_states = self.mlp(hidden_states)
121+
# residual connection
122+
hidden_states = residual + feed_forward_hidden_states
123+
return hidden_states
124+
125+
126+
class GPT2Model(nn.Module):
127+
128+
def __init__(self, config: GPT2Config):
129+
super().__init__()
130+
self.config = config
131+
assert config.add_cross_attention == False
132+
assert config.scale_attn_by_inverse_layer_idx == False
133+
assert config.reorder_and_upcast_attn == False
134+
self.embed_dim = config.hidden_size
135+
136+
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
137+
# to 50304 in order to make it divisible by 64.
138+
# This improves performance since GPUs are faster if the dimension
139+
# is divisible by 64. In addition, it allows us to shard the embedding
140+
# layer across 2, 4, 8, or more GPUs.
141+
vocab_size = ((config.vocab_size + 63) // 64) * 64
142+
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
143+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
144+
self.h = nn.ModuleList(
145+
[GPT2Block(config) for _ in range(config.num_hidden_layers)])
146+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
147+
148+
def forward(
149+
self,
150+
input_ids: torch.LongTensor,
151+
position_ids: torch.LongTensor,
152+
kv_caches: List[KVCache],
153+
input_metadata: InputMetadata,
154+
cache_events: Optional[List[torch.cuda.Event]],
155+
) -> torch.Tensor:
156+
inputs_embeds = self.wte(input_ids)
157+
position_embeds = self.wpe(position_ids)
158+
hidden_states = inputs_embeds + position_embeds
159+
160+
for i in range(len(self.h)):
161+
if cache_events is None:
162+
cache_event = None
163+
else:
164+
cache_event = cache_events[i]
165+
layer = self.h[i]
166+
hidden_states = layer(
167+
hidden_states, kv_caches[i], input_metadata, cache_event)
168+
169+
hidden_states = self.ln_f(hidden_states)
170+
return hidden_states
171+
172+
173+
class GPT2LMHeadModel(nn.Module):
174+
175+
def __init__(self, config: GPT2Config):
176+
super().__init__()
177+
self.config = config
178+
self.transformer = GPT2Model(config)
179+
# TODO(zhuohan): create a new weight after implementing pipeline
180+
# parallelism
181+
self.lm_head_weight = self.transformer.wte.weight
182+
self.sampler = Sampler(config.vocab_size)
183+
184+
def forward(
185+
self,
186+
input_ids: torch.LongTensor,
187+
positions: torch.LongTensor,
188+
kv_caches: List[KVCache],
189+
input_metadata: InputMetadata,
190+
cache_events: Optional[List[torch.cuda.Event]],
191+
) -> Dict[int, SequenceOutputs]:
192+
hidden_states = self.transformer(
193+
input_ids, positions, kv_caches, input_metadata, cache_events)
194+
next_tokens = self.sampler(
195+
self.lm_head_weight, hidden_states, input_metadata)
196+
return next_tokens
197+
198+
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
199+
_row_parallel_weights = ["c_proj.weight"]
200+
201+
def load_weights(self, model_name_or_path: str,
202+
cache_dir: Optional[str] = None,
203+
use_np_cache: bool = False):
204+
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
205+
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
206+
state_dict = self.state_dict()
207+
208+
for name, loaded_weight in hf_model_weights_iterator(
209+
model_name_or_path, cache_dir, use_np_cache):
210+
if "lm_head.weight" in name:
211+
# GPT-2 ties the weights of the embedding layer and the final
212+
# linear layer.
213+
continue
214+
if ".attn.bias" in name:
215+
# Skip attention mask.
216+
# NOTE: "c_attn.bias" should not be skipped.
217+
continue
218+
name = "transformer." + name
219+
220+
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
221+
# Because of this, we need to transpose the weights.
222+
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
223+
if conv1d_weight_name not in name:
224+
continue
225+
if not name.endswith(".weight"):
226+
continue
227+
loaded_weight = loaded_weight.t()
228+
param = state_dict[name]
229+
230+
if name == "transformer.wte.weight":
231+
# Consider padding in the vocab size.
232+
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
233+
num_extra_rows = padded_vocab_size - self.config.vocab_size
234+
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
235+
extra_rows = extra_rows.to(loaded_weight)
236+
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
237+
238+
# For the fused QKV linear layer, manually shard the weights.
239+
if "c_attn" in name:
240+
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
241+
# When tensor parallelism is used, we shard the weights along the head dimension.
242+
total_num_heads = self.config.num_attention_heads
243+
hidden_size = self.config.hidden_size
244+
head_size = hidden_size // total_num_heads
245+
num_heads = total_num_heads // tensor_model_parallel_world_size
246+
head_start = tensor_model_parallel_rank * num_heads
247+
head_end = (tensor_model_parallel_rank + 1) * num_heads
248+
249+
if name.endswith(".weight"):
250+
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
251+
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
252+
loaded_weight = loaded_weight.reshape(-1, hidden_size)
253+
elif name.endswith(".bias"):
254+
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
255+
loaded_weight = loaded_weight[:, head_start:head_end, :]
256+
loaded_weight = loaded_weight.reshape(-1)
257+
else:
258+
raise ValueError(f"Unexpected parameter name {name}")
259+
load_tensor_parallel_weights(param, loaded_weight, name,
260+
self._column_parallel_weights,
261+
self._row_parallel_weights)
262+
263+
def initialize_dummy_weights(self) -> None:
264+
for param in self.state_dict().values():
265+
param.data.uniform_(-1e-3, 1e-3)

cacheflow/models/gpt_neox.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def __init__(self, config):
173173
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
174174
bias=False, gather_output=False,
175175
perform_initialization=False)
176-
self.sampler = Sampler()
176+
self.sampler = Sampler(config.vocab_size)
177177

178178
def forward(
179179
self,
@@ -205,8 +205,8 @@ def load_weights(self, model_name_or_path: str,
205205
param = state_dict[name]
206206
if "query_key_value" in name:
207207
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
208-
# [num_heads * 3 * head_size, num_heads * head_size], while the
209-
# required shape is [3 * num_heads * head_size, num_heads * head_size].
208+
# [num_heads * 3 * head_size, hidden_size], while the
209+
# required shape is [3 * num_heads * head_size, hidden_size].
210210
# Thus, we need weight conversion.
211211
shard_size = param.shape[0]
212212
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
@@ -218,11 +218,11 @@ def load_weights(self, model_name_or_path: str,
218218
if 'query_key_value.weight' in name:
219219
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
220220
loaded_weight = loaded_weight.transpose(0, 1)
221-
loaded_weight = loaded_weight.reshape(-1, hidden_size).contiguous()
221+
loaded_weight = loaded_weight.reshape(-1, hidden_size)
222222
elif 'query_key_value.bias' in name:
223223
loaded_weight = loaded_weight.view(-1, 3, head_size)
224224
loaded_weight = loaded_weight.transpose(0, 1)
225-
loaded_weight = loaded_weight.reshape(-1).contiguous()
225+
loaded_weight = loaded_weight.reshape(-1)
226226
else:
227227
raise ValueError(f"Unexpected weight name: {name}")
228228
load_tensor_parallel_weights(param, loaded_weight, name,

cacheflow/models/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def __init__(self, config):
192192
bias=False,
193193
gather_output=False,
194194
perform_initialization=False)
195-
self.sampler = Sampler()
195+
self.sampler = Sampler(config.vocab_size)
196196

197197
def forward(
198198
self,

cacheflow/models/memory_analyzer.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,76 @@ def get_max_num_gpu_blocks(
7272
return max_num_blocks
7373

7474

75+
class GPT2MemoryAnalyzer(CacheFlowMemoryAnalyzer):
76+
77+
def __init__(
78+
self,
79+
model_name: str,
80+
block_size: int,
81+
dtype: torch.dtype,
82+
gpu_memory: int,
83+
cpu_memory: int,
84+
tensor_parallel_size: int,
85+
) -> None:
86+
self.model_name = model_name
87+
self.block_size = block_size
88+
self.dtype = dtype
89+
self.gpu_memory = gpu_memory
90+
self.cpu_memory = cpu_memory
91+
self.tensor_parallel_size = tensor_parallel_size
92+
93+
config = AutoConfig.from_pretrained(model_name)
94+
self.num_layers = config.num_hidden_layers
95+
self.hidden_size = config.hidden_size
96+
self.num_heads = config.num_attention_heads
97+
self.head_size = config.hidden_size // self.num_heads
98+
self.ffn_size = config.n_inner if config.n_inner is not None else 4 * self.hidden_size
99+
self.vocab_size = config.vocab_size
100+
self.max_position = config.max_position_embeddings
101+
102+
def get_param_size(self) -> int:
103+
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
104+
position_embedding = self.max_position * self.hidden_size
105+
106+
ln1 = 2 * self.hidden_size
107+
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
108+
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
109+
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
110+
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
111+
mha = ln1 + q + k + v + out
112+
113+
ln2 = 2 * self.hidden_size
114+
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
115+
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
116+
ffn = ln2 + ffn1 + ffn2
117+
118+
total = (word_embedding + position_embedding +
119+
self.num_layers * (mha + ffn))
120+
dtype_size = get_dtype_size(self.dtype)
121+
return dtype_size * total
122+
123+
def get_max_act_size(
124+
self,
125+
max_num_batched_tokens: int,
126+
) -> int:
127+
# NOTE: We approxmiately calculate the maximum activation size by
128+
# estimating
129+
# 1) the maximum activation tensor size during inference
130+
# 2) the residual tensor size during inference
131+
# Here, we assume that FlashAttention is used and
132+
# thus the attention maps are never materialized in GPU DRAM.
133+
residual = max_num_batched_tokens * self.hidden_size
134+
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
135+
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
136+
# Double the activation size for input and output.
137+
max_act = 2 * (max(qkv, ffn) + residual)
138+
# Size of output logits.
139+
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
140+
max_act = max(max_act, output_logits)
141+
dtype_size = get_dtype_size(self.dtype)
142+
return dtype_size * max_act
143+
144+
75145
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
76146

77147
def __init__(

0 commit comments

Comments
 (0)