Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selective checkpointing #507

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions configs/bert_large_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from .common.optim import optim
from .common.data.bert_dataset import dataloader, tokenization

vocab_file = "./data_test/bert_data/bert-base-chinese-vocab.txt"
data_prefix = "./data_test/bert_data/loss_compara_content_sentence"
# vocab_file = "./data_test/bert_data/bert-base-chinese-vocab.txt"
# data_prefix = "./data_test/bert_data/loss_compara_content_sentence"
vocab_file = "/data/home/liupeihong/datasets/libai_dataset/bert-base-chinese-vocab.txt"
data_prefix = "/data/home/liupeihong/datasets/libai_dataset/loss_compara_content_sentence"

tokenization.tokenizer.vocab_file = vocab_file
dataloader.train.dataset[0].data_prefix = data_prefix
Expand All @@ -34,3 +36,4 @@
train.evaluation.evaluator = LazyCall(PPLEvaluator)()

train.output_dir = "output/bert_output"
graph.enabled = False
15 changes: 11 additions & 4 deletions configs/gpt2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@

from .common.models.graph import graph

vocab_file = "./data_test/gpt_data/gpt2-vocab.json"
merge_files = "./data_test/gpt_data/gpt2-merges.txt"
data_prefix = "./data_test/gpt_data/loss_compara_content_sentence"
# vocab_file = "./data_test/gpt_data/gpt2-vocab.json"
# merge_files = "./data_test/gpt_data/gpt2-merges.txt"
# data_prefix = "./data_test/gpt_data/loss_compara_content_sentence"
vocab_file = "/data/home/liupeihong/datasets/libai_dataset/gpt2-vocab.json"
merge_files = "/data/home/liupeihong/datasets/libai_dataset/gpt2-merges.txt"
data_prefix = (
"/data/home/liupeihong/datasets/libai_dataset/loss_compara_content_sentence"
)


tokenization.tokenizer.vocab_file = vocab_file
tokenization.tokenizer.merges_file = merge_files
Expand Down Expand Up @@ -37,8 +43,9 @@
optim.lr = 1.5e-4

train.train_micro_batch_size = 4
train.amp.enabled = True
train.amp.enabled = False

train.evaluation.evaluator = LazyCall(PPLEvaluator)()

train.output_dir = "./output/gpt2_output"
graph.enabled = False
45 changes: 45 additions & 0 deletions libai/engine/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,51 @@
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/defaults.py
# --------------------------------------------------------
def count_all_parameters(model, verbose=False):
"""
Count total, trainable, and non-trainable parameters in a PyTorch model.
Args:
model (nn.Module): The model to count parameters for.
verbose (bool, optional): Print detailed information if True.
Returns:
Tuple containing total, trainable, and non-trainable parameters, and percent trainable parameters.
"""
train_params, all_params = 0, 0
for _, param in model.named_parameters():
num_params = param.numel()
all_params += num_params
if param.requires_grad:
train_params += num_params
nontrain_params = all_params - train_params
pct_train_params = train_params / all_params * 100
if verbose:
logger = logging.getLogger(__name__)
logger.info(f"Total params: {format_size(all_params)}")
logger.info(f"Trainable params: {format_size(train_params)}")
logger.info(f"Non-trainable params: {format_size(nontrain_params)}")
logger.info(f"Trainable params %: {pct_train_params:.4f}")
return all_params, train_params, nontrain_params, pct_train_params


def format_size(size):
"""
Convert bytes to a human-readable string with appropriate units.
Args:
size (int): The number of bytes.
Returns:
String representing the number of bytes with appropriate units.
"""
k, m, b, t = 1024, 1024**2, 10**9, 10**12
if size > t:
return f"{round(size / t, 4)}T"
elif size > b:
return f"{round(size / b, 4)}B"
elif size > m:
return f"{round(size / m, 4)}M"
elif size > k:
return f"{round(size / k, 4)}K"
else:
return f"{size}"

def _highlight(code, filename):
try:
Expand Down Expand Up @@ -563,6 +607,7 @@ def build_model(cls, cfg):
model = build_model(cfg.model)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
count_all_parameters(model, verbose=True)
model._apply(dist.convert_to_distributed_default_setting)
return model

Expand Down
164 changes: 104 additions & 60 deletions libai/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import oneflow as flow
from oneflow import nn
from oneflow.utils import checkpoint

from .linear import Linear

Expand All @@ -28,6 +29,88 @@ class AttnMaskType(enum.Enum):
causal = 2


