Skip to content

Commit

Permalink
Remove print statement styling (#504)
Browse files Browse the repository at this point in the history
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
  • Loading branch information
blessedcoolant and lstein authored Sep 11, 2022
1 parent 4951e66 commit b86a1de
Showing 1 changed file with 61 additions and 50 deletions.
111 changes: 61 additions & 50 deletions scripts/dream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,24 @@
# Just want to get the formatting look right for now.
output_cntr = 0


def main():
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()

if opt.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if opt.weights != 'model':
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
sys.exit(-1)

try:
models = OmegaConf.load(opt.config)
width = models[opt.model].width
height = models[opt.model].height
config = models[opt.model].config
models = OmegaConf.load(opt.config)
width = models[opt.model].width
height = models[opt.model].height
config = models[opt.model].config
weights = models[opt.model].weights
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
Expand All @@ -58,18 +59,18 @@ def main():
# additional parameters will be added (or overriden) during
# the user input loop
t2i = Generate(
width = width,
height = height,
sampler_name = opt.sampler_name,
weights = weights,
full_precision = opt.full_precision,
config = config,
grid = opt.grid,
width=width,
height=height,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
config=config,
grid=opt.grid,
# this is solely for recreating the prompt
seamless = opt.seamless,
embedding_path = opt.embedding_path,
device_type = opt.device,
ignore_ctrl_c = opt.infile is None,
seamless=opt.seamless,
embedding_path=opt.embedding_path,
device_type=opt.device,
ignore_ctrl_c=opt.infile is None,
)

# make sure the output directory exists
Expand Down Expand Up @@ -113,8 +114,8 @@ def main():

def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
"""prompt/read/execute loop"""
done = False
path_filter = re.compile(r'[<>:"/\\|?*]')
done = False
path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list()

# os.pathconf is not available on Windows
Expand All @@ -134,7 +135,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
except KeyboardInterrupt:
done = True
continue

# skip empty lines
if not command.strip():
continue
Expand Down Expand Up @@ -183,15 +184,17 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
if len(opt.prompt) == 0:
print('Try again with a prompt!')
continue
if opt.init_img is not None and re.match('^-\\d+$',opt.init_img): # retrieve previous value!
# retrieve previous value!
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
try:
opt.init_img = last_results[int(opt.init_img)][0]
print(f'>> Reusing previous image {opt.init_img}')
except IndexError:
print(f'>> No previous initial image at position {opt.init_img} found')
print(
f'>> No previous initial image at position {opt.init_img} found')
opt.init_img = None
continue

if opt.seed is not None and opt.seed < 0: # retrieve previous value!
try:
opt.seed = last_results[opt.seed][1]
Expand All @@ -201,12 +204,12 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
opt.seed = None
continue

do_grid = opt.grid or t2i.grid
do_grid = opt.grid or t2i.grid

if opt.with_variations is not None:
# shotgun parsing, woo
parts = []
broken = False # python doesn't have labeled loops...
broken = False # python doesn't have labeled loops...
for part in opt.with_variations.split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
Expand Down Expand Up @@ -241,7 +244,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))]
current_outdir = os.path.join(outdir, subdir)

print ('Writing files to directory: "' + current_outdir + '"')
print('Writing files to directory: "' + current_outdir + '"')

# make sure the output directory exists
if not os.path.exists(current_outdir):
Expand All @@ -253,9 +256,10 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
last_results = []
try:
file_writer = PngWriter(current_outdir)
prefix = file_writer.unique_prefix()
results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `do_grid`
prefix = file_writer.unique_prefix()
results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `do_grid`

def image_writer(image, seed, upscaled=False):
if do_grid:
grid_images[seed] = image
Expand All @@ -265,35 +269,41 @@ def image_writer(image, seed, upscaled=False):
else:
filename = f'{prefix}.{seed}.png'
if opt.variation_amount > 0:
iter_opt = argparse.Namespace(**vars(opt)) # copy
iter_opt = argparse.Namespace(**vars(opt)) # copy
this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None:
iter_opt.with_variations = this_variation
else:
iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0
normalized_prompt = PromptFormatter(t2i, iter_opt).normalize_prompt()
normalized_prompt = PromptFormatter(
t2i, iter_opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
elif opt.with_variations is not None:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{opt.seed}' # use the original seed - the per-iteration value is the last variation-seed
normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
# use the original seed - the per-iteration value is the last variation-seed
metadata_prompt = f'{normalized_prompt} -S{opt.seed}'
else:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{seed}'
path = file_writer.save_image_and_prompt_to_png(image, metadata_prompt, filename)
path = file_writer.save_image_and_prompt_to_png(
image, metadata_prompt, filename)
if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output
results.append([path, metadata_prompt])
last_results.append([path,seed])
last_results.append([path, seed])

t2i.prompt2image(image_callback=image_writer, **vars(opt))

if do_grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values()))
grid_img = make_grid(list(grid_images.values()))
first_seed = last_results[0][1]
filename = f'{prefix}.{first_seed}.png'
filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}'
path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
Expand All @@ -308,18 +318,16 @@ def image_writer(image, seed, upscaled=False):
print(e)
continue

print('\033[1mOutputs:\033[0m')
print('Outputs:')
log_path = os.path.join(current_outdir, 'dream_log.txt')
write_log_message(results, log_path)

print('goodbye!\033[0m')
print('goodbye!')


def get_next_command(infile=None) -> str: #command string
def get_next_command(infile=None) -> str: # command string
if infile is None:
print('\033[1m') # add some boldface
command = input('dream> ')
print('\033[0m',end='')
command = input('dream> ')
else:
command = infile.readline()
if not command:
Expand All @@ -329,6 +337,7 @@ def get_next_command(infile=None) -> str: #command string
print(f'#{command}')
return command


def dream_server_loop(t2i, host, port, outdir):
print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory
Expand All @@ -342,7 +351,8 @@ def dream_server_loop(t2i, host, port, outdir):
dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!")
if host == '0.0.0.0':
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
print(
f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
else:
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
print(f">> Point your browser at http://{host}:{port}.")
Expand All @@ -361,13 +371,13 @@ def write_log_message(results, log_path):
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
for l in log_lines:
output_cntr += 1
print(f'\033[1m[{output_cntr}]\033[0m {l}',end='')
print(output_cntr)

with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines)


SAMPLER_CHOICES=[
SAMPLER_CHOICES = [
'ddim',
'k_dpm_2_a',
'k_dpm_2',
Expand All @@ -378,6 +388,7 @@ def write_log_message(results, log_path):
'plms',
]


def create_argv_parser():
parser = argparse.ArgumentParser(
description="""Generate images using Stable Diffusion.
Expand Down Expand Up @@ -518,8 +529,8 @@ def create_argv_parser():
)
parser.add_argument(
'--config',
default ='configs/models.yaml',
help ='Path to configuration file for alternate models.',
default='configs/models.yaml',
help='Path to configuration file for alternate models.',
)
return parser

Expand Down

0 comments on commit b86a1de

Please sign in to comment.