Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF FSDP wrap BLOOM and OPT as well #83

Merged
merged 12 commits into from
Jan 27, 2023
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Data
my-copy-c4/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
212 changes: 157 additions & 55 deletions llm/src/hf_causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,187 @@
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
# helper functions from https://github.com/CarperAI/trlx/blob/main/trlx/utils/modeling.py
# which is MIT licensed

import functools
from typing import Any, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from composer.metrics.nlp import LanguageCrossEntropy, Perplexity
from composer.models.base import ComposerModel
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2LMHeadModel
from transformers.models.gpt_neo.modeling_gpt_neo import (GPTNeoBlock,
GPTNeoForCausalLM)
from transformers.models.gpt_neox.modeling_gpt_neox import (GPTNeoXForCausalLM,
GPTNeoXLayer)

_SUPPORTED_HF_MODELS = (
GPT2LMHeadModel,
GPTNeoForCausalLM,
GPTNeoXForCausalLM,
)

_HF_MODEL_BLOCKS = (
GPT2Block,
GPTNeoBlock,
GPTNeoXLayer,
)


def prepare_hf_causal_lm_model_for_fsdp(model):
assert isinstance(model, _SUPPORTED_HF_MODELS)
# When using the HF Causal LM models,
from composer.metrics.nlp import HFCrossEntropy, LanguageCrossEntropy, Perplexity
from composer.models.huggingface import HuggingFaceModel
from omegaconf import DictConfig
from torch import Tensor
from torchmetrics import Metric
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


# helper functions

def rhasattr(obj: Any, attr: str):
"""A chain-able attribute version of hasattr.

For example, to check if
`obj` has the attribute `foo.bar.baz`, you can use:
`rhasattr(obj, "foo.bar.baz")`
Reference: https://stackoverflow.com/a/67303315
"""
_nested_attrs = attr.split(".")
_curr_obj = obj
for _a in _nested_attrs[:-1]:
if hasattr(_curr_obj, _a):
_curr_obj = getattr(_curr_obj, _a)
else:
return False
return hasattr(_curr_obj, _nested_attrs[-1])


def rgetattr(obj: Any, attr: str, *args: List[Any]) -> object:
"""A chain-able attribute version of getattr.

For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
`rgetattr(obj, "foo.bar.baz")`
Reference: https://stackoverflow.com/a/31174427
"""

def _getattr(obj: Any, attr: str):
return getattr(obj, attr, *args)

return functools.reduce(_getattr, [obj] + attr.split("."))


def findattr(obj: Any, attrs: Tuple[str]) -> Union[object, None]:
for attr in attrs:
if rhasattr(obj, attr):
return rgetattr(obj, attr)
return None


def hf_get_causal_base_model(model: AutoModelForCausalLM) -> torch.nn.Module:
"""Returns the causal decoder backbone of the specified HuggingFace model.

NOTE: Different model configurations have different causal decoder attribute
names.
- transformer: (GPT2LMHeadModel, GPTJConfig)
- model.decoder: (OPTConfig, BloomConfig)
- gpt_neox: (GPTNeoXConfig)
"""
decoder_attrs = ("transformer", "model.decoder", "gpt_neox")
return findattr(model, decoder_attrs)


def hf_get_lm_head(model: AutoModelForCausalLM) -> torch.nn.Module:
"""Returns the lm head of the specified HuggingFace model.

NOTE: Different model configurations have different `lm_head` attribute names.
- lm_head: (GPT2LMHeadModel, BloomForCausalLM)
- embed_out: (GPTNeoXForCausalLM)
"""
return model.get_output_embeddings()


def hf_get_causal_hidden_layers(model: torch.nn.Module) -> Tuple[torch.nn.Module]:
"""Returns the hidden layers of the specified model.

NOTE: Different model configurations have different hidden layer attribute names.
- transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
- model.decoder.layers: (OPTForCausalLM)
- gpt_neox.layers: (GPTNeoXForCausalLM)
"""
hidden_layers_attrs = (
"transformer.h",
"model.decoder.layers",
"gpt_neox.layers",
)
return findattr(model, hidden_layers_attrs)


