Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nexfort and pixart alpha #882

Merged
merged 15 commits into from
May 21, 2024
13 changes: 13 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,16 @@ docker run -it --rm --gpus all --shm-size 12g --ipc=host --security-opt seccomp=
onediff:benchmark-community-default \
sh -c "cd /benchmark && sh run_all_benchmarks.sh -m models -o benchmark.md"
```

## Run Examples
### Run pixart alpha (with nexfort backend)
```
# model_id_or_path_to_PixArt-XL-2-1024-MS: /data/hf_models/PixArt-XL-2-1024-MS/
python3 text_to_image.py --model model_id_or_path_to_PixArt-XL-2-1024-MS --scheduler none --compiler nexfort
```
Performance on NVIDIA A100-PCIE-40GB:
Iterations per second of progress bar: 11.7
Inference time: 2.045s
Iterations per second: 10.517
CUDA Mem after: 13.569GiB

47 changes: 35 additions & 12 deletions benchmarks/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
CONTROLNET = None
STEPS = 30
PROMPT = "best quality, realistic, unreal engine, 4K, a beautiful girl"
NEGATIVE_PROMPT = None
NEGATIVE_PROMPT = ""
SEED = None
WARMUPS = 3
BATCH = 1
Expand All @@ -19,6 +19,8 @@
CACHE_INTERVAL = 3
CACHE_LAYER_ID = 0
CACHE_BLOCK_ID = 0
COMPILER = "oneflow"
COMPILER_CONFIG = None

import os
import importlib
Expand All @@ -30,7 +32,7 @@
from PIL import Image, ImageDraw
from diffusers.utils import load_image

from onediffx import compile_pipe
from onediffx import compile_pipe, CompileOptions


def parse_args():
Expand Down Expand Up @@ -60,16 +62,23 @@ def parse_args():
parser.add_argument(
"--compiler",
type=str,
default="oneflow",
default=COMPILER,
choices=["none", "oneflow", "nexfort", "compile", "compile-max-autotune"],
)
parser.add_argument(
"--compiler-config",
type=str,
default=COMPILER_CONFIG,
)
return parser.parse_args()


def load_pipe(
pipeline_cls,
model_name,
variant=None,
dtype=torch.float16,
device="cuda",
custom_pipeline=None,
scheduler=None,
lora=None,
Expand All @@ -80,31 +89,34 @@ def load_pipe(
extra_kwargs["custom_pipeline"] = custom_pipeline
if variant is not None:
extra_kwargs["variant"] = variant
if dtype is not None:
extra_kwargs["torch_dtype"] = dtype
if controlnet is not None:
from diffusers import ControlNetModel

controlnet = ControlNetModel.from_pretrained(
controlnet, torch_dtype=torch.float16,
controlnet, torch_dtype=dtype,
)
extra_kwargs["controlnet"] = controlnet
if os.path.exists(os.path.join(model_name, "calibrate_info.txt")):
from onediff.quantization import QuantPipeline

pipe = QuantPipeline.from_quantized(
pipeline_cls, model_name, torch_dtype=torch.float16, **extra_kwargs
pipeline_cls, model_name, **extra_kwargs
)
else:
pipe = pipeline_cls.from_pretrained(
model_name, torch_dtype=torch.float16, **extra_kwargs
model_name, **extra_kwargs
)
if scheduler is not None:
if scheduler is not None and scheduler != "none":
scheduler_cls = getattr(importlib.import_module("diffusers"), scheduler)
pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config)
if lora is not None:
pipe.load_lora_weights(lora)
pipe.fuse_lora()
pipe.safety_checker = None
pipe.to(torch.device("cuda"))
if device is not None:
pipe.to(torch.device(device))
return pipe


Expand Down Expand Up @@ -154,15 +166,25 @@ def main():
controlnet=args.controlnet,
)

height = args.height or pipe.unet.config.sample_size * pipe.vae_scale_factor
width = args.width or pipe.unet.config.sample_size * pipe.vae_scale_factor
core_net = None
if core_net is None:
core_net = getattr(pipe, "unet", None)
if core_net is None:
core_net = getattr(pipe, "transformer", None)
height = args.height or core_net.config.sample_size * pipe.vae_scale_factor
width = args.width or core_net.config.sample_size * pipe.vae_scale_factor

if args.compiler == "none":
pass
elif args.compiler == "oneflow":
pipe = compile_pipe(pipe)
elif args.compiler == "nexfort":
pipe = compile_pipe(pipe, backend="nexfort")
options = CompileOptions()
if args.compiler_config is not None:
options.nexfort = json.load(args.compiler_config)
else:
options.nexfort = json.loads('{"mode": "max-autotune", "memory_format": "channels_last"}')
strint marked this conversation as resolved.
Show resolved Hide resolved
pipe = compile_pipe(pipe, backend="nexfort", options=options, fuse_qkv_projections=True)
strint marked this conversation as resolved.
Show resolved Hide resolved
elif args.compiler in ("compile", "compile-max-autotune"):
mode = "max-autotune" if args.compiler == "compile-max-autotune" else None
pipe.unet = torch.compile(pipe.unet, mode=mode)
Expand Down Expand Up @@ -199,7 +221,6 @@ def get_kwarg_inputs():
negative_prompt=args.negative_prompt,
height=height,
width=width,
num_inference_steps=args.steps,
num_images_per_prompt=args.batch,
generator=None
if args.seed is None
Expand All @@ -210,6 +231,8 @@ def get_kwarg_inputs():
else json.loads(args.extra_call_kwargs)
),
)
if args.steps is not None:
kwarg_inputs["num_inference_steps"] = args.steps
if input_image is not None:
kwarg_inputs["image"] = input_image
if control_image is not None:
Expand Down
4 changes: 2 additions & 2 deletions onediff_diffusers_extensions/onediffx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__version__ = "1.1.0.dev1"
from onediff.infer_compiler import compile_options
from onediff.infer_compiler import compile_options, CompileOptions
from .compilers.diffusion_pipeline_compiler import compile_pipe, save_pipe, load_pipe

__all__ = ["compile_pipe", "compile_options", "save_pipe", "load_pipe"]
__all__ = ["compile_pipe", "compile_options", "CompileOptions", "save_pipe", "load_pipe"]
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def _recursive_setattr(obj, attr, value):
"fast_unet", # for deepcache
"prior", # for StableCascadePriorPipeline
"decoder", # for StableCascadeDecoderPipeline
"transformer", # for Transformer-based DiffusionPipeline such as DiTPipeline and PixArtAlphaPipeline
"vqgan.down_blocks", # for StableCascadeDecoderPipeline
"vqgan.up_blocks", # for StableCascadeDecoderPipeline
"vae.decoder",
"vae.encoder",
"transformer", # for Transformer-based DiffusionPipeline such as DiTPipeline and PixArtAlphaPipeline
]


Expand All @@ -52,8 +52,17 @@ def _filter_parts(ignores=()):


def compile_pipe(
pipe, *, backend="oneflow", options=None, ignores=(),
pipe, *, backend="oneflow", options=None, ignores=(), fuse_qkv_projections=False,
):
if fuse_qkv_projections:
print("****** fuse qkv projections ******")
pipe = fuse_qkv_projections_in_pipe(pipe)

if options.nexfort is not None and "memory_format" in options.nexfort:
memory_format = getattr(torch, options.nexfort["memory_format"])
pipe = convert_pipe_to_memory_format(pipe, ignores=ignores, memory_format=memory_format)
del options.nexfort["memory_format"]

# To fix the bug of graph load of vae. Please refer to: https://github.com/siliconflow/onediff/issues/452
if (
hasattr(pipe, "upcast_vae")
Expand Down Expand Up @@ -82,6 +91,33 @@ def compile_pipe(

return pipe

def fuse_qkv_projections_in_pipe(pipe):
if hasattr(pipe, "fuse_qkv_projections"):
pipe.fuse_qkv_projections()
return pipe


def convert_pipe_to_memory_format(pipe, *, ignores=(), memory_format=torch.preserve_format):
from nexfort.utils.attributes import multi_recursive_apply
from nexfort.utils.memory_format import apply_memory_format
import functools
if memory_format == torch.preserve_format:
return pipe

parts = [
"unet",
"controlnet",
"fast_unet", # for deepcache
"prior", # for StableCascadePriorPipeline
"decoder", # for StableCascadeDecoderPipeline
"transformer", # for Transformer-based DiffusionPipeline such as DiTPipeline and PixArtAlphaPipeline
"vqgan", # for StableCascadeDecoderPipeline
"vae",
]
multi_recursive_apply(
pipe, parts, functools.partial(apply_memory_format, memory_format=memory_format), ignores=ignores, verbose=True
)
return pipe

def save_pipe(pipe, dir="cached_pipe", *, ignores=(), overwrite=True):
if not os.path.exists(dir):
Expand Down
21 changes: 4 additions & 17 deletions src/onediff/infer_compiler/backends/nexfort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
from .registry import register_backend


def make_inductor_options(options):
inductor_options = {}
if options is None:
return inductor_options
for filed in dataclasses.fields(options):
filed_name = filed.name
inductor_options[f"inductor.{filed_name}"] = getattr(options, filed_name)
return inductor_options


@register_backend("nexfort")
def compile(torch_module: torch.nn.Module, *, options=None):
from nexfort.utils.memory_format import apply_memory_format
Expand All @@ -22,11 +12,8 @@ def compile(torch_module: torch.nn.Module, *, options=None):

options = options if options is not None else CompileOptions()
nexfort_options = options.nexfort
if nexfort_options.memory_format != torch.preserve_format:
model = apply_memory_format(
torch_module, memory_format=nexfort_options.memory_format
)
model = nexfort_compile(
model, options=make_inductor_options(nexfort_options.inductor)
compiled_model = nexfort_compile(
torch_module, **nexfort_options
)
return NexfortDeployableModule(model)
# return NexfortDeployableModule(compiled_model, torch_module)
return compiled_model
6 changes: 3 additions & 3 deletions src/onediff/infer_compiler/nexfort/deployable_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@


class NexfortDeployableModule(DeployableModule):
def __init__(self, torch_module):
def __init__(self, compiled_module, torch_module):
torch.nn.Module.__init__(self)
object.__setattr__(self, "_deployable_module_model", torch_module)
object.__setattr__(self, "_modules", torch_module._modules)
object.__setattr__(self, "_deployable_module_model", compiled_module)
object.__setattr__(self, "_modules", compiled_module._modules)
object.__setattr__(self, "_torch_module", torch_module)

def __call__(self, *args, **kwargs):
Expand Down
33 changes: 2 additions & 31 deletions src/onediff/infer_compiler/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,35 +42,6 @@ class OneflowCompileOptions:
kernel_glu_enable_y_gemm_impl: bool = None
kernel_glu_quant_enable_dual_gemm_impl: bool = None


@dataclasses.dataclass
class NexfortInductorCompileOptions:
disable: bool = False
mode: str = None
options: Dict = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
class NexfortCompileOptions:
memory_format: torch.memory_format
fuse_qkv_projections: bool
inductor: NexfortInductorCompileOptions

def __init__(
self,
memory_format=torch.channels_last,
fuse_qkv_projections=True,
inductor=None,
):
if isinstance(memory_format, str):
memory_format = getattr(torch, memory_format)
self.memory_format = memory_format
self.fuse_qkv_projections = fuse_qkv_projections
self.inductor = (
inductor if inductor is not None else NexfortInductorCompileOptions()
)


@dataclasses.dataclass
class CompileOptions:
# common options
Expand All @@ -80,12 +51,12 @@ class CompileOptions:
oneflow: OneflowCompileOptions

# nexfort specific options
nexfort: NexfortCompileOptions
nexfort: Dict

def __init__(self, dynamic=True, oneflow=None, nexfort=None):
self.dynamic = dynamic
self.oneflow = oneflow if oneflow is not None else OneflowCompileOptions()
self.nexfort = nexfort if nexfort is not None else NexfortCompileOptions()
self.nexfort = nexfort if nexfort is not None else dict()


# a global default compile options
Expand Down
Loading