Skip to content

Commit

Permalink
Refactor GPT2 (#11225)
Browse files Browse the repository at this point in the history
* refactor GPT2

* fix mlp and head pruning

* address Sylvains comments

* apply suggestion from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
  • Loading branch information
2 people authored and Rocketknight1 committed Apr 21, 2021
1 parent fc6322c commit f6533ae
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 92 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/gpt2/configuration_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class GPT2Config(PretrainedConfig):
and :class:`~transformers.TFGPT2DoubleHeadsModel`.
The dropout ratio to be used after the projection and activation.
scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`):
Scale attention weights by dividing by sqrt(hidden_size).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Expand Down Expand Up @@ -144,6 +146,7 @@ def __init__(
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=50256,
Expand Down Expand Up @@ -171,6 +174,7 @@ def __init__(
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache

self.bos_token_id = bos_token_id
Expand Down
220 changes: 128 additions & 92 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,87 +122,100 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
return model


class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()

n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
max_positions = config.max_position_embeddings
self.register_buffer(
"bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale

self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)

self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention

if self.is_cross_attention:
self.c_attn = Conv1D(2 * n_state, nx)
self.q_attn = Conv1D(n_state, nx)
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * n_state, nx)
self.c_proj = Conv1D(n_state, nx)
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
)
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

# Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)

def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
w = torch.matmul(q, k)
if self.scale:
w = w / (float(v.size(-1)) ** 0.5)
nd, ns = w.size(-2), w.size(-1)
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

if self.scale_attn_weights:
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)

if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

if attention_mask is not None:
# Apply the attention mask
w = w + attention_mask
attn_weights = attn_weights + attention_mask

w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = self.attn_dropout(attn_weights)

# Mask heads if we want to
if head_mask is not None:
w = w * head_mask
attn_weights = attn_weights * head_mask

outputs = (torch.matmul(w, v),)
if output_attentions:
outputs += (w,)
return outputs

def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states

def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)

def forward(
self,
Expand All @@ -216,65 +229,77 @@ def forward(
output_attentions=False,
):
if encoder_hidden_states is not None:
assert hasattr(
self, "q_attn"
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1)
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = (key.transpose(-2, -1), value) # transpose to have same shapes
present = (key, value)
else:
present = None

attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
a = attn_outputs[0]
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return (a, present) + attn_outputs[1:] # a, present, (attentions)
return outputs # a, present, (attentions)


class MLP(nn.Module):
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
nx = config.n_embd
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
embed_dim = config.hidden_size
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)

def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return self.dropout(h2)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states


class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
class GPT2Block(nn.Module):
def __init__(self, config):
super().__init__()
hidden_size = config.n_embd
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = Attention(hidden_size, n_ctx, config, scale)
self.attn = GPT2Attention(config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

if config.add_cross_attention:
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
self.crossattention = GPT2Attention(config, is_cross_attention=True)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = MLP(inner_dim, config)

self.mlp = GPT2MLP(inner_dim, config)

def forward(
self,
Expand All @@ -287,8 +312,10 @@ def forward(
use_cache=False,
output_attentions=False,
):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
self.ln_1(hidden_states),
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
Expand All @@ -298,15 +325,19 @@ def forward(
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + hidden_states
hidden_states = attn_output + residual

if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
"cross-attention layers by setting `config.add_cross_attention=True`"
)
residual = hidden_states
hidden_states = self.ln_cross_attn(hidden_states)
cross_attn_outputs = self.crossattention(
self.ln_cross_attn(hidden_states),
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
Expand All @@ -315,12 +346,14 @@ def forward(
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = hidden_states + attn_output
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights

feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = hidden_states + feed_forward_hidden_states
hidden_states = residual + feed_forward_hidden_states

if use_cache:
outputs = (hidden_states,) + outputs
Expand Down Expand Up @@ -390,8 +423,8 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""

loss: Optional[torch.FloatTensor] = None
Expand Down Expand Up @@ -539,11 +572,14 @@ class GPT2Model(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.embed_dim = config.hidden_size

self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

self.init_weights()

Expand Down Expand Up @@ -654,7 +690,7 @@ def forward(
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

# Attention mask.
# GPT2Attention mask.
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
Expand Down

0 comments on commit f6533ae

Please sign in to comment.