Skip to content

Commit

Permalink
Extendability refactors (#1290)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jun 20, 2024
1 parent 4b1fecb commit 8241f9c
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 53 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

def build_finetuning_dataloader(
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
device_batch_size: Union[int, float],
dataset: Dict[str, Any],
num_workers: int,
drop_last: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def build_streams(streams: Optional[Dict[str, Any]] = None,):

def build_text_dataloader(
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
device_batch_size: Union[int, float],
dataset: Dict[str, Any],
drop_last: bool,
num_workers: int,
Expand Down
19 changes: 19 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,32 @@ def apply_ffn(
indices = None
if not self.use_pad_tok_in_ffn and attention_mask is not None:
assert unpad_input is not None
attention_mask = self.slice_attention_mask(attention_mask, seq_len)
m, indices, _, _ = unpad_input(m, attention_mask)
n = self.ffn(m)
if not self.use_pad_tok_in_ffn and attention_mask is not None:
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
return n

def slice_attention_mask(
self,
attention_mask: torch.ByteTensor,
seq_len: int,
) -> torch.ByteTensor:
"""Slice attention mask to the correct size.
Can be overridden by subclasses to apply different slicing logic.
Args:
attention_mask (torch.ByteTensor): The attention mask.
seq_len (int): The sequence length.
Returns:
torch.ByteTensor: The sliced attention mask.
"""
return attention_mask


class FusedNormAttentionNorm(nn.Module):

Expand Down
25 changes: 11 additions & 14 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,12 @@
check_alibi_support,
is_flash_v2_installed,
)

# NOTE: All utils are imported directly even if unused so that
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm, build_fc, build_ffn # type: ignore (see note)
from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)
from llmfoundry.models.utils.config_defaults import (
attn_config_defaults,
fc_type_defaults,
ffn_config_defaults,
init_config_defaults,
fc_type_defaults,
) # type: ignore (see note)
)


class MPTConfig(PretrainedConfig):
Expand Down Expand Up @@ -196,6 +186,13 @@ def _set_config_defaults(
)
return config

def validate_attention_config(self) -> None:
if 'seq_parallel_world_size' in self.attn_config and self.attn_config[
'seq_parallel_world_size'] is None:
del self.attn_config['seq_parallel_world_size']
if self.attn_config.get('seq_parallel_world_size', 1) > 1:
raise NotImplementedError('Sequence Parallelism is not supported.')

def _validate_config(self) -> None:
# set config defaults
self.attn_config = self._set_config_defaults(
Expand Down Expand Up @@ -336,5 +333,5 @@ def _validate_config(self) -> None:
raise ImportError(
'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6',
)
if (self.attn_config.get('seq_parallel_world_size', 1) or 1) > 1:
raise NotImplementedError('Sequence Parallelism is not supported.')

self.validate_attention_config()
7 changes: 6 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
# isort: off
from llmfoundry.models.layers.fc import fcs # type: ignore
from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ # type: ignore
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore
# isort: on

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -425,6 +426,10 @@ def __init__(self, config: MPTConfig):
log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

@property
def block_class(self) -> Type[MPTBlock]:
return MPTBlock

def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
"""Construct the nn.ModuleList with the Transformer blocks.
Expand All @@ -437,7 +442,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
block_args = self.extract_block_args(config.to_dict())

return nn.ModuleList([
MPTBlock(
self.block_class(
device=config.init_device,
**block_args,
) for _ in range(config.n_layers)
Expand Down
22 changes: 22 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,27 @@
description=_icl_datasets_description,
)

_config_transforms_description = (
"""The config_transforms registry is used to register functions that transform the training config
The config will be transformed before it is used anywhere else. Note: By default ALL registered transforms will be applied to the train config
and NONE to the eval config. Each transform should return the modified config.
Args:
cfg (Dict[str, Any]): The training config.
Returns:
cfg (Dict[str, Any]): The modified training config.
"""
)
config_transforms = create_registry(
'llmfoundry',
'config_transforms',
generic_type=Callable[[Dict[str, Any]], Dict[str, Any]],
entry_points=True,
description=_config_transforms_description,
)

__all__ = [
'loggers',
'callbacks',
Expand All @@ -245,4 +266,5 @@
'attention_implementations',
'fcs',
'icl_datasets',
'config_transforms',
]
6 changes: 6 additions & 0 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.registry import config_transforms
from llmfoundry.utils.builders import (
build_algorithm,
build_callback,
Expand Down Expand Up @@ -59,6 +60,11 @@
experimental_function,
)

config_transforms.register(
'update_batch_size_info',
func=update_batch_size_info,
)

__all__ = [
'build_algorithm',
'build_callback',
Expand Down
11 changes: 8 additions & 3 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def build_evaluators(
eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]],
*,
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
device_eval_batch_size: Union[int, float],
icl_seq_len: int,
icl_subset_num_batches: Optional[int],
) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
Expand All @@ -79,6 +79,10 @@ def build_evaluators(
logger_keys = []
eval_gauntlet_callback = None
if icl_tasks_config is not None:
if not isinstance(device_eval_batch_size, int):
raise ValueError(
'device_eval_batch_size should be an int for icl tasks.',
)
icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks_config,
eval_gauntlet_config,
Expand All @@ -95,7 +99,7 @@ def build_evaluators(
def build_eval_loaders(
eval_loader_config: Union[Dict[str, Any], List[Dict[str, Any]]],
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
device_eval_batch_size: Union[int, float],
) -> List[Evaluator]:
evaluators: List[Evaluator] = []
if isinstance(eval_loader_config, list):
Expand All @@ -122,7 +126,8 @@ def build_eval_loaders(
# Load the eval data to fail fast. metrics will get added
# later in add_metrics_to_eval_loaders, after the model is loaded
metric_names=[],
device_eval_microbatch_size=device_eval_batch_size,
# TODO: Fix type in Composer
device_eval_microbatch_size=device_eval_batch_size, # type: ignore
)
evaluators.append(eval_loader)
return evaluators
Expand Down
82 changes: 66 additions & 16 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.registry import config_transforms

log = logging.getLogger(__name__)

Expand All @@ -48,7 +49,7 @@ class EvalConfig:
# Eval Config required parameters:
models: List[Dict[str, Any]] = MISSING
max_seq_len: int = MISSING
device_eval_batch_size: int = MISSING
device_eval_batch_size: Union[int, float] = MISSING

# Eval Config optional parameters:
code_paths: Optional[List[str]] = None
Expand Down Expand Up @@ -101,7 +102,7 @@ class TrainConfig:
scheduler: Dict[str, Any] = MISSING
train_loader: Dict[str, Any] = MISSING
device_train_batch_size: Union[int, float] = MISSING
device_eval_batch_size: int = MISSING
device_eval_batch_size: Union[int, float] = MISSING
max_duration: Union[int, str] = MISSING
eval_interval: Union[int, str] = MISSING
max_seq_len: int = MISSING
Expand Down Expand Up @@ -160,7 +161,7 @@ class TrainConfig:
save_ignore_keys: Optional[List[str]] = None

# Dataloader
device_train_microbatch_size: Union[str, int] = 'auto'
device_train_microbatch_size: Union[str, int, float] = 'auto'
global_train_batch_size: Optional[int] = None

# Eval dataloader
Expand Down Expand Up @@ -238,12 +239,60 @@ def to_container(
T = TypeVar('T')


def apply_transforms_to_config(
cfg: Dict[str, Any],
transforms: Optional[Union[List[Callable[[Dict[str, Any]], Dict[str, Any]]],
List[str], str]],
) -> Dict[str, Any]:
"""Applies a list of transforms to a config.
Args:
cfg (Dict[str, Any]): The config to transform.
transforms (Optional[Union[List[Callable[[Dict[str, Any]], Dict[str, Any]]], List[str], str]]): A list of
transform functions or strings representing transform functions to apply to the config. If a single string
with the value ``all`` is provided, all registered transforms will be applied.
Returns:
Dict[str, Any]: The transformed config.
"""
if transforms is None or (
isinstance(transforms, list) and len(transforms) == 0
):
return cfg

transform_functions = []
if isinstance(transforms, list):
for transform in transforms:
if isinstance(transform, str):
transform_functions.append(config_transforms.get(transform))
elif callable(transform):
transform_functions.append(transform)
else:
raise ValueError(
f'Invalid transform: {transform}. Must be a string or callable.',
)
elif isinstance(transforms, str) and transforms == 'all':
transform_functions = [
config_transforms.get(transform)
for transform in config_transforms.get_all()
]
else:
raise ValueError(
f'Invalid transforms: {transforms}. Must be a list of strings or callables, or ``all``.',
)

for transform in transform_functions:
cfg = transform(cfg)

return cfg


def make_dataclass_and_log_config(
cfg: DictConfig,
dataclass_constructor: Callable[..., T],
dataclass_fields: Set[str],
transforms: Optional[List[Callable[[Dict[str, Any]], Dict[str,
Any]]]] = None,
transforms: Optional[Union[List[Callable[[Dict[str, Any]], Dict[str, Any]]],
List[str], str]] = None,
icl_tasks_required: bool = False,
) -> Tuple[Dict[str, Any], T]:
"""Converts a DictConfig to a dataclass and creates a logged config."""
Expand Down Expand Up @@ -281,8 +330,10 @@ def make_dataclass_and_log_config(
logged_cfg: Dict[str, Any] = copy.deepcopy(unstructured_config)

# Apply transforms to the unstructured config before constructing dataclass
for transform in transforms or []:
unstructured_config = transform(unstructured_config)
unstructured_config = apply_transforms_to_config(
unstructured_config,
transforms,
)

logged_cfg.update(unstructured_config, merge=True)

Expand Down Expand Up @@ -367,20 +418,20 @@ def calculate_batch_size_info(
data_replication_degree: int = 1,
) -> Tuple[Union[int, float], Union[int, float, Literal['auto']], Union[
int, Literal['auto']]]:
if dist.get_world_size() % data_replication_degree != 0:

world_size = dist.get_world_size()
if world_size % data_replication_degree != 0:
raise ValueError(
f'World size {dist.get_world_size()} is not divisible by data replication degree {data_replication_degree}.',
f'World size {world_size} is not divisible by data replication degree {data_replication_degree}.',
)
if global_batch_size % (
dist.get_world_size() // data_replication_degree
) != 0:
if global_batch_size % (world_size // data_replication_degree) != 0:
raise ValueError(
f'Global batchsize {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} '
f'Global batchsize {global_batch_size} is not divisible by {(world_size // data_replication_degree)=} '
+
'as a result, the batch size would be truncated, please adjust `global_batch_size` '
+ f'to be divisible by world size, {dist.get_world_size()}.',
+ f'to be divisible by world size, {world_size}.',
)
device_batch_size = global_batch_size / dist.get_world_size()
device_batch_size = global_batch_size / world_size
if device_batch_size == round(device_batch_size):
device_batch_size = round(device_batch_size)
if device_microbatch_size == 'auto':
Expand All @@ -401,7 +452,6 @@ def calculate_batch_size_info(
return device_batch_size, device_microbatch_size, device_grad_accum


# Coming soon: this conversion math will be done inside Composer Trainer
def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]:
data_replication_degree = 1
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(
Expand Down
Loading

0 comments on commit 8241f9c

Please sign in to comment.