Skip to content

Commit

Permalink
Models registry (#1057)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and KuuCi committed Apr 18, 2024
1 parent 88215fd commit 9cb8d19
Show file tree
Hide file tree
Showing 27 changed files with 307 additions and 177 deletions.
2 changes: 0 additions & 2 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
Expand All @@ -53,7 +52,6 @@
'ComposerHFCausalLM',
'ComposerHFPrefixLM',
'ComposerHFT5',
'COMPOSER_MODEL_REGISTRY',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from torch.utils.data import DataLoader

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.utils.warnings import experimental
from llmfoundry.utils.warnings import experimental_class

log = logging.getLogger(__name__)


@experimental('CurriculumLearning callback')
@experimental_class('CurriculumLearning callback')
class CurriculumLearning(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
Expand Down
18 changes: 18 additions & 0 deletions llmfoundry/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,22 @@

from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
ComposerHFT5)
from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper,
FMAPIChatAPIEvalWrapper,
OpenAICausalLMEvalWrapper,
OpenAIChatAPIEvalWrapper)
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
from llmfoundry.registry import models

models.register('mpt_causal_lm', func=ComposerMPTCausalLM)
models.register('hf_causal_lm', func=ComposerHFCausalLM)
models.register('hf_prefix_lm', func=ComposerHFPrefixLM)
models.register('hf_t5', func=ComposerHFT5)
models.register('openai_causal_lm', func=OpenAICausalLMEvalWrapper)
models.register('fmapi_causal_lm', func=FMAPICasualLMEvalWrapper)
models.register('openai_chat', func=OpenAIChatAPIEvalWrapper)
models.register('fmapi_chat', func=FMAPIChatAPIEvalWrapper)

__all__ = [
'ComposerHFCausalLM',
Expand All @@ -15,4 +29,8 @@
'MPTModel',
'MPTForCausalLM',
'ComposerMPTCausalLM',
'OpenAICausalLMEvalWrapper',
'FMAPICasualLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
'FMAPIChatAPIEvalWrapper',
]
4 changes: 2 additions & 2 deletions llmfoundry/models/hf/hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
init_empty_weights)
from llmfoundry.utils.warnings import experimental
from llmfoundry.utils.warnings import experimental_class

__all__ = ['ComposerHFT5']


@experimental('ComposerHFT5')
@experimental_class('ComposerHFT5')
class ComposerHFT5(HuggingFaceModelWithZLoss):
"""Configures a :class:`.HuggingFaceModel` around a T5.
Expand Down
16 changes: 8 additions & 8 deletions llmfoundry/models/inference_api_wrapper/fmapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import logging
import os
import time
from typing import Dict

import requests
from omegaconf import DictConfig
from transformers import AutoTokenizer

from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
Expand All @@ -25,7 +25,7 @@ class FMAPIEvalInterface(OpenAIEvalInterface):
def block_until_ready(self, base_url: str):
"""Block until the endpoint is ready."""
sleep_s = 5
timout_s = 5 * 60 # At max, wait 5 minutes
timeout_s = 5 * 60 # At max, wait 5 minutes

ping_url = f'{base_url}/ping'

Expand All @@ -42,25 +42,25 @@ def block_until_ready(self, base_url: str):
time.sleep(sleep_s)
waited_s += sleep_s

if waited_s >= timout_s:
if waited_s >= timeout_s:
raise TimeoutError(
f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting'
)

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
is_local = model_cfg.pop('local', False)
def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer):
is_local = om_model_config.pop('local', False)
if is_local:
base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT',
'http://0.0.0.0:8080/v2')
model_cfg['base_url'] = base_url
om_model_config['base_url'] = base_url
self.block_until_ready(base_url)

if 'base_url' not in model_cfg:
if 'base_url' not in om_model_config:
raise ValueError(
'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper'
)

super().__init__(model_cfg, tokenizer)
super().__init__(om_model_config, tokenizer)


class FMAPICasualLMEvalWrapper(FMAPIEvalInterface, OpenAICausalLMEvalWrapper):
Expand Down
8 changes: 5 additions & 3 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Optional
from typing import Any, Optional

import torch
from composer.core.types import Batch
from composer.metrics import InContextLearningMetric
from composer.models import ComposerModel
from omegaconf import DictConfig
from torchmetrics import Metric
from transformers import AutoTokenizer

from llmfoundry.metrics import DEFAULT_CAUSAL_LM_EVAL_METRICS
from llmfoundry.utils.builders import build_metric


class InferenceAPIEvalWrapper(ComposerModel):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer):
from llmfoundry.utils.builders import build_metric

self.tokenizer = tokenizer
self.labels = None
eval_metrics = [
Expand Down
28 changes: 16 additions & 12 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from composer.core.types import Batch
from composer.utils.import_helpers import MissingConditionalImportError
from omegaconf import DictConfig
from transformers import AutoTokenizer

log = logging.getLogger(__name__)
Expand All @@ -34,8 +35,9 @@

class OpenAIEvalInterface(InferenceAPIEvalWrapper):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
def __init__(self, om_model_config: DictConfig,
tokenizer: AutoTokenizer) -> None:
super().__init__(om_model_config, tokenizer)
try:
import openai
except ImportError as e:
Expand All @@ -45,7 +47,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
conda_channel='conda-forge') from e

api_key = os.environ.get('OPENAI_API_KEY')
base_url = model_cfg.get('base_url')
base_url = om_model_config.get('base_url')
if base_url is None:
# Using OpenAI default, where the API key is required
if api_key is None:
Expand All @@ -61,10 +63,10 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
api_key = 'placeholder' # This cannot be None

self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
if 'version' in model_cfg:
self.model_name = model_cfg['version']
if 'version' in om_model_config:
self.model_name = om_model_config['version']
else:
self.model_name = model_cfg['name']
self.model_name = om_model_config['name']

def generate_completion(self, prompt: str, num_tokens: int):
raise NotImplementedError()
Expand Down Expand Up @@ -109,17 +111,18 @@ def try_generate_completion(self, prompt: str, num_tokens: int):

class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
def __init__(self, om_model_config: DictConfig,
tokenizer: AutoTokenizer) -> None:
super().__init__(om_model_config, tokenizer)

self.generate_completion = lambda prompt, num_tokens: self.client.chat.completions.create(
model=self.model_name,
messages=[{
'role':
'system',
'content':
model_cfg.get('system_role_prompt',
'Please complete the following text: ')
om_model_config.get('system_role_prompt',
'Please complete the following text: ')
}, {
'role': 'user',
'content': prompt
Expand Down Expand Up @@ -244,8 +247,9 @@ def process_result(self, completion: Optional['ChatCompletion']):

class OpenAICausalLMEvalWrapper(OpenAIEvalInterface):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
def __init__(self, om_model_config: DictConfig,
tokenizer: AutoTokenizer) -> None:
super().__init__(om_model_config, tokenizer)
self.generate_completion = lambda prompt, num_tokens: self.client.completions.create(
model=self.model_name,
prompt=prompt,
Expand Down
21 changes: 0 additions & 21 deletions llmfoundry/models/model_registry.py

This file was deleted.

4 changes: 2 additions & 2 deletions llmfoundry/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from composer.optim import ComposerScheduler, LinearScheduler
from composer.optim.scheduler import _convert_time

from llmfoundry.utils.warnings import experimental
from llmfoundry.utils.warnings import experimental_class

__all__ = ['InverseSquareRootWithWarmupScheduler']

Expand All @@ -34,7 +34,7 @@ def _raise_if_units_dur(time: Union[str, Time], name: str) -> None:
raise ValueError(f'{name} cannot be in units of "dur".')


@experimental('InverseSquareRootWithWarmupScheduler')
@experimental_class('InverseSquareRootWithWarmupScheduler')
class InverseSquareRootWithWarmupScheduler(ComposerScheduler):
r"""Inverse square root LR decay with warmup and optional linear cooldown.
Expand Down
12 changes: 12 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from composer.core import Algorithm, Callback, DataSpec
from composer.loggers import LoggerDestination
from composer.models import ComposerModel
from composer.optim import ComposerScheduler
from omegaconf import DictConfig
from torch.optim import Optimizer
Expand Down Expand Up @@ -83,6 +84,15 @@
entry_points=True,
description=_schedulers_description)

