Skip to content

Commit

Permalink
Merge pull request #1049 from bghira/feature/user-prompt-library-gene…
Browse files Browse the repository at this point in the history
…rator

add prompt expander based on 1B Llama model
  • Loading branch information
bghira authored Oct 12, 2024
2 parents 086b993 + 1444e83 commit bb78898
Show file tree
Hide file tree
Showing 2 changed files with 314 additions and 1 deletion.
39 changes: 38 additions & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,43 @@ def configure_env():
},
]

# Let's offer to generate a prompt library for the user. Preserve their existing one if it already exists.
should_generate_by_default = "n"
if not os.path.exists("config/user_prompt_library.json"):
should_generate_by_default = "y"
should_generate_prompt_library = (
prompt_user(
(
"Would you like to generate a very rudimentary subject-centric prompt library for your dataset?"
" This will download a small 1B Llama 3.2 model."
" If a user prompt library exists, it will be overwritten. (y/n)"
),
should_generate_by_default,
).lower()
== "y"
)
if should_generate_prompt_library:
try:
user_caption_trigger = prompt_user(
"Enter a trigger word (or a few words) that you would like Llama 3.2 1B to expand.",
"Character Name",
)
number_of_prompts = int(
prompt_user("How many prompts would you like to generate?", 8)
)
from helpers.prompt_expander import PromptExpander

PromptExpander.initialize_model()
user_prompt_library = PromptExpander.generate_prompts(
trigger_phrase=user_caption_trigger, num_prompts=number_of_prompts
)
with open("config/user_prompt_library.json", "w", encoding="utf-8") as f:
f.write(json.dumps(user_prompt_library, indent=4))
print("Prompt library generated successfully!")
env_contents["--user_prompt_library"] = "config/user_prompt_library.json"
except Exception as e:
print(f"(warning) Failed to generate prompt library: {e}")