def hf_get_tied_embedding_weights(model: torch.nn.Module) -> torch.nn.Module:
"""Returns the embeddings, which are weight tied layers.

NOTE: Different model configurations have different embedding attribute names.
- wte: (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM)
- word_embeddings: (BloomForCausalLM)
- embed_tokens: (OPTForCausalLM)
- GPT NeoX doesn't weight tie
"""
tied_embedding_attrs = (
"wte",
"word_embeddings",
"embed_tokens",
)
return findattr(model, tied_embedding_attrs)

# /end helper functions


def prepare_hf_causal_lm_model_for_fsdp(model: AutoModelForCausalLM):
"""FSDP wrap a HuggingFace model

Wrap any model for FSDP which follows one of the 3 existing conventions from
HuggingFace for decoder-only LLMs.
"""
causal_base_model = hf_get_causal_base_model(model)
model_block = hf_get_causal_hidden_layers(model)[0]
block_type = type(model_block)
lm_head = hf_get_causal_hidden_layers(model)
tied_embeddings = hf_get_tied_embedding_weights(causal_base_model)
modules = [causal_base_model, model_block, block_type, lm_head, tied_embeddings]
if not all(module is not None for module in modules):
raise ValueError('Unable to FSDP-wrap this model! It does not follow \
common layer/weight naming conventions.')
# When using the HF LM models,
# the weights of the self.lm_head and self.transformer.wte are tied.
# This tying occurs inside the `self.post_init()` function.
# This is a hurdle for FSDP because they need to be in the same FSDP block
# These lines ensures that both modules stay together in the top-most block
model.transformer._fsdp_wrap = False
model.transformer.wte._fsdp_wrap = False
model.lm_head._fsdp_wrap = False
# These lines ensures that both modules stay together in the top-most block when
# the model has this tying enabled (almost all do; this property defaults to True)
if model.config.tie_word_embeddings:
causal_base_model._fsdp_wrap = False
tied_embeddings._fsdp_wrap = False
lm_head._fsdp_wrap = False

# FSDP Wrap and Activation Checkpoint every GPT2Block
model.fsdp_wrap_fn = lambda module: isinstance(module, _HF_MODEL_BLOCKS)
# FSDP Wrap and Activation Checkpoint every model block
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
model.activation_checkpointing_fn = lambda module: isinstance(
module, _HF_MODEL_BLOCKS)
module, block_type)


class ComposerHFCausalLM(ComposerModel):

def __init__(self, cfg):
super().__init__()
class ComposerHFCausalLM(HuggingFaceModel):
def __init__(self, cfg: DictConfig):
config = AutoConfig.from_pretrained(cfg.hf_config_name_or_path)
self.model = AutoModelForCausalLM.from_config(config)
self.train_metrics = {
'LanguageCrossEntropy': LanguageCrossEntropy(config.vocab_size),
'Perplexity': Perplexity(),
}
self.eval_metrics = {
'LanguageCrossEntropy': LanguageCrossEntropy(config.vocab_size),
'Perplexity': Perplexity(),
}
prepare_hf_causal_lm_model_for_fsdp(self.model)

def get_targets(self, batch):
tokenizer = AutoTokenizer.from_pretrained(cfg.hf_config_name_or_path)

if cfg.pretrained:
model = AutoModelForCausalLM.from_pretrained(cfg.hf_config_name_or_path, config=config)
metrics = [HFCrossEntropy(), Perplexity()]
else:
model = AutoModelForCausalLM.from_config(config)
metrics = [LanguageCrossEntropy(len(tokenizer)), Perplexity()]

prepare_hf_causal_lm_model_for_fsdp(model)

super().__init__(model=model, tokenizer=tokenizer, metrics=metrics, use_logits=True)

