Skip to content

Commit

Permalink
Merge pull request #1 from DKnight54/DKnight54-patch-1
Browse files Browse the repository at this point in the history
simplify per process prompt
  • Loading branch information
DKnight54 authored Jan 24, 2024
2 parents 51abf37 + fad055b commit 09bb026
Showing 1 changed file with 12 additions and 25 deletions.
37 changes: 12 additions & 25 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4749,35 +4749,22 @@ def sample_images_common(
vae.to(org_vae_device)

def generate_per_device_prompt_list(prompts, num_of_processes, default_sampler, prompt_replacement=None):
temp_prompts = []
for i, prompt_dict in enumerate(prompts):
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
temp_dict: dict = {}
temp_dict["negative_prompt"] = prompt_dict.get("negative_prompt")
temp_dict["sample_steps"] = prompt_dict.get("sample_steps", 30)
temp_dict["width"] = prompt_dict.get("width", 512)
temp_dict["height"] = prompt_dict.get("height", 512)
temp_dict["scale"] = prompt_dict.get("scale", 7.5)
temp_dict["seed"] = prompt_dict.get("seed")
temp_dict["controlnet_image"] = prompt_dict.get("controlnet_image")
temp_dict["prompt"]: str = prompt_dict.get("prompt", "")
temp_dict["sample_sampler"]: str = prompt_dict.get("sample_sampler", default_sampler)
temp_dict["enum"] = i
# Refactor prompt replacement to here in order to simplify sample_image_inference function.
if prompt_replacement is not None:
temp_dict["prompt"] = temp_dict["prompt"].replace(prompt_replacement[0], prompt_replacement[1])
if temp_dict["negative_prompt"] is not None:
temp_dict["negative_prompt"] = temp_dict["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1])
temp_prompts.append(temp_dict)
prompts = temp_prompts


# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [[] for i in range(num_of_processes)]
for i, prompt in enumerate(prompts):
if isinstance(prompt, str):
prompt = line_to_prompt_dict(prompt)
assert isinstance(prompt, dict)
prompt.pop("subset", None) # Clean up subset key
prompt["enum"] = i
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
if prompt_replacement is not None:
prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1])
if prompt["negative_prompt"] is not None:
prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1])
# Refactor prompt replacement to here in order to simplify sample_image_inference function.
per_process_prompts[i % num_of_processes].append(prompt)
return per_process_prompts

Expand Down

0 comments on commit 09bb026

Please sign in to comment.