From 0bfa6f7aa56c8191d9c0724d89e4ae4dba2c0ff0 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Fri, 19 Jul 2024 03:00:57 +0000 Subject: [PATCH] Add Ascend NPU support 1. add Ascend NPU backend support 2. refactor func load_model in src/axolotl/utils/models.py 3. refactor load_in_8bit as a kwarg --- src/axolotl/utils/config/__init__.py | 25 +- src/axolotl/utils/distributed.py | 43 +- src/axolotl/utils/models.py | 955 +++++++++++++++------------ 3 files changed, 577 insertions(+), 446 deletions(-) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index ed165e89ca..6664201330 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -7,6 +7,7 @@ import torch from transformers.utils import is_torch_bf16_gpu_available +from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.config.models.input.v0_4_1 import ( @@ -28,8 +29,10 @@ def get_device(): if torch.backends.mps.is_available(): return "mps" + if is_torch_npu_available(): + return f"npu:{cfg.local_rank}" - raise SystemError("No CUDA/mps device found") + raise SystemError("No CUDA/mps/npu device found") except Exception: # pylint: disable=broad-exception-caught return "cpu" @@ -39,6 +42,8 @@ def get_device(): else: if cfg.device.startswith("cuda"): cfg.device_map = {"": torch.cuda.current_device()} + elif cfg.device.startswith("npu"): + cfg.device_map = {"": torch.npu.current_device()} else: cfg.device_map = {"": cfg.device} @@ -91,6 +96,24 @@ def normalize_config(cfg): if cfg.bf16: cfg.fp16 = True cfg.bf16 = False + elif cfg.device.startswith("npu"): + if cfg.load_in_8bit or cfg.load_in_4bit: + LOG.warn("Quantification is currently not supported in npu, disabling for this configuration.") + cfg.load_in_8bit = False + cfg.load_in_4bit = False + + if cfg.tf32: + LOG.warn("tf32 dtype is currently not supported in npu, disabling for this configuration.") + cfg.tf32 = False + + if cfg.bf16: + LOG.warn("bf16 is currently not supported in npu, casting to fp16.") + cfg.fp16 = True + cfg.bf16 = False + + if "bit" in cfg.optimizer: + LOG.error("{} is currently not supported in npu, choose another one.".format(cfg.optimizer)) + else: torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False if cfg.bf16: diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index ecb1bcc9ec..10a783f262 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -9,9 +9,24 @@ import torch import torch.distributed as dist from accelerate import PartialState +from transformers.utils.import_utils import ( + is_torch_npu_available, + is_torch_cuda_available, + is_torch_mps_available +) distributed_state = None # pylint: disable=invalid-name +def get_device(): + device = torch.device("cpu") + if is_torch_cuda_available(): + device = torch.device("cuda") + elif is_torch_mps_available(): + device = torch.device("mps") + elif is_torch_npu_available(): + device = torch.device("npu") + return device + def is_distributed(): """ @@ -83,12 +98,11 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n Returns: - A list of computed values from all ranks if on the gathering rank, otherwise None. """ + device = get_device() value_scalar = fn() if not is_distributed(): return [value_scalar] - value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() - ).float() + value_tensor = torch.tensor(value_scalar, device=device).float() if not is_main_process(): dist.gather(value_tensor, dst=0) @@ -111,13 +125,14 @@ def broadcast_dict(vals: dict): if not is_distributed(): return vals + device = get_device() if is_main_process(): data_byte = pickle.dumps(vals) - data_tensor = torch.ByteTensor(list(data_byte)).to("cuda") - data_size = torch.IntTensor([len(data_byte)]).to("cuda") + data_tensor = torch.ByteTensor(list(data_byte)).to(device) + data_size = torch.IntTensor([len(data_byte)]).to(device) else: - data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda") - data_size = torch.IntTensor([0]).to("cuda") + data_tensor = torch.empty([1024], dtype=torch.uint8, device=device) + data_size = torch.IntTensor([0]).to(device) dist.broadcast(data_size, 0) if not is_main_process(): @@ -146,15 +161,12 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name Returns: - The computed value (int or float). """ + device = get_device() if is_main_process(): value_scalar = fn() - value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() - ).float() + value_tensor = torch.tensor(value_scalar, device=device).float() else: - value_tensor = torch.tensor( - 0.0, device=torch.cuda.current_device() - ) # Placeholder tensor + value_tensor = torch.tensor(0.0, device=device) # Placeholder tensor # Broadcast the tensor to all processes. barrier() @@ -178,10 +190,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name Returns: - A list of computed values from all ranks if on the gathering rank, otherwise None. """ + device = get_device() value_scalar = fn() - value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() - ).float() + value_tensor = torch.tensor(value_scalar, device=device).float() # Placeholder tensor for gathering results if is_main_process(): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6185f0102f..9f7ac5f71c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -45,7 +45,7 @@ from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import chat_templates from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import zero_only +from axolotl.utils.distributed import zero_only, get_device from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant @@ -297,608 +297,705 @@ def load_tokenizer(cfg): return tokenizer -def load_model( - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - inference: bool = False, - reference_model: bool = False, -) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - """ - Load a model for a given configuration and tokenizer. - """ - base_model = cfg.base_model - model_type = cfg.type_of_model - model_config = load_model_config(cfg) +def get_device_count(): + device = get_device() + if "cuda" in device.__str__(): + return torch.cuda.device_count() + elif "npu" in device.__str__(): + return torch.npu.device_count() + return 1 + + +class ModelLoader: + def __init__(self, cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, + inference: bool = False, + reference_model: bool = False,) -> None: + self.cfg = cfg + self.tokenizer = tokenizer + self.inference: bool = inference + self.reference_model: bool = reference_model + + # init model kwargs + self.model_kwargs: Dict[str, Any] = {} + if cfg.model_kwargs: + for key, val in cfg.model_kwargs.items(): + self.model_kwargs[key] = val + + # init model config + self.model_config = load_model_config(cfg) + # init model device + self.device = get_device() + self.model_type = cfg.type_of_model + self.base_model = cfg.base_model + + + def apply_patches(self) -> None: + if self.cfg.gradient_checkpointing == "unsloth": + transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper + + if self.cfg.flash_attention: + self.patch_attention() + + if self.cfg.sample_packing and self.cfg.s2_attention: + raise ValueError( + "Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not currently support sample packing." + ) - # TODO refactor as a kwarg - load_in_8bit = cfg.load_in_8bit + if ( + self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES + and self.cfg.flash_attention + and self.cfg.sample_packing + ): + patch_for_multipack(self.cfg.model_config_type, model_name=self.cfg.base_model) - if cfg.gradient_checkpointing == "unsloth": - transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper + if self.cfg.is_llama_derived_model: + self.patch_loss() + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - if hasattr(model_config, "model_type") and model_config.model_type == "btlm": - if cfg.flash_attention: - from axolotl.monkeypatch.btlm_attn_hijack_flash import ( - replace_btlm_attn_with_flash_attn, - ) + patch_self_attn_lora() + elif self.cfg.is_llama_derived_model: + self.patch_llama_derived_model() - replace_btlm_attn_with_flash_attn(cfg.base_model) + if ( + self.cfg.model_config_type == "mistral" + and self.cfg.flash_attention + and self.cfg.sample_packing + ): + self.patch_mistral_derived_model() - if ( - hasattr(model_config, "model_type") - and model_config.model_type == "stablelm_epoch" - ): - if cfg.flash_attention and cfg.sample_packing: - from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( - replace_stablelm_attn_with_flash_attn, - ) - replace_stablelm_attn_with_flash_attn(cfg.base_model) + def patch_attention(self) -> None: + if hasattr(self.model_config, "model_type"): + if self.model_config.model_type == "btlm": + from axolotl.monkeypatch.btlm_attn_hijack_flash import ( + replace_btlm_attn_with_flash_attn, + ) - if cfg.sample_packing and cfg.s2_attention: - raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." - ) + replace_btlm_attn_with_flash_attn(self.cfg.base_model) - if ( - cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and cfg.flash_attention - and cfg.sample_packing - ): - patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) + if self.model_config.model_type == "stablelm_epoch" and self.cfg.sample_packing: + from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( + replace_stablelm_attn_with_flash_attn, + ) - if cfg.is_llama_derived_model: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_llama_cross_entropy, - patch_llama_rms_norm, + replace_stablelm_attn_with_flash_attn(self.cfg.base_model) + + + def patch_loss(self) -> None: + """ + Patch loss functions + """ + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_llama_cross_entropy, + patch_llama_rms_norm, + ) + + if self.cfg.flash_attn_cross_entropy: + patch_llama_cross_entropy() + if self.cfg.flash_attn_rms_norm: + patch_llama_rms_norm() + elif self.cfg.unsloth_rms_norm: + from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm + + patch_unsloth_layernorm() + if self.cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import ( + integrate_cross_entropy_loss_patch, ) - if cfg.flash_attn_cross_entropy: - patch_llama_cross_entropy() - if cfg.flash_attn_rms_norm: - patch_llama_rms_norm() - elif cfg.unsloth_rms_norm: - from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm - - patch_unsloth_layernorm() - if cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import ( - integrate_cross_entropy_loss_patch, - ) + integrate_cross_entropy_loss_patch() - integrate_cross_entropy_loss_patch() - if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - patch_self_attn_lora() - elif cfg.is_llama_derived_model: - # Modify all llama derived models in one block + def patch_llama_derived_model(self) -> None: + """ + Modify all llama derived models in one block + """ - if cfg.flash_attention: + if self.cfg.flash_attention: from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, ) - if cfg.sample_packing: - if cfg.device not in ["mps", "cpu"] and not inference: + if self.cfg.sample_packing: + if self.cfg.device not in ["mps", "cpu"] and not self.inference: LOG.info("patching with flash attention for sample packing") replace_llama_attn_with_flash_attn( packed=True, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, ) - elif cfg.s2_attention: + elif self.cfg.s2_attention: LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( packed=False, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, use_shifted_sparse_attn=True, ) - elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm: + elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: replace_llama_attn_with_flash_attn( packed=False, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, ) - elif cfg.xformers_attention: + elif self.cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention, ) LOG.info("patching with xformers attention") hijack_llama_attention() - elif cfg.sample_packing: + elif self.cfg.sample_packing: from axolotl.monkeypatch.llama_patch_multipack import ( hijack_llama_prepare_4d_mask, ) LOG.info("patching llama _prepare_4d_causal_attention_mask*") hijack_llama_prepare_4d_mask() - elif cfg.s2_attention: + elif self.cfg.s2_attention: raise NotImplementedError( "Shifted-sparse attention not currently implemented without flash attention." ) - if cfg.unsloth_cross_entropy_loss: + if self.cfg.unsloth_cross_entropy_loss: from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch integrate_cross_entropy_loss_patch() - if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora patch_self_attn_lora() - # Modify mistral derived models - if ( - cfg.model_config_type == "mistral" - and cfg.flash_attention - and cfg.sample_packing - ): + if self.cfg.sample_packing and not self.inference: + from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask + + LOG.info("patching _expand_mask") + hijack_expand_mask() + + + def patch_mistral_derived_model(self) -> None: from axolotl.monkeypatch.mistral_attn_hijack_flash import ( replace_mistral_attn_with_flash_attn, ) LOG.info("patching mistral with flash attention") - replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) - - if cfg.is_llama_derived_model and cfg.sample_packing and not inference: - from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask - - LOG.info("patching _expand_mask") - hijack_expand_mask() - - model_kwargs: Dict[str, Any] = {} + replace_mistral_attn_with_flash_attn(packed=self.cfg.sample_packing) - if cfg.model_kwargs: - for key, val in cfg.model_kwargs.items(): - model_kwargs[key] = val - max_memory = cfg.max_memory - device_map = cfg.device_map + def set_device_map_config(self) -> None: + device_map = self.cfg.device_map + max_memory = self.cfg.max_memory - if cfg.gpu_memory_limit: - gpu_memory_limit = ( - str(cfg.gpu_memory_limit) + "GiB" - if isinstance(cfg.gpu_memory_limit, int) - else cfg.gpu_memory_limit - ) - - max_memory = {} - for i in range(torch.cuda.device_count()): - max_memory[i] = gpu_memory_limit - max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything - - if max_memory is not None: - # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py - from accelerate import infer_auto_device_map - - with init_empty_weights(): - model_canvas = AutoModelForCausalLM.from_config( - model_config, trust_remote_code=cfg.trust_remote_code or False + if self.cfg.gpu_memory_limit: + gpu_memory_limit = ( + str(self.cfg.gpu_memory_limit) + "GiB" + if isinstance(self.cfg.gpu_memory_limit, int) + else self.cfg.gpu_memory_limit ) - model_canvas.tie_weights() - device_map = infer_auto_device_map( - model_canvas, - max_memory=max_memory, - dtype=cfg.torch_dtype, - ) - # We can discard max_memory now as we have a device map set up for us - max_memory = None - - model_kwargs["device_map"] = device_map - model_kwargs["torch_dtype"] = cfg.torch_dtype - if torch.backends.mps.is_available(): - model_kwargs["device_map"] = "mps:0" + self.max_memory = {} + for i in range(torch.cuda.device_count()): + max_memory[i] = gpu_memory_limit + max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything - # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss - # if cfg.rl: - # if torch.cuda.device_count() > 1: - # if reference_model: - # model_kwargs["device_map"] = "cuda:" + str( - # torch.cuda.current_device() + 1 - # ) - # else: - # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) + if max_memory is not None: + # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py + from accelerate import infer_auto_device_map - if is_deepspeed_zero3_enabled(): - del model_kwargs["device_map"] - - if cfg.revision_of_model: - model_kwargs["revision"] = cfg.revision_of_model - - if cfg.gptq: - if not hasattr(model_config, "quantization_config"): - LOG.warning("model config does not contain quantization_config information") - else: - if cfg.gptq_disable_exllama is not None: - model_config.quantization_config[ - "disable_exllama" - ] = cfg.gptq_disable_exllama - model_kwargs["quantization_config"] = GPTQConfig( - **model_config.quantization_config + with init_empty_weights(): + model_canvas = AutoModelForCausalLM.from_config( + self.model_config, trust_remote_code=self.cfg.trust_remote_code or False + ) + model_canvas.tie_weights() + device_map = infer_auto_device_map( + model_canvas, + max_memory=max_memory, + dtype=self.cfg.torch_dtype, ) - if cfg.adapter == "qlora" and cfg.load_in_4bit: - bnb_config = { - "load_in_4bit": True, - "llm_int8_threshold": 6.0, - "llm_int8_has_fp16_weight": False, - "bnb_4bit_compute_dtype": cfg.torch_dtype, - "bnb_4bit_use_double_quant": True, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_quant_storage": torch.bfloat16, - } - if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed: - # for some reason, this causes the loss to be off by an order of magnitude - # but deepspeed needs this still in bfloat16 - bnb_config["bnb_4bit_quant_storage"] = torch.float32 - - if cfg.bnb_config_kwargs: - bnb_config.update(cfg.bnb_config_kwargs) - - model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - elif cfg.adapter == "lora" and cfg.load_in_8bit: - bnb_config = { - "load_in_8bit": True, - } - # Exclude mamba blocks from int8 quantization for jamba - if cfg.model_config_type == "jamba": - bnb_config["llm_int8_skip_modules"] = ["mamba"] - model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - - if cfg.load_in_8bit and cfg.adapter is not None: - model_kwargs["load_in_8bit"] = True - if cfg.load_in_4bit and cfg.adapter is not None: - model_kwargs["load_in_4bit"] = True - - # no longer needed per https://github.com/huggingface/transformers/pull/26610 - if "quantization_config" in model_kwargs or cfg.gptq: - if "load_in_8bit" in model_kwargs: - del model_kwargs["load_in_8bit"] - if "load_in_4bit" in model_kwargs: - del model_kwargs["load_in_4bit"] - - # sample packing uses custom FA2 patch - if cfg.flash_attention: - if not cfg.sample_packing: - if cfg.s2_attention: - pass - # most other models support flash attention, we can define exceptions as they come up - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" + # We can discard max_memory now as we have a device map set up for us + max_memory = None + + self.model_kwargs["device_map"] = device_map + self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype + + if "mps" in self.device.__str__(): + self.model_kwargs["device_map"] = "mps:0" + elif "npu" in self.device.__str__(): + self.model_kwargs["device_map"] = "npu:0" + + # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss + # if cfg.rl: + # if torch.cuda.device_count() > 1: + # if reference_model: + # model_kwargs["device_map"] = "cuda:" + str( + # torch.cuda.current_device() + 1 + # ) + # else: + # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) + + if is_deepspeed_zero3_enabled(): + del self.model_kwargs["device_map"] + + + def set_quantization_config(self) -> None: + if self.cfg.gptq: + if not hasattr(self.model_config, "quantization_config"): + LOG.warning("model config does not contain quantization_config information") + else: + if self.cfg.gptq_disable_exllama is not None: + self.model_config.quantization_config[ + "disable_exllama" + ] = self.cfg.gptq_disable_exllama + self.model_kwargs["quantization_config"] = GPTQConfig( + **self.model_config.quantization_config + ) + if self.cfg.adapter == "qlora" and self.cfg.load_in_4bit: + bnb_config = { + "load_in_4bit": True, + "llm_int8_threshold": 6.0, + "llm_int8_has_fp16_weight": False, + "bnb_4bit_compute_dtype": self.cfg.torch_dtype, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_quant_storage": torch.bfloat16, + } + if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not self.cfg.deepspeed: + # for some reason, this causes the loss to be off by an order of magnitude + # but deepspeed needs this still in bfloat16 + bnb_config["bnb_4bit_quant_storage"] = torch.float32 + + if self.cfg.bnb_config_kwargs: + bnb_config.update(self.cfg.bnb_config_kwargs) + + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, ) - else: - if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES: - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access + elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit: + bnb_config = { + "load_in_8bit": True, + } + # Exclude mamba blocks from int8 quantization for jamba + if self.cfg.model_config_type == "jamba": + bnb_config["llm_int8_skip_modules"] = ["mamba"] + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, + ) + + if (self.cfg.adapter is not None + and not ("quantization_config" in self.model_kwargs or self.cfg.gptq) + ): + # no longer needed per https://github.com/huggingface/transformers/pull/26610 + self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit + self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit + + + def set_attention_config(self) -> None: + """ + sample packing uses custom FA2 patch + """ + if self.cfg.flash_attention: + if not self.cfg.sample_packing: + if self.cfg.s2_attention: + pass + # most other models support flash attention, we can define exceptions as they come up + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) else: - model_kwargs["attn_implementation"] = "eager" - model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) - elif cfg.sdp_attention: - model_kwargs["attn_implementation"] = "sdpa" - model_config._attn_implementation = "sdpa" # pylint: disable=protected-access - elif cfg.eager_attention: - model_kwargs["attn_implementation"] = "eager" - model_config._attn_implementation = "eager" # pylint: disable=protected-access - - if cfg.low_cpu_mem_usage: - model_kwargs["low_cpu_mem_usage"] = True + if self.model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES: + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) + else: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" + ) + elif self.cfg.sdp_attention: + self.model_kwargs["attn_implementation"] = "sdpa" + self.model_config._attn_implementation = "sdpa" # pylint: disable=protected-access + elif self.cfg.eager_attention: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = "eager" # pylint: disable=protected-access - qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" + if self.cfg.low_cpu_mem_usage: + self.model_kwargs["low_cpu_mem_usage"] = True - try: + def build_model(self, qlora_fsdp) -> Tuple[PreTrainedModel, bool]: skip_move_to_device = False if ( # pylint: disable=condition-evals-to-constant) - (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) + (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) and not qlora_fsdp and False ): model = load_sharded_model( - base_model, - model_config, - cfg, - torch_dtype=cfg.torch_dtype, + self.base_model, + self.model_config, + self.cfg, + torch_dtype=self.cfg.torch_dtype, ) skip_move_to_device = True elif ( qlora_fsdp - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and cfg.model_config_type == "dbrx" + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and self.cfg.model_config_type == "dbrx" ): - quant_storage = cfg.torch_dtype + quant_storage = self.cfg.torch_dtype model = load_sharded_model_quant( - base_model, - model_config, - cfg, + self.base_model, + self.model_config, + self.cfg, quant_storage=quant_storage, ) skip_move_to_device = True elif ( - model_config.model_type == "llama" - and not cfg.trust_remote_code - and not cfg.gptq + self.model_config.model_type == "llama" + and not self.cfg.trust_remote_code + and not self.cfg.gptq ): - if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if qlora_fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: skip_move_to_device = True - if "device_map" in model_kwargs: - del model_kwargs["device_map"] + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] model = AutoModelForCausalLM.from_pretrained( - base_model, - config=model_config, - **model_kwargs, + self.base_model, + config=self.model_config, + **self.model_kwargs, ) - if cfg.flash_attention and not inference: + # TODO (MengqingCao) split these patches seperately + if self.cfg.flash_attention and not self.inference: from axolotl.monkeypatch.llama_attn_hijack_flash import ( is_xformers_swiglu_available, replace_llama_mlp_with_swiglu, replace_llama_qkv_with_fused, ) - if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): + if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): LOG.info("patching with SwiGLU") replace_llama_mlp_with_swiglu(model) - if cfg.flash_attn_fuse_qkv: + if self.cfg.flash_attn_fuse_qkv: LOG.info("patching with fused QKV") replace_llama_qkv_with_fused(model) - elif model_type == "MambaLMHeadModel": + elif self.model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name - model_kwargs["dtype"] = model_kwargs["torch_dtype"] - model_kwargs["device"] = torch.cuda.current_device() - del model_kwargs["torch_dtype"] - del model_kwargs["device_map"] + self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] + self.model_kwargs["device"] = torch.cuda.current_device() + del self.model_kwargs["torch_dtype"] + del self.model_kwargs["device_map"] model = MambaLMHeadModel.from_pretrained( - base_model, - **model_kwargs, + self.base_model, + **self.model_kwargs, ) elif ( - model_type - and model_type != "AutoModelForCausalLM" - and not cfg.trust_remote_code + self.model_type + and self.model_type != "AutoModelForCausalLM" + and not self.cfg.trust_remote_code ): - if cfg.gptq: + if self.cfg.gptq: model = AutoModelForCausalLM.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) else: - model = getattr(transformers, model_type).from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + model = getattr(transformers, self.model_type).from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) else: # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # when training starts if ( - hasattr(model_config, "max_seq_len") - and model_config.max_seq_len - and cfg.sequence_len > model_config.max_seq_len + hasattr(self.model_config, "max_seq_len") + and self.model_config.max_seq_len + and self.cfg.sequence_len > self.model_config.max_seq_len ): - model_config.max_seq_len = cfg.sequence_len - LOG.warning(f"increasing context length to {cfg.sequence_len}") + self.model_config.max_seq_len = self.cfg.sequence_len + LOG.warning(f"increasing context length to {self.cfg.sequence_len}") elif ( - hasattr(model_config, "max_sequence_length") - and model_config.max_sequence_length - and cfg.sequence_len > model_config.max_sequence_length + hasattr(self.model_config, "max_sequence_length") + and self.model_config.max_sequence_length + and self.cfg.sequence_len > self.model_config.max_sequence_length ): - model_config.max_sequence_length = cfg.sequence_len - LOG.warning(f"increasing context length to {cfg.sequence_len}") - if cfg.gptq: + self.model_config.max_sequence_length = self.cfg.sequence_len + LOG.warning(f"increasing context length to {self.cfg.sequence_len}") + if self.cfg.gptq: model = AutoModelForCausalLM.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) else: - if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if qlora_fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: # disabling either of these two still leads to VRAM spike before setting back down skip_move_to_device = True - if "device_map" in model_kwargs: - del model_kwargs["device_map"] + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] model = AutoModelForCausalLM.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) - except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err) - raise err - - if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: - model = model.merge_and_unload() - - embeddings_len = ( - math.ceil(len(tokenizer) / 32) * 32 - if cfg.resize_token_embeddings_to_32x - else len(tokenizer) - ) - if ( - hasattr(model, "get_input_embeddings") - and model.get_input_embeddings().num_embeddings < embeddings_len - ): - model.resize_token_embeddings(embeddings_len) - else: - model.tie_weights() - if ( - hasattr(model, "config") - and hasattr(model.config, "max_position_embeddings") - and model.config.max_position_embeddings - and cfg.sequence_len > model.config.max_position_embeddings - ): - LOG.warning( - f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}" - ) - model.config.max_position_embeddings = cfg.sequence_len + return model, skip_move_to_device - if ( - hasattr(model, "config") - and hasattr(model.config, "bos_token_id") - and model.config.bos_token_id - and model.config.bos_token_id != tokenizer.bos_token_id - ): - model.config.bos_token_id = tokenizer.bos_token_id - if ( - hasattr(model, "config") - and hasattr(model.config, "eos_token_id") - and model.config.eos_token_id - and model.config.eos_token_id != tokenizer.eos_token_id - ): - model.config.eos_token_id = tokenizer.eos_token_id + def ajust_model_config(self, model) -> None: + if ( + hasattr(model, "config") + and hasattr(model.config, "max_position_embeddings") + and model.config.max_position_embeddings + and self.cfg.sequence_len > model.config.max_position_embeddings + ): + LOG.warning( + f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {self.cfg.sequence_len}" + ) + model.config.max_position_embeddings = self.cfg.sequence_len - if hasattr(model, "device") and model.device.type in ("cuda", "mps"): - log_gpu_memory_usage(LOG, "after model load", model.device) + if ( + hasattr(model, "config") + and hasattr(model.config, "bos_token_id") + and model.config.bos_token_id + and model.config.bos_token_id != self.tokenizer.bos_token_id + ): + model.config.bos_token_id = self.tokenizer.bos_token_id - # make sure these are fp32 per Ramesh et al. (2021) - embedding_modules = get_linear_embedding_layers(cfg.model_config_type) - if not cfg.fsdp: - # FSDP doesn't like mixed Float and BFloat16 - for name, module in model.named_modules(): - if "norm" in name or name.endswith(".gate"): - module.to(torch.float32) - if model_config.model_type == "btlm": - # don't upcast lm_head for btlm - continue - if any(m in name for m in embedding_modules): - if hasattr(module, "weight"): - module.to(torch.float32) + if ( + hasattr(model, "config") + and hasattr(model.config, "eos_token_id") + and model.config.eos_token_id + and model.config.eos_token_id != self.tokenizer.eos_token_id + ): + model.config.eos_token_id = self.tokenizer.eos_token_id - needs_fa2_dtype = cfg.adapter or cfg.fsdp - skip_prepare_model_for_kbit_training = False - if is_deepspeed_zero3_enabled(): + def set_z3_leaf_modules(self, model) -> None: from deepspeed.utils import ( # pylint: disable=no-name-in-module set_z3_leaf_modules, ) - if cfg.model_config_type == "mixtral": + if self.cfg.model_config_type == "mixtral": moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock") set_z3_leaf_modules(model, [moe_block]) - elif cfg.model_config_type == "dbrx": + elif self.cfg.model_config_type == "dbrx": moe_block = get_module_class_from_name(model, "DbrxFFN") set_z3_leaf_modules(model, [moe_block]) - if cfg.model_config_type == "qwen" and cfg.adapter == "lora": - # Qwen doesn't play nicely with LoRA if this is enabled - skip_prepare_model_for_kbit_training = True - loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits - if cfg.adapter == "lora" and loftq_bits: - skip_prepare_model_for_kbit_training = True + def prepare_model(self, qlora_fsdp) -> None: + skip_prepare_model_for_kbit_training = False + if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": + # Qwen doesn't play nicely with LoRA if this is enabled + skip_prepare_model_for_kbit_training = True - if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading): - # make sure everything is in the same dtype - skip_prepare_model_for_kbit_training = True + loftq_bits = self.cfg.peft and self.cfg.peft.loftq_config and self.cfg.peft.loftq_config.loftq_bits + if self.cfg.adapter == "lora" and loftq_bits: + skip_prepare_model_for_kbit_training = True - if cfg.adapter in ["lora", "qlora"]: - if cfg.gradient_checkpointing: - model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs - ) - if ( - cfg.load_in_8bit or cfg.load_in_4bit - ) and not skip_prepare_model_for_kbit_training: + if qlora_fsdp or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading): + # make sure everything is in the same dtype + skip_prepare_model_for_kbit_training = True + + if (not skip_prepare_model_for_kbit_training + and self.cfg.adapter in ["lora", "qlora"] + and (self.model_kwargs["load_in_8bit"] or self.model_kwargs["load_in_4bit"]) + ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") model = prepare_model_for_kbit_training( - model, use_gradient_checkpointing=cfg.gradient_checkpointing + model, use_gradient_checkpointing=self.cfg.gradient_checkpointing ) - needs_fa2_dtype = True - # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp: - LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) + + def convert_embedding_modules_dtype( + self, + model, + embedding_modules, + dist_dtype, + before_kbit_train_or_finetune + ) -> None: for name, module in model.named_modules(): if "norm" in name: - module.to(cfg.torch_dtype) + module.to(dist_dtype) + if before_kbit_train_or_finetune: + if name.endswith(".gate"): + module.to(dist_dtype) + if self.model_config.model_type == "btlm": + # don't upcast lm_head for btlm + continue if any(m in name for m in embedding_modules): if hasattr(module, "weight"): - module.to(cfg.torch_dtype) - - lora_config = None - if not reference_model or cfg.lora_model_dir: - # if we're not loading the reference model, then we're loading the model for training - # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto"] and not cfg.merge_lora: - _, lora_config = load_lora(model, cfg, inference=False, config_only=True) + module.to(dist_dtype) + + + def apply_lora_patch(self, model) -> None: + if self.cfg.unsloth_lora_mlp: + from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch + + integrate_lora_mlp_patch(model) + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import integrate_lora_patch + + integrate_lora_patch(model, self.cfg) + if self.cfg.unsloth_rope: + from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + + integrate_rope_embeddings() + + + def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: + self.apply_patches() + self.set_device_map_config() + if self.cfg.revision_of_model: + self.model_kwargs["revision"] = self.cfg.revision_of_model + self.set_quantization_config() + self.set_attention_config() + + qlora_fsdp = self.cfg.fsdp and self.cfg.adapter == "qlora" + skip_move_to_device = False + + try: + model, skip_move_to_device = self.build_model(qlora_fsdp) + except Exception as err: # pylint: disable=broad-exception-caught + LOG.exception(err) + raise err + + if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: + model = model.merge_and_unload() + + embeddings_len = ( + math.ceil(len(self.tokenizer) / 32) * 32 + if self.cfg.resize_token_embeddings_to_32x + else len(self.tokenizer) + ) + if (hasattr(model, "get_input_embeddings") + and model.get_input_embeddings().num_embeddings < embeddings_len + ): + model.resize_token_embeddings(embeddings_len) else: - model, lora_config = load_adapter(model, cfg, cfg.adapter) + model.tie_weights() + + self.ajust_model_config(model) + + # log device memory usage + if hasattr(model, "device") and model.device.type in ("cuda", "mps"): + log_gpu_memory_usage(LOG, "after model load", model.device) + + # make sure these are fp32 per Ramesh et al. (2021) + embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) + if not self.cfg.fsdp: + # FSDP doesn't like mixed Float and BFloat16 + self.convert_embedding_modules_dtype(model, embedding_modules, + dist_dtype=torch.float32, before_kbit_train_or_finetune=True) + + if is_deepspeed_zero3_enabled(): + self.set_z3_leaf_modules(model) + + needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp + if self.cfg.adapter in ["lora", "qlora"]: + needs_fa2_dtype = True + if self.cfg.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs + ) - if ( - cfg.ddp - and not load_in_8bit - and not (cfg.rl and cfg.load_in_4bit) - and not skip_move_to_device - ): - # TODO revaldate this conditional - model.to(f"cuda:{cfg.local_rank}") + self.prepare_model(qlora_fsdp) + + # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp: + LOG.info("converting modules to %s for flash attention", self.cfg.torch_dtype) + self.convert_embedding_modules_dtype(model, embedding_modules, + dist_dtype=self.cfg.torch_dtype, before_kbit_train_or_finetune=False) + + # --------------------------------------------------------- + # load lora or adapter + # --------------------------------------------------------- + lora_config = None + if not self.reference_model or self.cfg.lora_model_dir: + # if we're not loading the reference model, then we're loading the model for training + # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config + if self.cfg.adapter and self.cfg.rl in ["dpo", "ipo", "kto"] and not self.cfg.merge_lora: + _, lora_config = load_lora(model, self.cfg, inference=False, config_only=True) + else: + model, lora_config = load_adapter(model, self.cfg, self.cfg.adapter) - if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: - setattr(model, "is_parallelizable", True) - setattr(model, "model_parallel", True) + # --------------------------------------------------------- + # put model to accelerator + # --------------------------------------------------------- + if ( + self.cfg.ddp + and not self.model_kwargs["load_in_8bit"] + and not (self.cfg.rl and self.model_kwargs["load_in_4bit"]) + and not skip_move_to_device + ): + # TODO revaldate this conditional + model.to(f"{self.device.__str__()}:{self.cfg.local_rank}") - requires_grad = [] - for name, param in model.named_parameters(recurse=True): - if param.requires_grad: - requires_grad.append(f"{name}: {param.requires_grad}") - if len(requires_grad) == 0: - LOG.warning("there are no parameters that require gradient updates") - if hasattr(model, "config"): - model.config.use_cache = False + if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: + setattr(model, "is_parallelizable", True) + setattr(model, "model_parallel", True) - if cfg.flash_optimum: - from optimum.bettertransformer import BetterTransformer + # --------------------------------------------------------- + # parameters that require gradient updates + # --------------------------------------------------------- + requires_grad = [] + for name, param in model.named_parameters(recurse=True): + if param.requires_grad: + requires_grad.append(f"{name}: {param.requires_grad}") + if len(requires_grad) == 0: + LOG.warning("there are no parameters that require gradient updates") + if hasattr(model, "config"): + model.config.use_cache = False - model = BetterTransformer.transform(model) - if cfg.adapter is not None: - log_gpu_memory_usage(LOG, "after adapters", model.device) + if self.cfg.flash_optimum: + from optimum.bettertransformer import BetterTransformer - if cfg.unsloth_lora_mlp: - from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch + model = BetterTransformer.transform(model) - integrate_lora_mlp_patch(model) - if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import integrate_lora_patch + if self.cfg.adapter is not None: + log_gpu_memory_usage(LOG, "after adapters", model.device) - integrate_lora_patch(model, cfg) + self.apply_lora_patch(model) - if cfg.unsloth_rope: - from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + if "cuda" in self.device: + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() - integrate_rope_embeddings() + # TODO resume_from_checkpoint handling + return model, lora_config - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - # TODO resume_from_checkpoint handling - return model, lora_config +def load_model( + cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, + inference: bool = False, + reference_model: bool = False, +) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: + """ + Load a model for a given configuration and tokenizer. + """ + loader = ModelLoader(cfg, tokenizer, inference, reference_model) + return loader.load_model() def load_adapter(model, cfg, adapter, inference=False):