From 77d2dc8202bfa972fc67a1a1847a01e66d0d09de Mon Sep 17 00:00:00 2001 From: edbeeching Date: Fri, 12 Apr 2024 13:29:40 +0000 Subject: [PATCH 1/2] adds gen prompt to template and processor to hub --- examples/scripts/vsft_llava.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/scripts/vsft_llava.py b/examples/scripts/vsft_llava.py index 2c4c8ed458..8b9110cc55 100644 --- a/examples/scripts/vsft_llava.py +++ b/examples/scripts/vsft_llava.py @@ -66,6 +66,7 @@ import logging import os from contextlib import nullcontext +from pathlib import Path TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) @@ -79,6 +80,7 @@ from rich.logging import RichHandler import torch +from accelerate import Accelerator from datasets import load_dataset from tqdm.rich import tqdm @@ -111,7 +113,7 @@ ################ # Model, Tokenizer & Processor ################ - LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}""" + LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" torch_dtype = ( model_config.torch_dtype @@ -205,3 +207,5 @@ def __call__(self, examples): with save_context: trainer.save_model(training_args.output_dir) trainer.push_to_hub() + if Accelerator().is_main_process: + processor.push_to_hub(Path(training_args.output_dir).name) From 6010b30be23be2deace81c46e04d2d1378acbe06 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Fri, 12 Apr 2024 14:04:17 +0000 Subject: [PATCH 2/2] fixes hub model id, removes Path --- examples/scripts/vsft_llava.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/scripts/vsft_llava.py b/examples/scripts/vsft_llava.py index 8b9110cc55..64ef239d61 100644 --- a/examples/scripts/vsft_llava.py +++ b/examples/scripts/vsft_llava.py @@ -66,7 +66,6 @@ import logging import os from contextlib import nullcontext -from pathlib import Path TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) @@ -208,4 +207,4 @@ def __call__(self, examples): trainer.save_model(training_args.output_dir) trainer.push_to_hub() if Accelerator().is_main_process: - processor.push_to_hub(Path(training_args.output_dir).name) + processor.push_to_hub(training_args.hub_model_id)