Skip to content

Commit

Permalink
better dtype handle in loading
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 16, 2024
1 parent ddec9e1 commit d9f190f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
5 changes: 4 additions & 1 deletion src/llamafactory/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def init_adapter(
raise ValueError("You can only use lora for quantized models.")

if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params in half precision.")
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
cast_trainable_params_to_fp32 = False
else:
logger.info("Upcasting trainable params to float32.")
Expand Down Expand Up @@ -122,6 +122,9 @@ def init_adapter(
else:
param.requires_grad_(False)

if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
model.vision_tower.requires_grad_(False)

logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))

if finetuning_args.finetuning_type == "lora":
Expand Down
1 change: 1 addition & 0 deletions src/llamafactory/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def load_model(
)
else:
param_stats = "all params: {:d}".format(all_param)

logger.info(param_stats)

if model_args.print_param_status:
Expand Down
17 changes: 10 additions & 7 deletions src/llamafactory/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled

from ..extras.logging import get_logger
Expand Down Expand Up @@ -66,13 +66,16 @@ def patch_config(
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)

if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn

init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]:
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())

if deepspeed_config() is None and not is_fsdp_enabled(): # set dtype and device map if not use deepspeed or fsdp
init_kwargs["torch_dtype"] = model_args.compute_dtype

if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map

Expand Down

8 comments on commit d9f190f

@hiyouga
Copy link
Owner Author

@hiyouga hiyouga commented on d9f190f May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, we expect that the trainable params should be float32 in mixed precision training, i.e., fp16 or bf16 is enabled.

The DeepSpeed or FSDP engine automatically casts the dtype of trainable params and non-trainable params in training, thus we do not need to set torch_dtype when initializing models under DeepSpeed or FSDP. The models should be loaded in the float32 type on the CPU. (It seems that DeepSpeed ZeRO-3 initializes models in the float16 or bfloat16 type on the CUDA)

Contrarily, if we do not use DeepSpeed or FSDP, the trainer cannot automatically cast the dtype of model params, we need to manually handle the dtype and device map when initializing models (DoRA initialization needs CUDA devices for float16 models). The models should be loaded in the float16 or bfloat16 type on the CUDA. Then we cast the trainable params to float32 for training stability.

Empirical results:

Load Dtype DeepSpeed/FSDP VRAM
float32 No 34GB
float16 No 17GB
float32 Yes 18GB
float16 Yes 18GB
  • model: Llama3 8B
  • lora target modules: q, k, v, o, gate, up, down (adapter weights are in float32)
  • device: NVIDIA A100 40GB
  • token batch size: 512
  • activation checkpointing: enabled
  • flash attention: torch SDPA attention

Related materials:
[1] huggingface/peft#1249
[2] huggingface/peft#1336
[3] huggingface/peft#1706
[4] huggingface/trl#1644

@hiyouga
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @BenjaminBossan

We have a small experiment on the dtype of adapter weights, there may be some useful information.

@hiyouga
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, using cast_trainable_params_to_fp32 (i.e., param.data = param.data.to(torch.float32)) may cause hanging in DeepSpeed ZeRO-3 and FSDP examples.

@BenjaminBossan
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping. Also pinging @pacman100

I'm not sure I 100% understand the conclusion. So do you think the changes in huggingface/peft#1706 are correct or is there something else we should do?

@hiyouga
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan Just to share my findings, I think the changes in huggingface/peft#1706 should be fine. The information above is merely experimental observations, and we don't have a clear conclusion (since DeepSpeed's behaviour is too complex to understand).

@hiyouga
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loading the model on CPU consumes large RAM in distributed training, so we rollback the dtype and device setting for DeepSpeed non-zero3 circumstances: 31a0564

@BenjaminBossan
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, thanks for letting me know. Don't hesitate to open an issue if you think we should adjust something on the PEFT side.

@hiyouga
Copy link
Owner Author

@hiyouga hiyouga commented on d9f190f Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.