Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save cpu mem by leveraging FSDP rank0 broadcasting #77

Merged
merged 14 commits into from
Aug 11, 2023
3 changes: 2 additions & 1 deletion configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
@dataclass
class train_config:
model_name: str="PATH/to/LLAMA/7B"
enable_fsdp: bool= False
enable_fsdp: bool=False
low_cpu_fsdp: bool=False
run_validation: bool=True
batch_size_training: int=4
num_epochs: int=3
Expand Down
38 changes: 32 additions & 6 deletions llama_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
LlamaConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
Expand Down Expand Up @@ -62,8 +63,10 @@
from torch.optim.lr_scheduler import StepLR
from pkg_resources import packaging
import torch
import torch.nn as nn
import torch.cuda.nccl as nccl
import torch.distributed as dist
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
lchu-ibm marked this conversation as resolved.
Show resolved Hide resolved
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


Expand All @@ -90,11 +93,32 @@ def main(**kwargs):
gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size

# Load the pre-trained model and setup its configuration
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
# for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
# this avoids cpu oom when loading large models like llama 70B, in which case
# model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
# overhead and currently requires latest nightly.
v = packaging.version.parse(torch.__version__)
verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
if not verify_latest_nightly:
raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
"please install latest nightly.")
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we figure out why torch.device("meta") init doesn't work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma for non-0 ranks, we are using torch.device("meta") init.

train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)
else:
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

Expand Down Expand Up @@ -127,14 +151,16 @@ def main(**kwargs):

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

model = FSDP(
model,
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=True,
lchu-ibm marked this conversation as resolved.
Show resolved Hide resolved
param_init_fn=None if rank == 0 else lambda module: module.to_empty(device=torch.device("cuda"), recurse=False),
)
if fsdp_config.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)
Expand Down
4 changes: 2 additions & 2 deletions model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg):
reader = FileSystemReader(load_dir)

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
checkpoint = model.state_dict()
checkpoint = {"model": model.state_dict()}
if rank == 0:
ck = checkpoint.keys()
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
Expand All @@ -78,7 +78,7 @@ def load_model_sharded(model, rank, cfg):
print(f"checkpoint after load_state_dict()")
ck = checkpoint.keys()
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
model.load_state_dict(checkpoint)
model.load_state_dict(checkpoint["model"])
if rank == 0:
print(f"Sharded state checkpoint loaded from {load_dir}")

Expand Down