Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes #1004

Merged
merged 116 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
7b81ca5
Update _utils.py
danielhanchen Aug 22, 2024
94f2d34
Update _utils.py
danielhanchen Aug 22, 2024
7c5222d
Update _utils.py
danielhanchen Aug 22, 2024
15d4417
Update _utils.py
danielhanchen Aug 22, 2024
1ea463c
Update _utils.py
danielhanchen Aug 22, 2024
cf929e2
Update tokenizer_utils.py
danielhanchen Aug 22, 2024
5a7be98
Update tokenizer_utils.py
danielhanchen Aug 22, 2024
2590b4c
Update tokenizer_utils.py
danielhanchen Aug 22, 2024
621e65b
update token retrieval logic (#952)
not-lain Aug 23, 2024
b62e5cd
Update llama.py
danielhanchen Aug 23, 2024
fb9dd65
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Aug 23, 2024
3b49609
get_token
danielhanchen Aug 24, 2024
9c8875e
Update README.md
danielhanchen Aug 24, 2024
c25de14
Merge branch 'main' into nightly
danielhanchen Aug 25, 2024
646a27b
Merge branch 'main' into nightly
danielhanchen Aug 27, 2024
a44357d
Update gemma2.py
danielhanchen Aug 30, 2024
7ed1c16
Update rms_layernorm.py
danielhanchen Aug 30, 2024
d7ef49e
synchronize
danielhanchen Aug 30, 2024
9a69548
Update gemma2.py
danielhanchen Aug 30, 2024
e6dadb4
Update rms_layernorm.py
danielhanchen Aug 30, 2024
f8e77cf
Update rms_layernorm.py
danielhanchen Aug 30, 2024
cfbaa97
Update rms_layernorm.py
danielhanchen Aug 30, 2024
32b2f3f
layernorm
danielhanchen Aug 30, 2024
9e7057d
Update rms_layernorm.py
danielhanchen Aug 30, 2024
a193508
Update gemma2.py
danielhanchen Aug 30, 2024
65eaa2d
Update rms_layernorm.py
danielhanchen Aug 30, 2024
1beeb22
Update rms_layernorm.py
danielhanchen Aug 30, 2024
1eb7705
revert
danielhanchen Aug 30, 2024
c3fe972
Gemma
danielhanchen Aug 31, 2024
75dbfba
Update rms_layernorm.py
danielhanchen Aug 31, 2024
332b091
Update rms_layernorm.py
danielhanchen Aug 31, 2024
4ecc119
Update rms_layernorm.py
danielhanchen Aug 31, 2024
07a1246
Update rms_layernorm.py
danielhanchen Aug 31, 2024
e3239e4
Update rms_layernorm.py
danielhanchen Aug 31, 2024
6ae1ac2
Update rms_layernorm.py
danielhanchen Aug 31, 2024
4d89f27
Update rms_layernorm.py
danielhanchen Aug 31, 2024
c76be22
Update rms_layernorm.py
danielhanchen Aug 31, 2024
ace509c
Update rms_layernorm.py
danielhanchen Aug 31, 2024
e474cfe
Update rms_layernorm.py
danielhanchen Aug 31, 2024
1576a1e
Update rms_layernorm.py
danielhanchen Aug 31, 2024
a2c4691
Update rms_layernorm.py
danielhanchen Aug 31, 2024
1a02e75
Update rms_layernorm.py
danielhanchen Aug 31, 2024
a26e1d1
Update rms_layernorm.py
danielhanchen Aug 31, 2024
afdb443
Update rms_layernorm.py
danielhanchen Sep 1, 2024
c3e14d8
Update rms_layernorm.py
danielhanchen Sep 1, 2024
1830bdd
Update rms_layernorm.py
danielhanchen Sep 1, 2024
6abf66a
Update rms_layernorm.py
danielhanchen Sep 1, 2024
f5cf796
Update rms_layernorm.py
danielhanchen Sep 1, 2024
b191530
Update rms_layernorm.py
danielhanchen Sep 1, 2024
512c61f
Update rms_layernorm.py
danielhanchen Sep 1, 2024
f5d50ef
Update rms_layernorm.py
danielhanchen Sep 1, 2024
d791bb9
Update rms_layernorm.py
danielhanchen Sep 1, 2024
9225608
Update gemma2.py
danielhanchen Sep 1, 2024
f61869c
Change UnslothTrainingArguments base class to SFTConfig (#979)
vTuanpham Sep 2, 2024
73d49ad
Cohere
danielhanchen Sep 2, 2024
86b6236
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Sep 2, 2024
edef5ca
Update trainer.py
danielhanchen Sep 2, 2024
6d4300c
Cohere
danielhanchen Sep 2, 2024
754e670
Cohere
danielhanchen Sep 2, 2024
d242866
New models
danielhanchen Sep 3, 2024
0b7e973
Update llama.py
danielhanchen Sep 3, 2024
19549f2
Update llama.py
danielhanchen Sep 3, 2024
8823e13
Update cohere.py
danielhanchen Sep 3, 2024
90050b7
Update llama.py
danielhanchen Sep 3, 2024
4c1ec3a
Update cohere.py
danielhanchen Sep 3, 2024
97b3956
retry
danielhanchen Sep 3, 2024
fd615ea
Update fast_lora.py
danielhanchen Sep 3, 2024
fe45990
Update llama.py
danielhanchen Sep 3, 2024
f564b8a
Update fast_lora.py
danielhanchen Sep 3, 2024
b26da84
Update llama.py
danielhanchen Sep 3, 2024
61be6a3
Update llama.py
danielhanchen Sep 3, 2024
ea48761
Update cross_entropy_loss.py
danielhanchen Sep 3, 2024
6e795c6
_apply_lora_mlp
danielhanchen Sep 3, 2024
dacba39
Update _utils.py
danielhanchen Sep 3, 2024
5074427
Gemma fixes
danielhanchen Sep 3, 2024
743ba55
Update llama.py
danielhanchen Sep 3, 2024
315136a
Merge branch 'main' into nightly
danielhanchen Sep 3, 2024
7ea6395
Update flex_attention.py
danielhanchen Sep 3, 2024
91d6773
Merge branch 'main' into nightly
danielhanchen Sep 4, 2024
df06a04
Update llama.py
danielhanchen Sep 4, 2024
7f139f1
layernorm
danielhanchen Sep 4, 2024
068fc0d
Update llama.py
danielhanchen Sep 4, 2024
4eaccb0
Update llama.py
danielhanchen Sep 4, 2024
4f909fc
Flex Attention
danielhanchen Sep 5, 2024
efef0ee
Update gemma2.py
danielhanchen Sep 5, 2024
6e8951f
Update __init__.py
danielhanchen Sep 5, 2024
d60a18c
Update flex_attention.py
danielhanchen Sep 5, 2024
1b4132e
Update flex_attention.py
danielhanchen Sep 5, 2024
f5d11dc
Update flex_attention.py
danielhanchen Sep 5, 2024
2454659
Update flex_attention.py
danielhanchen Sep 5, 2024
984d217
Update flex_attention.py
danielhanchen Sep 5, 2024
e3846f5
Update flex_attention.py
danielhanchen Sep 5, 2024
2d29299
Update flex_attention.py
danielhanchen Sep 5, 2024
03310b9
Update flex_attention.py
danielhanchen Sep 5, 2024
eb37676
Update flex_attention.py
danielhanchen Sep 5, 2024
cb6a835
Update flex_attention.py
danielhanchen Sep 5, 2024
cbd6a6a
Update flex_attention.py
danielhanchen Sep 5, 2024
712deaa
Update flex_attention.py
danielhanchen Sep 5, 2024
6e74563
Update flex_attention.py
danielhanchen Sep 5, 2024
0703ce8
Update flex_attention.py
danielhanchen Sep 5, 2024
e2cafc4
Update flex_attention.py
danielhanchen Sep 5, 2024
25fb059
Update flex_attention.py
danielhanchen Sep 5, 2024
6ddcd60
Update flex_attention.py
danielhanchen Sep 6, 2024
a806b20
Update chat_templates.py (#999)
AgainstEntropy Sep 7, 2024
a690e5e
Update key from "from" to "user" (#1000)
wa008 Sep 7, 2024
6693712
Update chat_templates.py
danielhanchen Sep 7, 2024
fabda63
Also patch the KTO trainer (#1001)
corbt Sep 7, 2024
f9b8a73
flex attention
danielhanchen Sep 7, 2024
2fa9979
Update llama.py
danielhanchen Sep 7, 2024
86017d3
Update flex_attention.py
danielhanchen Sep 7, 2024
130c739
Update flex_attention.py
danielhanchen Sep 7, 2024
528c673
Update _utils.py
danielhanchen Sep 8, 2024
7380ac5
Update _utils.py
danielhanchen Sep 8, 2024
4e1a50c
Update flex_attention.py
danielhanchen Sep 8, 2024
6e9d3de
Update gemma2.py
danielhanchen Sep 8, 2024
879fc88
Update gemma2.py
danielhanchen Sep 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ def to_sharegpt(
merged_prompt = "",
merged_column_name = "instruction",
output_column_name = "output",
remove_unsued_columns = True,
remove_unused_columns = True,
conversation_extension = 1,
random_state = 3407,
):
Expand All @@ -1047,7 +1047,7 @@ def to_sharegpt(
merged_prompt = "", Prompt to merge columns into 1 input
merged_column_name = "instruction", Final column name for the input field
output_column_name = "output", Final column name for the output field
remove_unsued_columns = True,
remove_unused_columns = True,
conversation_extension = 1, Automatically combines `conversation_extension` convos into 1
random_state = 3407,
"""
Expand All @@ -1068,8 +1068,8 @@ def __convert_to_sharegpt__(examples):
assistants = examples[output_column_name]
texts = [
[
{"from" : "user", "content" : str(user) },
{"from" : "assistant", "content" : str(assistant)},
{"from" : "human", "value" : str(user) },
{"from" : "gpt", "value" : str(assistant)},
] \
for user, assistant in zip(users, assistants)
]
Expand All @@ -1080,8 +1080,8 @@ def __convert_to_sharegpt__(examples):
__convert_to_sharegpt__,
batched = True,
desc = "Converting to ShareGPT",
# Remove unsued columns!
remove_columns = dataset.column_names if remove_unsued_columns else None,
# Remove unused columns!
remove_columns = dataset.column_names if remove_unused_columns else None,
)

# Randomnly concat conversations to create a long stream!
Expand Down Expand Up @@ -1115,8 +1115,8 @@ def __convert_to_sharegpt__(examples):
__combine_conversations__,
batched = True,
desc = "Extending conversations",
# Remove unsued columns!
remove_columns = dataset.column_names if remove_unsued_columns else None,
# Remove unused columns!
remove_columns = dataset.column_names if remove_unused_columns else None,
)
return dataset
pass
Expand Down
8 changes: 2 additions & 6 deletions unsloth/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,10 @@
HAS_FLEX_ATTENTION,
slow_attention_softcapping,
slow_inference_attention_softcapping,
create_flex_attention_causal_mask,
create_flex_attention_sliding_window_mask,
)

if HAS_FLEX_ATTENTION:
from .flex_attention import (
FLEX_ATTENTION_PADDING,
)
pass

try:
print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
except:
Expand Down
157 changes: 109 additions & 48 deletions unsloth/kernels/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,59 +25,120 @@
}

# Flex Attention supported from torch 2.5 onwards only
import torch.nn
if hasattr(torch.nn, "attention"):
import torch.nn.attention
if hasattr(torch.nn.attention, "flex_attention"):
import torch.nn.attention.flex_attention
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
FLEX_ATTENTION_PADDING = getattr(
torch.nn.attention.flex_attention,
"_DEFAULT_SPARSE_BLOCK_SIZE",
1,
)
flex_attention = torch.compile(flex_attention, dynamic = False)
HAS_FLEX_ATTENTION = True
else:
HAS_FLEX_ATTENTION = False
pass
else:
try:
from torch.nn.attention.flex_attention import (
flex_attention as _flex_attention,
create_block_mask as _create_block_mask,
)
_flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
HAS_FLEX_ATTENTION = True
except:
HAS_FLEX_ATTENTION = False
pass

# Logit softcapping
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_groups = self.num_key_value_groups

# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)

# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
if not HAS_FLEX_ATTENTION:

Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch.matmul(Q, K.transpose(2, 3))
A = t * torch.tanh(A / t) # Logit softcapping
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch.matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
# Logit softcapping
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_groups = self.num_key_value_groups

# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)

# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping

Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch.matmul(Q, K.transpose(2, 3))
A = t * torch.tanh(A / t) # Logit softcapping
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch.matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
pass

create_flex_attention_causal_mask = None
create_flex_attention_sliding_window_mask = None
else:
# See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
# for more examples
# BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al
import functools, math

def generate_tanh_softcap(t):
def tanh_softcap(x, b, h, q_idx, kv_idx):
return t * torch.tanh(x / t)
return tanh_softcap
pass
def causal_masker(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
pass

@functools.lru_cache
def sliding_window_masker(size = 4096):
def sliding_window(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = q_idx - kv_idx <= size
return causal_mask & window_mask
return sliding_window
pass

@functools.lru_cache
def create_block_mask(mask, n = 128):
return _create_block_mask(
mask, 1, 1, n, n,
BLOCK_SIZE = 128,
_compile = True,
)
pass

def create_flex_attention_causal_mask(max_seq_length = 8192):
causal_mask = create_block_mask(causal_masker, max_seq_length)
return causal_mask
pass

def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
sliding_masker = sliding_window_masker(sliding_window)
causal_mask = create_block_mask(sliding_masker, max_seq_length)
return causal_mask
pass

@functools.lru_cache
def flex_attention(s, t):
scale = 1.0 / math.sqrt(s)
score_mod = generate_tanh_softcap(t)
return functools.partial(
_flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
)
pass

def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
fx = flex_attention(s, t)
A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
pass
pass


Expand Down
5 changes: 3 additions & 2 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,18 @@ def is_big_gpu(index):
"config.coordinate_descent_tuning = True",
"config.max_autotune_gemm = False", # GEMM is unnecessary
"config.autotune_multi_device = False",
"config.max_autotune_gemm_backends = 'ATEN'", # Not much faster
"config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster
"config.aggressive_fusion = False", # Careful changes results!
"config.cuda.enable_cuda_lto = True",
"config.cuda.use_fast_math = True",
"config.cuda.compile_opt_level = '-O2'",
]
# Torch dynamo arguments
torch_dynamo_arguments = [
"config.accumulated_cache_size_limit = 512", # Bump up a bit from 256
"config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256
"config.suppress_errors = True", # Supress errors for now
"config.do_not_emit_runtime_asserts = True",
"config.cache_size_limit = 1024", # Flex Attention
]
import torch._inductor.config as config
for _try_compile_argument in torch_compile_arguments:
Expand Down
1 change: 0 additions & 1 deletion unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def Gemma2Attention_fast_forward(
)
A = A.reshape(bsz, q_len, n_heads*head_dim)
else:
mask = causal_mask if attention_mask is None else attention_mask
fx = slow_inference_attention_softcapping \
if "_flag_for_generation" in kwargs else \
slow_attention_softcapping
Expand Down
48 changes: 24 additions & 24 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,6 @@ def LlamaModel_fast_forward(
offloaded_gradient_checkpointing = True
pass

# Check for Flex Attention
# if IS_GEMMA2 and HAS_FLEX_ATTENTION:
# if not (seq_length % FLEX_ATTENTION_PADDING == 0):
# USE_FLEX_ATTENTION = True


# Gemma2 has alternating SWA and global attn
if IS_GEMMA2:
if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
Expand All @@ -738,23 +732,29 @@ def LlamaModel_fast_forward(
sliding_window = None,
)
elif not hasattr(self, "SWA_mask"):
n = self.max_seq_length # self.config.max_position_embeddings
# masked_fill is making stuff slower!
# self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
# self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
self.SWA_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = self.config.sliding_window,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)

self.GA_mask = AttentionMaskConverter(
is_causal = True,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
if HAS_FLEX_ATTENTION:
# Use Flex Attention instead!
self.SWA_mask = create_flex_attention_sliding_window_mask(self.max_seq_length, self.config.sliding_window)
self.GA_mask = create_flex_attention_causal_mask(self.max_seq_length)
else:
n = self.max_seq_length # self.config.max_position_embeddings
# masked_fill is making stuff slower!
# self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
# self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
self.SWA_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = self.config.sliding_window,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)

self.GA_mask = AttentionMaskConverter(
is_causal = True,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
pass
pass
pass

Expand Down Expand Up @@ -821,7 +821,7 @@ def custom_forward(*inputs):
(fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
(self.norm, hidden_states)
elif IS_COHERE:
hidden_states = fast_layernorm_compiled(self.norm, hidden_states)
hidden_states = self.norm(hidden_states)
else:
hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
pass
Expand Down
2 changes: 1 addition & 1 deletion unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ def patch_sft_trainer_tokenizer():

# Patch train with fix_untrained_tokens
for path_to_trainer in \
("sft_trainer.SFTTrainer", "dpo_trainer.DPOTrainer",):
("sft_trainer.SFTTrainer", "dpo_trainer.DPOTrainer", "kto_trainer.KTOTrainer"):

function_name, replacer = "train", "if resume_from_checkpoint is False:"
function = getsource(eval(f"trl.trainer.{path_to_trainer}.{function_name}"))
Expand Down