Skip to content

Commit

Permalink
seed fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 2, 2024
1 parent e514910 commit 5960045
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
15 changes: 9 additions & 6 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,10 @@ def shark_sd_fn(
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)

generated_imgs = []
if seed in [-1, "-1"]:
seed = randint(0, 4294967295)
if submit_run_kwargs["seed"] in [-1, "-1"]:
submit_run_kwargs["seed"] = randint(0, 4294967295)
seed_increment = "random"
print(f"\n[LOG] Random seed: {seed}")
#print(f"\n[LOG] Random seed: {seed}")
progress(None, desc=f"Generating...")

for current_batch in range(batch_count):
Expand All @@ -483,20 +483,23 @@ def shark_sd_fn(
sd_kwargs,
)
generated_imgs.extend(out_imgs)
seed = get_next_seed(seed, seed_increment)

yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
if batch_count > 1:
submit_run_kwargs["seed"] = get_next_seed(seed, seed_increment)

return (generated_imgs, "")


def get_next_seed(seed, seed_increment: str | int = 10):
if isinstance(seed_increment, int):
print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
#print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
return int(seed + seed_increment)
elif seed_increment == "random":
seed = randint(0, 4294967295)
print(f"\n[LOG] Random seed: {seed}")
#print(f"\n[LOG] Random seed: {seed}")
return seed


Expand Down
2 changes: 1 addition & 1 deletion apps/shark_studio/modules/shared_cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def is_valid_file(arg):
p.add_argument(
"--batch_count",
type=int,
default=4,
default=1,
help="Number of batches to be generated with random seeds in " "single execution.",
)

Expand Down
6 changes: 4 additions & 2 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def base_model_changed(base_model_id):
elif ".py" in base_model_id:
new_steps = gr.Dropdown(
value=20,
choices=[10, 15, 20, 28],
choices=[10, 15, 20],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
Expand Down Expand Up @@ -462,7 +462,7 @@ def base_model_changed(base_model_id):
)
guidance_scale = gr.Slider(
0,
50,
5, #DEMO
value=cmd_opts.guidance_scale,
step=0.1,
label="\U0001F5C3\U0000FE0F CFG Scale",
Expand Down Expand Up @@ -636,6 +636,7 @@ def base_model_changed(base_model_id):
step=1,
label="Batch Count",
interactive=True,
visible=True,
)
batch_size = gr.Slider(
1,
Expand All @@ -649,6 +650,7 @@ def base_model_changed(base_model_id):
compiled_pipeline = gr.Checkbox(
True,
label="Faster txt2img (SDXL only)",
visible=False, # DEMO
)
with gr.Row():
stable_diffusion = gr.Button("Start")
Expand Down

0 comments on commit 5960045

Please sign in to comment.