Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move filter to finetune tab #84

Merged
merged 18 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions known_models_db/refact_known_models/refact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"chat_scratchpad_class": "refact_scratchpads:ScratchpadHuggingfaceRefact",
"model_class_kwargs": {},
"required_memory_mb": 6000,
"filter_caps": ["Refact", "completion"],
"filter_caps": ["Refact", "completion", "finetune"],
},

"CONTRASTcode/medium/multi": {
Expand All @@ -17,7 +17,7 @@
"model_class": "refact_models:CodifyModel",
"T": 2048,
"required_memory_mb": 3500,
"filter_caps": ["CONTRASTcode", "completion"],
"filter_caps": ["CONTRASTcode", "completion", "finetune"],
},

"CONTRASTcode/3b/multi": {
Expand Down
14 changes: 11 additions & 3 deletions refact_data_pipeline/finetune/finetune_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import math
import torch

from known_models_db.refact_known_models import models_mini_db
from refact_data_pipeline.finetune import traces
from self_hosting_machinery import env

from typing import Any, Dict, List


def base_config(env):
def base_config(model_name: str):
if model_name not in models_mini_db:
raise RuntimeError(f"Unknown model {model_name}, try to update repo")
model_info = models_mini_db[model_name]
if "finetune" not in model_info.get("filter_caps", []):
raise RuntimeError(f"Model {model_name} does not support finetune")
return dict(
model_name=model_name,
model_info=dict(
weight_path=env.DIR_WEIGHTS,
repo_id='smallcloudai/codify_3b_multi',
ctx_size=2048,
repo_id=model_info['model_path'],
ctx_size=model_info['T'],
lora={
"lora_target_modules": [
"qkv",
Expand Down
80 changes: 34 additions & 46 deletions refact_data_pipeline/finetune/finetune_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
import refact_data_pipeline.finetune.traces as traces
from refact_data_pipeline import DatasetOpts, finetune_datasource
from refact_data_pipeline.datautils import BatchIterator
from refact_data_pipeline.finetune import finetune_filtering_defaults
from refact_data_pipeline.finetune.finetune_utils import get_finetune_config
from refact_data_pipeline.finetune.finetune_utils import get_finetune_filter_stats
from refact_data_pipeline.finetune.finetune_filtering_defaults import finetune_filtering_defaults
from refact_data_pipeline.finetune.finetune_config import base_config
from refact_data_pipeline.finetune.model_handling import make_model, masked_loss
from refact_data_pipeline.finetune.finetune_train import save_status_json
from refact_data_pipeline.finetune.process_uploaded_files import make_matcher
from self_hosting_machinery import env

from typing import List
from typing import List, Dict, Any


unfiltered_train = os.path.join(env.DIR_UNPACKED, "train_set.jsonl")
Expand All @@ -33,32 +34,16 @@
filtered_train = os.path.join(env.DIR_UNPACKED, "train_set_filtered.jsonl")
filtered_test = os.path.join(env.DIR_UNPACKED, "test_set_filtered.jsonl")

status_dict = {
"started_ts": time.time(),
"total_steps": 0,
"worked_steps": 0,
"worked_minutes": 0,
"eta_minutes": 0,
"status": "starting",
"accepted": 0,
"rejected": 0,
"avg_loss": 0.0
}


def _save_stats(status_string):
save_status_json(status_dict, status_string)

def _update_and_dump_status(status_dict: Dict[str, Any], status_string: str):
if status_string in ["starting"]:
status_dict = get_finetune_filter_stats(default=True)
status_dict["started_ts"] = time.time()
status_dict["status"] = status_string
with open(env.CONFIG_FINETUNE_FILTER_STATS + ".tmp", "w") as f:
json.dump(status_dict, f, indent=4)
os.rename(env.CONFIG_FINETUNE_FILTER_STATS + ".tmp", env.CONFIG_FINETUNE_FILTER_STATS)


def _try_load_stats():
global status_dict
if not os.path.exists(env.CONFIG_FINETUNE_FILTER_STATS):
return
with open(env.CONFIG_FINETUNE_FILTER_STATS, "r") as f:
status_dict = json.load(f)
return status_dict


def _file_accepted(reason, path):
Expand All @@ -71,12 +56,6 @@ def _file_rejected(reason, path):
f.write("%s %s\n" % (reason, path))


def catch_sigusr1(signum, frame):
status_dict["error"] = "interrupted"
_save_stats("interrupted")
sys.exit(1)


def get_force_included_excluded_matchers():
fcfg = {
"filetypes_finetune": {},
Expand All @@ -100,7 +79,8 @@ def loss_based_filter(
loss_function,
dataopts,
*,
fcfg
fcfg,
status_dict,
):
t0 = time.time()
iter_times = []
Expand All @@ -112,7 +92,7 @@ def loss_based_filter(
is_force_included, is_force_excluded = get_force_included_excluded_matchers()
for iter_n, file in enumerate(train_files):
t0_iter = time.time()
_save_stats("filtering")
status_dict = _update_and_dump_status(status_dict, "filtering")
file_losses = []
if is_force_included(file['path']):
_file_accepted("FILTER1 INCLUDED_BY_MASK", file["path"])
Expand Down Expand Up @@ -169,8 +149,10 @@ def loss_based_filter(
return rejected


def pre_filtering():
fcfg = {**finetune_filtering_defaults.finetune_filtering_defaults}
def pre_filtering(status_dict):
finetune_cfg = get_finetune_config(logger=traces.log)

fcfg = {**finetune_filtering_defaults}
if os.path.exists(env.CONFIG_HOW_TO_FILTER):
traces.log("Reading %s" % env.CONFIG_HOW_TO_FILTER)
fcfg.update(**json.load(open(env.CONFIG_HOW_TO_FILTER)))
Expand All @@ -184,7 +166,7 @@ def pre_filtering():
logging.info("Train set filtering, loading model...")
traces.log("Train set filtering, loading model...")
t0 = time.time()
cfg = base_config(env)
cfg = base_config(finetune_cfg["model_name"])
model = make_model(
weights_path=cfg['model_info']['weight_path'],
repo_id=cfg['model_info']['repo_id'],
Expand Down Expand Up @@ -225,7 +207,7 @@ def pre_filtering():
traces.log(textwrap.fill(text, width=100))

filtered = loss_based_filter(
train_files, model, loss_function, dataopts, fcfg=fcfg
train_files, model, loss_function, dataopts, fcfg=fcfg, status_dict=status_dict,
)

test_filenames = set()
Expand Down Expand Up @@ -267,35 +249,41 @@ def needs_any_work():
return any(has_updates)


def main():
def main(status_dict):
if not needs_any_work():
_try_load_stats()
_save_stats("finished")
_update_and_dump_status(status_dict, "finished")
logging.info("Train set filtering: nothing changed since last time, quit")
return

_save_stats("starting")
status_dict = _update_and_dump_status(status_dict, "starting")
with open(env.LOG_FILES_ACCEPTED_FTF, "w") as f:
f.write("")
with open(env.LOG_FILES_REJECTED_FTF, "w") as f:
f.write("")
try:
pre_filtering()
_save_stats("finished")
pre_filtering(status_dict)
_update_and_dump_status(status_dict, "finished")
except BaseException as e: # BaseException includes KeyboardInterrupt
if traces.context():
logging.error("FAILED finetune filter at %s" % traces.context().path)
if "error" not in status_dict: # if there is, a more detailed error is already in place
t = str(e) or str(type(e))
status_dict["error"] = t
logging.error(t)
_save_stats("failed")
_update_and_dump_status(status_dict, "failed")
if not isinstance(e, ValueError): # don't print stack for ValueError which is used for mundane data problems
raise e


if __name__ == "__main__":
YMD_hms = os.environ.get("LORA_LOGDIR", "") or time.strftime("lora-%Y%m%d-%H%M%S")
traces.configure(task_dir="loras", task_name=YMD_hms, work_dir=env.PERMDIR)
status_dict = get_finetune_filter_stats()

def catch_sigusr1(signum, frame):
status_dict["error"] = "interrupted"
_update_and_dump_status(status_dict, "interrupted")
sys.exit(1)

signal.signal(signal.SIGUSR1, catch_sigusr1)
main()
main(status_dict)
10 changes: 4 additions & 6 deletions refact_data_pipeline/finetune/finetune_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from refact_data_pipeline import DatasetOpts, finetune_datasource
from refact_data_pipeline.datautils import BatchIterator
from refact_data_pipeline.finetune.finetune_config import base_config, ConfigBuilder
from refact_data_pipeline.finetune.finetune_train_defaults import finetune_train_defaults
from refact_data_pipeline.finetune.finetune_utils import get_finetune_config
from refact_data_pipeline.finetune.model_handling import make_model, masked_loss, save_model_state
from self_hosting_machinery import env

Expand Down Expand Up @@ -67,11 +67,9 @@ def _get_ds_len_per_epoch(cfg_builder):

with open(env.CONFIG_FINETUNE_FILTER_STATS, 'r') as f:
initial_loss = json.load(f)["avg_loss"]
cfg_builder = ConfigBuilder(base_config(env))
user_cfg = {**finetune_train_defaults}
if os.path.exists(env.CONFIG_FINETUNE):
traces.log("Reading %s" % env.CONFIG_FINETUNE)
user_cfg.update(**json.load(open(env.CONFIG_FINETUNE)))

user_cfg = get_finetune_config(logger=traces.log)
cfg_builder = ConfigBuilder(base_config(user_cfg['model_name']))
if user_cfg['use_heuristics']:
traces.log("Retrieving dataset length per epoch, it may take a while...")
ds_len = _get_ds_len_per_epoch(cfg_builder)
Expand Down
135 changes: 135 additions & 0 deletions refact_data_pipeline/finetune/finetune_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os
import json
import time

from known_models_db.refact_known_models import models_mini_db
from refact_data_pipeline.finetune.finetune_train_defaults import finetune_train_defaults

from self_hosting_machinery import env

from typing import Any, Dict, Optional, Callable


default_finetune_model = "CONTRASTcode/3b/multi"


def get_run_model_name(run_dir: str) -> str:
config_json_fn = os.path.join(run_dir, "config.json")
if not os.path.isfile(config_json_fn):
raise RuntimeError("get run model name: no config.json found")
with open(config_json_fn) as f:
return json.load(f).get("model_name", default_finetune_model)


def get_finetune_runs():
res = []
anyone_works = False
if not os.path.isdir(env.DIR_LORAS):
return [], anyone_works
for dirname in sorted(os.listdir(env.DIR_LORAS)):
dir_path = os.path.join(env.DIR_LORAS, dirname)
if not os.path.isdir(dir_path):
continue
d = {
"run_id": dirname,
"worked_minutes": "0",
"worked_steps": "0",
"status": "unknown", # working, starting, completed, failed
}
try:
d["model_name"] = get_run_model_name(dir_path)
except RuntimeError:
continue
status_fn = os.path.join(dir_path, "status.json")
if os.path.exists(status_fn):
d.update(json.load(open(status_fn, "r")))
if d["status"] in ["working", "starting"]:
mtime = os.path.getmtime(status_fn)
if mtime + 300 < time.time():
d["status"] = "failed"
else:
anyone_works = True
d["checkpoints"] = []
checkpoints_dir = os.path.join(dir_path, "checkpoints")
if os.path.isdir(checkpoints_dir):
for checkpoint_dir in sorted(os.listdir(checkpoints_dir)):
checkpoint_path = os.path.join(checkpoints_dir, checkpoint_dir)
if not os.path.isdir(checkpoint_path):
continue
d["checkpoints"].append({
"checkpoint_name": checkpoint_dir,
})
res.append(d)
return res, anyone_works


def get_active_loras() -> Dict[str, Dict[str, Any]]:
active_loras = {}
if os.path.exists(env.CONFIG_ACTIVE_LORA):
active_loras = json.load(open(env.CONFIG_ACTIVE_LORA))
if "lora_mode" in active_loras: # NOTE: old config format
active_loras = {
default_finetune_model: active_loras,
}
return {
model_name: {
"lora_mode": "latest-best",
**active_loras.get(model_name, {}),
}
for model_name, model_info in models_mini_db.items()
if "finetune" in model_info["filter_caps"]
}


def get_finetune_config(logger: Optional[Callable] = None) -> Dict[str, Any]:
cfg = {
"model_name": default_finetune_model,
**finetune_train_defaults
}
if os.path.exists(env.CONFIG_FINETUNE):
if logger is not None:
logger("Reading %s" % env.CONFIG_FINETUNE)
cfg.update(**json.load(open(env.CONFIG_FINETUNE)))
return cfg


def get_finetune_filter_stats(default: bool = False) -> Dict[str, Any]:
filter_stats = {
"started_ts": 0,
"total_steps": 0,
"worked_steps": 0,
"worked_minutes": 0,
"eta_minutes": 0,
"accepted": 0,
"rejected": 0,
"avg_loss": 0.0,
"status": "idle",
}
if not default and os.path.isfile(env.CONFIG_FINETUNE_FILTER_STATS):
filter_stats.update(**json.load(open(env.CONFIG_FINETUNE_FILTER_STATS)))
return filter_stats


def get_finetune_step() -> Optional[str]:

def get_sources_stats():
scan_stats = {
"scan_status": "idle",
}
if os.path.isfile(env.CONFIG_PROCESSING_STATS):
scan_stats.update(**json.load(open(env.CONFIG_PROCESSING_STATS, "r")))
return scan_stats

if os.path.exists(env.FLAG_LAUNCH_PROCESS_UPLOADS) or \
get_sources_stats()["scan_status"] in ["working"]:
return "sources"

if os.path.exists(env.FLAG_LAUNCH_FINETUNE_FILTER_ONLY) or \
get_finetune_filter_stats()["status"] in ["starting", "filtering"]:
return "filter"

if os.path.exists(env.FLAG_LAUNCH_FINETUNE) or \
get_finetune_runs()[1]:
return "finetune"

return None
Loading