diff --git a/configs/bert_large_pretrain.py b/configs/bert_large_pretrain.py index 28520d2ab..867a1b4b0 100644 --- a/configs/bert_large_pretrain.py +++ b/configs/bert_large_pretrain.py @@ -6,8 +6,10 @@ from .common.optim import optim from .common.data.bert_dataset import dataloader, tokenization -vocab_file = "./data_test/bert_data/bert-base-chinese-vocab.txt" -data_prefix = "./data_test/bert_data/loss_compara_content_sentence" +# vocab_file = "./data_test/bert_data/bert-base-chinese-vocab.txt" +# data_prefix = "./data_test/bert_data/loss_compara_content_sentence" +vocab_file = "/data/home/liupeihong/datasets/libai_dataset/bert-base-chinese-vocab.txt" +data_prefix = "/data/home/liupeihong/datasets/libai_dataset/loss_compara_content_sentence" tokenization.tokenizer.vocab_file = vocab_file dataloader.train.dataset[0].data_prefix = data_prefix @@ -34,3 +36,4 @@ train.evaluation.evaluator = LazyCall(PPLEvaluator)() train.output_dir = "output/bert_output" +graph.enabled = False \ No newline at end of file diff --git a/configs/gpt2_pretrain.py b/configs/gpt2_pretrain.py index 313c814e1..75ba07693 100644 --- a/configs/gpt2_pretrain.py +++ b/configs/gpt2_pretrain.py @@ -7,9 +7,15 @@ from .common.models.graph import graph -vocab_file = "./data_test/gpt_data/gpt2-vocab.json" -merge_files = "./data_test/gpt_data/gpt2-merges.txt" -data_prefix = "./data_test/gpt_data/loss_compara_content_sentence" +# vocab_file = "./data_test/gpt_data/gpt2-vocab.json" +# merge_files = "./data_test/gpt_data/gpt2-merges.txt" +# data_prefix = "./data_test/gpt_data/loss_compara_content_sentence" +vocab_file = "/data/home/liupeihong/datasets/libai_dataset/gpt2-vocab.json" +merge_files = "/data/home/liupeihong/datasets/libai_dataset/gpt2-merges.txt" +data_prefix = ( + "/data/home/liupeihong/datasets/libai_dataset/loss_compara_content_sentence" +) + tokenization.tokenizer.vocab_file = vocab_file tokenization.tokenizer.merges_file = merge_files @@ -37,8 +43,9 @@ optim.lr = 1.5e-4 train.train_micro_batch_size = 4 -train.amp.enabled = True +train.amp.enabled = False train.evaluation.evaluator = LazyCall(PPLEvaluator)() train.output_dir = "./output/gpt2_output" +graph.enabled = False diff --git a/libai/engine/default.py b/libai/engine/default.py index 14a107167..06b6e0db5 100644 --- a/libai/engine/default.py +++ b/libai/engine/default.py @@ -43,7 +43,51 @@ # References: # https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/defaults.py # -------------------------------------------------------- +def count_all_parameters(model, verbose=False): + """ + Count total, trainable, and non-trainable parameters in a PyTorch model. + Args: + model (nn.Module): The model to count parameters for. + verbose (bool, optional): Print detailed information if True. + Returns: + Tuple containing total, trainable, and non-trainable parameters, and percent trainable parameters. + """ + train_params, all_params = 0, 0 + for _, param in model.named_parameters(): + num_params = param.numel() + all_params += num_params + if param.requires_grad: + train_params += num_params + nontrain_params = all_params - train_params + pct_train_params = train_params / all_params * 100 + if verbose: + logger = logging.getLogger(__name__) + logger.info(f"Total params: {format_size(all_params)}") + logger.info(f"Trainable params: {format_size(train_params)}") + logger.info(f"Non-trainable params: {format_size(nontrain_params)}") + logger.info(f"Trainable params %: {pct_train_params:.4f}") + return all_params, train_params, nontrain_params, pct_train_params + +def format_size(size): + """ + Convert bytes to a human-readable string with appropriate units. + Args: + size (int): The number of bytes. + Returns: + String representing the number of bytes with appropriate units. + """ + k, m, b, t = 1024, 1024**2, 10**9, 10**12 + if size > t: + return f"{round(size / t, 4)}T" + elif size > b: + return f"{round(size / b, 4)}B" + elif size > m: + return f"{round(size / m, 4)}M" + elif size > k: + return f"{round(size / k, 4)}K" + else: + return f"{size}" def _highlight(code, filename): try: @@ -563,6 +607,7 @@ def build_model(cls, cfg): model = build_model(cfg.model) logger = logging.getLogger(__name__) logger.info("Model:\n{}".format(model)) + count_all_parameters(model, verbose=True) model._apply(dist.convert_to_distributed_default_setting) return model diff --git a/libai/layers/attention.py b/libai/layers/attention.py index 0bec6ebc1..45fcb9bc4 100644 --- a/libai/layers/attention.py +++ b/libai/layers/attention.py @@ -19,6 +19,7 @@ import oneflow as flow from oneflow import nn +from oneflow.utils import checkpoint from .linear import Linear @@ -28,6 +29,88 @@ class AttnMaskType(enum.Enum): causal = 2 +class CoreAttention(nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + attention_dropout_prob, + scale_mask_softmax_fusion, + attn_mask_type, + apply_query_key_layer_scaling=False, + layer_idx=0, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.head_size = hidden_size // num_attention_heads + self.norm_factor = 1.0 / math.sqrt(float(self.head_size)) + self.use_cache = False + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + self.attn_mask_type = attn_mask_type + self.coeff = None + if apply_query_key_layer_scaling: + self.coeff = layer_idx + 1 + self.norm_factor /= self.coeff + self.attention_dropout_prob = attention_dropout_prob + self.dropout = nn.Dropout(p=attention_dropout_prob) + + def forward(self, query, key, value, attention_mask): + # [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)] + attention_scores = flow.matmul( + query, key, transpose_b=True, alpha=self.norm_factor + ) + + # [S(0), S(1)] x [S(0), B] = [S(0), S(1)] + if attention_mask is not None: + if self.scale_mask_softmax_fusion: + if self.attn_mask_type == AttnMaskType.padding: + attention_mask = ( + attention_mask.expand_as(attention_scores) + if self.use_cache + else attention_mask + ) + attention_weights = flow._C.fused_scale_mask_softmax_dropout( + attention_scores, + attention_mask, + fill_value=-10000.0, + scale=self.coeff, + p=self.attention_dropout_prob, + )[0] + else: + if self.coeff is not None: + attention_scores *= self.coeff + attention_scores = flow.mul(attention_scores, attention_mask) + attention_scores = attention_scores - 10000.0 * (1 - attention_mask) + # TODO(xingyu.liao): graph will occur `where_scalar` errors + # when using `masked_fill` + # attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0) + attention_weights = flow.softmax(attention_scores, dim=-1) + # [bsz, num_heads, tgt_len, src_len] + attention_weights = self.dropout(attention_weights) + else: + if ( + self.scale_mask_softmax_fusion + and self.attn_mask_type == AttnMaskType.causal + ): + attention_weights = flow._C.fused_scale_tril_softmax_mask_scale( + attention_scores, + p=self.attention_dropout_prob, + diagonal=0, + tril_scale_value=self.coeff, + tril_fill_value=-10000.0, + )[0] + else: + attention_weights = flow.softmax(attention_scores, dim=-1) + # [bsz, num_heads, tgt_len, src_len] + attention_weights = self.dropout(attention_weights) + + # Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)] + context = flow.matmul(attention_weights, value) + # Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size] + context = context.transpose(1, 2) + return context.flatten(2) + + class MultiheadAttention(nn.Module): """Multi-head attention layer, support self attention and cross attention. @@ -80,18 +163,9 @@ def __init__( self.num_heads = num_attention_heads self.head_size = hidden_size // num_attention_heads - self.attn_mask_type = attn_mask_type - - self.attention_dropout_prob = attention_dropout_prob - self.dropout = nn.Dropout(p=attention_dropout_prob) - self.norm_factor = 1.0 / math.sqrt(float(self.head_size)) - self.coeff = None - if apply_query_key_layer_scaling: - self.coeff = layer_idx + 1 - self.norm_factor /= self.coeff self.is_cross_attention = is_cross_attention - self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + self.bias_dropout_fusion = bias_dropout_fusion if self.bias_dropout_fusion: @@ -131,6 +205,15 @@ def __init__( skip_bias_add=self.bias_dropout_fusion, layer_idx=layer_idx, ) + self.core_attention = CoreAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout_prob=attention_dropout_prob, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, + attn_mask_type=attn_mask_type, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + layer_idx=layer_idx, + ) def forward( self, @@ -158,7 +241,7 @@ def forward( use_cache (bool, optional): it will be set to True, when the model is in the inference phase and used for incremental decoding. Defaults to False. """ - + self.core_attention.use_cache = use_cache # hidden_states, encoder_states: [S(0), B] # attention_mask: [S(0), B] @@ -193,7 +276,9 @@ def forward( # hidden_states is the last-added state, # the full key and value could be obtained by concatenating with past_key_value. query_key_value = self.query_key_value(hidden_states) - query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size) + query_key_value = query_key_value.view( + bsz, -1, self.num_heads, 3 * self.head_size + ) query_key_value = query_key_value.permute( 0, 2, 1, 3 ) # [bsz, num_heads, src_len, 3 * head_size] @@ -207,58 +292,17 @@ def forward( if use_cache: past_key_value = (key, value) - # [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)] - attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor) - - # [S(0), S(1)] x [S(0), B] = [S(0), S(1)] - if attention_mask is not None: - if self.scale_mask_softmax_fusion: - if self.attn_mask_type == AttnMaskType.padding: - attention_mask = ( - attention_mask.expand_as(attention_scores) if use_cache else attention_mask - ) - attention_weights = flow._C.fused_scale_mask_softmax_dropout( - attention_scores, - attention_mask, - fill_value=-10000.0, - scale=self.coeff, - p=self.attention_dropout_prob, - )[0] - else: - if self.coeff is not None: - attention_scores *= self.coeff - attention_scores = flow.mul(attention_scores, attention_mask) - attention_scores = attention_scores - 10000.0 * (1 - attention_mask) - # TODO(xingyu.liao): graph will occur `where_scalar` errors - # when using `masked_fill` - # attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0) - attention_weights = flow.softmax(attention_scores, dim=-1) - # [bsz, num_heads, tgt_len, src_len] - attention_weights = self.dropout(attention_weights) - else: - if self.scale_mask_softmax_fusion and self.attn_mask_type == AttnMaskType.causal: - attention_weights = flow._C.fused_scale_tril_softmax_mask_scale( - attention_scores, - p=self.attention_dropout_prob, - diagonal=0, - tril_scale_value=self.coeff, - tril_fill_value=-10000.0, - )[0] - else: - attention_weights = flow.softmax(attention_scores, dim=-1) - # [bsz, num_heads, tgt_len, src_len] - attention_weights = self.dropout(attention_weights) - - # Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)] - context = flow.matmul(attention_weights, value) - # Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size] - context = context.transpose(1, 2) - # Concat multi-head results from # [bsz, tgt_len, num_heads, head_size] -> [bsz, tgt_len, num_heads * head_size] # SBP sign: [S(0), S(2)] # [S(0), S(2)] x [B, S(0)] = [S(0), P] -> [S(0), B] - output = self.dense(context.flatten(2)) + + # context = self.core_attention(query, key, value, attention_mask) + + context = checkpoint.checkpoint( + self.core_attention, query, key, value, attention_mask + ) + output = self.dense(context) if self.bias_dropout_fusion: output, bias = output