Skip to content

Commit

Permalink
working version
Browse files Browse the repository at this point in the history
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
  • Loading branch information
samhita-alla committed Sep 17, 2024
1 parent 3b2ccac commit f11133f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 26 deletions.
1 change: 1 addition & 0 deletions _blogs/ollama/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ torch
trl
union
flytekitplugins-kfpytorch
flytekitplugins-inference>=1.13.6b0
5 changes: 1 addition & 4 deletions _blogs/ollama/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def load_model(train_args):
)

model = AutoModelForCausalLM.from_pretrained(train_args.model, **model_kwargs)
model.save_pretrained(train_args.output_dir)
return model


Expand Down Expand Up @@ -112,8 +113,4 @@ def save_model(trainer, train_args):

trainer.train()
trainer.save_model(train_args.adapter_dir)

model = PeftModel.from_pretrained(trainer.model, train_args.adapter_dir)
merged_model = model.merge_and_unload()
merged_model.save_pretrained(train_args.output_dir)
trainer.tokenizer.save_pretrained(train_args.output_dir)
8 changes: 3 additions & 5 deletions _blogs/ollama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
name="gguf-ollama",
apt_packages=["git"],
registry=REGISTRY,
packages=["huggingface_hub"],
packages=["huggingface_hub", "flytekitplugins-inference>=1.13.6b0"],
python_version="3.11",
).with_commands(
[
"git clone --branch b3046 https://github.com/ggerganov/llama.cpp /root/llama.cpp",
"git clone --branch master https://github.com/ggerganov/llama.cpp /root/llama.cpp",
"pip install -r /root/llama.cpp/requirements.txt",
]
)
Expand All @@ -30,9 +30,7 @@
name="phi3-ollama-serve",
registry=REGISTRY,
apt_packages=["git"],
packages=[
"git+https://github.com/flyteorg/flytekit.git@bcc13f799da3ce28e81d6060fc5776b6b4bca0a0#subdirectory=plugins/flytekit-inference"
],
packages=["flytekitplugins-inference>=1.13.6b0"],
)


Expand Down
39 changes: 22 additions & 17 deletions _blogs/ollama/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@
from flytekit.types.file import FlyteFile
from flytekitplugins.inference import Model, Ollama

from ollama.train import (
create_trainer,
initialize_tokenizer,
load_model,
prepare_dataset,
save_model,
)
from ollama.utils import (
PEFTConfig,
TrainingConfig,
Expand All @@ -27,9 +20,10 @@
model=Model(
name="phi3-pubmed",
modelfile='''
FROM {inputs.gguf}
FROM phi3:mini-4k
ADAPTER {inputs.gguf}
TEMPLATE """{{ if .System }}<|system|>
TEMPLATE """{{ if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|>
Expand Down Expand Up @@ -76,22 +70,30 @@ def create_dataset(queries: list[str], top_n: int) -> FlyteDirectory:

@task(
cache=True,
cache_version="0.1",
cache_version="0.4",
container_image=image_spec,
accelerator=T4,
requests=Resources(mem="10Gi", cpu="2", gpu="1"),
environment={"TOKENIZERS_PARALLELISM": "false"},
)
def phi3_finetune(
train_args: TrainingConfig, peft_args: PEFTConfig, dataset_dir: FlyteDirectory
) -> FlyteDirectory:
) -> tuple[FlyteDirectory, FlyteDirectory]:
from ollama.train import (
create_trainer,
initialize_tokenizer,
load_model,
prepare_dataset,
save_model,
)

model = load_model(train_args)
tokenizer = initialize_tokenizer(train_args.model)
dataset_splits = prepare_dataset(dataset_dir, train_args, tokenizer)
trainer = create_trainer(model, train_args, peft_args, dataset_splits, tokenizer)
save_model(trainer, train_args)

return FlyteDirectory(train_args.output_dir)
return FlyteDirectory(train_args.adapter_dir), FlyteDirectory(train_args.output_dir)


@task(
Expand All @@ -101,19 +103,22 @@ def phi3_finetune(
requests=Resources(mem="5Gi", cpu="2", gpu="1"),
accelerator=T4,
)
def hf_to_gguf(model_dir: FlyteDirectory) -> FlyteFile:
def hf_to_gguf(adapter_dir: FlyteDirectory, model_dir: FlyteDirectory) -> FlyteFile:
adapter_dir.download()
model_dir.download()
output_dir = Path(current_context().working_directory)

subprocess.run(
[
sys.executable,
"/root/llama.cpp/convert-hf-to-gguf.py",
"/root/llama.cpp/convert_lora_to_gguf.py",
adapter_dir.path,
"--base",
model_dir.path,
"--outfile",
str(output_dir / "model.gguf"),
"--outtype",
"q8_0", # quantize the model to 8-bit float representation
"q8_0", # quantize the model to 8-bit float representation
],
check=True,
)
Expand Down Expand Up @@ -163,10 +168,10 @@ def phi3_ollama(
],
) -> list[str]:
dataset_dir = create_dataset(queries=queries, top_n=top_n)
model_dir = phi3_finetune(
adapter_dir, model_dir = phi3_finetune(
train_args=train_args, peft_args=peft_args, dataset_dir=dataset_dir
)
gguf_file = hf_to_gguf(model_dir=model_dir)
gguf_file = hf_to_gguf(adapter_dir=adapter_dir, model_dir=model_dir)
return model_serving(
questions=model_queries,
gguf=gguf_file,
Expand Down

0 comments on commit f11133f

Please sign in to comment.