Skip to content

Commit

Permalink
Prerelease bugfixes (#49)
Browse files Browse the repository at this point in the history
* disable flash attention for devices less than 8.0 cuda caps

* use attention_mask with a vanilla self-attention

* get rid of weak lora presets

* get rid of twostage_filter
handle zero files cases in the gpu-filter and finetune training

* typos

* removed legacy

---------

Co-authored-by: mitya <dimitry.ageev@gmail.com>
  • Loading branch information
JegernOUTT and mitya52 authored Jul 28, 2023
1 parent ddb6dd7 commit 73eeffd
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 162 deletions.
10 changes: 1 addition & 9 deletions refact_data_pipeline/finetune/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,7 @@ def set_lora_quality_by_heuristics(
}

scores_per_loraconfigs = {
(0, 2): dict(lora_target_modules=[
"qkv", "out",
], lora_r=4, lora_alpha=8, lora_dropout=0.05, lora_init_scale=0.01,
freeze_exceptions=["lora"]),
(2, 4): dict(lora_target_modules=[
"qkv", "out",
], lora_r=8, lora_alpha=16, lora_dropout=0.05, lora_init_scale=0.01,
freeze_exceptions=["lora"]),
(4, 6): dict(lora_target_modules=[
(0, 6): dict(lora_target_modules=[
"qkv", "out",
], lora_r=32, lora_alpha=64, lora_dropout=0.01, lora_init_scale=0.01,
freeze_exceptions=["lora"]),
Expand Down
164 changes: 19 additions & 145 deletions refact_data_pipeline/finetune/finetune_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,125 +166,7 @@ def loss_based_filter(
traces.log("calculated frames %i " % len(train_files))
traces.log("avg loss %0.4f" % status_dict['avg_loss'])

return rejected, None


def twostage_filter(
train_files,
model,
loss_function,
dataopts,
*,
fcfg
):
t0 = time.time()
grad31_name = ""
grad31_size = []
for n, p in model.named_parameters():
if n == "blocks.31.sa.qkv.lora_B.weight":
grad31_name = n
grad31_size = p.numel() # [7680, 16]
assert grad31_size
all_documents_grad31 = th.zeros([grad31_size], dtype=th.float32, device="cuda")

def dig_out_grad31():
for n, p in model.named_parameters():
# if p.grad is not None and "lora_B" in n and "qkv.lora_B" in n:
# traces.log("%s %s mean=%0.6f std=%0.6f" % (n, p.shape, p.grad.mean(), p.grad.std()))
if n == grad31_name:
if not th.isnan(p.grad).any():
return p.grad.flatten()
else:
return None

scale = 65536
most_typical_file_fdict = None
most_typical_file_cos = -1
rejected = set()
status_dict['total_steps'] = len(train_files)
is_force_included, is_force_excluded = get_force_included_excluded_matchers()
for Pass in [1, 2]:
loss_list = []
logging.info("STATUS filtering %d" % Pass)
for iter_n, fdict in enumerate(train_files):
t0_iter = time.time()
_save_stats("filtering %d" % Pass)
test_ds = finetune_datasource.local_plain([fdict], dataopts)
test_batch_iter_fn = partial(BatchIterator, dataopts=dict(batch_size=1, drop_last=False))
if is_force_included(fdict['path']):
_file_accepted("INCLUDED_BY_MASK", fdict["path"])
status_dict["accepted"] += 1
continue
elif is_force_excluded(fdict['path']):
traces.log("REJECTED FILTER %-100s EXCLUDED_BY_MASK" % fdict["path"])
rejected.add(fdict["path"])
_file_rejected("FILTER1 EXCLUDED_BY_MASK", fdict["path"])
status_dict["rejected"] += 1
continue
for batch, _stats in test_batch_iter_fn(test_ds):
input = batch['input'][:].contiguous()
while 1:
model.zero_grad()
if fcfg["low_gpu_mem_mode"]:
logits = model.forward_train_cp(input)
else:
logits = model.lm_forward(model(input, attention_mask=None)[0])
loss = loss_function(
logits=logits,
labels=batch['labels'],
mask=batch['mask'],
)
floss = float(loss.item())
loss *= float(scale)
loss.backward()
grad31 = dig_out_grad31()
if grad31 is None:
traces.log("scaling %d -> %d" % (scale, scale // 2))
scale //= 2
if scale <= 128:
traces.log("The `scale` is too low, can't complete the task :(")
traces.log("One potential way to fix this is to delete very unusual files, "
"this last one was:\n%s" % fdict["path"])
sys.exit(1)
continue
break
assert floss is not None
loss_list.append(floss)
status_dict['avg_loss'] = sum(loss_list) / len(loss_list)
if Pass == 1:
if floss > fcfg['filter_loss_threshold']:
traces.log("REJECTED FILTER1 %-100s loss %0.3f" % (fdict["path"], floss))
rejected.add(fdict["path"])
_file_rejected("FILTER1 %0.3f" % (floss,), fdict["path"])
status_dict["rejected"] += 1
else:
all_documents_grad31 += grad31
elif Pass == 2:
cos = th.nn.functional.cosine_similarity(all_documents_grad31, grad31.float(), dim=0)
if cos < fcfg['filter_gradcosine_threshold']:
traces.log("REJECTED FILTER2 %-100s cosine_similarity %+0.3f" % (fdict["path"], cos))
rejected.add(fdict["path"])
_file_rejected("FILTER2 %+0.3f" % (cos,), fdict["path"])
status_dict["rejected"] += 1
else:
_file_accepted("LOSS %0.3f COSINE %+0.3f" % (floss, cos), fdict["path"])
status_dict["accepted"] += 1
if cos > most_typical_file_cos:
most_typical_file_cos = cos
most_typical_file_fdict = fdict
break
iter_time = time.time() - t0_iter
if iter_time_last is None:
eta = (len(train_files) - iter_n) * iter_time
else:
eta = (len(train_files) - iter_n) * ((iter_time + iter_time_last) / 2)
iter_time_last = iter_time
status_dict["eta_minutes"] = int(round(eta / 60))
status_dict["worked_steps"] = iter_n
status_dict["worked_minutes"] = int((time.time() - t0) / 60)
traces.log("calculated frames %i " % len(loss_list))
traces.log("avg loss %0.4f" % status_dict['avg_loss'])
return rejected, most_typical_file_fdict
return rejected


def pre_filtering():
Expand All @@ -293,6 +175,11 @@ def pre_filtering():
traces.log("Reading %s" % env.CONFIG_HOW_TO_FILTER)
fcfg.update(**json.load(open(env.CONFIG_HOW_TO_FILTER)))

has_train_files = os.path.exists(os.path.join(env.DIR_UNPACKED, unfiltered_train)) and \
len(list(jsonlines.open(os.path.join(env.DIR_UNPACKED, unfiltered_train))))
if not has_train_files:
raise RuntimeError("No train files have been provided for filtering")

logging.info("STATUS smart filter init")
logging.info("Train set filtering, loading model...")
traces.log("Train set filtering, loading model...")
Expand Down Expand Up @@ -333,31 +220,23 @@ def pre_filtering():
f"more than allowed {fcfg['limit_test_files']}.\n"
f"It could heavily slow down the training process")

text = "FILTER1 explanation: initial loss too big calculated on a single file, threshold is %0.3f. " \
text = "FILTER explanation: initial loss too big calculated on a single file, threshold is %0.3f. " \
"Likely means the file doesn't contain code." % fcfg["filter_loss_threshold"]
traces.log(textwrap.fill(text, width=100))
if fcfg["gradient_based_filter"]:
text = "FILTER2 explanation: gradient cosine similarity is bad, calculated on a file, threshold is %0.3f. " \
"This means the file does not pull the model in the same direction as the rest of the files." \
% fcfg["filter_gradcosine_threshold"]
traces.log(textwrap.fill(text, width=100))

filter = twostage_filter if fcfg["gradient_based_filter"] else loss_based_filter
filtered, most_typical_file_fdict = filter(

filtered = loss_based_filter(
train_files, model, loss_function, dataopts, fcfg=fcfg
)

test_filenames = set()
if most_typical_file_fdict is None and len(test_files) == 0:
traces.log("no most_typical_file_fdict was found, create the new test set")
test_files = random.choices(train_files, k=fcfg["limit_test_files"])
test_filenames.update([p['path'] for p in test_files])
elif most_typical_file_fdict is not None:
traces.log("detected \"most typical\" file: %s" % (most_typical_file_fdict["path"]))
if len(test_files) > 0:
most_typical_file_fdict = None
if len(test_files) == 0:
test_files_count = min(fcfg["limit_test_files"], len(train_files) // 2)
if test_files_count == 0:
traces.log("Warning: It is too little files to choose a test set from. "
"It's strongly recommended to choose a test set manually to be able to prevent overfitting")
else:
test_filenames.add(most_typical_file_fdict["path"])
test_files = random.choices(train_files, k=fcfg["limit_test_files"])
test_filenames.update([p['path'] for p in test_files])

with open(filtered_train, "w") as f:
for fdict in train_files:
Expand All @@ -370,14 +249,9 @@ def pre_filtering():

traces.log("-" * 40 + "TEST SET" + "-" * 40)
with open(filtered_test, "w") as f:
if most_typical_file_fdict is not None:
traces.log("test set is auto selected, consists of one file called: %s" % most_typical_file_fdict["path"])
traces.log("this file is removed from the train set.")
f.write(json.dumps(most_typical_file_fdict) + "\n")
else:
for fdict in test_files:
traces.log("test set file: %s" % (fdict["path"]))
f.write(json.dumps(fdict) + "\n")
for fdict in test_files:
traces.log("test set file: %s" % (fdict["path"]))
f.write(json.dumps(fdict) + "\n")


def needs_any_work():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
"filter_loss_threshold": 3.5,
"filter_gradcosine_threshold": 0.1,
"low_gpu_mem_mode": True,
"gradient_based_filter": False,
"debug": False
}
11 changes: 8 additions & 3 deletions refact_data_pipeline/finetune/finetune_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,14 @@ def create_data(cfg, enc) -> Tuple[Any, Optional[Any]]:
batch_size=cfg['train_batch_size'],
drop_last=True
))

if os.path.exists(os.path.join(env.DIR_UNPACKED, filtered_test)) and \
len(list(jsonlines.open(os.path.join(env.DIR_UNPACKED, filtered_test)))) > 0:
has_train_files = os.path.exists(os.path.join(env.DIR_UNPACKED, filtered_train)) and \
len(list(jsonlines.open(os.path.join(env.DIR_UNPACKED, filtered_train)))) > 0
if not has_train_files:
raise RuntimeError("No train files have been provided")

has_test_files = os.path.exists(os.path.join(env.DIR_UNPACKED, filtered_test)) \
and len(list(jsonlines.open(os.path.join(env.DIR_UNPACKED, filtered_test)))) > 0
if has_test_files:
test_ds = finetune_datasource.local_sequence_plain_infill(filtered_test, test_dataopts)
test_ds = list(test_ds)
else:
Expand Down
3 changes: 2 additions & 1 deletion refact_data_pipeline/finetune/model_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def make_model(
lora_dropout=lora_dropout,
lora_init_scale=lora_init_scale
)
model = apply_flash_attention(model)
if th.cuda.get_device_capability() >= (8, 0):
model = apply_flash_attention(model)
for param in list(model.parameters()):
param.requires_grad = True
model = freeze_model(
Expand Down
16 changes: 14 additions & 2 deletions refact_models/codify_modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from typing import Optional, Tuple

import functools
import torch
import torch.nn.functional as F
from torch import nn
from typing import Optional, Tuple

from refact_models.codify_modules import ALiBiBias


@functools.lru_cache(maxsize=1)
def get_attention_mask(context_size: int, device) -> torch.Tensor:
mask = torch.ones((context_size, context_size),
dtype=torch.bool, device=device)
mask = torch.triu(mask, 1)
return mask


class MultiheadSelfAttention(nn.Module):
def __init__(self, config):
super(MultiheadSelfAttention, self).__init__()
Expand Down Expand Up @@ -43,6 +51,10 @@ def attention(self,
query.device,
query.dtype)
attn_weights = attn_weights + alibi
if attention_mask is None:
# get the attention mask anyway because we're dealing only with decoder models here
attention_mask = get_attention_mask(context_size=key.shape[2], device=key.device)

if attention_mask is not None:
attn_weights = torch.masked_fill(attn_weights, attention_mask, -10000)

Expand Down
14 changes: 13 additions & 1 deletion refact_models/refact_modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

from refact_models.codify_modules import ALiBiBias

from typing import Optional, Tuple

@functools.lru_cache(maxsize=1)
def get_attention_mask(context_size: int, device) -> torch.Tensor:
mask = torch.ones((context_size, context_size),
dtype=torch.bool, device=device)
mask = torch.triu(mask, 1)
return mask


class MultiHeadAttention(nn.Module):
Expand Down Expand Up @@ -34,6 +42,10 @@ def _attention(self,
query.shape[0], query.shape[2], key.shape[2],
query.device, query.dtype)
attn_weights = attn_weights + alibi
if attention_mask is None:
# get the attention mask anyway because we're dealing only with decoder models here
attention_mask = get_attention_mask(context_size=key.shape[2], device=key.device)

if attention_mask is not None:
attn_weights = torch.masked_fill(attn_weights, attention_mask, -10000)

Expand Down

0 comments on commit 73eeffd

Please sign in to comment.