# now we ask user the path to their data, the path to the cache (cache/), number of repeats, update the id placeholder based on users dataset name
# then we'll write the file to multidatabackend.json
should_configure_dataloader = (
Expand Down Expand Up @@ -807,7 +844,7 @@ def configure_env():
if "instanceprompt" in dataset_caption_strategy:
dataset_instance_prompt = prompt_user(
"Enter the instance_prompt you want to use for all images in this dataset",
"CatchPhrase",
"Character Name",
)
dataset_repeats = int(
prompt_user(
Expand Down
276 changes: 276 additions & 0 deletions helpers/prompt_expander/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
import torch
import random
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


class PromptExpander:
# Class variables to hold the model, tokenizer, and generator
model = None
tokenizer = None
generator = None

@staticmethod
def initialize_model(model_path="meta-llama/Llama-3.2-1B-Instruct"):
"""
Initializes the language model, tokenizer, and text generation pipeline.
"""
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
PromptExpander.model = AutoModelForCausalLM.from_pretrained(model_path).to(
device
)
PromptExpander.tokenizer = AutoTokenizer.from_pretrained(model_path)
PromptExpander.generator = pipeline(
"text-generation",
model=PromptExpander.model,
tokenizer=PromptExpander.tokenizer,
device=0 if device == "cuda" else -1,
)

@staticmethod
def generate_prompts(trigger_phrase, num_prompts=25):
"""
Generates expanded prompts based on the provided trigger phrase.
Args:
trigger_phrase (str): The trigger phrase to include in the prompts.
num_prompts (int): The number of prompts to generate.
"""
# Check if the model is initialized
if PromptExpander.generator is None:
print("Model not initialized. Please call initialize_model() first.")
return

# Define the list of prompt templates and styles
prompt_templates = [
"An image of TRIGGERPHRASE in the style of {style}.",
"A painting of {style}, featuring TRIGGERPHRASE.",
"TRIGGERPHRASE as depicted in {style} art.",
"{style} illustration showing TRIGGERPHRASE.",
"An abstract representation of TRIGGERPHRASE in {style}.",
"A realistic portrayal of TRIGGERPHRASE, inspired by {style}.",
"An artistic depiction of TRIGGERPHRASE using {style} techniques.",
"A {style} photograph of TRIGGERPHRASE.",
"TRIGGERPHRASE captured in the style of {style}.",
"An image featuring TRIGGERPHRASE with {style} elements.",
"A {style}-inspired scene with TRIGGERPHRASE.",
"TRIGGERPHRASE illustrated in {style} style.",
"An artistic rendering of TRIGGERPHRASE in {style} fashion.",
"A {style}-themed artwork of TRIGGERPHRASE.",
"TRIGGERPHRASE portrayed in {style} aesthetics.",
"{style} art featuring TRIGGERPHRASE.",
"An expressive {style} depiction of TRIGGERPHRASE.",
"A creative {style} representation of TRIGGERPHRASE.",
"TRIGGERPHRASE shown in a {style} setting.",
"An imaginative {style} image of TRIGGERPHRASE.",
"A {style} design including TRIGGERPHRASE.",
"TRIGGERPHRASE in a {style} composition.",
"An illustration of TRIGGERPHRASE with {style} influence.",
"A {style}-inspired portrait of TRIGGERPHRASE.",
"An image where TRIGGERPHRASE meets {style}.",
"TRIGGERPHRASE blended into a {style} background.",
"A {style} collage featuring TRIGGERPHRASE.",
"An artistic scene of TRIGGERPHRASE in {style} mood.",
"A {style} depiction of TRIGGERPHRASE in motion.",
"TRIGGERPHRASE rendered in {style} tones.",
"An atmospheric {style} image of TRIGGERPHRASE.",
"An expressive portrait of TRIGGERPHRASE in {style} style.",
"A surreal {style} painting of TRIGGERPHRASE.",
"TRIGGERPHRASE integrated into a {style} landscape.",
"An abstract {style} representation of TRIGGERPHRASE.",
"A {style} sketch of TRIGGERPHRASE.",
"TRIGGERPHRASE depicted in {style} illustration.",
"An image of TRIGGERPHRASE with {style} patterns.",
"A {style} poster featuring TRIGGERPHRASE.",
"An iconic {style} image of TRIGGERPHRASE.",
"TRIGGERPHRASE in a {style} artwork.",
"A vibrant {style} depiction of TRIGGERPHRASE.",
"An ethereal {style} image of TRIGGERPHRASE.",
"A dynamic {style} scene with TRIGGERPHRASE.",
"TRIGGERPHRASE portrayed through {style} art.",
"A {style} mural of TRIGGERPHRASE.",
"An imaginative {style} illustration of TRIGGERPHRASE.",
"TRIGGERPHRASE set in a {style} environment.",
"A {style}-inspired depiction of TRIGGERPHRASE.",
"An image of TRIGGERPHRASE with {style} motifs.",
]

styles = [
"Impressionism",
"Cubism",
"Surrealism",
"Pop Art",
"Futurism",
"Baroque",
"Gothic",
"Abstract Expressionism",
"Renaissance",
"Minimalism",
"Digital Art",
"Vintage Photography",
"Sci-Fi",
"Fantasy",
"Steampunk",
"Cyberpunk",
"Art Deco",
"Graffiti",
"Watercolor",
"Oil Painting",
"Black and White",
"Colorful",
"Retro",
"Comic Book",
"Manga",
"3D Rendering",
"Low Poly",
"Pixel Art",
"Line Art",
"Flat Design",
"Concept Art",
"Photorealism",
"High Contrast",
"Monochrome",
"Collage",
"Typography",
"Street Art",
"Ukiyo-e",
"Pop Surrealism",
"Digital Illustration",
"Neon",
"Expressionism",
"Anime",
"Realism",
"Dadaism",
"Constructivism",
"Avant-Garde",
"Hyperrealism",
"Symbolism",
"Fauvism",
]

used_templates = []
used_styles = []
user_prompt_library = {}
idx = 0

for _ in range(num_prompts):
idx += 1
# Randomly select a prompt template and style
prompt_template = None
style = None
while (prompt_template is None or prompt_template in used_templates) or (
style is None or style in used_styles
):
prompt_template = random.choice(prompt_templates)
style = random.choice(styles)

used_templates.append(prompt_template)
used_styles.append(style)

# Replace placeholders in the template
prompt = prompt_template.replace("{style}", style.lower())
prompt = prompt.replace("TRIGGERPHRASE", trigger_phrase)

# Generate the text
input_prompt = (
"You are a text-to-text interface that returns improved prompts. "
"The captions should be expanded to be more descriptive. "
"Captions look like sentence fragments and tags.\n\n"
f"WITHOUT CHANGING ANY SPELLINGS: Clean this prompt, and return NOTHING but the upgraded prompt: {prompt}"
)

import time

begin = time.time()
output_prompt = None

def refused(output_prompt):
triggers = [
"i cannot",
"can you",
"what kind",
"i can't",
"i won't",
"am unable to",
"here is the",
"here's the",
"improving the caption",
"improving the prompt",
"should ",
"improving the text",
"improving the sentence",
"improving the description",
"remember:",
"please describe",
"please",
"improving the tag",
"improving the fragment",
"i am unable to",
"i am not able to",
"i am not capable of",
"i am not capable to",
"i am not capable",
"i am not able",
]
if output_prompt is None:
return None
output_prompt_lower = output_prompt.lower()
for trigger in triggers:
if trigger in output_prompt_lower:
return trigger
return None

refused_term = None
while output_prompt is None or refused_term is not None:
if output_prompt is not None and refused_term:
print(
f"-> (REFUSAL) Prompt contains '{refused_term}'. Generating a new prompt.",
end="\n\n",
)
generated_caption = PromptExpander.generator(
input_prompt,
temperature=0.4,
max_new_tokens=77,
return_full_text=False,
)

end = time.time()
output_prompt = (
generated_caption[0]["generated_text"]
.strip()
.lower()
.replace("create a ", "")
.replace("explore the", "")
.replace("imagine a", "")
.replace("find a", "")
.replace('"', "")
)
refused_term = refused(output_prompt)

time_taken = end - begin
print("Prompt expanded in", round(time_taken, 2), "seconds:")
print(f"Original prompt: {prompt}")
print(f"Expanded prompt: {output_prompt}", end="\n\n")
user_prompt_library[f"prompt_{idx}"] = output_prompt

# Output the generated prompts as JSON
return user_prompt_library


if __name__ == "__main__":
# Example usage:
# Initialize the model first
PromptExpander.initialize_model()

# Generate prompts with your trigger phrase
user_prompt_library = PromptExpander.generate_prompts(
trigger_phrase="your_trigger_phrase_here", num_prompts=25
)

# Write to disk
with open("config/user_prompt_library.json", "w") as f:
json.dump(user_prompt_library, f, indent=4)

0 comments on commit bb78898

Please sign in to comment.