From 7b81ca5f5bbbea5949419c51eddfd74ee539a93f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 22 Aug 2024 01:38:24 -0700 Subject: [PATCH 001/110] Update _utils.py --- unsloth/models/_utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1c48e8e5..cd8841e4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -284,18 +284,18 @@ def patch_mistral_nemo_config(config): # ============================================= # Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout' if Version(xformers_version) >= Version("0.0.27"): - import accelerate.utils.operations - if hasattr(accelerate.utils.operations, "send_to_device") and \ - accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": - from accelerate.utils.operations import * - send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) - send_to_device = re.sub( - r"([ ]{4,})return tensor\.to\(device\)", - r"\1try: return tensor.to(device)\n\1except: return tensor", - send_to_device, - ).replace("def send_to_device", "def _fixed_send_to_device") - exec(send_to_device) - accelerate.utils.operations.send_to_device = _fixed_send_to_device + # import accelerate.utils.operations + # if hasattr(accelerate.utils.operations, "send_to_device") and \ + # accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": + # from accelerate.utils.operations import * + # send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) + # send_to_device = re.sub( + # r"([ ]{4,})return tensor\.to\(device\)", + # r"\1try: return tensor.to(device)\n\1except: return tensor", + # send_to_device, + # ).replace("def send_to_device", "def _fixed_send_to_device") + # exec(send_to_device) + # accelerate.utils.operations.send_to_device = _fixed_send_to_device pass pass # ============================================= From 94f2d3477ce8a3ef43182472c0c321527b10102d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 22 Aug 2024 01:42:43 -0700 Subject: [PATCH 002/110] Update _utils.py --- unsloth/models/_utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cd8841e4..d4e504bd 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -284,18 +284,18 @@ def patch_mistral_nemo_config(config): # ============================================= # Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout' if Version(xformers_version) >= Version("0.0.27"): - # import accelerate.utils.operations - # if hasattr(accelerate.utils.operations, "send_to_device") and \ - # accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": - # from accelerate.utils.operations import * - # send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) - # send_to_device = re.sub( - # r"([ ]{4,})return tensor\.to\(device\)", - # r"\1try: return tensor.to(device)\n\1except: return tensor", - # send_to_device, - # ).replace("def send_to_device", "def _fixed_send_to_device") - # exec(send_to_device) - # accelerate.utils.operations.send_to_device = _fixed_send_to_device + import accelerate.utils.operations + if hasattr(accelerate.utils.operations, "send_to_device") and \ + accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": + from accelerate.utils.operations import * + send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) + send_to_device = re.sub( + r"([ ]{4,})return tensor\.to\(device\)", + r"\1print(type(tensor)); return tensor.to(device)", + send_to_device, + ).replace("def send_to_device", "def _fixed_send_to_device") + exec(send_to_device) + accelerate.utils.operations.send_to_device = _fixed_send_to_device pass pass # ============================================= From 7c5222d594f698d8b7f16fc13444459eafe8949b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 22 Aug 2024 01:46:37 -0700 Subject: [PATCH 003/110] Update _utils.py --- unsloth/models/_utils.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index d4e504bd..97c12906 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -281,25 +281,6 @@ def patch_mistral_nemo_config(config): ) pass -# ============================================= -# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout' -if Version(xformers_version) >= Version("0.0.27"): - import accelerate.utils.operations - if hasattr(accelerate.utils.operations, "send_to_device") and \ - accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": - from accelerate.utils.operations import * - send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) - send_to_device = re.sub( - r"([ ]{4,})return tensor\.to\(device\)", - r"\1print(type(tensor)); return tensor.to(device)", - send_to_device, - ).replace("def send_to_device", "def _fixed_send_to_device") - exec(send_to_device) - accelerate.utils.operations.send_to_device = _fixed_send_to_device - pass -pass -# ============================================= - # ============================================= # Torch compile settings @@ -1110,3 +1091,23 @@ def test_mask_creation(): assert(torch.all(correct_mask == our_mask)) pass pass + + +# ============================================= +# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout' +if Version(xformers_version) >= Version("0.0.27"): + import accelerate.utils.operations + if hasattr(accelerate.utils.operations, "send_to_device") and \ + accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": + from accelerate.utils.operations import * + send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) + send_to_device = re.sub( + r"([ ]{4,})return tensor\.to\(device\)", + r"\1try: return tensor.to(device)\n\1except: return tensor", + send_to_device, + ).replace("def send_to_device", "def _fixed_send_to_device") + exec(send_to_device) + accelerate.utils.operations.send_to_device = _fixed_send_to_device + pass +pass +# ============================================= From 15d44179415d3c9da868a9a01a29894586b3a789 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 22 Aug 2024 01:52:28 -0700 Subject: [PATCH 004/110] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 97c12906..f5e1c988 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1103,7 +1103,7 @@ def test_mask_creation(): send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) send_to_device = re.sub( r"([ ]{4,})return tensor\.to\(device\)", - r"\1try: return tensor.to(device)\n\1except: return tensor", + r"\1try: return tensor.to(device)\n\1except: print(tensor)", send_to_device, ).replace("def send_to_device", "def _fixed_send_to_device") exec(send_to_device) From 1ea463c72195b753e4ea69d75cd1998ed2a46359 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 22 Aug 2024 01:57:57 -0700 Subject: [PATCH 005/110] Update _utils.py --- unsloth/models/_utils.py | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f5e1c988..1c48e8e5 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -281,6 +281,25 @@ def patch_mistral_nemo_config(config): ) pass +# ============================================= +# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout' +if Version(xformers_version) >= Version("0.0.27"): + import accelerate.utils.operations + if hasattr(accelerate.utils.operations, "send_to_device") and \ + accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": + from accelerate.utils.operations import * + send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) + send_to_device = re.sub( + r"([ ]{4,})return tensor\.to\(device\)", + r"\1try: return tensor.to(device)\n\1except: return tensor", + send_to_device, + ).replace("def send_to_device", "def _fixed_send_to_device") + exec(send_to_device) + accelerate.utils.operations.send_to_device = _fixed_send_to_device + pass +pass +# ============================================= + # ============================================= # Torch compile settings @@ -1091,23 +1110,3 @@ def test_mask_creation(): assert(torch.all(correct_mask == our_mask)) pass pass - - -# ============================================= -# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout' -if Version(xformers_version) >= Version("0.0.27"): - import accelerate.utils.operations - if hasattr(accelerate.utils.operations, "send_to_device") and \ - accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": - from accelerate.utils.operations import * - send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) - send_to_device = re.sub( - r"([ ]{4,})return tensor\.to\(device\)", - r"\1try: return tensor.to(device)\n\1except: print(tensor)", - send_to_device, - ).replace("def send_to_device", "def _fixed_send_to_device") - exec(send_to_device) - accelerate.utils.operations.send_to_device = _fixed_send_to_device - pass -pass -# ============================================= From cf929e29e64de98ac3f818de678a9654ff047843 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 22 Aug 2024 02:03:59 -0700 Subject: [PATCH 006/110] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index b677f864..e8f2e8b5 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1187,6 +1187,18 @@ def patch_sft_trainer_tokenizer(): "pass\n"\ "\n" + # Also DPO weirdly tokenizes non numeric columns? Delete them! + check_text += \ + "\n"\ + "column_names = set(self.train_dataset.column_names)\n"\ + "check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ + " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ + " 'prompt_input_ids', 'prompt_attention_mask']\n"\ + "if all(x in column_names for x in check):\n"\ + " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ + "del check, column_names\n"\ + "\n" + check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) From 5a7be989fb473100d24fd9294ea469385b129d0b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 22 Aug 2024 02:07:17 -0700 Subject: [PATCH 007/110] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index e8f2e8b5..2b4d6c44 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1195,6 +1195,7 @@ def patch_sft_trainer_tokenizer(): " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ " 'prompt_input_ids', 'prompt_attention_mask']\n"\ "if all(x in column_names for x in check):\n"\ + " print(1)\n"\ " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ "del check, column_names\n"\ "\n" From 2590b4c9509213f3a69c2553132a27ad35f40b6e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 22 Aug 2024 02:13:47 -0700 Subject: [PATCH 008/110] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 131 +++++++++++++++++++------------------ 1 file changed, 67 insertions(+), 64 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 2b4d6c44..044629ea 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1143,70 +1143,73 @@ def patch_sft_trainer_tokenizer(): pass # Patch train with fix_untrained_tokens - function_name, replacer = "train", "if resume_from_checkpoint is False:" - function = getsource(eval(f"trl.trainer.sft_trainer.SFTTrainer.{function_name}")) - where = function.find("def") - function = function.split("\n") - function = "\n".join(x[where:] for x in function) - - check_text = \ - "\n"\ - "if self._inner_training_loop.__name__ != '_fast_inner_training_loop':\n"\ - " raise RuntimeError(\n"\ - " 'Please do not edit specific areas of the Unsloth codebase or you will get CUDA segfaults.'\n"\ - " )\n"\ - "pass\n"\ - "import subprocess, re, gc, numpy as np\n"\ - "a = np.array([0,])\n"\ - "try:\n"\ - " a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"\ - " a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)\n"\ - " a = np.array([int(x.decode('utf-8'))/1024 for x in a])\n"\ - "except:\n"\ - " if not torch.cuda.is_available():\n"\ - " raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')\n"\ - "if ((a - PRE_CHECK) >= 1).sum() > 1:\n"\ - " raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')\n"\ - "for _ in range(3):\n"\ - " gc.collect()\n"\ - " torch.cuda.empty_cache()\n"\ - "pass\n"\ - "\n"\ - "fix_untrained_tokens(self.model, self.tokenizer, self.train_dataset, eps = 1e-16)\n\n" - - # Add NEFTune since it doesn't seem to work?? We need to manually inject it - check_text += \ - "\n"\ - "if hasattr(self, 'neftune_hook_handle'):\n"\ - " self.neftune_hook_handle.remove()\n"\ - " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\ - "\n"\ - "if getattr(self, 'neftune_noise_alpha', None) is not None:\n"\ - " self.model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\ - " self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"\ - "pass\n"\ - "\n" - - # Also DPO weirdly tokenizes non numeric columns? Delete them! - check_text += \ - "\n"\ - "column_names = set(self.train_dataset.column_names)\n"\ - "check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ - " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ - " 'prompt_input_ids', 'prompt_attention_mask']\n"\ - "if all(x in column_names for x in check):\n"\ - " print(1)\n"\ - " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ - "del check, column_names\n"\ - "\n" - - check_text = check_text.split("\n") - check_text = "\n".join(" "*where + x for x in check_text) - - function = function.replace(replacer, check_text + replacer) - exec(function, globals()) - - exec(f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}", globals()) + for path_to_trainer in \ + ("sft_trainer.SFTTrainer", "dpo_trainer.DPOTrainer",): + + function_name, replacer = "train", "if resume_from_checkpoint is False:" + function = getsource(eval(f"trl.trainer.{path_to_trainer}.{function_name}")) + where = function.find("def") + function = function.split("\n") + function = "\n".join(x[where:] for x in function) + + check_text = \ + "\n"\ + "if self._inner_training_loop.__name__ != '_fast_inner_training_loop':\n"\ + " raise RuntimeError(\n"\ + " 'Please do not edit specific areas of the Unsloth codebase or you will get CUDA segfaults.'\n"\ + " )\n"\ + "pass\n"\ + "import subprocess, re, gc, numpy as np\n"\ + "a = np.array([0,])\n"\ + "try:\n"\ + " a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"\ + " a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)\n"\ + " a = np.array([int(x.decode('utf-8'))/1024 for x in a])\n"\ + "except:\n"\ + " if not torch.cuda.is_available():\n"\ + " raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')\n"\ + "if ((a - PRE_CHECK) >= 1).sum() > 1:\n"\ + " raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')\n"\ + "for _ in range(3):\n"\ + " gc.collect()\n"\ + " torch.cuda.empty_cache()\n"\ + "pass\n"\ + "\n"\ + "fix_untrained_tokens(self.model, self.tokenizer, self.train_dataset, eps = 1e-16)\n\n" + + # Add NEFTune since it doesn't seem to work?? We need to manually inject it + check_text += \ + "\n"\ + "if hasattr(self, 'neftune_hook_handle'):\n"\ + " self.neftune_hook_handle.remove()\n"\ + " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\ + "\n"\ + "if getattr(self, 'neftune_noise_alpha', None) is not None:\n"\ + " self.model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\ + " self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"\ + "pass\n"\ + "\n" + + # Also DPO weirdly tokenizes non numeric columns? Delete them! + check_text += \ + "\n"\ + "column_names = set(self.train_dataset.column_names)\n"\ + "check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ + " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ + " 'prompt_input_ids', 'prompt_attention_mask']\n"\ + "if all(x in column_names for x in check):\n"\ + " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ + "del check, column_names\n"\ + "\n" + + check_text = check_text.split("\n") + check_text = "\n".join(" "*where + x for x in check_text) + + function = function.replace(replacer, check_text + replacer) + exec(function, globals()) + + exec(f"trl.trainer.{path_to_trainer}.{function_name} = {function_name}", globals()) + pass pass patch_sft_trainer_tokenizer() From 621e65b3b8dec09860be613bd904574b136f344c Mon Sep 17 00:00:00 2001 From: Hafedh <70411813+not-lain@users.noreply.github.com> Date: Sat, 24 Aug 2024 00:41:59 +0100 Subject: [PATCH 009/110] update token retrieval logic (#952) * Fix DPO (#947) * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * update hf token retrieval logic --------- Co-authored-by: Daniel Han --- unsloth/models/llama.py | 8 +++----- unsloth/models/loader.py | 8 +++----- unsloth/save.py | 17 +++++------------ 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 376b4b4e..d832b3ef 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -61,6 +61,7 @@ from peft.tuners.lora import Linear4bit as Peft_Linear4bit from ..save import patch_saving_functions import re, os, inspect, math, sys +from huggingface_hub.utils._token import get_token def original_apply_qkv(self, X): @@ -1418,11 +1419,8 @@ def from_pretrained( ) pass - if token is None and "HF_TOKEN" in os.environ: - token = os.environ["HF_TOKEN"] - - if token is None and "HUGGINGFACE_TOKEN" in os.environ: - token = os.environ["HUGGINGFACE_TOKEN"] + if token is None : + token = get_token() if model_patcher is None: model_patcher = FastLlamaModel SUPPORTS_BFLOAT16 = is_bfloat16_supported() diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 02ed00f5..712d31d7 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -21,6 +21,7 @@ from peft import PeftConfig, PeftModel from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit import os +from huggingface_hub.utils._token import get_token # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! from packaging.version import Version @@ -152,11 +153,8 @@ def from_pretrained( revision = None, *args, **kwargs, ): - if token is None and "HF_TOKEN" in os.environ: - token = os.environ["HF_TOKEN"] - - if token is None and "HUGGINGFACE_TOKEN" in os.environ: - token = os.environ["HUGGINGFACE_TOKEN"] + if token is None : + token = get_token() old_model_name = model_name model_name = get_model_name(model_name, load_in_4bit) diff --git a/unsloth/save.py b/unsloth/save.py index f45d8062..1d23a206 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -29,6 +29,7 @@ from transformers.models.llama.modeling_llama import logger from .tokenizer_utils import fix_sentencepiece_gguf from huggingface_hub import HfApi +from huggingface_hub.utils._token import get_token __all__ = [ "print_quantization_methods", @@ -207,12 +208,8 @@ def unsloth_save_model( temporary_location : str = "_unsloth_temporary_saved_buffers", maximum_memory_usage : float = 0.9, ): - if token is None and "HF_TOKEN" in os.environ: - token = os.environ["HF_TOKEN"] - elif token is None and "hf_token" in os.environ: - token = os.environ["hf_token"] - elif token is None and "HUGGINGFACE_TOKEN" in os.environ: - token = os.environ["HUGGINGFACE_TOKEN"] + if token is None : + token = get_token() if commit_message is None: commit_message = "" if "Unsloth" not in commit_message: @@ -1321,12 +1318,8 @@ def create_huggingface_repo( token = None, private = False, ): - if token is None and "HF_TOKEN" in os.environ: - token = os.environ["HF_TOKEN"] - elif token is None and "hf_token" in os.environ: - token = os.environ["hf_token"] - elif token is None and "HUGGINGFACE_TOKEN" in os.environ: - token = os.environ["HUGGINGFACE_TOKEN"] + if token is None : + token = get_token() pass save_directory, username = _determine_username(save_directory, "", token) From b62e5cd4b8155a15469d99c5a8bc07ebed3a2969 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 23 Aug 2024 16:51:48 -0700 Subject: [PATCH 010/110] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 376b4b4e..9a8571b4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1263,7 +1263,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len - t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float() + t = torch.arange(self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64).float() # Long sequences freqs = torch.outer(t, self.long_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) From 3b49609a72d9175ec9fd61b93851759026771726 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 23 Aug 2024 17:36:04 -0700 Subject: [PATCH 011/110] get_token --- unsloth/models/llama.py | 5 +---- unsloth/models/loader.py | 5 ++--- unsloth/save.py | 3 +-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1017ded1..f62f0f11 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1418,10 +1418,7 @@ def from_pretrained( "Are you certain you want to do remote code execution?" ) pass - - if token is None : - token = get_token() - + if token is None: token = get_token() if model_patcher is None: model_patcher = FastLlamaModel SUPPORTS_BFLOAT16 = is_bfloat16_supported() gpu_stats = torch.cuda.get_device_properties(0) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 712d31d7..e1f17aca 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -153,9 +153,8 @@ def from_pretrained( revision = None, *args, **kwargs, ): - if token is None : - token = get_token() - + if token is None: token = get_token() + old_model_name = model_name model_name = get_model_name(model_name, load_in_4bit) diff --git a/unsloth/save.py b/unsloth/save.py index 1d23a206..66e2ec6b 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -208,8 +208,7 @@ def unsloth_save_model( temporary_location : str = "_unsloth_temporary_saved_buffers", maximum_memory_usage : float = 0.9, ): - if token is None : - token = get_token() + if token is None: token = get_token() if commit_message is None: commit_message = "" if "Unsloth" not in commit_message: From 9c8875e81c1efb651a387e7ddd51d5533e8ccfc3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 23 Aug 2024 17:37:45 -0700 Subject: [PATCH 012/110] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5c4de1f5..6cd1be13 100644 --- a/README.md +++ b/README.md @@ -32,12 +32,13 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and | **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) | 1.9x faster | 43% less | - **Kaggle Notebooks** for [Llama 3.1 (8B)](https://www.kaggle.com/danielhanchen/kaggle-llama-3-1-8b-unsloth-notebook), [Gemma 2 (9B)](https://www.kaggle.com/code/danielhanchen/kaggle-gemma-7b-unsloth-notebook/), [Mistral (7B)](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook) -- Run [Llama 3 conversational notebook](https://colab.research.google.com/drive/1XamvWYinY6FOSX9GLvnqSjjsNflxdhNc?usp=sharing) and [Mistral v0.3 ChatML](https://colab.research.google.com/drive/15F1xyn8497_dUbxZP4zWmPZ3PJx1Oymv?usp=sharing) +- Run [Llama 3.1 conversational notebook](https://colab.research.google.com/drive/15OyFkGoCImV9dSsewU1wa2JuKB4-mDE_?usp=sharing) and [Mistral v0.3 ChatML](https://colab.research.google.com/drive/15F1xyn8497_dUbxZP4zWmPZ3PJx1Oymv?usp=sharing) - This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text - This [continued pretraining notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) is for learning another language - Click [here](https://github.com/unslothai/unsloth/wiki) for detailed documentation for Unsloth. ## 🦥 Unsloth.ai News +- 📣 NEW! [Llama 3.1 Conversational notebook](https://colab.research.google.com/drive/15OyFkGoCImV9dSsewU1wa2JuKB4-mDE_?usp=sharing) includes training only on completions / outputs (increase accuracy), ShareGPT standardization and more! - 📣 NEW! [Phi-3.5 (mini)](https://colab.research.google.com/drive/1lN6hPQveB_mHSnTOYifygFcrO8C1bxq4?usp=sharing) now supported - 📣 NEW! `pip install unsloth` now works! Head over to [pypi](https://pypi.org/project/unsloth/) to check it out! This allows non git pull installs. Use `pip install unsloth[colab-new]` for non dependency installs. - 📣 NEW! [Gemma-2-2b](https://colab.research.google.com/drive/1weTpKOjBZxZJ5PQ-Ql8i6ptAY2x-FWVA?usp=sharing) now supported! Try out [Chat interface](https://colab.research.google.com/drive/1i-8ESvtLRGNkkUQQr_-z_rcSAIo9c3lM?usp=sharing)! From a44357d86fa170a1954d9549cb63a854b963d955 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 19:23:14 -0700 Subject: [PATCH 013/110] Update gemma2.py --- unsloth/models/gemma2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 6858f525..c2f55855 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -18,6 +18,7 @@ GemmaFixedRotaryEmbedding, GemmaFixedLinearScalingRotaryEmbedding, fast_geglu_inference, + fast_rms_layernorm, ) try: from transformers.models.gemma2.modeling_gemma2 import ( @@ -204,7 +205,7 @@ def Gemma2DecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states - hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, @@ -215,14 +216,14 @@ def Gemma2DecoderLayer_fast_forward( use_cache=use_cache, padding_mask=padding_mask, ) - hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True) hidden_states = self.mlp(hidden_states) - hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states pass From 7ed1c16abdd89e8664e00cc9038daf9eddab9d1c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 20:13:31 -0700 Subject: [PATCH 014/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index f26e5965..f8aff859 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -41,7 +41,7 @@ def _rms_layernorm_forward( r += row_idx * r_row_stride X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) - W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) row_var = tl.sum(X_row * X_row, axis = 0) / n_cols inv_var = tl.math.rsqrt(row_var + eps) From d7ef49ebb2b8727700887eaed0ccc7ba0aac7183 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 20:18:55 -0700 Subject: [PATCH 015/110] synchronize --- unsloth/kernels/rms_layernorm.py | 2 +- unsloth/models/gemma2.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index f8aff859..f26e5965 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -41,7 +41,7 @@ def _rms_layernorm_forward( r += row_idx * r_row_stride X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) - W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32) row_var = tl.sum(X_row * X_row, axis = 0) / n_cols inv_var = tl.math.rsqrt(row_var + eps) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index c2f55855..f9407a78 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -206,6 +206,7 @@ def Gemma2DecoderLayer_fast_forward( else: residual = hidden_states hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True) + torch.cuda.synchronize() hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, @@ -217,13 +218,16 @@ def Gemma2DecoderLayer_fast_forward( padding_mask=padding_mask, ) hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) + torch.cuda.synchronize() hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True) + torch.cuda.synchronize() hidden_states = self.mlp(hidden_states) hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True) + torch.cuda.synchronize() hidden_states = residual + hidden_states pass From 9a6954877e4f1c0b48126c0016eb0ca2284dde78 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 20:21:10 -0700 Subject: [PATCH 016/110] Update gemma2.py --- unsloth/models/gemma2.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index f9407a78..c2f55855 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -206,7 +206,6 @@ def Gemma2DecoderLayer_fast_forward( else: residual = hidden_states hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True) - torch.cuda.synchronize() hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, @@ -218,16 +217,13 @@ def Gemma2DecoderLayer_fast_forward( padding_mask=padding_mask, ) hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) - torch.cuda.synchronize() hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True) - torch.cuda.synchronize() hidden_states = self.mlp(hidden_states) hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True) - torch.cuda.synchronize() hidden_states = residual + hidden_states pass From e6dadb42202ca2b726b79344fa3a6f80830da843 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 20:22:54 -0700 Subject: [PATCH 017/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index f26e5965..937765f7 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -116,7 +116,7 @@ def _gemma_rms_layernorm_forward( r += row_idx * r_row_stride X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) - W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32) row_var = tl.sum(X_row * X_row, axis = 0) / n_cols inv_var = tl.math.rsqrt(row_var + eps) From f8e77cf0e597ee283b8a92c8609cafb422a8ee3f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 20:27:00 -0700 Subject: [PATCH 018/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 937765f7..00d25185 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -116,14 +116,15 @@ def _gemma_rms_layernorm_forward( r += row_idx * r_row_stride X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) - W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) row_var = tl.sum(X_row * X_row, axis = 0) / n_cols inv_var = tl.math.rsqrt(row_var + eps) tl.store(r, inv_var) normed = X_row * inv_var output = normed * (W_row + 1.0) - + output = output.to(X_row.dtype) + tl.store(Y + col_offsets, output, mask = mask) pass From cfbaa97cecd1c01798f8b8d963df9d81b0a52477 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 20:29:54 -0700 Subject: [PATCH 019/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 00d25185..a1158b5e 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -124,7 +124,7 @@ def _gemma_rms_layernorm_forward( normed = X_row * inv_var output = normed * (W_row + 1.0) output = output.to(X_row.dtype) - + tl.store(Y + col_offsets, output, mask = mask) pass @@ -141,6 +141,7 @@ def forward(ctx, X, W, eps, gemma = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + torch.cuda.synchronize() fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward fx[(n_rows,)]( Y, Y.stride(0), @@ -151,6 +152,7 @@ def forward(ctx, X, W, eps, gemma = False): BLOCK_SIZE = BLOCK_SIZE, num_warps = num_warps, ) + torch.cuda.synchronize() ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -168,6 +170,7 @@ def backward(ctx, dY): n_rows, n_cols = dY.shape dW = X + torch.cuda.synchronize() _rms_layernorm_backward[(n_rows,)]( dY, dY.stride(0), X, X .stride(0), @@ -179,6 +182,7 @@ def backward(ctx, dY): BLOCK_SIZE = ctx.BLOCK_SIZE, num_warps = ctx.num_warps, ) + torch.cuda.synchronize() dX = dY.view(*shape) return dX, None, None, None pass From 32b2f3f3b38b738a91a0c64f32f23ca934ae39f0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 20:33:22 -0700 Subject: [PATCH 020/110] layernorm --- unsloth/kernels/rms_layernorm.py | 5 ----- unsloth/models/_utils.py | 2 +- unsloth/models/gemma2.py | 9 ++++----- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index a1158b5e..f26e5965 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -123,7 +123,6 @@ def _gemma_rms_layernorm_forward( tl.store(r, inv_var) normed = X_row * inv_var output = normed * (W_row + 1.0) - output = output.to(X_row.dtype) tl.store(Y + col_offsets, output, mask = mask) pass @@ -141,7 +140,6 @@ def forward(ctx, X, W, eps, gemma = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - torch.cuda.synchronize() fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward fx[(n_rows,)]( Y, Y.stride(0), @@ -152,7 +150,6 @@ def forward(ctx, X, W, eps, gemma = False): BLOCK_SIZE = BLOCK_SIZE, num_warps = num_warps, ) - torch.cuda.synchronize() ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -170,7 +167,6 @@ def backward(ctx, dY): n_rows, n_cols = dY.shape dW = X - torch.cuda.synchronize() _rms_layernorm_backward[(n_rows,)]( dY, dY.stride(0), X, X .stride(0), @@ -182,7 +178,6 @@ def backward(ctx, dY): BLOCK_SIZE = ctx.BLOCK_SIZE, num_warps = ctx.num_warps, ) - torch.cuda.synchronize() dX = dY.view(*shape) return dX, None, None, None pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1c48e8e5..2fbeb4da 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -350,7 +350,7 @@ def is_big_gpu(index): "epilogue_fusion" : True, "max_autotune" : True, "shape_padding" : True, - "trace.enabled" : False, # Output Triton kernel outputs! + "trace.enabled" : True, # Output Triton kernel outputs! "triton.cudagraphs" : False, } # ============================================= diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index c2f55855..6858f525 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -18,7 +18,6 @@ GemmaFixedRotaryEmbedding, GemmaFixedLinearScalingRotaryEmbedding, fast_geglu_inference, - fast_rms_layernorm, ) try: from transformers.models.gemma2.modeling_gemma2 import ( @@ -205,7 +204,7 @@ def Gemma2DecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states - hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, @@ -216,14 +215,14 @@ def Gemma2DecoderLayer_fast_forward( use_cache=use_cache, padding_mask=padding_mask, ) - hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True) hidden_states = self.mlp(hidden_states) - hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states pass From 9e7057d10c3c7f732870dc180a117d8749ca12e6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 20:54:22 -0700 Subject: [PATCH 021/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index f26e5965..36e98045 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -120,6 +120,7 @@ def _gemma_rms_layernorm_forward( row_var = tl.sum(X_row * X_row, axis = 0) / n_cols inv_var = tl.math.rsqrt(row_var + eps) + tl.debug_barrier() tl.store(r, inv_var) normed = X_row * inv_var output = normed * (W_row + 1.0) From a19350834e0b19051137f6068175dc5e0f000cc3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 20:54:42 -0700 Subject: [PATCH 022/110] Update gemma2.py --- unsloth/models/gemma2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 6858f525..c2f55855 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -18,6 +18,7 @@ GemmaFixedRotaryEmbedding, GemmaFixedLinearScalingRotaryEmbedding, fast_geglu_inference, + fast_rms_layernorm, ) try: from transformers.models.gemma2.modeling_gemma2 import ( @@ -204,7 +205,7 @@ def Gemma2DecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states - hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, @@ -215,14 +216,14 @@ def Gemma2DecoderLayer_fast_forward( use_cache=use_cache, padding_mask=padding_mask, ) - hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True) hidden_states = self.mlp(hidden_states) - hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states pass From 65eaa2d2a2319dd3cc8c53a54ead9fcf5753edd8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 21:07:06 -0700 Subject: [PATCH 023/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 36e98045..62b4aeab 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -16,7 +16,7 @@ import triton.language as tl import torch from .utils import calculate_settings - +from triton.language.extra import libdevice @triton.jit def _rms_layernorm_forward( @@ -119,7 +119,7 @@ def _gemma_rms_layernorm_forward( W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) row_var = tl.sum(X_row * X_row, axis = 0) / n_cols - inv_var = tl.math.rsqrt(row_var + eps) + inv_var = libdevice.rsqrt(row_var + eps) tl.debug_barrier() tl.store(r, inv_var) normed = X_row * inv_var From 1beeb22ed7fbf504de7ee31e1f4e78f6ce0508f7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 21:12:48 -0700 Subject: [PATCH 024/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 62b4aeab..f4f85a07 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -116,13 +116,13 @@ def _gemma_rms_layernorm_forward( r += row_idx * r_row_stride X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) - W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) row_var = tl.sum(X_row * X_row, axis = 0) / n_cols inv_var = libdevice.rsqrt(row_var + eps) tl.debug_barrier() tl.store(r, inv_var) normed = X_row * inv_var + W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) output = normed * (W_row + 1.0) tl.store(Y + col_offsets, output, mask = mask) From 1eb770591625b7ab34ffb9af1dabda166e7881d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 29 Aug 2024 21:15:43 -0700 Subject: [PATCH 025/110] revert --- unsloth/kernels/rms_layernorm.py | 7 +++---- unsloth/models/gemma2.py | 9 ++++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index f4f85a07..f26e5965 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -16,7 +16,7 @@ import triton.language as tl import torch from .utils import calculate_settings -from triton.language.extra import libdevice + @triton.jit def _rms_layernorm_forward( @@ -116,13 +116,12 @@ def _gemma_rms_layernorm_forward( r += row_idx * r_row_stride X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) row_var = tl.sum(X_row * X_row, axis = 0) / n_cols - inv_var = libdevice.rsqrt(row_var + eps) - tl.debug_barrier() + inv_var = tl.math.rsqrt(row_var + eps) tl.store(r, inv_var) normed = X_row * inv_var - W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) output = normed * (W_row + 1.0) tl.store(Y + col_offsets, output, mask = mask) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index c2f55855..6858f525 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -18,7 +18,6 @@ GemmaFixedRotaryEmbedding, GemmaFixedLinearScalingRotaryEmbedding, fast_geglu_inference, - fast_rms_layernorm, ) try: from transformers.models.gemma2.modeling_gemma2 import ( @@ -205,7 +204,7 @@ def Gemma2DecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states - hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, @@ -216,14 +215,14 @@ def Gemma2DecoderLayer_fast_forward( use_cache=use_cache, padding_mask=padding_mask, ) - hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True) hidden_states = self.mlp(hidden_states) - hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states pass From c3fe972fc420ab28cd791b66e3d0a1ced3471f7b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 11:32:34 -0700 Subject: [PATCH 026/110] Gemma --- unsloth/kernels/rms_layernorm.py | 43 ++++++++++++++++++++++++++++++-- unsloth/models/gemma2.py | 9 ++++--- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index f26e5965..2aa95f86 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -95,6 +95,43 @@ def _rms_layernorm_backward( pass +@triton.jit +def _gemma_rms_layernorm_backward( + dY, dY_row_stride, + X, X_row_stride, + W, W_row_stride, + r, r_row_stride, + dW, dW_row_stride, + n_cols, eps, + BLOCK_SIZE : tl.constexpr, +): + """ + Fast RMS Layernorm kernel for the backward pass + Inspiration from a Triton tutorial: + https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + """ + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dY += row_idx * dY_row_stride + X += row_idx * X_row_stride + r += row_idx * r_row_stride + + dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32) + X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) + + # Get saved row variance + inv_var = tl.load(r).to(tl.float32) + normed = X_row * inv_var + dY_W = dY_row * (W_row + 1.0) + + rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0) + output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) + tl.store(dY + col_offsets, output, mask = mask) +pass + @triton.jit def _gemma_rms_layernorm_forward( Y, Y_row_stride, @@ -167,7 +204,7 @@ def backward(ctx, dY): n_rows, n_cols = dY.shape dW = X - _rms_layernorm_backward[(n_rows,)]( + _gemma_rms_layernorm_backward[(n_rows,)]( dY, dY.stride(0), X, X .stride(0), W, W .stride(0), @@ -186,7 +223,9 @@ def backward(ctx, dY): def fast_rms_layernorm(layernorm, X, gemma = False): W = layernorm.weight - eps = layernorm.variance_epsilon + eps = layernorm.variance_epsilon if \ + hasattr(layernorm, "variance_epsilon") \ + else layernorm.eps out = Fast_RMS_Layernorm.apply(X, W, eps, gemma) return out pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 6858f525..c2f55855 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -18,6 +18,7 @@ GemmaFixedRotaryEmbedding, GemmaFixedLinearScalingRotaryEmbedding, fast_geglu_inference, + fast_rms_layernorm, ) try: from transformers.models.gemma2.modeling_gemma2 import ( @@ -204,7 +205,7 @@ def Gemma2DecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states - hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, @@ -215,14 +216,14 @@ def Gemma2DecoderLayer_fast_forward( use_cache=use_cache, padding_mask=padding_mask, ) - hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True) hidden_states = self.mlp(hidden_states) - hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states pass From 75dbfbab54404bcc2f76ab756a40d1c7c8b720c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 11:40:20 -0700 Subject: [PATCH 027/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 2aa95f86..7962049e 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -126,7 +126,7 @@ def _gemma_rms_layernorm_backward( inv_var = tl.load(r).to(tl.float32) normed = X_row * inv_var dY_W = dY_row * (W_row + 1.0) - + rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0) output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) tl.store(dY + col_offsets, output, mask = mask) @@ -211,7 +211,6 @@ def backward(ctx, dY): r, r .stride(0), dW, dW.stride(0), n_cols, ctx.eps, - GEMMA = ctx.GEMMA, BLOCK_SIZE = ctx.BLOCK_SIZE, num_warps = ctx.num_warps, ) From 332b091385e13acd580a4c30b1233cb625a90d3a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 11:44:46 -0700 Subject: [PATCH 028/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 7962049e..12fb1b56 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -125,7 +125,7 @@ def _gemma_rms_layernorm_backward( # Get saved row variance inv_var = tl.load(r).to(tl.float32) normed = X_row * inv_var - dY_W = dY_row * (W_row + 1.0) + dY_W = dY_row * W_row rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0) output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) From 4ecc1198851dd625274f5f90e41983dee3ecd917 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 13:44:28 -0700 Subject: [PATCH 029/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 12fb1b56..8c4d7f33 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -125,7 +125,7 @@ def _gemma_rms_layernorm_backward( # Get saved row variance inv_var = tl.load(r).to(tl.float32) normed = X_row * inv_var - dY_W = dY_row * W_row + dY_W = dY_row * (W_row + 1.0) rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0) output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) @@ -199,7 +199,7 @@ def forward(ctx, X, W, eps, gemma = False): def backward(ctx, dY): shape = dY.shape dim = shape[-1] - dY = dY.view(-1, dim) + dY = dY.contiguous().view(-1, dim) X, W, r = ctx.saved_tensors n_rows, n_cols = dY.shape dW = X From 07a12465e6e1d196268c5dfbdb387b2604d7b44e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 14:46:18 -0700 Subject: [PATCH 030/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 8c4d7f33..150cd060 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -204,16 +204,22 @@ def backward(ctx, dY): n_rows, n_cols = dY.shape dW = X - _gemma_rms_layernorm_backward[(n_rows,)]( - dY, dY.stride(0), - X, X .stride(0), - W, W .stride(0), - r, r .stride(0), - dW, dW.stride(0), - n_cols, ctx.eps, - BLOCK_SIZE = ctx.BLOCK_SIZE, - num_warps = ctx.num_warps, - ) + inv_var = r.float() + normed = X * inv_var + dY_W = dY * (W.float() + 1.0) + rowsum_dY_normed = dY_W.sum(axis = 0) + dY = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) + + # _gemma_rms_layernorm_backward[(n_rows,)]( + # dY, dY.stride(0), + # X, X .stride(0), + # W, W .stride(0), + # r, r .stride(0), + # dW, dW.stride(0), + # n_cols, ctx.eps, + # BLOCK_SIZE = ctx.BLOCK_SIZE, + # num_warps = ctx.num_warps, + # ) dX = dY.view(*shape) return dX, None, None, None pass From e3239e437e09e693124cecabd8f8de2881735a3b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 15:06:32 -0700 Subject: [PATCH 031/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 150cd060..24498a16 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -205,6 +205,7 @@ def backward(ctx, dY): dW = X inv_var = r.float() + print(inv_var.shape, X.shape) normed = X * inv_var dY_W = dY * (W.float() + 1.0) rowsum_dY_normed = dY_W.sum(axis = 0) From 6ae1ac2a8838f1b6fcc5f010140840f7010c3974 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 15:23:48 -0700 Subject: [PATCH 032/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 24498a16..2e7dab21 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -206,7 +206,7 @@ def backward(ctx, dY): inv_var = r.float() print(inv_var.shape, X.shape) - normed = X * inv_var + normed = X * inv_var.unsqueeze(-1) dY_W = dY * (W.float() + 1.0) rowsum_dY_normed = dY_W.sum(axis = 0) dY = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) From 4d89f276ae47177704ef8febceee97d5172dbe9b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 15:26:36 -0700 Subject: [PATCH 033/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 2e7dab21..ca1aba2f 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -204,9 +204,8 @@ def backward(ctx, dY): n_rows, n_cols = dY.shape dW = X - inv_var = r.float() - print(inv_var.shape, X.shape) - normed = X * inv_var.unsqueeze(-1) + inv_var = r.float().unsqueeze(-1) + normed = X * inv_var dY_W = dY * (W.float() + 1.0) rowsum_dY_normed = dY_W.sum(axis = 0) dY = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) From c76be22d0fbdb1d07d7dbbe005772dd0e5be6a6e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 15:42:09 -0700 Subject: [PATCH 034/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index ca1aba2f..f2d55e1b 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -177,16 +177,22 @@ def forward(ctx, X, W, eps, gemma = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward - fx[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, W.stride(0), - r, r.stride(0), - n_cols, eps, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) + + row_var = X.to(torch.float32).square().sum(axis = 0) / n_cols + inv_var = torch.rqsrt(row_var + eps) + normed = X.to(torch.float32) * inv_var.unsqueeze(-1) + Y = normed * (W_row.unsqueeze(-1) + 1.0) + + # fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward + # fx[(n_rows,)]( + # Y, Y.stride(0), + # X, X.stride(0), + # W, W.stride(0), + # r, r.stride(0), + # n_cols, eps, + # BLOCK_SIZE = BLOCK_SIZE, + # num_warps = num_warps, + # ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps From ace509c5aef8bebd5e0545f65e55768855a69b45 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 15:44:28 -0700 Subject: [PATCH 035/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index f2d55e1b..5c2595e4 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -179,7 +179,7 @@ def forward(ctx, X, W, eps, gemma = False): row_var = X.to(torch.float32).square().sum(axis = 0) / n_cols - inv_var = torch.rqsrt(row_var + eps) + inv_var = torch.rsqrt(row_var + eps) normed = X.to(torch.float32) * inv_var.unsqueeze(-1) Y = normed * (W_row.unsqueeze(-1) + 1.0) From e474cfe0fc7c277ce8a63bf467175451fdee2706 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 15:46:21 -0700 Subject: [PATCH 036/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 5c2595e4..98ccb0fe 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -178,7 +178,8 @@ def forward(ctx, X, W, eps, gemma = False): r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - row_var = X.to(torch.float32).square().sum(axis = 0) / n_cols + row_var = X.to(torch.float32).square().sum(axis = 1) / n_cols + print(row_var.shape) inv_var = torch.rsqrt(row_var + eps) normed = X.to(torch.float32) * inv_var.unsqueeze(-1) Y = normed * (W_row.unsqueeze(-1) + 1.0) From 1576a1ecac095a6c83d5930db8165891deee0cc3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 15:50:25 -0700 Subject: [PATCH 037/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 98ccb0fe..139b1a28 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -182,7 +182,7 @@ def forward(ctx, X, W, eps, gemma = False): print(row_var.shape) inv_var = torch.rsqrt(row_var + eps) normed = X.to(torch.float32) * inv_var.unsqueeze(-1) - Y = normed * (W_row.unsqueeze(-1) + 1.0) + Y = normed * (W + 1.0) # fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward # fx[(n_rows,)]( From a2c469103c986e95490d4253ee7234b0efd90012 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 15:56:19 -0700 Subject: [PATCH 038/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 139b1a28..d90b6b9c 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -183,7 +183,8 @@ def forward(ctx, X, W, eps, gemma = False): inv_var = torch.rsqrt(row_var + eps) normed = X.to(torch.float32) * inv_var.unsqueeze(-1) Y = normed * (W + 1.0) - + Y = Y.to(torch.bfloat16) + # fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward # fx[(n_rows,)]( # Y, Y.stride(0), From 1a02e75a895dd690bddc9aef3d85f2659442ccdc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 15:58:22 -0700 Subject: [PATCH 039/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index d90b6b9c..e18883db 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -179,7 +179,6 @@ def forward(ctx, X, W, eps, gemma = False): row_var = X.to(torch.float32).square().sum(axis = 1) / n_cols - print(row_var.shape) inv_var = torch.rsqrt(row_var + eps) normed = X.to(torch.float32) * inv_var.unsqueeze(-1) Y = normed * (W + 1.0) From a26e1d1155d3031844bcff3fb235cf4f6f191f19 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 31 Aug 2024 16:14:11 -0700 Subject: [PATCH 040/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index e18883db..e729aa2e 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -177,7 +177,7 @@ def forward(ctx, X, W, eps, gemma = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - + print(X.shape) row_var = X.to(torch.float32).square().sum(axis = 1) / n_cols inv_var = torch.rsqrt(row_var + eps) normed = X.to(torch.float32) * inv_var.unsqueeze(-1) From afdb443e7f345e0c1f53ef68a3b3358162acc4bf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 00:35:07 -0700 Subject: [PATCH 041/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index e729aa2e..83097437 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -177,8 +177,8 @@ def forward(ctx, X, W, eps, gemma = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - print(X.shape) - row_var = X.to(torch.float32).square().sum(axis = 1) / n_cols + # print(X.shape) + row_var = torch.mean(X.to(torch.float32).square(), axis = 0) inv_var = torch.rsqrt(row_var + eps) normed = X.to(torch.float32) * inv_var.unsqueeze(-1) Y = normed * (W + 1.0) @@ -214,8 +214,8 @@ def backward(ctx, dY): inv_var = r.float().unsqueeze(-1) normed = X * inv_var dY_W = dY * (W.float() + 1.0) - rowsum_dY_normed = dY_W.sum(axis = 0) - dY = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) + rowsum_dY_normed = dY_W.mean(axis = 0) + dY = inv_var*dY_W - normed*rowsum_dY_normed*inv_var # _gemma_rms_layernorm_backward[(n_rows,)]( # dY, dY.stride(0), From c3e14d8a1ac8d600a7e7a98eddbfb52e12b08730 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 00:39:17 -0700 Subject: [PATCH 042/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 83097437..ded872c2 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -180,7 +180,7 @@ def forward(ctx, X, W, eps, gemma = False): # print(X.shape) row_var = torch.mean(X.to(torch.float32).square(), axis = 0) inv_var = torch.rsqrt(row_var + eps) - normed = X.to(torch.float32) * inv_var.unsqueeze(-1) + normed = X.to(torch.float32) * inv_var Y = normed * (W + 1.0) Y = Y.to(torch.bfloat16) From 1830bdd5fa559d102d2cdd5ebfa990d21233a116 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 00:41:31 -0700 Subject: [PATCH 043/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index ded872c2..71566b21 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -214,8 +214,8 @@ def backward(ctx, dY): inv_var = r.float().unsqueeze(-1) normed = X * inv_var dY_W = dY * (W.float() + 1.0) - rowsum_dY_normed = dY_W.mean(axis = 0) - dY = inv_var*dY_W - normed*rowsum_dY_normed*inv_var + rowsum_dY_normed = dY_W.sum(axis = 0) + dY = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) # _gemma_rms_layernorm_backward[(n_rows,)]( # dY, dY.stride(0), From 6abf66a368896c1f48ee4e2e0b534ea3f63ec2c8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 00:46:20 -0700 Subject: [PATCH 044/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 71566b21..ff2bb969 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -176,11 +176,10 @@ def forward(ctx, X, W, eps, gemma = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - - # print(X.shape) - row_var = torch.mean(X.to(torch.float32).square(), axis = 0) + + row_var = X.to(torch.float32).square().sum(axis = 1) / n_cols inv_var = torch.rsqrt(row_var + eps) - normed = X.to(torch.float32) * inv_var + normed = X.to(torch.float32) * inv_var.unsqueeze(-1) Y = normed * (W + 1.0) Y = Y.to(torch.bfloat16) From f5cf796c9df05694b33a0c27e25beccb95add936 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 00:57:38 -0700 Subject: [PATCH 045/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index ff2bb969..7b340b02 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -176,8 +176,8 @@ def forward(ctx, X, W, eps, gemma = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - - row_var = X.to(torch.float32).square().sum(axis = 1) / n_cols + + row_var = X.to(torch.float32).square().mean(axis = 1) inv_var = torch.rsqrt(row_var + eps) normed = X.to(torch.float32) * inv_var.unsqueeze(-1) Y = normed * (W + 1.0) From b19153024c85440225157ac8eeca4f5b86c51363 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 01:00:45 -0700 Subject: [PATCH 046/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 7b340b02..27381440 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -213,8 +213,8 @@ def backward(ctx, dY): inv_var = r.float().unsqueeze(-1) normed = X * inv_var dY_W = dY * (W.float() + 1.0) - rowsum_dY_normed = dY_W.sum(axis = 0) - dY = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) + rowsum_dY_normed = dY_W.mean(axis = 0) + dY = inv_var * (dY_W - normed*rowsum_dY_normed) # _gemma_rms_layernorm_backward[(n_rows,)]( # dY, dY.stride(0), From 512c61f1704d60de379e8566831887ae209b9c91 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 01:04:06 -0700 Subject: [PATCH 047/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 27381440..79e8e36b 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -205,10 +205,11 @@ def forward(ctx, X, W, eps, gemma = False): def backward(ctx, dY): shape = dY.shape dim = shape[-1] - dY = dY.contiguous().view(-1, dim) + dY = dY.view(-1, dim) X, W, r = ctx.saved_tensors n_rows, n_cols = dY.shape dW = X + print(dY.shape, X.shape) inv_var = r.float().unsqueeze(-1) normed = X * inv_var From f5d50ef042a263f07753cb4403598e1a1081c41b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 01:07:46 -0700 Subject: [PATCH 048/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 79e8e36b..7688c43a 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -209,7 +209,6 @@ def backward(ctx, dY): X, W, r = ctx.saved_tensors n_rows, n_cols = dY.shape dW = X - print(dY.shape, X.shape) inv_var = r.float().unsqueeze(-1) normed = X * inv_var From d791bb9c03b1d1b362c32442fa1bdd71e8372c08 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 01:16:26 -0700 Subject: [PATCH 049/110] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 7688c43a..f1527839 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -212,7 +212,7 @@ def backward(ctx, dY): inv_var = r.float().unsqueeze(-1) normed = X * inv_var - dY_W = dY * (W.float() + 1.0) + dY_W = dY * (W.float()) rowsum_dY_normed = dY_W.mean(axis = 0) dY = inv_var * (dY_W - normed*rowsum_dY_normed) From 92256086831258c05dd3502c209223191d6b8a75 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 1 Sep 2024 12:33:28 -0700 Subject: [PATCH 050/110] Update gemma2.py --- unsloth/models/gemma2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index c2f55855..6cd537d1 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -205,7 +205,7 @@ def Gemma2DecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states - hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, @@ -216,14 +216,14 @@ def Gemma2DecoderLayer_fast_forward( use_cache=use_cache, padding_mask=padding_mask, ) - hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True) hidden_states = self.mlp(hidden_states) - hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True) hidden_states = residual + hidden_states pass From f61869c3921604f94679f389d14cfb96900d16ca Mon Sep 17 00:00:00 2001 From: Tuan Pham <82665400+vTuanpham@users.noreply.github.com> Date: Mon, 2 Sep 2024 15:01:15 +0700 Subject: [PATCH 051/110] Change UnslothTrainingArguments base class to SFTConfig (#979) --- unsloth/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/trainer.py b/unsloth/trainer.py index c8e00be2..883f6652 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -14,8 +14,7 @@ from dataclasses import dataclass, field from typing import Optional -from transformers import TrainingArguments -from trl import SFTTrainer +from trl import SFTTrainer, SFTConfig from . import is_bfloat16_supported __all__ = [ @@ -25,7 +24,7 @@ @dataclass -class UnslothTrainingArguments(TrainingArguments): +class UnslothTrainingArguments(SFTConfig): embedding_learning_rate : Optional[float] = field( default = None, metadata = {"help" : "Different learning rates for embeddings and lm_head."} From 73d49ad39aaa9100bf442719a9c2cab89feb5e8f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 01:27:44 -0700 Subject: [PATCH 052/110] Cohere --- unsloth/kernels/rms_layernorm.py | 90 ++--- unsloth/models/_utils.py | 2 +- unsloth/models/cohere.py | 604 +++++++++++++++++++++++++++++++ unsloth/models/gemma2.py | 1 - unsloth/models/llama.py | 13 +- 5 files changed, 638 insertions(+), 72 deletions(-) create mode 100644 unsloth/models/cohere.py diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index f1527839..ac5beb5a 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -95,43 +95,6 @@ def _rms_layernorm_backward( pass -@triton.jit -def _gemma_rms_layernorm_backward( - dY, dY_row_stride, - X, X_row_stride, - W, W_row_stride, - r, r_row_stride, - dW, dW_row_stride, - n_cols, eps, - BLOCK_SIZE : tl.constexpr, -): - """ - Fast RMS Layernorm kernel for the backward pass - Inspiration from a Triton tutorial: - https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html - """ - row_idx = tl.program_id(0) - col_offsets = tl.arange(0, BLOCK_SIZE) - mask = col_offsets < n_cols - - dY += row_idx * dY_row_stride - X += row_idx * X_row_stride - r += row_idx * r_row_stride - - dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32) - X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) - W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) - - # Get saved row variance - inv_var = tl.load(r).to(tl.float32) - normed = X_row * inv_var - dY_W = dY_row * (W_row + 1.0) - - rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0) - output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) - tl.store(dY + col_offsets, output, mask = mask) -pass - @triton.jit def _gemma_rms_layernorm_forward( Y, Y_row_stride, @@ -177,22 +140,16 @@ def forward(ctx, X, W, eps, gemma = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - row_var = X.to(torch.float32).square().mean(axis = 1) - inv_var = torch.rsqrt(row_var + eps) - normed = X.to(torch.float32) * inv_var.unsqueeze(-1) - Y = normed * (W + 1.0) - Y = Y.to(torch.bfloat16) - - # fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward - # fx[(n_rows,)]( - # Y, Y.stride(0), - # X, X.stride(0), - # W, W.stride(0), - # r, r.stride(0), - # n_cols, eps, - # BLOCK_SIZE = BLOCK_SIZE, - # num_warps = num_warps, - # ) + fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward + fx[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -210,22 +167,17 @@ def backward(ctx, dY): n_rows, n_cols = dY.shape dW = X - inv_var = r.float().unsqueeze(-1) - normed = X * inv_var - dY_W = dY * (W.float()) - rowsum_dY_normed = dY_W.mean(axis = 0) - dY = inv_var * (dY_W - normed*rowsum_dY_normed) - - # _gemma_rms_layernorm_backward[(n_rows,)]( - # dY, dY.stride(0), - # X, X .stride(0), - # W, W .stride(0), - # r, r .stride(0), - # dW, dW.stride(0), - # n_cols, ctx.eps, - # BLOCK_SIZE = ctx.BLOCK_SIZE, - # num_warps = ctx.num_warps, - # ) + _rms_layernorm_backward[(n_rows,)]( + dY, dY.stride(0), + X, X .stride(0), + W, W .stride(0), + r, r .stride(0), + dW, dW.stride(0), + n_cols, ctx.eps, + GEMMA = ctx.GEMMA, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) dX = dY.view(*shape) return dX, None, None, None pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2fbeb4da..1c48e8e5 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -350,7 +350,7 @@ def is_big_gpu(index): "epilogue_fusion" : True, "max_autotune" : True, "shape_padding" : True, - "trace.enabled" : True, # Output Triton kernel outputs! + "trace.enabled" : False, # Output Triton kernel outputs! "triton.cudagraphs" : False, } # ============================================= diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py new file mode 100644 index 00000000..53669836 --- /dev/null +++ b/unsloth/models/cohere.py @@ -0,0 +1,604 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .llama import * +from ._utils import __version__ +try: + from transformers.models.gemma2.modeling_gemma2 import ( + CohereAttention, + CohereDecoderLayer, + CohereModel, + CohereForCausalLM, + CohereRotaryEmbedding, + apply_rotary_pos_emb, + repeat_kv, + ) +except: + from packaging.version import Version + transformers_version = Version(transformers_version) + if not transformers_version >= Version("4.42"): + raise ImportError( + f"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\n"\ + f"The minimum required version is 4.42.3.\n"\ + f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\ + f"to obtain the latest transformers build, then restart this session."\ + ) + pass +pass + +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask_for_sdpa, +) +# For Pytorch 2.1.1 +try: + from transformers.models.cohere.modeling_cohere import ( + CohereSdpaAttention, + CohereFlashAttention2, + ) +except: + CohereSdpaAttention = CohereAttention + CohereFlashAttention2 = CohereAttention +pass + + +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +def fast_layernorm_compiled(layernorm, X): + old_dtype = X.dtype + X = X.float() + mean = X.mean(-1, keepdim = True) + Xbar = X - mean + X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \ + layernorm.variance_epsilon) * \ + layernorm.weight.float() + return X.to(old_dtype) +pass + + +def fast_layernorm_inference(self, X, out_weight = None): + XX = X.to(torch.float32, copy = True) + XX -= X.mean(-1, keepdim = True) + variance = XX.square().mean(-1, keepdim = True) + variance += self.variance_epsilon + XX *= variance.rsqrt_() + out_weight[:] = self.weight + XX *= out_weight + return XX.to(X.dtype) +pass + + +# QK norm in Cohere +def CohereAttention_fast_forward( + self, + hidden_states: torch.Tensor, + causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + *args, **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + # Clear inference + if hasattr(self, "paged_attention"): + del self.paged_attention_K + del self.paged_attention_V + del self.paged_attention + del self.temp_QA + del self.temp_KV + del self.RH_Q + del self.attention + del self.q_norm_out_weight + del self.k_norm_out_weight + pass + + bsz, q_len, _ = hidden_states.size() + + n_heads = self.num_heads + n_groups = self.num_key_value_groups + n_kv_heads = self.num_key_value_heads + head_dim = self.head_dim + assert(n_kv_heads * n_groups == n_heads) + + Q, K, V = self.apply_qkv(self, hidden_states) + Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) + K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + if self.use_qk_norm: + Q = fast_layernorm_compiled(self.q_norm, Q) + K = fast_layernorm_compiled(self.k_norm, K) + pass + + kv_seq_len = K.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if position_ids is None: + cos = self.rotary_emb.cos_cached + sin = self.rotary_emb.sin_cached + Q, K = fast_rope_embedding(Q, K, cos, sin) + else: + cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) + Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) + pass + + if past_key_value is not None: + K = torch.cat([past_key_value[0], K], dim = 2) + V = torch.cat([past_key_value[1], V], dim = 2) + pass + past_key_value = (K, V) if use_cache else None + + # Attention module + if (not HAS_FLASH_ATTENTION and attention_mask is None): + # Xformers memory efficient attention + # Also has Flash Attention v2 dispatching + Q = Q.transpose(1, 2) + K = K.transpose(1, 2) + V = V.transpose(1, 2) + + # Group query attention + if n_groups != 1: + K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) + V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) + K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) + V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) + if hidden_states.requires_grad: + K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) + V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) + else: + Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) + pass + A = xformers_attention(Q, K, V, attn_bias = causal_mask) + A = A.view(bsz, q_len, n_heads, head_dim) + + elif HAS_FLASH_ATTENTION and attention_mask is None: + Q = Q.transpose(1, 2) + K = K.transpose(1, 2) + V = V.transpose(1, 2) + A = flash_attn_func(Q, K, V, causal = True) + else: + # Grouped query attention + if n_groups != 1: + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) + V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) + pass + # Must be contiguous or else results are False! + # https://github.com/pytorch/pytorch/issues/112577 + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() + # Needs (batch_size, n_heads, seq_len, head_dim) + # is_casual and attention_mask must not be both set! + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) + # Go back to (batch_size, seq_len, n_heads, head_dim) + A = A.transpose(1, 2).contiguous() + pass + attn_output = A.reshape(bsz, q_len, n_heads*head_dim) + attn_output = self.apply_o(self, attn_output) + attn_weights = None + return attn_output, attn_weights, past_key_value +pass + + +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590 +def CohereDecoderLayer_fast_forward( + self, + hidden_states: torch.Tensor, + causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + *args, **kwargs, +): + if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None: + out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0") + + # Self Attention + residual = hidden_states + hidden_states = fast_layernorm_inference(self.input_layernorm, hidden_states, out_weight) + hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + + # Fully Connected + hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states) + residual += hidden_states_attention + residual += hidden_states_mlp + hidden_states = residual + else: + residual = hidden_states + hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states) + hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + + # Fully Connected + hidden_states_mlp = self.mlp(hidden_states) + hidden_states = residual + hidden_states_attention + hidden_states_mlp + pass + + outputs = (hidden_states,) + if output_attentions: outputs += (self_attn_weights,) + if use_cache: outputs += (present_key_value,) + return outputs +pass + + +from math import sqrt as math_sqrt +KV_CACHE_INCREMENT = 256 # KV Cache update size +torch_nn_functional_softmax = torch.nn.functional.softmax +torch_matmul = torch.matmul + +def CohereAttention_fast_forward_inference( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]], + position_ids, + do_prefill = False, + attention_mask = None, +): + Xn = hidden_states + bsz, _, hd = hidden_states.size() + K1, V1 = past_key_value + dtype = Xn.dtype + + n_heads = self.num_heads + n_groups = self.num_key_value_groups + n_kv_heads = self.num_key_value_heads + head_dim = self.head_dim + attention_size = n_heads*head_dim + # assert(n_kv_heads * n_groups == n_heads) + seq_len = K1.shape[-2] + kv_seq_len = seq_len + 1 + + # Prefill phase + # if not hasattr(self, "paged_attention"): + if do_prefill: + self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0") + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) + self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) + self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") + self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") + + # Mistral Nemo 12b has weird dimensions + if attention_size != self.hidden_size: + self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + else: + self.temp_O = self.temp_QA[1][:,:,:self.hidden_size] + pass + + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") + self.scalar = 1.0 / math_sqrt(self.head_dim) + self.half_head_dim = head_dim // 2 + # Cohere has QK layernorms + if self.use_qk_norm: + self.q_norm_out_weight = torch.empty_like(self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0") + self.k_norm_out_weight = torch.empty_like(self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0") + else: + self.q_norm_out_weight = None + self.k_norm_out_weight = None + pass + elif kv_seq_len >= self.paged_attention.shape[0]: + self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim)) + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT)) + pass + + Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) + Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) + Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) + Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) + Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) + Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) + if self.use_qk_norm: + Q = fast_layernorm_inference(self.q_norm, Q, self.q_norm_out_weight) + K = fast_layernorm_inference(self.k_norm, K, self.k_norm_out_weight) + pass + + # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) + # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) + cos, sin = self.rotary_emb.get_cached(kv_seq_len) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + h = self.half_head_dim + + RH_Q = self.RH_Q + RH_Q[:,:,:,:h] = Qn[:,:,:,h:] + RH_Q[:,:,:,h:] = Qn[:,:,:,:h] + torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]) + Qn *= cos + Qn.addcmul_(RH_Q, sin) + + RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0") + RH_K[:,:,:,:h] = Kn[:,:,:,h:] + RH_K[:,:,:,h:] = Kn[:,:,:,:h] + torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) + Kn *= cos + Kn.addcmul_(RH_K, sin) + + # New KV cache + # Kn = torch.cat([K1, Kn], dim = 2) + # Vn = torch.cat([V1, Vn], dim = 2) + self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3) + self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3) + Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3) + Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) + + # Handle sliding windows + sliding_window = getattr(self.config, "sliding_window", None) + if sliding_window is not None and kv_seq_len > sliding_window: + # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193 + slicing_tokens = 1 - sliding_window + Knn = Kn[:, :, slicing_tokens:, :]#.contiguous() + Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous() + else: + Knn, Vnn = Kn, Vn + pass + + # Grouped query attention + _, _, cached_len, _ = Knn.shape + if n_groups != 1: + Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) + Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim) + pass + # else: + # Knn, Vnn = Knn, Vnn + # pass + + # Attention + if bsz == 1: + Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 + # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows + A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) + # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched + A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) + A = torch_matmul(A, Vnn, out = Qn) + else: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + pass + A = A.transpose(1, 2) + A = A.reshape(bsz, 1, attention_size) + A = fast_linear_forward(self.o_proj, A, out = self.temp_O) + return A, (Kn, Vn) +pass + + +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 +# @torch.inference_mode +def CohereModel_fast_forward_inference( + self, + input_ids, + past_key_values, + position_ids, + attention_mask = None, +): + out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") + input_ids = input_ids[:,:self.max_seq_length] + hidden_states = self.model.embed_tokens(input_ids) + hidden_states = hidden_states.to(self.config.torch_dtype) + # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32 + # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32 + hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype) + + bsz, q_len, hd = hidden_states.shape + seq_len = past_key_values[0][0].shape[-2] + if bsz != 1: + if HAS_FLASH_ATTENTION_SOFTCAPPING: + SWA = True + GA = False + else: + SWA = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (bsz, q_len), + hidden_states, + seq_len, + sliding_window = self.config.sliding_window, + ) + GA = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (bsz, q_len), + hidden_states, + seq_len, + ) + pass + else: + SWA = attention_mask + GA = attention_mask + pass + next_decoder_cache = [] + for idx, decoder_layer in enumerate(self.model.layers): + + use_sliding_window = idx % 2 == 0 + + residual = hidden_states + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states, present_key_value = Gemma2Attention_fast_forward_inference( + decoder_layer.self_attn, + hidden_states = hidden_states, + past_key_value = past_key_values[idx], + position_ids = position_ids, + attention_mask = SWA if use_sliding_window else GA, + do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), + use_sliding_window = use_sliding_window, + ) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) + hidden_states += residual + + residual = hidden_states + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight) + hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight) + hidden_states += residual + + next_decoder_cache.append(present_key_value) + pass + hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight) + + return BaseModelOutputWithPast( + last_hidden_state = hidden_states, + past_key_values = next_decoder_cache, + hidden_states = [], + attentions = [], + ) +pass + + +class FastCohereModel(FastLlamaModel): + + @staticmethod + def pre_patch(): + init_name, function = patch_linear_scaling( + model_name = "gemma2", + rope_module = GemmaFixedRotaryEmbedding, + scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding, + attention_module = Gemma2Attention, + ) + if init_name is not None: + exec(function, globals()) + Gemma2Attention.__init__ = eval(init_name) + pass + Gemma2Attention .forward = Gemma2Attention_fast_forward + Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward + Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward + Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward + Gemma2Model .forward = LlamaModel_fast_forward + Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference) + PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward + fix_prepare_inputs_for_generation(Gemma2ForCausalLM) + + # Solves https://github.com/unslothai/unsloth/issues/168 + # Static KV Cache was introduced in 4.38.0, causing training to be much slower. + # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. + # https://github.com/huggingface/transformers/pull/27931 + # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py + import transformers.models.gemma2.modeling_gemma2 + transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding + return + pass + + + @staticmethod + def post_patch(model): + # Patch model for Gemma + layers = model.model.layers + + # Torch.compile fails on embedding matrix?? + # Workaround randomnly fixes it for torch versions < 2.2 + model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) + model.config.update({"unsloth_version" : __version__}) + + # We also do this for the lm_head + lm_head = torch.nn.Linear(1, 1, bias = None) + del lm_head.weight + lm_head.weight = model.lm_head.weight + lm_head.in_features = lm_head.weight.shape[1] + lm_head.out_features = lm_head.weight.shape[0] + model.lm_head = lm_head + + # Gemma has tied weights! This means lm_head == embed_tokens + if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr(): + lm_head = torch.nn.Linear(1, 1, bias = None) + del lm_head.weight + lm_head.weight = model.model.embed_tokens.weight + lm_head.in_features = lm_head.weight.shape[1] + lm_head.out_features = lm_head.weight.shape[0] + model.lm_head = lm_head + pass + + # Also patch all dtypes - BnB seems to not allocate the correct type? + # BnB default dtype seems to be float16! + correct_dtype = lm_head.weight.dtype + + for name, module in model.named_modules(): + if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): + weight = module.weight + quant_state = weight.quant_state + + if type(quant_state) is list: + # BnB seems to have float16 as default! + module.weight.quant_state[2] = correct_dtype # Cast to correct dtype + else: + # https://github.com/TimDettmers/bitsandbytes/pull/763/files + quant_state.dtype = correct_dtype + pass + pass + # Downcast RoPE embedding to correct data type + # RoPE must be done in float32 for Gemma + # if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \ + # and (module.cos_cached.dtype != correct_dtype): + + # module.cos_cached = module.cos_cached.to(correct_dtype) + # module.sin_cached = module.sin_cached.to(correct_dtype) + # pass + # pass + pass + + # Add 1 to weight + # return output * (1 + self.weight) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89 + from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm + + # Freeze all parameters except LoRA + # We do this first since += 1 seems to not be liked by requires_grad = True + for name, param in model.named_parameters(): + if ".lora_A." in name or ".lora_B." in name: + param.requires_grad_(True) + else: + param.requires_grad_(False) + pass + + # Patch RMS Layernorm + for name, module in model.named_modules(): + if isinstance(module, Gemma2RMSNorm): + # Must be in float32 + # https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36 + # module = module.to(torch.float32) + # Leave + 1 to Triton kernel itself + # module.weight += 1.0 # return output * (1 + self.weight) + if not hasattr(module, "variance_epsilon"): + module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon + pass + + # Clear deleted GPU items + import gc + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + return model + pass +pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 6cd537d1..6858f525 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -18,7 +18,6 @@ GemmaFixedRotaryEmbedding, GemmaFixedLinearScalingRotaryEmbedding, fast_geglu_inference, - fast_rms_layernorm, ) try: from transformers.models.gemma2.modeling_gemma2 import ( diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f62f0f11..66b059b4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -495,6 +495,16 @@ def LlamaDecoderLayer_fast_forward( pass +# https://github.com/unslothai/unsloth/issues/404#issuecomment-2323473452 +__DTYPE_MAP = { + "float32": torch.float32, + torch.float32: torch.float32, + "float16": torch.float16, + torch.float16: torch.float16, + "bfloat16": torch.bfloat16, + torch.bfloat16: torch.bfloat16, +} + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 def LlamaModel_fast_forward( self, @@ -576,7 +586,8 @@ def LlamaModel_fast_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = inputs_embeds.to(self.config.torch_dtype) + # inputs_embeds = inputs_embeds.to(self.config.torch_dtype) + inputs_embeds = inputs_embeds.to(__DTYPE_MAP[self.config.torch_dtype]) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") From edef5cac84a603c58c15176403564b82e4e82592 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 11:19:44 -0700 Subject: [PATCH 053/110] Update trainer.py --- unsloth/trainer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 883f6652..45616ca6 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -14,7 +14,13 @@ from dataclasses import dataclass, field from typing import Optional -from trl import SFTTrainer, SFTConfig + +from trl import SFTTrainer +try: + from trl import SFTConfig as TrainingArguments +except: + from transformers import TrainingArguments +pass from . import is_bfloat16_supported __all__ = [ @@ -24,7 +30,7 @@ @dataclass -class UnslothTrainingArguments(SFTConfig): +class UnslothTrainingArguments(TrainingArguments): embedding_learning_rate : Optional[float] = field( default = None, metadata = {"help" : "Different learning rates for embeddings and lm_head."} From 6d4300c06548150422982beb5aa09d8dcd393129 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 11:55:17 -0700 Subject: [PATCH 054/110] Cohere --- unsloth/models/cohere.py | 184 +++++++-------------------------------- unsloth/models/llama.py | 9 +- 2 files changed, 41 insertions(+), 152 deletions(-) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 53669836..98116f43 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -304,8 +304,8 @@ def CohereAttention_fast_forward_inference( self.half_head_dim = head_dim // 2 # Cohere has QK layernorms if self.use_qk_norm: - self.q_norm_out_weight = torch.empty_like(self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0") - self.k_norm_out_weight = torch.empty_like(self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0") + self.q_norm_out_weight = torch.empty(self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0") + self.k_norm_out_weight = torch.empty(self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0") else: self.q_norm_out_weight = None self.k_norm_out_weight = None @@ -411,63 +411,41 @@ def CohereModel_fast_forward_inference( input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) - # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32 - # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32 - hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype) - bsz, q_len, hd = hidden_states.shape seq_len = past_key_values[0][0].shape[-2] if bsz != 1: - if HAS_FLASH_ATTENTION_SOFTCAPPING: - SWA = True - GA = False - else: - SWA = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (bsz, q_len), - hidden_states, - seq_len, - sliding_window = self.config.sliding_window, - ) - GA = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (bsz, q_len), - hidden_states, - seq_len, - ) - pass + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (bsz, q_len), + hidden_states, + seq_len, + sliding_window = getattr(self.config, "sliding_window", None), + ) else: - SWA = attention_mask - GA = attention_mask + attention_mask = None pass + next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - - use_sliding_window = idx % 2 == 0 - residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) - hidden_states, present_key_value = Gemma2Attention_fast_forward_inference( + hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference( decoder_layer.self_attn, hidden_states = hidden_states, past_key_value = past_key_values[idx], position_ids = position_ids, - attention_mask = SWA if use_sliding_window else GA, + attention_mask = attention_mask, do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), - use_sliding_window = use_sliding_window, ) - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) - hidden_states += residual - residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight) - hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states) - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight) - hidden_states += residual + hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states) + residual += hidden_states_attention + residual += hidden_states_mlp + hidden_states = residual next_decoder_cache.append(present_key_value) pass - hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight) + hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight) return BaseModelOutputWithPast( last_hidden_state = hidden_states, @@ -483,122 +461,26 @@ class FastCohereModel(FastLlamaModel): @staticmethod def pre_patch(): init_name, function = patch_linear_scaling( - model_name = "gemma2", - rope_module = GemmaFixedRotaryEmbedding, - scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding, - attention_module = Gemma2Attention, + model_name = "cohere", + rope_module = LlamaRotaryEmbedding, + scaled_rope_module = LlamaLinearScalingRotaryEmbedding, + attention_module = CohereAttention, ) if init_name is not None: exec(function, globals()) - Gemma2Attention.__init__ = eval(init_name) + CohereAttention.__init__ = eval(init_name) pass - Gemma2Attention .forward = Gemma2Attention_fast_forward - Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward - Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward - Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward - Gemma2Model .forward = LlamaModel_fast_forward - Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference) + CohereAttention .forward = CohereAttention_fast_forward + CohereSdpaAttention .forward = CohereAttention_fast_forward + CohereFlashAttention2.forward = CohereAttention_fast_forward + CohereDecoderLayer .forward = CohereDecoderLayer_fast_forward + CohereModel .forward = LlamaModel_fast_forward + CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference) PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward - fix_prepare_inputs_for_generation(Gemma2ForCausalLM) + fix_prepare_inputs_for_generation(CohereForCausalLM) - # Solves https://github.com/unslothai/unsloth/issues/168 - # Static KV Cache was introduced in 4.38.0, causing training to be much slower. - # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. - # https://github.com/huggingface/transformers/pull/27931 - # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py - import transformers.models.gemma2.modeling_gemma2 - transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding + import transformers.models.cohere.modeling_cohere + transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding return pass - - - @staticmethod - def post_patch(model): - # Patch model for Gemma - layers = model.model.layers - - # Torch.compile fails on embedding matrix?? - # Workaround randomnly fixes it for torch versions < 2.2 - model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) - model.config.update({"unsloth_version" : __version__}) - - # We also do this for the lm_head - lm_head = torch.nn.Linear(1, 1, bias = None) - del lm_head.weight - lm_head.weight = model.lm_head.weight - lm_head.in_features = lm_head.weight.shape[1] - lm_head.out_features = lm_head.weight.shape[0] - model.lm_head = lm_head - - # Gemma has tied weights! This means lm_head == embed_tokens - if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr(): - lm_head = torch.nn.Linear(1, 1, bias = None) - del lm_head.weight - lm_head.weight = model.model.embed_tokens.weight - lm_head.in_features = lm_head.weight.shape[1] - lm_head.out_features = lm_head.weight.shape[0] - model.lm_head = lm_head - pass - - # Also patch all dtypes - BnB seems to not allocate the correct type? - # BnB default dtype seems to be float16! - correct_dtype = lm_head.weight.dtype - - for name, module in model.named_modules(): - if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): - weight = module.weight - quant_state = weight.quant_state - - if type(quant_state) is list: - # BnB seems to have float16 as default! - module.weight.quant_state[2] = correct_dtype # Cast to correct dtype - else: - # https://github.com/TimDettmers/bitsandbytes/pull/763/files - quant_state.dtype = correct_dtype - pass - pass - # Downcast RoPE embedding to correct data type - # RoPE must be done in float32 for Gemma - # if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \ - # and (module.cos_cached.dtype != correct_dtype): - - # module.cos_cached = module.cos_cached.to(correct_dtype) - # module.sin_cached = module.sin_cached.to(correct_dtype) - # pass - # pass - pass - - # Add 1 to weight - # return output * (1 + self.weight) - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89 - from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm - - # Freeze all parameters except LoRA - # We do this first since += 1 seems to not be liked by requires_grad = True - for name, param in model.named_parameters(): - if ".lora_A." in name or ".lora_B." in name: - param.requires_grad_(True) - else: - param.requires_grad_(False) - pass - - # Patch RMS Layernorm - for name, module in model.named_modules(): - if isinstance(module, Gemma2RMSNorm): - # Must be in float32 - # https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36 - # module = module.to(torch.float32) - # Leave + 1 to Triton kernel itself - # module.weight += 1.0 # return output * (1 + self.weight) - if not hasattr(module, "variance_epsilon"): - module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon - pass - - # Clear deleted GPU items - import gc - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - return model - pass pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 66b059b4..90bbb3ab 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -587,7 +587,12 @@ def LlamaModel_fast_forward( inputs_embeds = self.embed_tokens(input_ids) # inputs_embeds = inputs_embeds.to(self.config.torch_dtype) - inputs_embeds = inputs_embeds.to(__DTYPE_MAP[self.config.torch_dtype]) + torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) + if torch_dtype is not None: + inputs_embeds = inputs_embeds.to(torch_dtype) + else: + raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") + pass # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") @@ -888,6 +893,7 @@ def _CausalLM_fast_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + num_logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -927,6 +933,7 @@ def _CausalLM_fast_forward( bsz, q_len, hd = hidden_states.shape lm_head = self.lm_head.weight if bsz == 1 and q_len == 1: + print(num_logits_to_keep, hidden_states.shape) logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype)) logits = logits.unsqueeze(0).unsqueeze(0) else: From 754e670daf6b53bf8fe92c5f07bae25a96aa67f1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 14:18:40 -0700 Subject: [PATCH 055/110] Cohere --- unsloth/kernels/cross_entropy_loss.py | 123 ++++++++++++++++++-------- unsloth/models/llama.py | 49 +++++++--- 2 files changed, 123 insertions(+), 49 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index b8473e60..aeec9184 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -19,17 +19,22 @@ from transformers.models.llama.modeling_llama import logger -@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],}) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], + "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +}) @triton.jit def _cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING: tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -62,8 +67,11 @@ def _cross_entropy_forward( label_idx = tl.load(labels_ptr).to(tl.int32) logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + + # Go logit scaling for Cohere: t * x + if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) + if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) logits = logits.to(tl.float32) c = tl.max(logits, 0) @@ -71,8 +79,10 @@ def _cross_entropy_forward( if label_idx != -100: x = tl.load(logits_ptr + label_idx) + # Go logit scaling for Cohere: t * x + if DO_LOGIT_SCALING: x = LOGIT_SCALE * x # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) + if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) loss = logsumexp - x.to(tl.float32) else: loss = 0.0 @@ -81,18 +91,23 @@ def _cross_entropy_forward( pass -@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],}) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], + "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +}) @triton.jit def _chunked_cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, - VOCAB_SIZE : tl.constexpr, - N_CHUNKS : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING: tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ 256K vocab divided in 4 chunks @@ -130,8 +145,11 @@ def _chunked_cross_entropy_forward( label_idx = tl.load(labels_ptr).to(tl.int32) logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + + # Go logit scaling for Cohere: t * x + if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) + if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) logits = logits.to(tl.float32) c = tl.max(logits, 0) @@ -142,8 +160,10 @@ def _chunked_cross_entropy_forward( # Do the -x separately if label_idx != -100: x = tl.load(logits_ptr + label_idx).to(tl.float32) + # Go logit scaling for Cohere: t * x + if DO_LOGIT_SCALING: x = LOGIT_SCALE * x # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) + if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) loss = -1.0 * x.to(tl.float32) else: loss = 0.0 @@ -153,17 +173,22 @@ def _chunked_cross_entropy_forward( pass -@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],}) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], + "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +}) @triton.jit def _cross_entropy_backward( logits_ptr, logits_row_stride, dloss_ptr, dloss_row_stride, logsumexp_ptr, labels_ptr, - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING: tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) @@ -210,6 +235,11 @@ def _cross_entropy_backward( y, # exp(x - logsumexp) ) + if DO_LOGIT_SCALING: + # d/dx [s * x] = s + y = y * LOGIT_SCALE + pass + if DO_SOFTCAPPING: # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) y = y * (1.0 - partial*partial) @@ -224,14 +254,15 @@ def _cross_entropy_backward( class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod - def forward(ctx, logits, labels, logit_softcapping = 0): + def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): n_rows, vocab_size = logits.shape div, mod = divmod(vocab_size, MAX_FUSED_SIZE) n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - DO_SOFTCAPPING = (logit_softcapping != 0) + DO_SOFTCAPPING = (logit_softcapping != 0) + DO_LOGIT_SCALING = (logit_scaling != 0) if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral @@ -243,11 +274,13 @@ def forward(ctx, logits, labels, logit_softcapping = 0): losses, logsumexp, labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, - SOFTCAP = logit_softcapping, - num_warps = num_warps, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, + LOGIT_SCALE = logit_scaling, + num_warps = num_warps, ) else: # For large vocabs > 65336 like Gemma 256K @@ -258,12 +291,14 @@ def forward(ctx, logits, labels, logit_softcapping = 0): losses, logsumexp, labels, - VOCAB_SIZE = vocab_size, - N_CHUNKS = n_chunks, - BLOCK_SIZE = MAX_FUSED_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, - SOFTCAP = logit_softcapping, - num_warps = 32, + VOCAB_SIZE = vocab_size, + N_CHUNKS = n_chunks, + BLOCK_SIZE = MAX_FUSED_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, + LOGIT_SCALE = logit_scaling, + num_warps = 32, ) # logsumexp(chunked_logsumexp) - x # Do the -x separately @@ -275,6 +310,8 @@ def forward(ctx, logits, labels, logit_softcapping = 0): ctx.save_for_backward(logits, logsumexp, labels) ctx.DO_SOFTCAPPING = DO_SOFTCAPPING ctx.logit_softcapping = logit_softcapping + ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING + ctx.logit_scaling = logit_scaling return losses pass @@ -292,19 +329,26 @@ def backward(ctx, dlosses): dlosses, dlosses.stride(0), logsumexp, labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, - SOFTCAP = ctx.logit_softcapping, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, + SOFTCAP = ctx.logit_softcapping, + DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, + LOGIT_SCALE = ctx.logit_scaling, num_warps = 8, ) - return logits, None, None, + return logits, None, None, None, pass pass @torch._disable_dynamo -def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0): +def fast_cross_entropy_loss( + logits, + labels, + logit_softcapping = 0, + logit_scaling = 0, +): """ Arguments: logits: (batch, seq_len, vocab_size) @@ -319,6 +363,7 @@ def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0): logits.view(batch*seq_len, d), labels.view(-1), logit_softcapping, + logit_scaling, ) n_items = torch.count_nonzero(labels != -100) return loss.sum() / n_items diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 90bbb3ab..7284cdd8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -305,6 +305,20 @@ def fast_rms_layernorm_inference_gemma(self, X, out_weight = None): pass +# Normal layernorm with mean removal +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +def fast_layernorm_compiled(layernorm, X): + old_dtype = X.dtype + X = X.float() + mean = X.mean(-1, keepdim = True) + Xbar = X - mean + X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \ + layernorm.variance_epsilon) * \ + layernorm.weight.float() + return X.to(old_dtype) +pass + + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320 def LlamaAttention_fast_forward( self, @@ -597,6 +611,7 @@ def LlamaModel_fast_forward( # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") IS_GEMMA2 = self.config.model_type.startswith("gemma2") + IS_COHERE = self.config.model_type.startswith("cohere") train_embed_tokens = self.embed_tokens.weight.requires_grad if IS_GEMMA: @@ -802,8 +817,11 @@ def custom_forward(*inputs): # Final layernorm if use_cache: - hidden_states = (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\ + hidden_states = \ + (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\ (self.norm, hidden_states) + elif IS_COHERE: + hidden_states = fast_layernorm_compiled(self.norm, hidden_states) else: hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA) pass @@ -943,6 +961,7 @@ def _CausalLM_fast_forward( loss = None logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) + logit_scaling = getattr(self.config, "logit_scale", 0) if labels is not None: shift_logits = logits if not hasattr(self, "extra_ignored_labels"): @@ -955,16 +974,26 @@ def _CausalLM_fast_forward( logits = shift_logits, labels = shift_labels, logit_softcapping = logit_softcapping, + logit_scaling = logit_scaling, ) - elif logit_softcapping != 0: - if logits.requires_grad: - logits = (1.0 / logit_softcapping) * logits - logits = torch.tanh(logits) - logits = logit_softcapping * logits - else: - logits *= (1.0 / logit_softcapping) - torch.tanh(logits, out = logits) - logits *= logit_softcapping + else: + if logit_scaling != 0: + if logits.requires_grad: + logits = logit_scaling * logits + else: + logits *= logit_scaling + pass + pass + if logit_softcapping != 0: + if logits.requires_grad: + logits = (1.0 / logit_softcapping) * logits + logits = torch.tanh(logits) + logits = logit_softcapping * logits + else: + logits *= (1.0 / logit_softcapping) + torch.tanh(logits, out = logits) + logits *= logit_softcapping + pass pass pass From d242866d24a68addad7574767bb6c0be08dc0c8d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 21:52:40 -0700 Subject: [PATCH 056/110] New models --- unsloth/models/loader.py | 7 +++++-- unsloth/models/mapper.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index e1f17aca..13710eed 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -13,9 +13,10 @@ # limitations under the License. from ._utils import is_bfloat16_supported, HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING -from .llama import FastLlamaModel, logger +from .llama import FastLlamaModel, logger from .mistral import FastMistralModel -from .qwen2 import FastQwen2Model +from .qwen2 import FastQwen2Model +from .cohere import FastCohereModel from transformers import AutoConfig from transformers import __version__ as transformers_version from peft import PeftConfig, PeftModel @@ -278,6 +279,8 @@ def from_pretrained( dispatch_model = FastGemma2Model elif model_type == "qwen2": dispatch_model = FastQwen2Model + elif model_type == "cohere": + dispatch_model = FastCohereModel else: raise NotImplementedError( f"Unsloth: {model_name} not supported yet!\n"\ diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 3f49c965..bff7f025 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -227,6 +227,7 @@ "meta-llama/Meta-Llama-3.1-8B-Instruct", ), "unsloth/Meta-Llama-3.1-70B-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-70B", "meta-llama/Meta-Llama-3.1-70B", ), "unsloth/Meta-Llama-3.1-405B-bnb-4bit" : ( @@ -236,6 +237,7 @@ "meta-llama/Meta-Llama-3.1-405B-Instruct", ), "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-70B-Instruct", "meta-llama/Meta-Llama-3.1-70B-Instruct", ), "unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : ( @@ -253,6 +255,27 @@ "unsloth/Phi-3.5-mini-instruct", "microsoft/Phi-3.5-mini-instruct", ), + "unsloth/c4ai-command-r-08-2024-bnb-4bit" : ( + "CohereForAI/c4ai-command-r-08-2024", + ), + "unsloth/c4ai-command-r-plus-08-2024-bnb-4bit" : ( + "CohereForAI/c4ai-command-r-plus-08-2024", + ), + "unsloth/Llama-3.1-Storm-8B-bnb-4bit" : ( + "unsloth/Llama-3.1-Storm-8B", + "akjindal53244/Llama-3.1-Storm-8B", + ), + "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit" : ( + "unsloth/Hermes-3-Llama-3.1-8B", + "NousResearch/Hermes-3-Llama-3.1-8B", + ), + "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit" : ( + "unsloth/Hermes-3-Llama-3.1-70B", + "NousResearch/Hermes-3-Llama-3.1-70B", + ), + "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit" : ( + "NousResearch/Hermes-3-Llama-3.1-405B", + ), } INT_TO_FLOAT_MAPPER = {} From 0b7e973aff62099873e759655e0fed92d7c32435 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 21:57:43 -0700 Subject: [PATCH 057/110] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7284cdd8..bcb5cffb 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1025,6 +1025,7 @@ def PeftModelForCausalLM_fast_forward( output_hidden_states=None, return_dict=None, task_ids=None, + num_logits_to_keep=0, **kwargs, ): return self.base_model( @@ -1036,6 +1037,7 @@ def PeftModelForCausalLM_fast_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + num_logits_to_keep=num_logits_to_keep, **kwargs, ) pass From 19549f22c8144d1d75fafca2ce17614ffa620f4c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 22:15:44 -0700 Subject: [PATCH 058/110] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index bcb5cffb..9a837ac4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -951,7 +951,6 @@ def _CausalLM_fast_forward( bsz, q_len, hd = hidden_states.shape lm_head = self.lm_head.weight if bsz == 1 and q_len == 1: - print(num_logits_to_keep, hidden_states.shape) logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype)) logits = logits.unsqueeze(0).unsqueeze(0) else: From 8823e134d8e72134638e7e941d2326ee97ec3b32 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 23:45:47 -0700 Subject: [PATCH 059/110] Update cohere.py --- unsloth/models/cohere.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 98116f43..aa0bcb55 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -15,7 +15,7 @@ from .llama import * from ._utils import __version__ try: - from transformers.models.gemma2.modeling_gemma2 import ( + from transformers.models.cohere.modeling_cohere import ( CohereAttention, CohereDecoderLayer, CohereModel, @@ -52,19 +52,6 @@ pass -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) -def fast_layernorm_compiled(layernorm, X): - old_dtype = X.dtype - X = X.float() - mean = X.mean(-1, keepdim = True) - Xbar = X - mean - X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \ - layernorm.variance_epsilon) * \ - layernorm.weight.float() - return X.to(old_dtype) -pass - - def fast_layernorm_inference(self, X, out_weight = None): XX = X.to(torch.float32, copy = True) XX -= X.mean(-1, keepdim = True) From 90050b7f0dfad16ca88168d79594e2eb5048773d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 23:48:40 -0700 Subject: [PATCH 060/110] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9a837ac4..bd7c8256 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2229,6 +2229,7 @@ def patch_peft_model( elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx + elif model_type == "cohere": apply_lora_mlp = apply_lora_mlp_swiglu else: raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!") pass From 4c1ec3ab1f4858bda9599c5650bd3a4c5d1a023f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 2 Sep 2024 23:58:48 -0700 Subject: [PATCH 061/110] Update cohere.py --- unsloth/models/cohere.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index aa0bcb55..b9276bc1 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -104,6 +104,7 @@ def CohereAttention_fast_forward( K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) if self.use_qk_norm: + print("QK_norm") Q = fast_layernorm_compiled(self.q_norm, Q) K = fast_layernorm_compiled(self.k_norm, K) pass From 97b395655e83d88803dad6232e98242ff94165ee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 00:04:22 -0700 Subject: [PATCH 062/110] retry --- unsloth/models/cohere.py | 1 - unsloth/models/llama.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index b9276bc1..aa0bcb55 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -104,7 +104,6 @@ def CohereAttention_fast_forward( K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) if self.use_qk_norm: - print("QK_norm") Q = fast_layernorm_compiled(self.q_norm, Q) K = fast_layernorm_compiled(self.k_norm, K) pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index bd7c8256..280dff53 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -973,7 +973,7 @@ def _CausalLM_fast_forward( logits = shift_logits, labels = shift_labels, logit_softcapping = logit_softcapping, - logit_scaling = logit_scaling, + # logit_scaling = logit_scaling, ) else: if logit_scaling != 0: From fd615eafa09f8e72dfe74a57133acf9989551c6e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 00:08:33 -0700 Subject: [PATCH 063/110] Update fast_lora.py --- unsloth/kernels/fast_lora.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index 8f410179..42e8f9bc 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -229,7 +229,8 @@ class LoRA_QKV(torch.autograd.Function): def forward(ctx, X : torch.Tensor, QW, QW_quant, QA, QB, QS, KW, KW_quant, KA, KB, KS, - VW, VW_quant, VA, VB, VS,): + VW, VW_quant, VA, VB, VS, + inplace = True): dtype = X.dtype Q = matmul_lora(X, QW, QW_quant, QA, QB, QS) @@ -242,6 +243,7 @@ def forward(ctx, X : torch.Tensor, VW, VW_quant, VS, ) ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,) + ctx.inplace = inplace return Q, K, V pass @@ -286,7 +288,7 @@ def backward(ctx, dQ, dK, dV): # Combine derivatives to find dX # dQ QW = fast_dequantize(QW.t(), QW_quant) - dX = torch.matmul(dQ, QW.t(), out = X) + dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None) del QW dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) @@ -308,12 +310,13 @@ def backward(ctx, dQ, dK, dV): return dX.view(batch, seq_len, hd), \ None, None, d_QA.t(), d_QB.t(), None, \ None, None, d_KA.t(), d_KB.t(), None, \ - None, None, d_VA.t(), d_VB.t(), None + None, None, d_VA.t(), d_VB.t(), None, \ + None, pass pass -def apply_lora_qkv(self, X): +def apply_lora_qkv(self, X, inplace = True): QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) @@ -321,6 +324,7 @@ def apply_lora_qkv(self, X): QW, QW_quant, QA, QB, QS, KW, KW_quant, KA, KB, KS, VW, VW_quant, VA, VB, VS, + inplace = inplace, ) return Q, K, V pass From fe45990721bae3f5b362674c376f2580ea9758b2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 00:11:01 -0700 Subject: [PATCH 064/110] Update llama.py --- unsloth/models/llama.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 280dff53..219d186e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2289,6 +2289,14 @@ def patch_peft_model( lora_dropout = model.peft_config[active_adapter].lora_dropout bias = model.peft_config[active_adapter].bias + # We also do not inplace edit QKV for Cohere! + from functools import partial + _apply_lora_qkv = \ + partial(apply_lora_qkv, inplace = False) \ + if model_type == "cohere" else \ + apply_lora_qkv + pass + if lora_dropout == 0 and bias == "none": for idx, layer in enumerate(model.model.model.layers): @@ -2331,7 +2339,7 @@ def patch_peft_model( (len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \ (len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0): - layer.self_attn.apply_qkv = apply_lora_qkv + layer.self_attn.apply_qkv = _apply_lora_qkv n_qkv += 1 else: if model_type != "qwen2": From f564b8a397ad180584281bf5350f27cf058c50c3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 00:11:58 -0700 Subject: [PATCH 065/110] Update fast_lora.py --- unsloth/kernels/fast_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index 42e8f9bc..ef85e72c 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -324,7 +324,7 @@ def apply_lora_qkv(self, X, inplace = True): QW, QW_quant, QA, QB, QS, KW, KW_quant, KA, KB, KS, VW, VW_quant, VA, VB, VS, - inplace = inplace, + inplace, ) return Q, K, V pass From b26da8461132f02c8d317a4d37c5955506654415 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 00:13:44 -0700 Subject: [PATCH 066/110] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 219d186e..dc12db70 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2316,7 +2316,7 @@ def patch_peft_model( (len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0): # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module - layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) + # layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) n_mlp += 1 else: logger.warning_once( From 61be6a335a8a54a2cef458287da4a5373756d2c2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 00:15:43 -0700 Subject: [PATCH 067/110] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index dc12db70..0dc25c75 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -973,7 +973,7 @@ def _CausalLM_fast_forward( logits = shift_logits, labels = shift_labels, logit_softcapping = logit_softcapping, - # logit_scaling = logit_scaling, + logit_scaling = logit_scaling, ) else: if logit_scaling != 0: From ea48761ce8db01310711133f97e659a634668569 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 00:19:26 -0700 Subject: [PATCH 068/110] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index aeec9184..24e8002b 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -220,6 +220,13 @@ def _cross_entropy_backward( dloss = 0.0 x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + + # Do logit scaling for Cohere + if DO_LOGIT_SCALING: + # d/dx [s * x] = s + x = x * LOGIT_SCALE + pass + # Do logit softcapping for Gemma 2: t * tanh(1/t * x) if DO_SOFTCAPPING: # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) From 6e795c6110a1b22e746e7285df69d93cb95b16c8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 00:25:45 -0700 Subject: [PATCH 069/110] _apply_lora_mlp --- unsloth/kernels/fast_lora.py | 18 +++++++++++------- unsloth/models/llama.py | 10 +++++----- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index ef85e72c..2177b43b 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -68,7 +68,8 @@ def forward(ctx, X : torch.Tensor, gateW, gateW_quant, gateA, gateB, gateS, upW, upW_quant, upA, upB, upS, downW, downW_quant, downA, downB, downS, - _forward_function, _backward_function,): + _forward_function, _backward_function, + inplace = True,): dtype = X.dtype e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS) @@ -84,6 +85,7 @@ def forward(ctx, X : torch.Tensor, ) ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB, X, e, g) + ctx.inplace = inplace return i pass @@ -131,7 +133,7 @@ def backward(ctx, dY : torch.Tensor): # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) upW = fast_dequantize(upW.t(), upW_quant) - dX = torch.matmul(df, upW.t(), out = X) + dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None) del upW dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) @@ -147,13 +149,13 @@ def backward(ctx, dY : torch.Tensor): None, None, d_gateA.t(), d_gateB.t(), None, \ None, None, d_upA.t(), d_upB.t(), None, \ None, None, d_downA.t(), d_downB.t(), None, \ - None, None, # _backward and _forward + None, None, None, # _backward and _forward and inplace pass pass from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel -def apply_lora_mlp_swiglu(self, X): +def apply_lora_mlp_swiglu(self, X, inplace = True): gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj) downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) @@ -161,13 +163,14 @@ def apply_lora_mlp_swiglu(self, X): gateW, gateW_quant, gateA, gateB, gateS, upW, upW_quant, upA, upB, upS, downW, downW_quant, downA, downB, downS, - swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,) + swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel, + inplace,) return out pass from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel -def apply_lora_mlp_geglu_exact(self, X): +def apply_lora_mlp_geglu_exact(self, X, inplace = True): gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj) downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) @@ -175,7 +178,8 @@ def apply_lora_mlp_geglu_exact(self, X): gateW, gateW_quant, gateA, gateB, gateS, upW, upW_quant, upA, upB, upS, downW, downW_quant, downA, downB, downS, - geglu_exact_forward_kernel, geglu_exact_backward_kernel,) + geglu_exact_forward_kernel, geglu_exact_backward_kernel, + inplace,) return out pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0dc25c75..5ccf906a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2291,10 +2291,10 @@ def patch_peft_model( # We also do not inplace edit QKV for Cohere! from functools import partial - _apply_lora_qkv = \ - partial(apply_lora_qkv, inplace = False) \ + _apply_lora_mlp = \ + partial(apply_lora_mlp, inplace = False) \ if model_type == "cohere" else \ - apply_lora_qkv + apply_lora_mlp pass if lora_dropout == 0 and bias == "none": @@ -2316,7 +2316,7 @@ def patch_peft_model( (len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0): # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module - # layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) + layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp) n_mlp += 1 else: logger.warning_once( @@ -2339,7 +2339,7 @@ def patch_peft_model( (len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \ (len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0): - layer.self_attn.apply_qkv = _apply_lora_qkv + layer.self_attn.apply_qkv = apply_lora_qkv n_qkv += 1 else: if model_type != "qwen2": From dacba398545f52a9a5a39ed922bd6ca70d0ce6b7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 01:38:02 -0700 Subject: [PATCH 070/110] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1c48e8e5..ea9a0c53 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -295,7 +295,7 @@ def patch_mistral_nemo_config(config): send_to_device, ).replace("def send_to_device", "def _fixed_send_to_device") exec(send_to_device) - accelerate.utils.operations.send_to_device = _fixed_send_to_device + # accelerate.utils.operations.send_to_device = _fixed_send_to_device pass pass # ============================================= From 5074427a9a4686c12ec7b37040d04bd965e4e64f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 12:04:25 -0700 Subject: [PATCH 071/110] Gemma fixes --- unsloth/kernels/__init__.py | 6 ++++- unsloth/kernels/flex_attention.py | 39 ++++++++++++++++++++++++++++++- unsloth/models/_utils.py | 4 ++++ unsloth/models/gemma2.py | 6 ++++- unsloth/models/llama.py | 7 ++++++ 5 files changed, 59 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index c2de979a..26f632ee 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -33,7 +33,11 @@ ) from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora -from .flex_attention import HAS_FLEX_ATTENTION, slow_attention_softcapping +from .flex_attention import ( + HAS_FLEX_ATTENTION, + slow_attention_softcapping, + slow_inference_attention_softcapping, +) if HAS_FLEX_ATTENTION: from .flex_attention import ( diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index a992a023..9a2054c5 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -47,7 +47,7 @@ pass # Logit softcapping -@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim @@ -80,3 +80,40 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): return A pass + +torch_matmul = torch.matmul +torch_tanh = torch.tanh +torch_nn_functional_softmax = torch.nn.functional.softmax +def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): + n_heads = self.num_heads + head_dim = self.head_dim + n_kv_heads = self.num_key_value_heads + n_groups = self.num_key_value_groups + + # Grouped query attention + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + K = K.reshape(bsz, n_heads, q_len, head_dim) + V = V.reshape(bsz, n_heads, q_len, head_dim) + + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below + # We default to using the config file itself + # s = self.config.hidden_size // self.config.num_attention_heads + s = self.config.query_pre_attn_scalar + t = self.config.attn_logit_softcapping + + Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly + A = torch_matmul(Q, K.transpose(2, 3)) + + # Logit softcapping + A /= t; torch_tanh(A, out = A); A *= t; + A += causal_mask[:q_len, :q_len] + # Much slower in torch compile! + # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf")) + A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype) + A = torch_matmul(A, V) + A = A.transpose(1, 2).contiguous() + A = A.reshape(bsz, q_len, n_heads*head_dim) + return A +pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ea9a0c53..242d234d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -39,6 +39,8 @@ "create_boolean_mask", "torch_amp_custom_fwd", "torch_amp_custom_bwd", + "accelerate_old_send_to_device", + "accelerate_new_send_to_device", ] import torch @@ -287,6 +289,7 @@ def patch_mistral_nemo_config(config): import accelerate.utils.operations if hasattr(accelerate.utils.operations, "send_to_device") and \ accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": + accelerate_old_send_to_device = accelerate.utils.operations.send_to_device from accelerate.utils.operations import * send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) send_to_device = re.sub( @@ -296,6 +299,7 @@ def patch_mistral_nemo_config(config): ).replace("def send_to_device", "def _fixed_send_to_device") exec(send_to_device) # accelerate.utils.operations.send_to_device = _fixed_send_to_device + accelerate_new_send_to_device = _fixed_send_to_device pass pass # ============================================= diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 6858f525..218849ef 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -157,7 +157,10 @@ def Gemma2Attention_fast_forward( A = A.reshape(bsz, q_len, n_heads*head_dim) else: mask = causal_mask if attention_mask is None else attention_mask - A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len) + fx = slow_inference_attention_softcapping \ + if "_flag_for_generation" in kwargs else \ + slow_attention_softcapping + A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len) pass A = self.apply_o(self, A) return A, None, past_key_value @@ -192,6 +195,7 @@ def Gemma2DecoderLayer_fast_forward( output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + _flag_for_generation=True, ) hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight) hidden_states += residual diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5ccf906a..5b7ca163 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2413,6 +2413,10 @@ def for_inference(model): # return # pass + # Must patch accelerate for Xformers + import accelerate.utils.operations + accelerate.utils.operations.send_to_device = accelerate_new_send_to_device + internal_model = model internal_model.gradient_checkpointing = False internal_model.training = False @@ -2468,6 +2472,9 @@ def for_inference(model): if hasattr(embeddings, "training"): embeddings.training = False pass + # Return accelerate back + accelerate.utils.operations.send_to_device = accelerate_old_send_to_device + return model pass From 743ba55cf2d158e301f484667d46bd243b3cea5e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 13:38:59 -0700 Subject: [PATCH 072/110] Update llama.py --- unsloth/models/llama.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5b7ca163..3fcb8a76 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -953,6 +953,8 @@ def _CausalLM_fast_forward( if bsz == 1 and q_len == 1: logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype)) logits = logits.unsqueeze(0).unsqueeze(0) + elif num_logits_to_keep != 0: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype)) else: logits = self.lm_head(hidden_states.to(lm_head.dtype)) pass @@ -1368,8 +1370,14 @@ def _fast_generate(*args, **kwargs): pass internal_model._flag_for_generation = True + # Must patch accelerate for Xformers + import accelerate.utils.operations + accelerate.utils.operations.send_to_device = accelerate_new_send_to_device + # For newer HF kwargs["cache_implementation"] = "dynamic" + # For num_logits_to_keep + kwargs["num_logits_to_keep"] = 1 # Remove token_type_ids kwargs.pop("token_type_ids", None) @@ -1402,6 +1410,9 @@ def _fast_generate(*args, **kwargs): pass if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation + # Return accelerate back + accelerate.utils.operations.send_to_device = accelerate_old_send_to_device + return output pass return _fast_generate @@ -2413,10 +2424,6 @@ def for_inference(model): # return # pass - # Must patch accelerate for Xformers - import accelerate.utils.operations - accelerate.utils.operations.send_to_device = accelerate_new_send_to_device - internal_model = model internal_model.gradient_checkpointing = False internal_model.training = False @@ -2472,9 +2479,6 @@ def for_inference(model): if hasattr(embeddings, "training"): embeddings.training = False pass - # Return accelerate back - accelerate.utils.operations.send_to_device = accelerate_old_send_to_device - return model pass From 7ea63955c99c36d0559002970eeb967ad9a71033 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 13:51:31 -0700 Subject: [PATCH 073/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 9a2054c5..9cf999e2 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -47,7 +47,7 @@ pass # Logit softcapping -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim From df06a040b240b419ef70cf3a8a4e0b03b1427cbf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 4 Sep 2024 01:04:03 -0700 Subject: [PATCH 074/110] Update llama.py --- unsloth/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 39998127..4d8539e6 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2331,7 +2331,7 @@ def patch_peft_model( (len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0): # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module - layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp) + # layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp) n_mlp += 1 else: logger.warning_once( @@ -2354,7 +2354,7 @@ def patch_peft_model( (len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \ (len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0): - layer.self_attn.apply_qkv = apply_lora_qkv + # layer.self_attn.apply_qkv = apply_lora_qkv n_qkv += 1 else: if model_type != "qwen2": @@ -2371,7 +2371,7 @@ def patch_peft_model( (getattr(o_proj, "base_layer", o_proj).bias is None) and \ (len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0): - layer.self_attn.apply_o = apply_lora_o + # layer.self_attn.apply_o = apply_lora_o n_o += 1 else: logger.warning_once( From 7f139f1559db29ba9e83e5cfd1644fa51f835207 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 4 Sep 2024 01:06:48 -0700 Subject: [PATCH 075/110] layernorm --- unsloth/models/cohere.py | 2 +- unsloth/models/llama.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index aa0bcb55..b0a9b653 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -216,7 +216,7 @@ def CohereDecoderLayer_fast_forward( hidden_states = residual else: residual = hidden_states - hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states) + hidden_states = self.input_layernorm(hidden_states) hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4d8539e6..ac3b9315 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -821,7 +821,7 @@ def custom_forward(*inputs): (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\ (self.norm, hidden_states) elif IS_COHERE: - hidden_states = fast_layernorm_compiled(self.norm, hidden_states) + hidden_states = self.norm(hidden_states) else: hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA) pass @@ -2331,7 +2331,7 @@ def patch_peft_model( (len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0): # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module - # layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp) + layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp) n_mlp += 1 else: logger.warning_once( @@ -2354,7 +2354,7 @@ def patch_peft_model( (len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \ (len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0): - # layer.self_attn.apply_qkv = apply_lora_qkv + layer.self_attn.apply_qkv = apply_lora_qkv n_qkv += 1 else: if model_type != "qwen2": @@ -2371,7 +2371,7 @@ def patch_peft_model( (getattr(o_proj, "base_layer", o_proj).bias is None) and \ (len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0): - # layer.self_attn.apply_o = apply_lora_o + layer.self_attn.apply_o = apply_lora_o n_o += 1 else: logger.warning_once( From 068fc0d2c3b70d18928ab41ff3aeda995b24ec44 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 4 Sep 2024 01:11:46 -0700 Subject: [PATCH 076/110] Update llama.py --- unsloth/models/llama.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ac3b9315..c92ece26 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -964,19 +964,31 @@ def _CausalLM_fast_forward( logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) logit_scaling = getattr(self.config, "logit_scale", 0) if labels is not None: - shift_logits = logits - if not hasattr(self, "extra_ignored_labels"): - # Fixes https://github.com/unslothai/unsloth/issues/10 - self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") - pass + # shift_logits = logits + # if not hasattr(self, "extra_ignored_labels"): + # # Fixes https://github.com/unslothai/unsloth/issues/10 + # self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") + # pass - shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) - loss = fast_cross_entropy_loss( - logits = shift_logits, - labels = shift_labels, - logit_softcapping = logit_softcapping, - logit_scaling = logit_scaling, - ) + # shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) + # loss = fast_cross_entropy_loss( + # logits = shift_logits, + # labels = shift_labels, + # logit_softcapping = logit_softcapping, + # logit_scaling = logit_scaling, + # ) + logits = logits.float() + logits = logits * self.logit_scale + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) else: if logit_scaling != 0: if logits.requires_grad: From 4eaccb0fb230d0f9b5409612f11bbfcbf21208db Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 4 Sep 2024 01:13:13 -0700 Subject: [PATCH 077/110] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c92ece26..7ae60260 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -983,7 +983,7 @@ def _CausalLM_fast_forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() + loss_fct = torch.nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism From 4f909fc5a47a55281b6955cb55015e2869225d4d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 4 Sep 2024 23:49:52 -0700 Subject: [PATCH 078/110] Flex Attention --- unsloth/kernels/flex_attention.py | 129 +++++++++++++++++++++--------- unsloth/models/_utils.py | 3 +- unsloth/models/cohere.py | 2 +- unsloth/models/gemma2.py | 2 +- unsloth/models/llama.py | 82 ++++++++----------- 5 files changed, 130 insertions(+), 88 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 9cf999e2..b16ab6a5 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -29,15 +29,11 @@ if hasattr(torch.nn, "attention"): import torch.nn.attention if hasattr(torch.nn.attention, "flex_attention"): - import torch.nn.attention.flex_attention - from torch.nn.attention.flex_attention import flex_attention - from torch.nn.attention.flex_attention import create_block_mask - FLEX_ATTENTION_PADDING = getattr( - torch.nn.attention.flex_attention, - "_DEFAULT_SPARSE_BLOCK_SIZE", - 1, + from torch.nn.attention.flex_attention import ( + flex_attention as _flex_attention, + create_block_mask as _create_block_mask, ) - flex_attention = torch.compile(flex_attention, dynamic = False) + _flex_attention = torch.compile(_flex_attention, dynamic = False) HAS_FLEX_ATTENTION = True else: HAS_FLEX_ATTENTION = False @@ -46,38 +42,95 @@ HAS_FLEX_ATTENTION = False pass -# Logit softcapping -@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) -def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): - n_heads = self.num_heads - head_dim = self.head_dim - n_kv_heads = self.num_key_value_heads - n_groups = self.num_key_value_groups - - # Grouped query attention - K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) - V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) - K = K.reshape(bsz, n_heads, q_len, head_dim) - V = V.reshape(bsz, n_heads, q_len, head_dim) +if not HAS_FLEX_ATTENTION: - # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e - # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below - # We default to using the config file itself - # s = self.config.hidden_size // self.config.num_attention_heads - s = self.config.query_pre_attn_scalar - t = self.config.attn_logit_softcapping + # Logit softcapping + @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) + def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): + n_heads = self.num_heads + head_dim = self.head_dim + n_kv_heads = self.num_key_value_heads + n_groups = self.num_key_value_groups + + # Grouped query attention + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + K = K.reshape(bsz, n_heads, q_len, head_dim) + V = V.reshape(bsz, n_heads, q_len, head_dim) - Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly - A = torch.matmul(Q, K.transpose(2, 3)) - A = t * torch.tanh(A / t) # Logit softcapping - A += causal_mask[:q_len, :q_len] - # Much slower in torch compile! - # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf")) - A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype) - A = torch.matmul(A, V) - A = A.transpose(1, 2).contiguous() - A = A.reshape(bsz, q_len, n_heads*head_dim) - return A + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below + # We default to using the config file itself + # s = self.config.hidden_size // self.config.num_attention_heads + s = self.config.query_pre_attn_scalar + t = self.config.attn_logit_softcapping + + Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly + A = torch.matmul(Q, K.transpose(2, 3)) + A = t * torch.tanh(A / t) # Logit softcapping + A += causal_mask[:q_len, :q_len] + # Much slower in torch compile! + # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf")) + A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype) + A = torch.matmul(A, V) + A = A.transpose(1, 2).contiguous() + A = A.reshape(bsz, q_len, n_heads*head_dim) + return A + pass +else: + # See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb + # for more examples + # BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al + import functools, math + + def generate_tanh_softcap(t): + def tanh_softcap(x, b, h, q_idx, kv_idx): + return t * torch.tanh(x / t) + return tanh_softcap + pass + def causal_masker(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + pass + + @functools.lru_cache + def sliding_window_masker(size = 4096): + def sliding_window(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + window_mask = q_idx - kv_idx <= size + return causal_mask & window_mask + return sliding_window + pass + + @functools.lru_cache + def create_block_mask(mask, n = 128): + return _create_block_mask(mask, 1, 1, n, n, device = "cuda") + pass + + @functools.lru_cache + def flex_attention(s, t): + scale = 1.0 / math.sqrt(s) + score_mod = generate_tanh_softcap(t) + return functools.partial( + _flex_attention(score_mod = score_mod, scale = scale, enable_gqa = True) + ) + pass + + def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): + if causal_mask == 0: + # Global attention + causal_mask = create_block_mask(causal_masker, q_len) + else: + # Sliding window attention + causal_mask = create_block_mask(sliding_window_masker(causal_mask), q_len) + pass + + s = self.config.query_pre_attn_scalar + t = self.config.attn_logit_softcapping + A = flex_attention(s, t)(Q, K, V, block_mask = causal_mask) + A = A.transpose(1, 2).contiguous() + A = A.reshape(bsz, q_len, n_heads*head_dim) + return A + pass pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 6dd17e73..386fb84c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -338,9 +338,10 @@ def is_big_gpu(index): ] # Torch dynamo arguments torch_dynamo_arguments = [ - "config.accumulated_cache_size_limit = 512", # Bump up a bit from 256 + "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 "config.suppress_errors = True", # Supress errors for now "config.do_not_emit_runtime_asserts = True", + "config.cache_size_limit = 1024", # Flex Attention ] import torch._inductor.config as config for _try_compile_argument in torch_compile_arguments: diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index b0a9b653..aa0bcb55 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -216,7 +216,7 @@ def CohereDecoderLayer_fast_forward( hidden_states = residual else: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states) hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, causal_mask=causal_mask, diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 218849ef..15763dc3 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -160,7 +160,7 @@ def Gemma2Attention_fast_forward( fx = slow_inference_attention_softcapping \ if "_flag_for_generation" in kwargs else \ slow_attention_softcapping - A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len) + A = fx(Q, K, V, mask, self, bsz, kv_seq_len) pass A = self.apply_o(self, A) return A, None, past_key_value diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7ae60260..a5dc381d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -711,12 +711,6 @@ def LlamaModel_fast_forward( offloaded_gradient_checkpointing = True pass - # Check for Flex Attention - # if IS_GEMMA2 and HAS_FLEX_ATTENTION: - # if not (seq_length % FLEX_ATTENTION_PADDING == 0): - # USE_FLEX_ATTENTION = True - - # Gemma2 has alternating SWA and global attn if IS_GEMMA2: if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None: @@ -738,23 +732,29 @@ def LlamaModel_fast_forward( sliding_window = None, ) elif not hasattr(self, "SWA_mask"): - n = self.max_seq_length # self.config.max_position_embeddings - # masked_fill is making stuff slower! - # self. GA_mask = create_boolean_mask(n = n, sliding_window = 0) - # self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window) - from transformers.modeling_attn_mask_utils import AttentionMaskConverter - self.SWA_mask = AttentionMaskConverter( - is_causal = True, - sliding_window = self.config.sliding_window, - )\ - .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ - .squeeze(0).squeeze(0) - - self.GA_mask = AttentionMaskConverter( - is_causal = True, - )\ - .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ - .squeeze(0).squeeze(0) + if HAS_FLEX_ATTENTION: + # Use Flex Attention instead! + self.SWA_mask = self.config.sliding_window + self.GA_mask = 0 + else: + n = self.max_seq_length # self.config.max_position_embeddings + # masked_fill is making stuff slower! + # self. GA_mask = create_boolean_mask(n = n, sliding_window = 0) + # self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window) + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + self.SWA_mask = AttentionMaskConverter( + is_causal = True, + sliding_window = self.config.sliding_window, + )\ + .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ + .squeeze(0).squeeze(0) + + self.GA_mask = AttentionMaskConverter( + is_causal = True, + )\ + .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ + .squeeze(0).squeeze(0) + pass pass pass @@ -964,31 +964,19 @@ def _CausalLM_fast_forward( logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) logit_scaling = getattr(self.config, "logit_scale", 0) if labels is not None: - # shift_logits = logits - # if not hasattr(self, "extra_ignored_labels"): - # # Fixes https://github.com/unslothai/unsloth/issues/10 - # self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") - # pass + shift_logits = logits + if not hasattr(self, "extra_ignored_labels"): + # Fixes https://github.com/unslothai/unsloth/issues/10 + self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") + pass - # shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) - # loss = fast_cross_entropy_loss( - # logits = shift_logits, - # labels = shift_labels, - # logit_softcapping = logit_softcapping, - # logit_scaling = logit_scaling, - # ) - logits = logits.float() - logits = logits * self.logit_scale - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = torch.nn.CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) + loss = fast_cross_entropy_loss( + logits = shift_logits, + labels = shift_labels, + logit_softcapping = logit_softcapping, + logit_scaling = logit_scaling, + ) else: if logit_scaling != 0: if logits.requires_grad: From efef0ee2e7d5131ecb9e21716bef90af5ce6ee2e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 4 Sep 2024 23:51:08 -0700 Subject: [PATCH 079/110] Update gemma2.py --- unsloth/models/gemma2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 15763dc3..bf40ea8a 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -156,11 +156,10 @@ def Gemma2Attention_fast_forward( ) A = A.reshape(bsz, q_len, n_heads*head_dim) else: - mask = causal_mask if attention_mask is None else attention_mask fx = slow_inference_attention_softcapping \ if "_flag_for_generation" in kwargs else \ slow_attention_softcapping - A = fx(Q, K, V, mask, self, bsz, kv_seq_len) + A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len) pass A = self.apply_o(self, A) return A, None, past_key_value From 6e8951ff5a5d3bfc9052f489638354cf528e59b9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 4 Sep 2024 23:52:46 -0700 Subject: [PATCH 080/110] Update __init__.py --- unsloth/kernels/__init__.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index 26f632ee..d302768b 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -39,12 +39,6 @@ slow_inference_attention_softcapping, ) -if HAS_FLEX_ATTENTION: - from .flex_attention import ( - FLEX_ATTENTION_PADDING, - ) -pass - try: print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.") except: From d60a18c8c3267bf32abf682e6195972700487aae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 4 Sep 2024 23:59:05 -0700 Subject: [PATCH 081/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index b16ab6a5..ac3c5fe7 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -103,7 +103,11 @@ def sliding_window(b, h, q_idx, kv_idx): @functools.lru_cache def create_block_mask(mask, n = 128): - return _create_block_mask(mask, 1, 1, n, n, device = "cuda") + return _create_block_mask( + mask, 1, 1, n, n, + BLOCK_SIZE = 128, + _compile = True, + ) pass @functools.lru_cache @@ -123,6 +127,7 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): # Sliding window attention causal_mask = create_block_mask(sliding_window_masker(causal_mask), q_len) pass + print(causal_mask) s = self.config.query_pre_attn_scalar t = self.config.attn_logit_softcapping From 1b4132ea5ef9383130b6c491f6de22c112d93dbe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:03:31 -0700 Subject: [PATCH 082/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index ac3c5fe7..2b585280 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -25,23 +25,18 @@ } # Flex Attention supported from torch 2.5 onwards only -import torch.nn -if hasattr(torch.nn, "attention"): - import torch.nn.attention - if hasattr(torch.nn.attention, "flex_attention"): - from torch.nn.attention.flex_attention import ( - flex_attention as _flex_attention, - create_block_mask as _create_block_mask, - ) - _flex_attention = torch.compile(_flex_attention, dynamic = False) - HAS_FLEX_ATTENTION = True - else: - HAS_FLEX_ATTENTION = False - pass +try: + from torch.nn.attention.flex_attention import ( + flex_attention as _flex_attention, + create_block_mask as _create_block_mask, + ) + _flex_attention = torch.compile(_flex_attention, dynamic = False) + HAS_FLEX_ATTENTION = True else: HAS_FLEX_ATTENTION = False pass + if not HAS_FLEX_ATTENTION: # Logit softcapping From f5d11dc0a0e4c1d5b90b835d686326f10f9537f7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:04:10 -0700 Subject: [PATCH 083/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 2b585280..c07e9ddf 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -32,7 +32,7 @@ ) _flex_attention = torch.compile(_flex_attention, dynamic = False) HAS_FLEX_ATTENTION = True -else: +except: HAS_FLEX_ATTENTION = False pass From 24546594d7852144bc681cb5c35629de3af96500 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:07:29 -0700 Subject: [PATCH 084/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index c07e9ddf..96c4f89b 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -126,7 +126,7 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): s = self.config.query_pre_attn_scalar t = self.config.attn_logit_softcapping - A = flex_attention(s, t)(Q, K, V, block_mask = causal_mask) + A = flex_attention(s, t)(query = Q, key = K, value = V, block_mask = causal_mask) A = A.transpose(1, 2).contiguous() A = A.reshape(bsz, q_len, n_heads*head_dim) return A From 984d21708c11cacb99a58a703858c307bf642039 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:10:13 -0700 Subject: [PATCH 085/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 96c4f89b..5d42dff2 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -110,7 +110,7 @@ def flex_attention(s, t): scale = 1.0 / math.sqrt(s) score_mod = generate_tanh_softcap(t) return functools.partial( - _flex_attention(score_mod = score_mod, scale = scale, enable_gqa = True) + _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True, ) pass From e3846f51a46bb237d4392aede4fd667105de6bb5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:12:22 -0700 Subject: [PATCH 086/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 5d42dff2..bb412bc2 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -115,6 +115,8 @@ def flex_attention(s, t): pass def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): + n_heads = self.num_heads + head_dim = self.head_dim if causal_mask == 0: # Global attention causal_mask = create_block_mask(causal_masker, q_len) From 2d292995c89d3f110d72497240d709c5cfd8d6e9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:16:25 -0700 Subject: [PATCH 087/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index bb412bc2..7724b70d 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -119,12 +119,13 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): head_dim = self.head_dim if causal_mask == 0: # Global attention + print(1) causal_mask = create_block_mask(causal_masker, q_len) else: # Sliding window attention + print(2) causal_mask = create_block_mask(sliding_window_masker(causal_mask), q_len) pass - print(causal_mask) s = self.config.query_pre_attn_scalar t = self.config.attn_logit_softcapping From 03310b9b224aebaf8f3d58f1cf0664a116c5a37c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:18:18 -0700 Subject: [PATCH 088/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 7724b70d..c5761b14 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -88,12 +88,12 @@ def causal_masker(b, h, q_idx, kv_idx): pass @functools.lru_cache - def sliding_window_masker(size = 4096): + def sliding_window_masker(size = 4096, q_len = 4096): def sliding_window(b, h, q_idx, kv_idx): causal_mask = q_idx >= kv_idx window_mask = q_idx - kv_idx <= size return causal_mask & window_mask - return sliding_window + return sliding_window if q_len >= size else causal_masker pass @functools.lru_cache @@ -124,7 +124,8 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): else: # Sliding window attention print(2) - causal_mask = create_block_mask(sliding_window_masker(causal_mask), q_len) + sliding_masker = sliding_window_masker(causal_mask, q_len) + causal_mask = create_block_mask(sliding_masker, q_len) pass s = self.config.query_pre_attn_scalar From eb376767cb4a27a04fe214c9a23a4988ec4287e9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:21:04 -0700 Subject: [PATCH 089/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index c5761b14..b0d76de6 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -119,11 +119,9 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): head_dim = self.head_dim if causal_mask == 0: # Global attention - print(1) causal_mask = create_block_mask(causal_masker, q_len) else: # Sliding window attention - print(2) sliding_masker = sliding_window_masker(causal_mask, q_len) causal_mask = create_block_mask(sliding_masker, q_len) pass From cb6a835b68d65786bc831543523418f38e510179 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:28:25 -0700 Subject: [PATCH 090/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index b0d76de6..5ed6651c 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -31,7 +31,7 @@ create_block_mask as _create_block_mask, ) _flex_attention = torch.compile(_flex_attention, dynamic = False) - HAS_FLEX_ATTENTION = True + HAS_FLEX_ATTENTION = False except: HAS_FLEX_ATTENTION = False pass From cbd6a6a24bd8c49195c5e17c5c562034a29a4a79 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:33:30 -0700 Subject: [PATCH 091/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 52 ++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 5ed6651c..46aa390f 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -78,9 +78,59 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): # BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al import functools, math + from torch.nn.attention.flex_attention import _score_mod_signature + from torch._inductor.lowering import make_pointwise, register_lowering + + # Some internal torch.compile details + from torch._inductor.virtualized import ops + from functools import partial + + + @torch.library.custom_op("approx::tanh", mutates_args=()) + def _tanh_approx(inp: Tensor) -> Tensor: + return torch.tanh(inp) + + + @_tanh_approx.register_fake + def _(inp: torch.Tensor) -> torch.Tensor: + return torch.tanh(inp) + + + def _tanh_approx_lowering(inp): + fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;") + return make_pointwise(fn)(inp) + + + register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering) + + + class _TanhApprox(torch.autograd.Function): + @staticmethod + def forward(x): + return torch.ops.approx.tanh(x) + + @staticmethod + def setup_context(ctx, inputs, output): + (x,) = inputs + result = output + ctx.save_for_backward(result) + + @staticmethod + def backward(ctx, grad_output): + (result,) = ctx.saved_tensors + return grad_output * (1 - result * result) + + @staticmethod + def vmap(info, in_dims, x): + return torch.tanh(x), 0 + + + _tanh_approx = _TanhApprox.apply + + @functools.lru_cache def generate_tanh_softcap(t): def tanh_softcap(x, b, h, q_idx, kv_idx): - return t * torch.tanh(x / t) + return t * _tanh_approx(x / t) return tanh_softcap pass def causal_masker(b, h, q_idx, kv_idx): From 712deaa1494efdaa18cc5bfb6d2b5874244be6eb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:35:34 -0700 Subject: [PATCH 092/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 46aa390f..ff277934 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -31,7 +31,7 @@ create_block_mask as _create_block_mask, ) _flex_attention = torch.compile(_flex_attention, dynamic = False) - HAS_FLEX_ATTENTION = False + HAS_FLEX_ATTENTION = True except: HAS_FLEX_ATTENTION = False pass From 6e74563ee2e64b74842d3bb130b87b4b1f612853 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:36:44 -0700 Subject: [PATCH 093/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index ff277934..e9191522 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -80,7 +80,8 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): from torch.nn.attention.flex_attention import _score_mod_signature from torch._inductor.lowering import make_pointwise, register_lowering - + from torch import Tensor + # Some internal torch.compile details from torch._inductor.virtualized import ops from functools import partial From 0703ce81363738fc47d37ced551af24c93c1ea3a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:39:21 -0700 Subject: [PATCH 094/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index e9191522..ac13d3a1 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -81,7 +81,7 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): from torch.nn.attention.flex_attention import _score_mod_signature from torch._inductor.lowering import make_pointwise, register_lowering from torch import Tensor - + # Some internal torch.compile details from torch._inductor.virtualized import ops from functools import partial @@ -128,12 +128,25 @@ def vmap(info, in_dims, x): _tanh_approx = _TanhApprox.apply - @functools.lru_cache - def generate_tanh_softcap(t): - def tanh_softcap(x, b, h, q_idx, kv_idx): - return t * _tanh_approx(x / t) + def generate_tanh_softcap(soft_cap: int, approx: bool = True) -> _score_mod_signature: + """Returns an tanh bias score_mod given the number of heads H + + Args: + soft_cap: The soft cap value to use for normalizing logits + approx: Whether to use the `tanh.approx.` ptx instruction + + Returns: + tanh_softcap: score_mod + """ + tanh = _tanh_approx if approx else torch.tanh + + def tanh_softcap(score, b, h, q_idx, kv_idx): + return soft_cap * tanh(score / soft_cap) + + prefix = "tanh_softcap_approx" if approx else "tanh_softcap" + tanh_softcap.__name__ = f"{prefix}_{soft_cap}" + return tanh_softcap - pass def causal_masker(b, h, q_idx, kv_idx): return q_idx >= kv_idx pass From e2cafc4a75c2738165ecfda7465566b8af3a0607 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 00:41:27 -0700 Subject: [PATCH 095/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 72 ++----------------------------- 1 file changed, 4 insertions(+), 68 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index ac13d3a1..b0d76de6 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -78,75 +78,11 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): # BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al import functools, math - from torch.nn.attention.flex_attention import _score_mod_signature - from torch._inductor.lowering import make_pointwise, register_lowering - from torch import Tensor - - # Some internal torch.compile details - from torch._inductor.virtualized import ops - from functools import partial - - - @torch.library.custom_op("approx::tanh", mutates_args=()) - def _tanh_approx(inp: Tensor) -> Tensor: - return torch.tanh(inp) - - - @_tanh_approx.register_fake - def _(inp: torch.Tensor) -> torch.Tensor: - return torch.tanh(inp) - - - def _tanh_approx_lowering(inp): - fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;") - return make_pointwise(fn)(inp) - - - register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering) - - - class _TanhApprox(torch.autograd.Function): - @staticmethod - def forward(x): - return torch.ops.approx.tanh(x) - - @staticmethod - def setup_context(ctx, inputs, output): - (x,) = inputs - result = output - ctx.save_for_backward(result) - - @staticmethod - def backward(ctx, grad_output): - (result,) = ctx.saved_tensors - return grad_output * (1 - result * result) - - @staticmethod - def vmap(info, in_dims, x): - return torch.tanh(x), 0 - - - _tanh_approx = _TanhApprox.apply - - def generate_tanh_softcap(soft_cap: int, approx: bool = True) -> _score_mod_signature: - """Returns an tanh bias score_mod given the number of heads H - - Args: - soft_cap: The soft cap value to use for normalizing logits - approx: Whether to use the `tanh.approx.` ptx instruction - - Returns: - tanh_softcap: score_mod - """ - tanh = _tanh_approx if approx else torch.tanh - - def tanh_softcap(score, b, h, q_idx, kv_idx): - return soft_cap * tanh(score / soft_cap) - - prefix = "tanh_softcap_approx" if approx else "tanh_softcap" - tanh_softcap.__name__ = f"{prefix}_{soft_cap}" - + def generate_tanh_softcap(t): + def tanh_softcap(x, b, h, q_idx, kv_idx): + return t * torch.tanh(x / t) return tanh_softcap + pass def causal_masker(b, h, q_idx, kv_idx): return q_idx >= kv_idx pass From 25fb05926e68a69618fa3240fd99889d95983c97 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 01:39:57 -0700 Subject: [PATCH 096/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index b0d76de6..ddc2ffbf 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -30,7 +30,7 @@ flex_attention as _flex_attention, create_block_mask as _create_block_mask, ) - _flex_attention = torch.compile(_flex_attention, dynamic = False) + _flex_attention = torch.compile(_flex_attention, dynamic = True) HAS_FLEX_ATTENTION = True except: HAS_FLEX_ATTENTION = False From 6ddcd6074b4c65c6b98e236827141ce15e307c7b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 5 Sep 2024 23:19:04 -0700 Subject: [PATCH 097/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index ddc2ffbf..727691d0 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -128,7 +128,8 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): s = self.config.query_pre_attn_scalar t = self.config.attn_logit_softcapping - A = flex_attention(s, t)(query = Q, key = K, value = V, block_mask = causal_mask) + fx = flex_attention(s, t) + A = fx(query = Q, key = K, value = V, block_mask = causal_mask) A = A.transpose(1, 2).contiguous() A = A.reshape(bsz, q_len, n_heads*head_dim) return A From a806b2063e4d26150c5198fa4e3c37d0146d140f Mon Sep 17 00:00:00 2001 From: Yihao Wang <42559837+AgainstEntropy@users.noreply.github.com> Date: Sat, 7 Sep 2024 01:48:19 -0400 Subject: [PATCH 098/110] Update chat_templates.py (#999) fix all misspelled "unsued" to "unused" --- unsloth/chat_templates.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index d81413ae..276fa7ca 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1033,7 +1033,7 @@ def to_sharegpt( merged_prompt = "", merged_column_name = "instruction", output_column_name = "output", - remove_unsued_columns = True, + remove_unused_columns = True, conversation_extension = 1, random_state = 3407, ): @@ -1047,7 +1047,7 @@ def to_sharegpt( merged_prompt = "", Prompt to merge columns into 1 input merged_column_name = "instruction", Final column name for the input field output_column_name = "output", Final column name for the output field - remove_unsued_columns = True, + remove_unused_columns = True, conversation_extension = 1, Automatically combines `conversation_extension` convos into 1 random_state = 3407, """ @@ -1080,8 +1080,8 @@ def __convert_to_sharegpt__(examples): __convert_to_sharegpt__, batched = True, desc = "Converting to ShareGPT", - # Remove unsued columns! - remove_columns = dataset.column_names if remove_unsued_columns else None, + # Remove unused columns! + remove_columns = dataset.column_names if remove_unused_columns else None, ) # Randomnly concat conversations to create a long stream! @@ -1115,8 +1115,8 @@ def __convert_to_sharegpt__(examples): __combine_conversations__, batched = True, desc = "Extending conversations", - # Remove unsued columns! - remove_columns = dataset.column_names if remove_unsued_columns else None, + # Remove unused columns! + remove_columns = dataset.column_names if remove_unused_columns else None, ) return dataset pass From a690e5e0b7a3129da4ade94a991c4914679e51f6 Mon Sep 17 00:00:00 2001 From: Peng Date: Sat, 7 Sep 2024 13:51:45 +0800 Subject: [PATCH 099/110] Update key from "from" to "user" (#1000) When use [tokenizer.apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating), the key should be "role" rather than "from", this is liknk to [this issue](https://github.com/unslothai/unsloth/issues/994) I don't know it is suitable for all situation, I also can add a dedicated parameter of the key if you think it is better. --- unsloth/chat_templates.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 276fa7ca..74479510 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1068,8 +1068,8 @@ def __convert_to_sharegpt__(examples): assistants = examples[output_column_name] texts = [ [ - {"from" : "user", "content" : str(user) }, - {"from" : "assistant", "content" : str(assistant)}, + {"role" : "user", "content" : str(user) }, + {"role" : "assistant", "content" : str(assistant)}, ] \ for user, assistant in zip(users, assistants) ] From 669371272c20e52deabe6aec574f60242b38cc04 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 6 Sep 2024 22:53:44 -0700 Subject: [PATCH 100/110] Update chat_templates.py --- unsloth/chat_templates.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 74479510..e19bea07 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1068,8 +1068,8 @@ def __convert_to_sharegpt__(examples): assistants = examples[output_column_name] texts = [ [ - {"role" : "user", "content" : str(user) }, - {"role" : "assistant", "content" : str(assistant)}, + {"from" : "human", "value" : str(user) }, + {"from" : "gpt", "value" : str(assistant)}, ] \ for user, assistant in zip(users, assistants) ] From fabda63997e9f263e706860dca87df988fb0e8e0 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Sat, 7 Sep 2024 11:19:17 -0700 Subject: [PATCH 101/110] Also patch the KTO trainer (#1001) --- unsloth/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 044629ea..b8f710b2 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1144,7 +1144,7 @@ def patch_sft_trainer_tokenizer(): # Patch train with fix_untrained_tokens for path_to_trainer in \ - ("sft_trainer.SFTTrainer", "dpo_trainer.DPOTrainer",): + ("sft_trainer.SFTTrainer", "dpo_trainer.DPOTrainer", "kto_trainer.KTOTrainer"): function_name, replacer = "train", "if resume_from_checkpoint is False:" function = getsource(eval(f"trl.trainer.{path_to_trainer}.{function_name}")) From f9b8a73f441e11f7e8adff24da450d1a6b556609 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 7 Sep 2024 15:13:35 -0700 Subject: [PATCH 102/110] flex attention --- unsloth/kernels/__init__.py | 2 ++ unsloth/kernels/flex_attention.py | 27 ++++++++++++++++----------- unsloth/models/llama.py | 4 ++-- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index d302768b..cd1d90f2 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -37,6 +37,8 @@ HAS_FLEX_ATTENTION, slow_attention_softcapping, slow_inference_attention_softcapping, + create_flex_attention_causal_mask, + create_flex_attention_sliding_window_mask, ) try: diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 727691d0..ac827fbb 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -72,6 +72,9 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): A = A.reshape(bsz, q_len, n_heads*head_dim) return A pass + + create_flex_attention_causal_mask = None + create_flex_attention_sliding_window_mask = None else: # See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb # for more examples @@ -88,12 +91,12 @@ def causal_masker(b, h, q_idx, kv_idx): pass @functools.lru_cache - def sliding_window_masker(size = 4096, q_len = 4096): + def sliding_window_masker(size = 4096): def sliding_window(b, h, q_idx, kv_idx): causal_mask = q_idx >= kv_idx window_mask = q_idx - kv_idx <= size return causal_mask & window_mask - return sliding_window if q_len >= size else causal_masker + return sliding_window pass @functools.lru_cache @@ -105,6 +108,17 @@ def create_block_mask(mask, n = 128): ) pass + def create_flex_attention_causal_mask(max_seq_length = 8192): + causal_mask = create_block_mask(causal_masker, max_seq_length) + return causal_mask + pass + + def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096): + sliding_masker = sliding_window_masker(sliding_window) + causal_mask = create_block_mask(sliding_masker, max_seq_length) + return causal_mask + pass + @functools.lru_cache def flex_attention(s, t): scale = 1.0 / math.sqrt(s) @@ -117,15 +131,6 @@ def flex_attention(s, t): def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim - if causal_mask == 0: - # Global attention - causal_mask = create_block_mask(causal_masker, q_len) - else: - # Sliding window attention - sliding_masker = sliding_window_masker(causal_mask, q_len) - causal_mask = create_block_mask(sliding_masker, q_len) - pass - s = self.config.query_pre_attn_scalar t = self.config.attn_logit_softcapping fx = flex_attention(s, t) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a5dc381d..d7d1e99b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -734,8 +734,8 @@ def LlamaModel_fast_forward( elif not hasattr(self, "SWA_mask"): if HAS_FLEX_ATTENTION: # Use Flex Attention instead! - self.SWA_mask = self.config.sliding_window - self.GA_mask = 0 + self.SWA_mask = create_flex_attention_sliding_window_mask(self.config.sliding_window) + self.GA_mask = create_flex_attention_causal_mask(self.max_seq_length) else: n = self.max_seq_length # self.config.max_position_embeddings # masked_fill is making stuff slower! From 2fa99790d32f719b8f6a6b3d540caee4e57db860 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 7 Sep 2024 15:17:36 -0700 Subject: [PATCH 103/110] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d7d1e99b..22dcb25a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -734,7 +734,7 @@ def LlamaModel_fast_forward( elif not hasattr(self, "SWA_mask"): if HAS_FLEX_ATTENTION: # Use Flex Attention instead! - self.SWA_mask = create_flex_attention_sliding_window_mask(self.config.sliding_window) + self.SWA_mask = create_flex_attention_sliding_window_mask(self.max_seq_length, self.config.sliding_window) self.GA_mask = create_flex_attention_causal_mask(self.max_seq_length) else: n = self.max_seq_length # self.config.max_position_embeddings From 86017d3f220933738d06e545be3a92114e7d1bb6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 7 Sep 2024 15:49:13 -0700 Subject: [PATCH 104/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index ac827fbb..7899b925 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -124,13 +124,22 @@ def flex_attention(s, t): scale = 1.0 / math.sqrt(s) score_mod = generate_tanh_softcap(t) return functools.partial( - _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True, + _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = False, ) pass def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim + n_kv_heads = self.num_key_value_heads + n_groups = self.num_key_value_groups + + # Grouped query attention + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + K = K.reshape(bsz, n_heads, q_len, head_dim) + V = V.reshape(bsz, n_heads, q_len, head_dim) + s = self.config.query_pre_attn_scalar t = self.config.attn_logit_softcapping fx = flex_attention(s, t) From 130c739f051e63843632b5db1e140bcf15da301f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 7 Sep 2024 15:51:53 -0700 Subject: [PATCH 105/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 7899b925..ac827fbb 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -124,22 +124,13 @@ def flex_attention(s, t): scale = 1.0 / math.sqrt(s) score_mod = generate_tanh_softcap(t) return functools.partial( - _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = False, + _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True, ) pass def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim - n_kv_heads = self.num_key_value_heads - n_groups = self.num_key_value_groups - - # Grouped query attention - K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) - V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) - K = K.reshape(bsz, n_heads, q_len, head_dim) - V = V.reshape(bsz, n_heads, q_len, head_dim) - s = self.config.query_pre_attn_scalar t = self.config.attn_logit_softcapping fx = flex_attention(s, t) From 528c673d372aef4357486beeeb161d1ac1c92e1b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 7 Sep 2024 17:37:07 -0700 Subject: [PATCH 106/110] Update _utils.py --- unsloth/models/_utils.py | 60 ++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 386fb84c..d6b8ba27 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -323,36 +323,36 @@ def is_big_gpu(index): # Torch compile arguments -torch_compile_arguments = [ - "config.dce = True", - "config.memory_planning = True", - "config.memory_pool = 'combined'", - "config.coordinate_descent_tuning = True", - "config.max_autotune_gemm = False", # GEMM is unnecessary - "config.autotune_multi_device = False", - "config.max_autotune_gemm_backends = 'ATEN'", # Not much faster - "config.aggressive_fusion = False", # Careful changes results! - "config.cuda.enable_cuda_lto = True", - "config.cuda.use_fast_math = True", - "config.cuda.compile_opt_level = '-O2'", -] -# Torch dynamo arguments -torch_dynamo_arguments = [ - "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 - "config.suppress_errors = True", # Supress errors for now - "config.do_not_emit_runtime_asserts = True", - "config.cache_size_limit = 1024", # Flex Attention -] -import torch._inductor.config as config -for _try_compile_argument in torch_compile_arguments: - try: exec(_try_compile_argument) - except: pass -pass -import torch._dynamo.config as config -for _try_dynamo_argument in torch_dynamo_arguments: - try: exec(_try_dynamo_argument) - except: pass -pass +# torch_compile_arguments = [ +# "config.dce = True", +# "config.memory_planning = True", +# "config.memory_pool = 'combined'", +# "config.coordinate_descent_tuning = True", +# "config.max_autotune_gemm = False", # GEMM is unnecessary +# "config.autotune_multi_device = False", +# "config.max_autotune_gemm_backends = 'ATEN'", # Not much faster +# "config.aggressive_fusion = False", # Careful changes results! +# "config.cuda.enable_cuda_lto = True", +# "config.cuda.use_fast_math = True", +# "config.cuda.compile_opt_level = '-O2'", +# ] +# # Torch dynamo arguments +# torch_dynamo_arguments = [ +# "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 +# "config.suppress_errors = True", # Supress errors for now +# "config.do_not_emit_runtime_asserts = True", +# "config.cache_size_limit = 1024", # Flex Attention +# ] +# import torch._inductor.config as config +# for _try_compile_argument in torch_compile_arguments: +# try: exec(_try_compile_argument) +# except: pass +# pass +# import torch._dynamo.config as config +# for _try_dynamo_argument in torch_dynamo_arguments: +# try: exec(_try_dynamo_argument) +# except: pass +# pass torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : True, From 7380ac532ca5ea4905c1315eca54a68a9d6fe78e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 7 Sep 2024 17:41:15 -0700 Subject: [PATCH 107/110] Update _utils.py --- unsloth/models/_utils.py | 60 ++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index d6b8ba27..b5a57a77 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -323,36 +323,36 @@ def is_big_gpu(index): # Torch compile arguments -# torch_compile_arguments = [ -# "config.dce = True", -# "config.memory_planning = True", -# "config.memory_pool = 'combined'", -# "config.coordinate_descent_tuning = True", -# "config.max_autotune_gemm = False", # GEMM is unnecessary -# "config.autotune_multi_device = False", -# "config.max_autotune_gemm_backends = 'ATEN'", # Not much faster -# "config.aggressive_fusion = False", # Careful changes results! -# "config.cuda.enable_cuda_lto = True", -# "config.cuda.use_fast_math = True", -# "config.cuda.compile_opt_level = '-O2'", -# ] -# # Torch dynamo arguments -# torch_dynamo_arguments = [ -# "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 -# "config.suppress_errors = True", # Supress errors for now -# "config.do_not_emit_runtime_asserts = True", -# "config.cache_size_limit = 1024", # Flex Attention -# ] -# import torch._inductor.config as config -# for _try_compile_argument in torch_compile_arguments: -# try: exec(_try_compile_argument) -# except: pass -# pass -# import torch._dynamo.config as config -# for _try_dynamo_argument in torch_dynamo_arguments: -# try: exec(_try_dynamo_argument) -# except: pass -# pass +torch_compile_arguments = [ + "config.dce = True", + "config.memory_planning = True", + "config.memory_pool = 'combined'", + "config.coordinate_descent_tuning = True", + "config.max_autotune_gemm = False", # GEMM is unnecessary + "config.autotune_multi_device = False", + "config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster + "config.aggressive_fusion = False", # Careful changes results! + "config.cuda.enable_cuda_lto = True", + "config.cuda.use_fast_math = True", + "config.cuda.compile_opt_level = '-O2'", +] +# Torch dynamo arguments +torch_dynamo_arguments = [ + "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 + "config.suppress_errors = True", # Supress errors for now + "config.do_not_emit_runtime_asserts = True", + "config.cache_size_limit = 1024", # Flex Attention +] +import torch._inductor.config as config +for _try_compile_argument in torch_compile_arguments: + try: exec(_try_compile_argument) + except: pass +pass +import torch._dynamo.config as config +for _try_dynamo_argument in torch_dynamo_arguments: + try: exec(_try_dynamo_argument) + except: pass +pass torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : True, From 4e1a50c4f19673ef23a2d1059980517f19f6d7b5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 7 Sep 2024 17:43:58 -0700 Subject: [PATCH 108/110] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index ac827fbb..2fba359b 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -30,7 +30,7 @@ flex_attention as _flex_attention, create_block_mask as _create_block_mask, ) - _flex_attention = torch.compile(_flex_attention, dynamic = True) + _flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options) HAS_FLEX_ATTENTION = True except: HAS_FLEX_ATTENTION = False From 6e9d3de33011f1b8f1074329364291e6fe5ef41f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 7 Sep 2024 17:48:28 -0700 Subject: [PATCH 109/110] Update gemma2.py --- unsloth/models/gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index bf40ea8a..7c9d5c92 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -62,7 +62,7 @@ # [TODO] We must randomnly use torch.compile? # I checked the gradients and formulas and I'm sure it's correct. # I'm stumped :( -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True): old_dtype = X.dtype X = X.float() From 879fc88e4b43e1e3ade9c5a61e139e7e5706af7f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 8 Sep 2024 03:15:36 -0700 Subject: [PATCH 110/110] Update gemma2.py --- unsloth/models/gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 7c9d5c92..bf40ea8a 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -62,7 +62,7 @@ # [TODO] We must randomnly use torch.compile? # I checked the gradients and formulas and I'm sure it's correct. # I'm stumped :( -# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True): old_dtype = X.dtype X = X.float()