class CoreAttention(nn.Module):
def __init__(
self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
scale_mask_softmax_fusion,
attn_mask_type,
apply_query_key_layer_scaling=False,
layer_idx=0,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.head_size = hidden_size // num_attention_heads
self.norm_factor = 1.0 / math.sqrt(float(self.head_size))
self.use_cache = False
self.scale_mask_softmax_fusion = scale_mask_softmax_fusion
self.attn_mask_type = attn_mask_type
self.coeff = None
if apply_query_key_layer_scaling:
self.coeff = layer_idx + 1
self.norm_factor /= self.coeff
self.attention_dropout_prob = attention_dropout_prob
self.dropout = nn.Dropout(p=attention_dropout_prob)

def forward(self, query, key, value, attention_mask):
# [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)]
attention_scores = flow.matmul(
query, key, transpose_b=True, alpha=self.norm_factor
)

# [S(0), S(1)] x [S(0), B] = [S(0), S(1)]
if attention_mask is not None:
if self.scale_mask_softmax_fusion:
if self.attn_mask_type == AttnMaskType.padding:
attention_mask = (
attention_mask.expand_as(attention_scores)
if self.use_cache
else attention_mask
)
attention_weights = flow._C.fused_scale_mask_softmax_dropout(
attention_scores,
attention_mask,
fill_value=-10000.0,
scale=self.coeff,
p=self.attention_dropout_prob,
)[0]
else:
if self.coeff is not None:
attention_scores *= self.coeff
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
# TODO(xingyu.liao): graph will occur `where_scalar` errors
# when using `masked_fill`
# attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0)
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
else:
if (
self.scale_mask_softmax_fusion
and self.attn_mask_type == AttnMaskType.causal
):
attention_weights = flow._C.fused_scale_tril_softmax_mask_scale(
attention_scores,
p=self.attention_dropout_prob,
diagonal=0,
tril_scale_value=self.coeff,
tril_fill_value=-10000.0,
)[0]
else:
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)

# Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)]
context = flow.matmul(attention_weights, value)
# Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size]
context = context.transpose(1, 2)
return context.flatten(2)


class MultiheadAttention(nn.Module):
"""Multi-head attention layer, support self attention and cross attention.

Expand Down Expand Up @@ -80,18 +163,9 @@ def __init__(

self.num_heads = num_attention_heads
self.head_size = hidden_size // num_attention_heads
self.attn_mask_type = attn_mask_type

self.attention_dropout_prob = attention_dropout_prob
self.dropout = nn.Dropout(p=attention_dropout_prob)
self.norm_factor = 1.0 / math.sqrt(float(self.head_size))
self.coeff = None
if apply_query_key_layer_scaling:
self.coeff = layer_idx + 1
self.norm_factor /= self.coeff

self.is_cross_attention = is_cross_attention
self.scale_mask_softmax_fusion = scale_mask_softmax_fusion

self.bias_dropout_fusion = bias_dropout_fusion

if self.bias_dropout_fusion:
Expand Down Expand Up @@ -131,6 +205,15 @@ def __init__(
skip_bias_add=self.bias_dropout_fusion,
layer_idx=layer_idx,
)
self.core_attention = CoreAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
attn_mask_type=attn_mask_type,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
layer_idx=layer_idx,
)

def forward(
self,
Expand Down Expand Up @@ -158,7 +241,7 @@ def forward(
use_cache (bool, optional): it will be set to True, when the model is in the inference
phase and used for incremental decoding. Defaults to False.
"""

self.core_attention.use_cache = use_cache
# hidden_states, encoder_states: [S(0), B]
# attention_mask: [S(0), B]

Expand Down Expand Up @@ -193,7 +276,9 @@ def forward(
# hidden_states is the last-added state,
# the full key and value could be obtained by concatenating with past_key_value.
query_key_value = self.query_key_value(hidden_states)
query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size)
query_key_value = query_key_value.view(
bsz, -1, self.num_heads, 3 * self.head_size
)
query_key_value = query_key_value.permute(
0, 2, 1, 3
) # [bsz, num_heads, src_len, 3 * head_size]
Expand All @@ -207,58 +292,17 @@ def forward(
if use_cache:
past_key_value = (key, value)

# [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)]
attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor)

# [S(0), S(1)] x [S(0), B] = [S(0), S(1)]
if attention_mask is not None:
if self.scale_mask_softmax_fusion:
if self.attn_mask_type == AttnMaskType.padding:
attention_mask = (
attention_mask.expand_as(attention_scores) if use_cache else attention_mask
)
attention_weights = flow._C.fused_scale_mask_softmax_dropout(
attention_scores,
attention_mask,
fill_value=-10000.0,
scale=self.coeff,
p=self.attention_dropout_prob,
)[0]
else:
if self.coeff is not None:
attention_scores *= self.coeff
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
# TODO(xingyu.liao): graph will occur `where_scalar` errors
# when using `masked_fill`
# attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0)
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
else:
if self.scale_mask_softmax_fusion and self.attn_mask_type == AttnMaskType.causal:
attention_weights = flow._C.fused_scale_tril_softmax_mask_scale(
attention_scores,
p=self.attention_dropout_prob,
diagonal=0,
tril_scale_value=self.coeff,
tril_fill_value=-10000.0,
)[0]
else:
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)

# Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)]
context = flow.matmul(attention_weights, value)
# Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size]
context = context.transpose(1, 2)

# Concat multi-head results from
# [bsz, tgt_len, num_heads, head_size] -> [bsz, tgt_len, num_heads * head_size]
# SBP sign: [S(0), S(2)]
# [S(0), S(2)] x [B, S(0)] = [S(0), P] -> [S(0), B]
output = self.dense(context.flatten(2))

# context = self.core_attention(query, key, value, attention_mask)

context = checkpoint.checkpoint(
self.core_attention, query, key, value, attention_mask
)
output = self.dense(context)

if self.bias_dropout_fusion:
output, bias = output
Expand Down