Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Registry docs update #1323

Merged
merged 5 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,18 @@ We provide two commands currently:

Use `--help` on any of these commands for more information.

These commands can also help you understand what each registry is composed of, as each registry contains a docstring that will be printed out. The general concept is that each registry defines an interface, and components registered to that registry must implement that interface. If there is a part of the library that is not currently extendable, but you think it should be, please open an issue!

## How to register

There are a few ways to register a new component:

### Python entrypoints

You can specify registered components via a Python entrypoint if you are building your own package with registered components.
This would be the expected usage if you are building a large extension to LLM Foundry, and going to be overriding many components. Note that things registered via entrypoints will override components registered directly in code.

For example, the following would register the `WandBLogger` class, under the key `wandb`, in the `llm_foundry.loggers` registry:
For example, the following would register the `MyLogger` class, under the key `my_logger`, in the `llm_foundry.loggers` registry:

<!--pytest.mark.skip-->
```yaml
Expand Down Expand Up @@ -359,6 +362,7 @@ code_paths:
...
```

One of these would be the expected usage if you are building a small extension to LLM Foundry, only overriding a few components, and thus don't want to create an entire package.

# Learn more about LLM Foundry!

Expand Down
152 changes: 130 additions & 22 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,65 @@

from llmfoundry.utils.registry_utils import create_registry

_norm_description = (
'The norms registry is used to register classes that implement normalization layers.'
_norms_description = (
"""The norms registry is used to register classes that implement normalization layers.

One example of this is torch.nn.LayerNorm. See norm.py for examples.

Args:
normalized_shape Union[int, List[int], torch.Size]: The shape of the input tensor.
device: Optional[torch.device]: The device to use for the normalization layer.

Returns:
torch.nn.Module: The normalization layer.
"""
)
norms = create_registry(
'llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description,
description=_norms_description,
)
_fc_description = (
'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).'
+
'These classes should take in_features and out_features in as args, at a minimum.'

_fcs_description = (
"""The fcs registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).

See fc.py for examples.

Args:
in_features: int: The number of input features.
out_features: int: The number of output features.
kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer.

Returns:
torch.nn.Module: The fully connected layer.
"""
)
fcs = create_registry(
'llmfoundry',
'fcs',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_fc_description,
description=_fcs_description,
)

_ffns_description = (
'The ffns registry is used to register functions that build ffn layers.' +
'See ffn.py for examples.'
"""The ffns registry is used to register functions that build FFN layers.

These layers are generally composed of fc layers and activation functions.
One example is MPTMLP. See ffn.py for examples.

Args:
d_model: int: The size of the input and output tensors.
expansion_ratio: float: The expansion ratio for the hidden layer.
device: Optional[str]: The device to use for the layer.
bias: bool: Whether or not to include a bias term.
kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer.

Returns:
torch.nn.Module: The FFN layer.
"""
)
ffns = create_registry(
'llmfoundry',
Expand All @@ -43,8 +76,21 @@
)

_ffns_with_norm_description = (
'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.'
+ 'See ffn.py for examples.'
"""The ffns_with_norm registry is used to register functions that build FFN layers with normalization.

The resulting layer will have ._has_norm set on it.
One example is te.LayerNormMLP. See ffn.py for examples.

Args:
d_model: int: The size of the input and output tensors.
expansion_ratio: float: The expansion ratio for the hidden layer.
device: Optional[str]: The device to use for the layer.
bias: bool: Whether or not to include a bias term.
kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer.

Returns:
torch.nn.Module: The FFN layer.
"""
)
ffns_with_norm = create_registry(
'llmfoundry',
Expand All @@ -58,6 +104,16 @@
'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.'
+ 'See ffn.py for examples.'
)
_ffns_with_megablocks_description = (
"""The ffns_with_megablocks registry is used to register functions that build FFN layers using MegaBlocks.

The resulting layer will have ._uses_megablocks set on it.
One example is megablocks.layers.dmoe.dMoE. See ffn.py for examples.

Returns:
torch.nn.Module: The FFN layer.
"""
)
ffns_with_megablocks = create_registry(
'llmfoundry',
'ffns_with_megablocks',
Expand All @@ -67,8 +123,17 @@
)

_attention_classes_description = (
'The attention_classes registry is used to register classes that implement attention layers. See '
+ 'attention.py for expected constructor signature.'
"""The attention_classes registry is used to register classes that implement attention layers.

The kwargs are passed directly to the constructor of the class.
One example is GroupedQueryAttention. See attention.py for examples.

Args:
kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer.

Returns:
torch.nn.Module: The attention layer.
"""
)
attention_classes = create_registry(
'llmfoundry',
Expand All @@ -79,8 +144,29 @@
)

_attention_implementations_description = (
'The attention_implementations registry is used to register functions that implement the attention operation.'
+ 'See attention.py for expected function signature.'
"""The attention_implementations registry is used to register functions that implement the attention operation.

One example is 'flash'. See attention.py for examples.

Args:
query (torch.Tensor): The query tensor.
key (torch.Tensor): The key tensor.
value (torch.Tensor): The value tensor.
n_heads (int): The number of attention heads.
kv_n_heads (int): The number of attention heads for the key and value tensors.
past_key_value (Optional[tuple[torch.Tensor, torch.Tensor]]): The past key and value tensors.
softmax_scale (Optional[float]) = None
attn_bias (Optional[torch.Tensor]) = None
is_causal (bool) = False
dropout_p (float) = 0.0
training (bool) = True
needs_weights (bool) = False
kwargs: Dict[str, Any]: Additional keyword arguments the implementation accepts.

Returns:
tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
The output tensor, the attention weights, and the past key and value tensors.
"""
)
attention_implementations = create_registry(
'llmfoundry',
Expand All @@ -91,9 +177,17 @@
)

_param_init_fns_description = (
'The param_init_fns registry is used to register functions that initialize parameters.'
+
'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.'
"""The param_init_fns registry is used to register functions that initialize parameters.

These functions should take in a torch.nn.Module, additional kwargs, and initialize the parameters of the module.
Generally they can call generic_param_init_fn_ with an appropriate partial function. See param_init_fns.py for examples.

Note: These functions should take in arbitrary kwargs, and discard any they don't need.

Args:
module: torch.nn.Module: The module to initialize.
kwargs: Dict[str, Any]: Additional keyword arguments to use for initialization.
"""
)
param_init_fns = create_registry(
'llmfoundry',
Expand All @@ -103,9 +197,23 @@
description=_param_init_fns_description,
)

_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules.
These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents.
They should take in the module, init_div_is_residual, and div_is_residual arguments."""
_module_init_fns_description = (
"""The module_init_fns registry is used to register functions that initialize specific modules.

These functions should return True if they initialize the module, and False otherwise.
This allows them to be called without knowing their contents. They should take in the module and additional kwargs.
If multiple functions can initialize the module, the one that is registered first will be used, so it is recommended to
override an existing function if you want to change existing initialization behavior, and add new functions if you have new
layer types. See param_init_fns.py for details.

Args:
module: torch.nn.Module: The module to initialize.
kwargs: Dict[str, Any]: Additional keyword arguments to use for initialization.

Returns:
bool: Whether or not the module was initialized.
"""
)
module_init_fns = create_registry(
'llmfoundry',
'module_init_fns',
Expand Down
Loading
Loading