def get_targets(self, batch: dict):
targets = torch.roll(batch['labels'], shifts=-1)
targets[:, -1] = -100
return targets

def forward(self, batch):
def forward(self, batch: dict):
return self.model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'].bool()).logits

def eval_forward(self, batch, outputs=None):
def eval_forward(self, batch: dict, outputs: Optional[Tensor] = None):
return outputs if outputs is not None else self.forward(batch)

def loss(self, outputs, batch):
def loss(self, outputs: Tensor, batch: dict):
targets = self.get_targets(batch)
return F.cross_entropy(outputs.view(-1, outputs.size(-1)),
targets.view(-1),
ignore_index=-100)

def get_metrics(self, is_train=False):
return self.train_metrics if is_train else self.eval_metrics

def update_metric(self, batch, outputs, metric):
def update_metric(self, batch: dict, outputs: Tensor, metric: Metric):
outputs = outputs.view(-1, outputs.size(-1))
targets = self.get_targets(batch).view(-1)
metric.update(outputs, targets)
3 changes: 3 additions & 0 deletions llm/yamls/hf_causal_lm/gpt-neo-125m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ run_name: gpt-neo-125m
model:
name: hf_causal_lm
hf_config_name_or_path: EleutherAI/gpt-neo-125M
device: cpu # this is where to initialize weights; cpu memory or meta tensor
pretrained: false # false: only use the architecture; true: initialize with pretrained weights

# Tokenizer
tokenizer:
Expand All @@ -19,6 +21,7 @@ tokenizer:
tokenizer_name: *tokenizer_name
max_seq_len: *max_seq_len


# Dataloaders
train_loader:
name: text
Expand Down
2 changes: 2 additions & 0 deletions llm/yamls/hf_causal_lm/gpt2-small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ run_name: gpt2-small
model:
name: hf_causal_lm
hf_config_name_or_path: gpt2
device: cpu # this is where to initialize weights; cpu memory or meta tensor
pretrained: false # false: only use the architecture; true: initialize with pretrained weights

# Tokenizer
tokenizer:
Expand Down
115 changes: 115 additions & 0 deletions llm/yamls/hf_causal_lm/opt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
data_local: &data_local ./my-copy-c4
data_remote: &data_remote null
tokenizer_name: &tokenizer_name facebook/opt-2.7b
max_seq_len: &max_seq_len 256
global_seed: &global_seed 17

# Run Name
run_name: opt-2.7b

# Model
model:
name: hf_causal_lm
hf_config_name_or_path: facebook/opt-2.7b
device: cpu
pretrained: true

# Tokenizer
tokenizer:
type: hftokenizer
args:
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}

# Dataloaders
train_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: train
shuffle: true
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
group_method: concat
shuffle_seed: ${global_seed}
drop_last: true
num_workers: 8

eval_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: val
shuffle: false
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
group_method: truncate
shuffle_seed: ${global_seed}
drop_last: false
num_workers: 8

# Optimization
scheduler:
name: cosine_with_warmup
t_warmup: 100ba
alpha_f: 0.1

optimizer:
name: decoupled_adamw
lr: 6.0e-4
betas:
- 0.9
- 0.95
eps: 1.0e-08
weight_decay: 0.0

algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1.0

max_duration: 4800ba # ~ 2.5B tokens
eval_interval: 500ba
global_train_batch_size: 256

# System
seed: 17
device_eval_batch_size: 4
device_train_microbatch_size: 4
# device_train_microbatch_size: auto
precision: amp_bf16

# FSDP
fsdp_config:
sharding_strategy: FULL_SHARD
min_params: 1e8
mixed_precision: DEFAULT
activation_checkpointing: false
activation_cpu_offload: false
verbose: true

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
speed_monitor:
window_size: 10
lr_monitor: {}
memory_monitor: {}

# loggers:
# wandb: {}

# Checkpoint to local filesystem or remote object store
# save_interval: 500ba
# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
# save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Load from local filesystem or remote object store
# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt
# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt