Skip to content

Commit

Permalink
refactor attn layers (#240)
Browse files Browse the repository at this point in the history
* refactor attn layers

* enable cross attn

* adding unpacked flash and torch

* cleanup

* lint

* pass attn fn thu __init__

* refac

* mv bias subset

* pr cmts & enable non-causal alibi

* fool proof tests
  • Loading branch information
vchiley authored Mar 18, 2023
1 parent a1d063e commit 37e4f4c
Show file tree
Hide file tree
Showing 12 changed files with 503 additions and 702 deletions.
18 changes: 9 additions & 9 deletions examples/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from examples.llm.src.models.hf import (ComposerHFCausalLM,
ComposerHFPrefixLM, ComposerHFT5)
from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention,
alibi_bias)
from examples.llm.src.models.layers.flash_attention import (FlashAttention,
FlashMHA)
MultiheadAttention, alibi_bias, attn_bias, attn_bias_shape,
flash_attn_fn, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock
from examples.llm.src.models.mosaic_gpt import ComposerMosaicGPT, MosaicGPT
from examples.llm.src.tokenizer import (TOKENIZER_REGISTRY, HFTokenizer,
Expand All @@ -31,15 +30,16 @@
) from e

__all__ = [
'FlashAttention',
'FlashMHA',
'ComposerHFCausalLM',
'ComposerHFPrefixLM',
'ComposerHFT5',
'COMPOSER_MODEL_REGISTRY',
'TorchCausalAttention',
'FlashCausalAttention',
'TritonFlashCausalAttention',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
'MultiheadAttention',
'attn_bias_shape',
'attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
15 changes: 2 additions & 13 deletions examples/llm/scripts/export_for_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from composer.utils import get_device, maybe_create_object_store_from_uri
from omegaconf import OmegaConf as om

from examples.llm import TorchCausalAttention
from examples.llm.src.model_registry import COMPOSER_MODEL_REGISTRY


Expand Down Expand Up @@ -127,18 +126,8 @@ def main(cfg):
load_weights_only=True)
# replace flash/triton attention with torch causal attention
for idx in range(cfg.model.n_layers):
torch_causal_attn = TorchCausalAttention(cfg.model)
torch_causal_attn.mhsa.in_proj_weight = orig_model.model.transformer.blocks[
idx].causal_attn.mhsa.Wqkv.weight
torch_causal_attn.mhsa.in_proj_bias = orig_model.model.transformer.blocks[
idx].causal_attn.mhsa.Wqkv.bias
torch_causal_attn.mhsa.out_proj.weight = (
orig_model.model.transformer.blocks[idx].causal_attn.mhsa.
out_proj.weight)
torch_causal_attn.mhsa.out_proj.bias = orig_model.model.transformer.blocks[
idx].causal_attn.mhsa.out_proj.bias
export_model.model.transformer.blocks[
idx].causal_attn = torch_causal_attn
export_model.model.transformer.blocks[idx].attn.load_state_dict(
orig_model.model.transformer.blocks[idx].attn.state_dict())
else:
export_model = orig_model

Expand Down
17 changes: 8 additions & 9 deletions examples/llm/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,26 @@
from examples.llm.src.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
ComposerHFT5)
from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention,
alibi_bias)
from examples.llm.src.models.layers.flash_attention import (FlashAttention,
FlashMHA)
MultiheadAttention, alibi_bias, attn_bias, attn_bias_shape, flash_attn_fn,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock
from examples.llm.src.models.mosaic_gpt import ComposerMosaicGPT, MosaicGPT
from examples.llm.src.tokenizer import (TOKENIZER_REGISTRY, HFTokenizer,
LLMTokenizer)

__all__ = [
'build_text_denoising_dataloader',
'flash_attn_fn',
'triton_flash_attn_fn',
'MixtureOfDenoisersCollator',
'FlashAttention',
'FlashMHA',
'ComposerHFCausalLM',
'ComposerHFPrefixLM',
'ComposerHFT5',
'COMPOSER_MODEL_REGISTRY',
'TorchCausalAttention',
'FlashCausalAttention',
'TritonFlashCausalAttention',
'scaled_multihead_dot_product_attention',
'MultiheadAttention',
'attn_bias_shape',
'attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
17 changes: 8 additions & 9 deletions examples/llm/src/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
# SPDX-License-Identifier: Apache-2.0

from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention,
alibi_bias)
from examples.llm.src.models.layers.flash_attention import (FlashAttention,
FlashMHA)
MultiheadAttention, alibi_bias, attn_bias, attn_bias_shape, flash_attn_fn,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from examples.llm.src.models.layers.gpt_blocks import GPTMLP, GPTBlock

__all__ = [
'FlashAttention',
'FlashMHA',
'TorchCausalAttention',
'FlashCausalAttention',
'TritonFlashCausalAttention',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
'MultiheadAttention',
'attn_bias_shape',
'attn_bias',
'alibi_bias',
'GPTMLP',
'GPTBlock',
Expand Down
Loading

0 comments on commit 37e4f4c

Please sign in to comment.