Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to SD pipeline and abstracted components for saving output .npys #792

Merged
merged 4 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,21 @@ class PipelineComponent:
"""

def __init__(
self, printer, dest_type="devicearray", dest_dtype="float16", benchmark=False
self,
printer,
dest_type="devicearray",
dest_dtype="float16",
benchmark=False,
save_outputs=False,
):
self.runner = None
self.module_name = None
self.device = None
self.metadata = None
self.printer = printer
self.benchmark = benchmark
self.save_outputs = save_outputs
self.output_counter = 0
self.dest_type = dest_type
self.dest_dtype = dest_dtype

Expand Down Expand Up @@ -218,6 +225,16 @@ def _output_cast(self, output):
case _:
return output

def save_output(self, function_name, output):
if isinstance(output, tuple) or isinstance(output, list):
for i in output:
self.save_output(function_name, i)
else:
np.save(
f"{function_name}_output_{self.output_counter}.npy", output.to_host()
)
self.output_counter += 1

def _run(self, function_name, inputs: list):
return self.module[function_name](*inputs)

Expand All @@ -239,6 +256,8 @@ def __call__(self, function_name, inputs: list):
output = self._run_and_benchmark(function_name, inputs)
else:
output = self._run(function_name, inputs)
if self.save_outputs:
self.save_output(function_name, output)
output = self._output_cast(output)
return output

Expand Down Expand Up @@ -340,6 +359,7 @@ def __init__(
hf_model_name: str | dict[str] = None,
benchmark: bool | dict[bool] = False,
verbose: bool = False,
save_outputs: bool | dict[bool] = False,
common_export_args: dict = {},
):
self.map = model_map
Expand Down Expand Up @@ -374,6 +394,7 @@ def __init__(
"external_weights": external_weights,
"hf_model_name": hf_model_name,
"benchmark": benchmark,
"save_outputs": save_outputs,
}
for arg in map_arguments.keys():
self.map = merge_arg_into_map(self.map, map_arguments[arg], arg)
Expand All @@ -391,7 +412,7 @@ def __init__(
)
for submodel in self.map.keys():
for key, value in map_arguments.items():
if key != "benchmark":
if key not in ["benchmark", "save_outputs"]:
self.map = merge_export_arg(self.map, value, key)
for key, value in self.map[submodel].get("export_args", {}).items():
if key == "hf_model_name":
Expand Down Expand Up @@ -744,6 +765,7 @@ def load_submodel(self, submodel):
printer=self.printer,
dest_type=dest_type,
benchmark=self.map[submodel].get("benchmark", False),
save_outputs=self.map[submodel].get("save_outputs", False),
)
self.map[submodel]["runner"].load(
self.map[submodel]["driver"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def is_valid_file(arg):
help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.",
)

p.add_argument(
"--save_outputs",
type=str,
default=None,
help="A comma-separated list of submodel IDs for which to save output .npys for, or 'all' for all components.",
)
##############################################################################
# SDXL Modelling Options
# These options are used to control model defining parameters for SDXL.
Expand Down
13 changes: 12 additions & 1 deletion models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def __init__(
batch_prompts: bool = False,
punet_quant_paths: dict[str] = None,
vae_weight_path: str = None,
vae_harness: bool = False,
vae_harness: bool = True,
add_tk_kernels: bool = False,
save_outputs: bool | dict[bool] = False,
):
common_export_args = {
"hf_model_name": None,
Expand Down Expand Up @@ -286,6 +287,7 @@ def __init__(
hf_model_name,
benchmark,
verbose,
save_outputs,
common_export_args,
)
for submodel in sd_model_map:
Expand Down Expand Up @@ -742,6 +744,14 @@ def numpy_to_pil_image(images):
benchmark[i] = True
else:
benchmark = False
if args.save_outputs:
if args.save_outputs.lower() == "all":
save_outputs = True
else:
for i in args.save_outputs.split(","):
save_outputs[i] = True
else:
save_outputs = False
if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]):
args.decomp_attn = {
"text_encoder": args.decomp_attn,
Expand Down Expand Up @@ -772,6 +782,7 @@ def numpy_to_pil_image(images):
args.use_i8_punet,
benchmark,
args.verbose,
save_outputs=save_outputs,
)
sd_pipe.prepare_all()
sd_pipe.load_map()
Expand Down
Loading