_models_description = """The models registry is used to register classes that implement the ComposerModel interface. The model
constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`.
Note: This will soon be updated to take in named kwargs instead of a config directly."""
models = create_registry('llmfoundry',
'models',
generic_type=Type[ComposerModel],
entry_points=True,
description=_models_description)

_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take
a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec."""
dataloaders = create_registry(
Expand All @@ -106,5 +116,7 @@
'optimizers',
'algorithms',
'schedulers',
'models',
'metrics',
'dataloaders',
]
58 changes: 57 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
import functools
import logging
import os
import re
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import (Any, ContextManager, Dict, Iterable, List, Optional, Tuple,
Union)

import torch
from composer.core import Algorithm, Callback, Evaluator
from composer.datasets.in_context_learning_evaluation import \
get_icl_task_dataloader
from composer.loggers import LoggerDestination
from composer.models import ComposerModel
from composer.optim.scheduler import ComposerScheduler
from composer.utils import dist
from omegaconf import DictConfig, ListConfig
Expand All @@ -39,6 +42,7 @@
'build_optimizer',
'build_scheduler',
'build_tokenizer',
'build_composer_model',
'build_metric',
]

Expand Down Expand Up @@ -155,6 +159,58 @@ def build_icl_data_and_gauntlet(
return icl_evaluators, logger_keys, eval_gauntlet_cb


def build_composer_model(
name: str,
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
init_context: Optional[ContextManager] = None,
master_weights_dtype: Optional[str] = None,
) -> ComposerModel:
"""Builds a ComposerModel from the registry.
Args:
name (str): Name of the model to build.
cfg (DictConfig): Configuration for the model.
tokenizer (PreTrainedTokenizerBase): Tokenizer to use.
init_context (Optional[ContextManager], optional): Context manager to use for initialization. Defaults to None.
master_weights_dtype (Optional[str], optional): Master weights dtype. Defaults to None.
Returns:
ComposerModel: _description_
"""
if init_context is None:
init_context = contextlib.nullcontext()

with init_context:
model = construct_from_registry(
name=name,
registry=registry.models,
pre_validation_function=ComposerModel,
post_validation_function=None,
kwargs={
'om_model_config': cfg,
'tokenizer': tokenizer
},
)

str_dtype_to_torch_dtype = {
'f16': torch.float16,
'float16': torch.float16,
'bf16': torch.bfloat16,
'bfloat16': torch.bfloat16,
}

if master_weights_dtype is not None:
if master_weights_dtype not in str_dtype_to_torch_dtype:
raise ValueError(
f'Invalid master_weights_dtype: {master_weights_dtype}. ' +
f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.')
dtype = str_dtype_to_torch_dtype[master_weights_dtype]
model = model.to(dtype=dtype)

return model


def build_callback(
name: str,
kwargs: Optional[Dict[str, Any]] = None,
Expand Down
Loading

0 comments on commit 9cb8d19

Please sign in to comment.