Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
Refactoring model generator (#77)
Browse files Browse the repository at this point in the history
* refactoring generator

* add batch size

* fine-tuning codes
  • Loading branch information
mikecovlee authored Jul 18, 2024
1 parent a644da4 commit 39cb1f4
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 93 deletions.
3 changes: 0 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def main(
prompts=[(instruction, input)],
)

if max_seq_len is None:
max_seq_len = model.config_.max_seq_len_

output = mlora.generate(
model,
tokenizer,
Expand Down
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def generate_with_streaming(**kwargs):
minimum=1,
maximum=model.config_.max_seq_len_,
step=1,
value=128,
value=1024,
label="Max Tokens",
),
gr.components.Checkbox(label="Stream Output", value=True),
Expand Down
20 changes: 14 additions & 6 deletions mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,25 @@ def inference_callback(cur_pos, outputs):


def inference(
llm_model: mlora.LLMModel,
model: mlora.LLMModel,
tokenizer: mlora.Tokenizer,
adapters: List[mlora.GenerateConfig],
configs: List[mlora.GenerateConfig],
concurrent_jobs: int,
):
while True:
input_raw = input("INPUT WITHOUT PROMPT: ")
if input_raw == "QUIT":
return
for config in adapters:
for config in configs:
config.prompts = [input_raw]
callback = None if args.disable_log else inference_callback
outputs = mlora.generate(
llm_model,
model,
tokenizer,
adapters,
configs,
max_gen_len=128,
use_cache=args.disable_cache,
concurrent_jobs=concurrent_jobs,
cache_implementation=args.cache_implementation,
stream_callback=callback,
)
Expand Down Expand Up @@ -298,7 +301,12 @@ def inference(
mlora_backend.empty_cache()

if args.inference:
inference(model, tokenizer, adapters)
inference(
model=model,
tokenizer=tokenizer,
configs=adapters,
concurrent_jobs=config.get("inference_lora_simultaneously_num", 2),
)
elif args.evaluate:
mlora.evaluate(
model=model,
Expand Down
4 changes: 0 additions & 4 deletions mlora/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ def _dispatch_task_in(tokenizer, configs, concurrent_jobs, max_seq_len):
if len(tokens) > max_seq_len:
tokens = tokens[:max_seq_len]
max_tokens_len = max(len(tokens), max_tokens_len)
# sequence_lengths.append(len(tokens))
# while len(tokens) < max_seq_len:
# tokens.append(tokenizer.pad_id_)
batch_tokens.append(tokens)
# atten_masks.append(tokenizer.mask_from(tokens))
batch_labels.append(labels.copy())

config.batch_start_idx_ = config.batch_end_idx_
Expand Down
Loading

0 comments on commit 39cb1f4

Please sign in to comment.