diff --git a/examples/scripts/vsft_llava.py b/examples/scripts/vsft_llava.py index 2c4c8ed458..64ef239d61 100644 --- a/examples/scripts/vsft_llava.py +++ b/examples/scripts/vsft_llava.py @@ -79,6 +79,7 @@ from rich.logging import RichHandler import torch +from accelerate import Accelerator from datasets import load_dataset from tqdm.rich import tqdm @@ -111,7 +112,7 @@ ################ # Model, Tokenizer & Processor ################ - LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}""" + LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" torch_dtype = ( model_config.torch_dtype @@ -205,3 +206,5 @@ def __call__(self, examples): with save_context: trainer.save_model(training_args.output_dir) trainer.push_to_hub() + if Accelerator().is_main_process: + processor.push_to_hub(training_args.hub_model_id)