Skip to content

Commit

Permalink
fix problems
Browse files Browse the repository at this point in the history
  • Loading branch information
JegernOUTT committed Sep 13, 2023
1 parent 3f1ce88 commit de9da2b
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 6 deletions.
4 changes: 2 additions & 2 deletions known_models_db/refact_known_models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
"diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM",
"chat_scratchpad_class": None,
"model_class_kwargs": {},
"required_memory_mb": 15000,
"T": 4096,
"required_memory_mb": 18000,
"T": 2048,
"filter_caps": ["completion", "finetune"],
},
"wizardcoder/15b": {
Expand Down
16 changes: 16 additions & 0 deletions refact_data_pipeline/finetune/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
"lm_head": "lm_head",
"lora": "lora"
},
"tokenizer": {
"eot_idx": 50256,
"padding_idx": 48049,
"fim_prefix": None,
"fim_middle": None,
"fim_suffix": None,
"escape": 47171
},
"train_ds_pipeline": {
"ds_opts": "n_ctx={n_ctx},pack_at_most=10,shuffle_depth=3000",
"pipeline_name": "local_mix_plain_infill"
Expand All @@ -45,6 +53,14 @@
"lm_head": "lm_head",
"lora": "lora"
},
"tokenizer": {
"eot_idx": 50256,
"padding_idx": 48049,
"fim_prefix": None,
"fim_middle": None,
"fim_suffix": None,
"escape": 47171
},
"train_ds_pipeline": {
"ds_opts": "n_ctx={n_ctx},pack_at_most=10,shuffle_depth=3000",
"pipeline_name": "local_mix_plain_infill"
Expand Down
3 changes: 2 additions & 1 deletion refact_data_pipeline/finetune/finetune_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def loss_based_filter(
logging.info("STATUS filtering")
status_dict['total_steps'] = len(train_files)
is_force_included, is_force_excluded = get_force_included_excluded_matchers()
forward = partial(model_forward, model=model, low_gpu_mem_mode=False, backend=cfg['model_info']['backend'])
for iter_n, file in enumerate(train_files):
t0_iter = time.time()
status_dict = _update_and_dump_status(status_dict, "filtering")
Expand All @@ -107,7 +108,7 @@ def loss_based_filter(
continue

for batch, stats in batch_iter_fn(finetune_datasource.local_plain([file], dataopts)):
logits = model_forward(model, batch, low_gpu_mem_mode=False, backend=cfg['model_info']['backend'])
logits = forward(input=batch['input'])
loss = float(loss_function(
logits=logits.to(th.bfloat16), # more stable than float16 and takes much less memory than float32
labels=batch['labels'],
Expand Down
4 changes: 2 additions & 2 deletions refact_data_pipeline/finetune/model_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ def make_model(
elif backend == "transformers":
model = AutoModelForCausalLM.from_pretrained(
repo_id, cache_dir=weights_path,
device_map="auto", torch_dtype="auto",
device_map=init_device, torch_dtype=dtype,
trust_remote_code=True
).to(dtype)
)
model.encoding = encoding
_apply_model_modifiers(model, MODELS_CONFIGS[model_name]['train_model_modifiers'])
else:
Expand Down
2 changes: 1 addition & 1 deletion refact_models/codify_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def encoding(self):
def from_pretrained(cls, path: str, device: str = "cuda", repo_id: Optional[str] = None):
config = load_config(path, repo_id)
model = cls(config, device)
model = load_checkpoint(model, path, repo_id)
load_checkpoint(model, path, repo_id)
return model

def generate(self, *args, **kwargs):
Expand Down

0 comments on commit de9da2b

Please sign in to comment.