From daccf1a19d55206b86e6c3e6facf6feb1b2f18b3 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 7 May 2024 17:46:16 -0700 Subject: [PATCH 01/12] refactoring --- llmfoundry/models/layers/attention.py | 228 +++++++++++++-------- llmfoundry/models/layers/blocks.py | 71 ++++--- llmfoundry/models/mpt/configuration_mpt.py | 6 +- llmfoundry/models/mpt/modeling_mpt.py | 70 +++++-- llmfoundry/models/utils/config_moe_args.py | 14 +- llmfoundry/utils/config_utils.py | 40 ++-- 6 files changed, 273 insertions(+), 156 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 82fee68af6..4884b568fd 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,10 +5,9 @@ import math import warnings -from typing import Any, Optional +from typing import Any, Dict, Optional, Tuple import torch -import torch.nn as nn import transformers from einops import rearrange from packaging import version @@ -233,7 +232,6 @@ def flash_attn_fn( dropout_p: float = 0.0, training: bool = False, needs_weights: bool = False, - multiquery: bool = False, should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, alibi_slopes: Optional[torch.Tensor] = None, @@ -506,6 +504,54 @@ def forward( flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: + query, key, value = self.get_qkv(x) + + if rotary_emb_w_meta_info is not None: + query, key, value = self._apply_rotary_embeddings( + rotary_emb_w_meta_info, + query, + key, + value, + ) + + extra_attn_kwargs = self.get_implementation_specific_args( + attention_mask, + alibi_slopes, + flash_attn_padding_info, + ) + + context, attn_weights, past_key_value = self.attn_fn( + query, + key, + value, + n_heads=self.n_heads, + kv_n_heads=self.kv_n_heads, + past_key_value=past_key_value, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + is_causal=is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights, + **extra_attn_kwargs, + ) + + return self.out_proj(context), attn_weights, past_key_value + + def get_qkv( + self, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes and returns the query, key, and value tensors. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + query (torch.Tensor): The query tensor. + key (torch.Tensor): The key tensor. + value (torch.Tensor): The value tensor. + """ qkv = self.Wqkv(x) if self.clip_qkv: @@ -520,8 +566,6 @@ def forward( dim=2, ) - key_padding_mask = attention_mask - if self.qk_ln or self.qk_gn: # Applying layernorm to qk q_shape, k_shape = query.shape, key.shape @@ -533,97 +577,105 @@ def forward( query = self.q_ln(query).to(dtype).view(q_shape) key = self.k_ln(key).to(dtype).view(k_shape) - if rotary_emb_w_meta_info is not None: - rotary_emb = rotary_emb_w_meta_info['rotary_emb'] - seq_len = rotary_emb_w_meta_info['seq_len'] - offset_info = rotary_emb_w_meta_info['offset_info'] - bsz, seqlen = query.shape[:2] - query = query.view(bsz, seqlen, -1, self.head_dim) - key = key.view(bsz, seqlen, -1, self.head_dim) - - if rotary_emb_w_meta_info['impl'] == 'dail': - value = value.view(bsz, seqlen, -1, self.head_dim) - - kv = torch.stack([key, value], dim=2) - query, kv = rotary_emb( - query, - kv, - seqlen_offset=offset_info, - max_seqlen=seq_len, + return query, key, value + + def _apply_rotary_embeddings( + self, + rotary_emb_w_meta_info: Dict[str, Any], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] + seq_len = rotary_emb_w_meta_info['seq_len'] + offset_info = rotary_emb_w_meta_info['offset_info'] + bsz, seqlen = query.shape[:2] + query = query.view(bsz, seqlen, -1, self.head_dim) + key = key.view(bsz, seqlen, -1, self.head_dim) + + if rotary_emb_w_meta_info['impl'] == 'dail': + value = value.view(bsz, seqlen, -1, self.head_dim) + + kv = torch.stack([key, value], dim=2) + query, kv = rotary_emb( + query, + kv, + seqlen_offset=offset_info, + max_seqlen=seq_len, + ) + [key, value] = torch.unbind(kv, dim=2) + + value = value.view(bsz, seqlen, -1) + elif rotary_emb_w_meta_info['impl'] == 'hf': + if is_transformers_version_gte('4.38'): + (cos, sin) = rotary_emb( + x=value, + position_ids=offset_info, + ) + else: + (cos, sin) = rotary_emb(x=value, seq_len=seq_len) + if is_transformers_version_gte('4.38'): + query, key = apply_rotary_pos_emb( + q=query, + k=key, + cos=cos, + sin=sin, + position_ids=None, + unsqueeze_dim=2, + ) + elif is_transformers_version_gte('4.36'): + query, key = apply_rotary_pos_emb( + q=query, + k=key, + cos=cos, + sin=sin, + position_ids=offset_info, + unsqueeze_dim=2, + ) + else: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + query, key = apply_rotary_pos_emb( + q=query, + k=key, + cos=cos, + sin=sin, + position_ids=offset_info, ) - [key, value] = torch.unbind(kv, dim=2) - - value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) - elif rotary_emb_w_meta_info['impl'] == 'hf': - if is_transformers_version_gte('4.38'): - (cos, sin) = rotary_emb( - x=value, - position_ids=offset_info, - ) - else: - (cos, sin) = rotary_emb(x=value, seq_len=seq_len) - if is_transformers_version_gte('4.38'): - query, key = apply_rotary_pos_emb( - q=query, - k=key, - cos=cos, - sin=sin, - position_ids=None, - unsqueeze_dim=2, - ) - elif is_transformers_version_gte('4.36'): - query, key = apply_rotary_pos_emb( - q=query, - k=key, - cos=cos, - sin=sin, - position_ids=offset_info, - unsqueeze_dim=2, - ) - else: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - query, key = apply_rotary_pos_emb( - q=query, - k=key, - cos=cos, - sin=sin, - position_ids=offset_info, - ) - query = query.transpose(1, 2) - key = key.transpose(1, 2) - - query = query.view(bsz, seqlen, self.d_model) - key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim) - - extra_attn_kwargs = {} + query = query.transpose(1, 2) + key = key.transpose(1, 2) + + query = query.view(bsz, seqlen, -1) + key = key.view(bsz, seqlen, -1) + return query, key, value + + def get_implementation_specific_args( + self, + attention_mask: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, Any]: + """Returns attention implementation specific args. + + Args: + attention_mask (Optional[torch.Tensor]): The attention mask. + alibi_slopes (Optional[torch.Tensor]): The alibi slopes. + flash_attn_padding_info (Optional[dict[str, torch.Tensor]]): The padding information, only required for flash attention. + + Returns: + extra_attn_kwargs (dict[str, Any]): Implementation specific args. + """ if self.attn_impl == 'flash': - key_padding_mask = None extra_attn_kwargs = { 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, 'flash_attn_padding_info': flash_attn_padding_info, + 'key_padding_mask': None, } - - context, attn_weights, past_key_value = self.attn_fn( - query, - key, - value, - self.n_heads, - self.kv_n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - **extra_attn_kwargs, - ) - - return self.out_proj(context), attn_weights, past_key_value + else: + extra_attn_kwargs = {'key_padding_mask': attention_mask} + return extra_attn_kwargs @attention_classes.register_class('multihead_attention') diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 494bdcdff1..d88f311811 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -3,7 +3,7 @@ """GPT Blocks used for the GPT Model.""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Set, Tuple import torch import torch.nn as nn @@ -67,6 +67,7 @@ def __init__( device: Optional[str] = None, no_bias: bool = False, use_pad_tok_in_ffn: bool = True, + extra_args_to_exclude_in_attn_class: Optional[Set[str]] = None, **kwargs: Any, ): if attn_config is None: @@ -84,10 +85,30 @@ def __init__( ffn_type = ffn_config['ffn_type'] ffn_has_norm = ffn_type in ffns_with_norm + # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs + base_args_to_exclude_in_attn_class = { + 'attn_type', + 'alibi', + 'attn_uses_sequence_id', + 'alibi_bias_max', + 'rope', + 'rope_theta', + 'rope_impl', + 'rope_dail_config', + 'rope_hf_config', + } + if extra_args_to_exclude_in_attn_class is None: + extra_args_to_exclude_in_attn_class = set() + self.args_to_exclude_in_attn_class = base_args_to_exclude_in_attn_class.union( + extra_args_to_exclude_in_attn_class, + ) + if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, n_heads=n_heads, + args_to_exclude_in_attn_class=self. + args_to_exclude_in_attn_class, attn_config=attn_config, ffn_has_norm=ffn_has_norm, fc_type=fc_type, @@ -98,22 +119,10 @@ def __init__( ) else: assert isinstance(attn_config['attn_type'], str) - # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs - args_to_exclude_in_attn_class = { - 'attn_type', - 'alibi', - 'attn_uses_sequence_id', - 'alibi_bias_max', - 'rope', - 'rope_theta', - 'rope_impl', - 'rope_dail_config', - 'rope_hf_config', - } attn_config_subset_for_attn_class = { k: v for k, v in attn_config.items() - if k not in args_to_exclude_in_attn_class + if k not in self.args_to_exclude_in_attn_class } self.norm_1 = build_norm( @@ -196,6 +205,24 @@ def forward( if self.norm_2 is not None: m = self.norm_2(x) + n = self.apply_ffn(attention_mask, m) + x = x + self.resid_ffn_dropout(n) + return x, attn_weights, past_key_value + + def apply_ffn( + self, + attention_mask: Optional[torch.ByteTensor], + m: torch.Tensor, + ) -> torch.Tensor: + """Apply feed forward layers to the input. + + Args: + attention_mask (Optional[torch.ByteTensor]): The attention mask. + m (torch.Tensor): The input. + + Returns: + n (torch.Tensor): The output. + """ batch_size, seq_len = m.size()[:2] indices = None if not self.use_pad_tok_in_ffn: @@ -205,8 +232,7 @@ def forward( if not self.use_pad_tok_in_ffn: assert pad_input is not None n = pad_input(n, indices, batch_size, seq_len) - x = x + self.resid_ffn_dropout(n) - return x, attn_weights, past_key_value + return n class FusedNormAttentionNorm(nn.Module): @@ -215,6 +241,7 @@ def __init__( self, d_model: int, n_heads: int, + args_to_exclude_in_attn_class: Set[str], attn_config: Optional[Dict] = None, ffn_has_norm: bool = False, fc_type: str = 'torch', @@ -228,18 +255,6 @@ def __init__( assert attn_config is not None assert isinstance(attn_config['attn_type'], str) - # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs - args_to_exclude_in_attn_class = { - 'attn_type', - 'alibi', - 'attn_uses_sequence_id', - 'alibi_bias_max', - 'rope', - 'rope_theta', - 'rope_impl', - 'rope_dail_config', - 'rope_hf_config', - } attn_config_subset_for_attn_class = { k: v for k, v in attn_config.items() diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 78653fabdc..a1716fa214 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -140,7 +140,9 @@ def __init__( self.n_heads = n_heads self.n_layers = n_layers self.expansion_ratio = expansion_ratio - self.max_seq_len = max_seq_len + if max_seq_len != int(max_seq_len): + raise ValueError('max_seq_len must be an integer') + self.max_seq_len = int(max_seq_len) self.vocab_size = vocab_size self.resid_pdrop = resid_pdrop self.emb_pdrop = emb_pdrop @@ -327,3 +329,5 @@ def _validate_config(self) -> None: raise ImportError( 'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6', ) + if (self.attn_config.get('seq_parallel_world_size', 1) or 1) > 1: + raise NotImplementedError('Sequence Parallelism is not supported.') diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index bdf6cff925..61d588c8ab 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -367,15 +367,9 @@ def __init__(self, config: MPTConfig): ) self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None - block_args = config.to_dict() - if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: - block_args['ffn_config'] = config_moe_args( - block_args['ffn_config'], - config.d_model, - config.expansion_ratio, - config.n_layers, - ) - self.mb_args = block_args['ffn_config'].get('args') + self.shift_labels = True + + block_args = self.extract_block_args(config.to_dict()) self.blocks = nn.ModuleList([ MPTBlock( @@ -444,6 +438,19 @@ def __init__(self, config: MPTConfig): log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') + def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: + """Sets the block args.""" + + if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: + block_args['ffn_config'] = config_moe_args( + block_args['ffn_config'], + block_args['d_model'], + block_args['expansion_ratio'], + block_args['n_layers'], + ) + self.mb_args = block_args['ffn_config'].get('args') + return block_args + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: return self.wte @@ -583,17 +590,17 @@ def forward( ) elif input_ids is not None: bsz = input_ids.size(0) - S = input_ids.size(1) x = self.wte(input_ids) input_device = input_ids.device elif inputs_embeds is not None: bsz = inputs_embeds.size(0) - S = inputs_embeds.size(1) x = inputs_embeds input_device = inputs_embeds.device else: raise ValueError('You must specify input_ids or inputs_embeds') + S = self.get_sequence_length(x) + assert ( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' @@ -746,6 +753,17 @@ def forward( attentions=all_self_attns, ) + def get_sequence_length(self, x: torch.Tensor) -> int: + """Returns the sequence length. + + Args: + x (torch.Tensor): The input Tensor. + + Returns: + S (int): The sequence length. + """ + return x.size(1) + # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] @@ -1092,7 +1110,7 @@ def __init__( use_logits=True, metrics=train_metrics, eval_metrics=eval_metrics, - shift_labels=True, + shift_labels=model.transformer.shift_labels, allow_embedding_resizing=True, ) @@ -1148,7 +1166,11 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> Union[dict, torch.Tensor]: - targets = self.get_targets(batch) + if self.model.transformer.shift_labels: + targets = self.get_targets(batch) + else: + targets = batch['labels'] + losses = self.loss_fn( outputs.logits.view(-1, outputs.logits.size(-1)), targets.view(-1), @@ -1158,6 +1180,12 @@ def loss(self, outputs: CausalLMOutputWithPast, loss = losses.sum() else: loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() + if 'sample_weighing_factor' in batch: + if batch['sample_weighing_factor'].shape[0] > 1: + raise ValueError( + 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', + ) + loss = loss * batch['sample_weighing_factor'][0].item() if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss @@ -1195,9 +1223,19 @@ def flops_per_batch(self, batch: Mapping): params = self.n_active_params params_flops_per_token = 2 * params params_flops_per_seq = params_flops_per_token * msl - attn_flops_per_seq = ( + attn_flops_per_seq = self.get_attention_flops(msl) + return (params_flops_per_seq + attn_flops_per_seq) * 3 * bs + + def get_attention_flops(self, msl: int) -> int: + """Computes the attention flops for the batch. + + Args: + msl (int): The batch sequence length. + + Returns: + attn_flopts (int): The attention flops. + """ + return ( self.model.config.n_layers * 2 * 2 * (self.model.config.d_model * (msl**2)) ) - - return (params_flops_per_seq + attn_flops_per_seq) * 3 * bs diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 2d9a8cadd4..29f8d1bfcc 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -3,6 +3,7 @@ """Helper function to configure MPT with MoEs.""" +import inspect from typing import Union import torch @@ -143,7 +144,10 @@ def config_megablocks_moe_args( elif lbl_process_group == 'global_group': lbl_process_group = distributed.group.WORLD elif isinstance(lbl_process_group, int): - lbl_process_group = create_set_process_group(lbl_process_group) + if lbl_process_group > 1: + lbl_process_group = create_set_process_group(lbl_process_group) + else: + lbl_process_group = None elif lbl_process_group is not None: raise ValueError( f'Unknown {lbl_process_group=}. Options are: none | expert_group | global_group | .', @@ -153,6 +157,14 @@ def config_megablocks_moe_args( ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio) ffn_config.setdefault('ffn_hidden_size', ffn_hidden_size) + args_to_keep_in_ffn_config = inspect.signature( + megablocks.layers.arguments.Arguments, + ).parameters + + ffn_config = { + k: v for k, v in ffn_config.items() if k in args_to_keep_in_ffn_config + } + args = megablocks.layers.arguments.Arguments( hidden_size=d_model, **ffn_config, diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index db180e3168..3c4eab3e42 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -12,7 +12,6 @@ from composer.utils import dist, parse_uri from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -from transformers import PretrainedConfig from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.utils import init_empty_weights @@ -60,36 +59,31 @@ def pop_config( return default_value -def get_hf_config_value(config: Union[dict, PretrainedConfig], key: str) -> Any: - """Get a value from a Hugging Face config. - - Args: - config (Union[dict, PretrainedConfig]): The Hugging Face config object. - key (str): The key to get from the config. - - Returns: - Any: The value from the config. None if the key does not exist. - """ - if isinstance(config, dict): - return config.get(key) - return getattr(config, key, None) - - def calculate_batch_size_info( global_batch_size: int, - device_microbatch_size: Union[int, Literal['auto']], -) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]: - if global_batch_size % dist.get_world_size() != 0: + device_microbatch_size: Union[int, float, Literal['auto']], + data_replication_degree: int = 1, +) -> Tuple[Union[int, float], Union[int, float, Literal['auto']], Union[ + int, Literal['auto']]]: + if dist.get_world_size() % data_replication_degree != 0: + raise ValueError( + f'World size {dist.get_world_size()} is not divisible by data replication degree {data_replication_degree}.', + ) + if global_batch_size % ( + dist.get_world_size() // data_replication_degree + ) != 0: raise ValueError( - f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' + f'Global batch size {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.', ) - device_batch_size = global_batch_size // dist.get_world_size() + device_batch_size = global_batch_size / dist.get_world_size() + if device_batch_size == round(device_batch_size): + device_batch_size = round(device_batch_size) if device_microbatch_size == 'auto': device_grad_accum = 'auto' - elif isinstance(device_microbatch_size, int): + elif isinstance(device_microbatch_size, (int, float)): if device_microbatch_size > device_batch_size: log.warn( f'device_microbatch_size > device_batch_size, ' + @@ -107,9 +101,11 @@ def calculate_batch_size_info( # Coming soon: this conversion math will be done inside Composer Trainer def update_batch_size_info(cfg: DictConfig) -> DictConfig: + data_replication_degree = 1 device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( cfg.global_train_batch_size, cfg.device_train_microbatch_size, + data_replication_degree=data_replication_degree, ) cfg.n_gpus = dist.get_world_size() cfg.device_train_batch_size = device_train_batch_size From 769f3841772c7c53b767c4e6da735fffa062a315 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 7 May 2024 18:04:35 -0700 Subject: [PATCH 02/12] adding back a function that got deleted by mistake --- llmfoundry/utils/config_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 3c4eab3e42..3620b9f564 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -12,6 +12,7 @@ from composer.utils import dist, parse_uri from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from transformers import PretrainedConfig from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.utils import init_empty_weights @@ -59,6 +60,21 @@ def pop_config( return default_value +def get_hf_config_value(config: Union[dict, PretrainedConfig], key: str) -> Any: + """Get a value from a Hugging Face config. + + Args: + config (Union[dict, PretrainedConfig]): The Hugging Face config object. + key (str): The key to get from the config. + + Returns: + Any: The value from the config. None if the key does not exist. + """ + if isinstance(config, dict): + return config.get(key) + return getattr(config, key, None) + + def calculate_batch_size_info( global_batch_size: int, device_microbatch_size: Union[int, float, Literal['auto']], From 83623b2c045b77b522e7346d9d6b7c97811a8d13 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 7 May 2024 18:18:32 -0700 Subject: [PATCH 03/12] adding co-authors Co-Authored-By: Vitaliy Chiley Co-Authored-By: Cheng Li --- llmfoundry/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 3620b9f564..ea79c81ac6 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -89,7 +89,7 @@ def calculate_batch_size_info( dist.get_world_size() // data_replication_degree ) != 0: raise ValueError( - f'Global batch size {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + f'Global batchsize {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.', From ccb76cc7abb09bf0d7a40ac0fb7c7e130b8fafee Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 7 May 2024 18:21:33 -0700 Subject: [PATCH 04/12] adding co-authors Co-Authored-By: Vitaliy Chiley --- llmfoundry/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index ea79c81ac6..3620b9f564 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -89,7 +89,7 @@ def calculate_batch_size_info( dist.get_world_size() // data_replication_degree ) != 0: raise ValueError( - f'Global batchsize {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + f'Global batch size {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.', From 85d27f83cb9679199abfa8336de9f0bfa8f5a617 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 7 May 2024 18:23:43 -0700 Subject: [PATCH 05/12] adding co-authors Co-authored-by: Vitaliy Chiley Co-authored-by: Vitaliy Chiley --- llmfoundry/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 3620b9f564..ea79c81ac6 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -89,7 +89,7 @@ def calculate_batch_size_info( dist.get_world_size() // data_replication_degree ) != 0: raise ValueError( - f'Global batch size {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + f'Global batchsize {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.', From 9a251ec448a9da919ee52dd696a315ef27b44470 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Tue, 7 May 2024 18:31:12 -0700 Subject: [PATCH 06/12] Update config_utils.py adding co-authors Co-authored-by: Vitaliy Chiley Co-authored-by: Vitaliy Chiley Co-authored-by: Cheng Li Co-authored-by: Cheng Li <@cli99> --- llmfoundry/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index ea79c81ac6..3620b9f564 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -89,7 +89,7 @@ def calculate_batch_size_info( dist.get_world_size() // data_replication_degree ) != 0: raise ValueError( - f'Global batchsize {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + f'Global batch size {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.', From ae3bb9bc2fb0e7e51bafdee151e1f86a52e3ddae Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 7 May 2024 18:36:39 -0700 Subject: [PATCH 07/12] lint Co-authored-by: Vitaliy Chiley Co-authored-by: Vitaliy Chiley --- llmfoundry/models/mpt/modeling_mpt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 61d588c8ab..e24e8b052d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -440,7 +440,6 @@ def __init__(self, config: MPTConfig): def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: """Sets the block args.""" - if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], From 8479b966dbd6a083532415dc0e0ab3bf4512d189 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 7 May 2024 18:43:51 -0700 Subject: [PATCH 08/12] Adding_co_authors Co-authored-by: Vitaliy Chiley Co-authored-by: Vitaliy Chiley Co-authored-by: Cheng Li --- llmfoundry/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 3620b9f564..ea79c81ac6 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -89,7 +89,7 @@ def calculate_batch_size_info( dist.get_world_size() // data_replication_degree ) != 0: raise ValueError( - f'Global batch size {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + f'Global batchsize {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.', From eeff23d23357f6ec89b71b3cedc06ed86831cd4a Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Wed, 8 May 2024 10:03:53 -0700 Subject: [PATCH 09/12] Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index e24e8b052d..6a0dd60b6e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -1232,7 +1232,7 @@ def get_attention_flops(self, msl: int) -> int: msl (int): The batch sequence length. Returns: - attn_flopts (int): The attention flops. + attn_flops (int): The attention flops. """ return ( self.model.config.n_layers * 2 * 2 * From 8ee7409c7652637b1e47787ce78ae7b39ea5e0d2 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 8 May 2024 10:19:53 -0700 Subject: [PATCH 10/12] addressing comments --- llmfoundry/models/layers/blocks.py | 34 +++++++++++++----------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index d88f311811..c335a07824 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -67,7 +67,6 @@ def __init__( device: Optional[str] = None, no_bias: bool = False, use_pad_tok_in_ffn: bool = True, - extra_args_to_exclude_in_attn_class: Optional[Set[str]] = None, **kwargs: Any, ): if attn_config is None: @@ -85,24 +84,6 @@ def __init__( ffn_type = ffn_config['ffn_type'] ffn_has_norm = ffn_type in ffns_with_norm - # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs - base_args_to_exclude_in_attn_class = { - 'attn_type', - 'alibi', - 'attn_uses_sequence_id', - 'alibi_bias_max', - 'rope', - 'rope_theta', - 'rope_impl', - 'rope_dail_config', - 'rope_hf_config', - } - if extra_args_to_exclude_in_attn_class is None: - extra_args_to_exclude_in_attn_class = set() - self.args_to_exclude_in_attn_class = base_args_to_exclude_in_attn_class.union( - extra_args_to_exclude_in_attn_class, - ) - if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, @@ -162,6 +143,21 @@ def __init__( self.resid_ffn_dropout = nn.Dropout(resid_pdrop) self.use_pad_tok_in_ffn = use_pad_tok_in_ffn + @property + def args_to_exclude_in_attn_class(self): + # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs + return { + 'attn_type', + 'alibi', + 'attn_uses_sequence_id', + 'alibi_bias_max', + 'rope', + 'rope_theta', + 'rope_impl', + 'rope_dail_config', + 'rope_hf_config', + } + def forward( self, x: torch.Tensor, From 122862d0d60047be98509fa9ba96ca4b3b26afd5 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 8 May 2024 10:55:45 -0700 Subject: [PATCH 11/12] adding_co_authors Co-authored-by: Cheng Li --- llmfoundry/models/layers/blocks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c335a07824..3ff65fd8b3 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -100,6 +100,7 @@ def __init__( ) else: assert isinstance(attn_config['attn_type'], str) + # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs attn_config_subset_for_attn_class = { k: v for k, v in attn_config.items() @@ -145,7 +146,6 @@ def __init__( @property def args_to_exclude_in_attn_class(self): - # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs return { 'attn_type', 'alibi', @@ -251,6 +251,7 @@ def __init__( assert attn_config is not None assert isinstance(attn_config['attn_type'], str) + # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs attn_config_subset_for_attn_class = { k: v for k, v in attn_config.items() From d3493fb42f92318a6881e079cf3c5ecc96213308 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 8 May 2024 16:32:26 -0400 Subject: [PATCH 12/12] Update llmfoundry/utils/config_utils.py --- llmfoundry/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 92fc89cda8..9470ce2ac6 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -399,7 +399,7 @@ def calculate_batch_size_info( # Coming soon: this conversion math will be done inside Composer Trainer -def update_batch_size_info(cfg: DictConfig) -> DictConfig: +def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: data_replication_degree = 1 device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( cfg['global_train_batch_size'],