Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api for new backend #794

Merged
merged 18 commits into from
May 15, 2024
Merged
16 changes: 10 additions & 6 deletions benchmarks/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from PIL import Image, ImageDraw
from diffusers.utils import load_image

import oneflow as flow
hjchen2 marked this conversation as resolved.
Show resolved Hide resolved
from onediffx import compile_pipe


Expand Down Expand Up @@ -62,7 +61,7 @@ def parse_args():
"--compiler",
type=str,
default="oneflow",
choices=["none", "oneflow", "compile", "compile-max-autotune"],
choices=["none", "oneflow", "nexfort", "compile", "compile-max-autotune"],
)
return parser.parse_args()

Expand Down Expand Up @@ -162,6 +161,8 @@ def main():
pass
elif args.compiler == "oneflow":
pipe = compile_pipe(pipe)
elif args.compiler == "nexfort":
pipe = compile_pipe(pipe, backend="nexfort")
elif args.compiler in ("compile", "compile-max-autotune"):
mode = "max-autotune" if args.compiler == "compile-max-autotune" else None
pipe.unet = torch.compile(pipe.unet, mode=mode)
Expand Down Expand Up @@ -248,10 +249,13 @@ def get_kwarg_inputs():
iter_per_sec = iter_profiler.get_iter_per_sec()
if iter_per_sec is not None:
print(f"Iterations per second: {iter_per_sec:.3f}")
cuda_mem_after_used = flow._oneflow_internal.GetCUDAMemoryUsed()
host_mem_after_used = flow._oneflow_internal.GetCPUMemoryUsed()
print(f"CUDA Mem after: {cuda_mem_after_used / 1024:.3f}GiB")
print(f"Host Mem after: {host_mem_after_used / 1024:.3f}GiB")
if args.compiler == "oneflow":
import oneflow as flow

cuda_mem_after_used = flow._oneflow_internal.GetCUDAMemoryUsed() / 1024
else:
cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"CUDA Mem after: {cuda_mem_after_used:.3f}GiB")
print("=======================================")

if args.output_image is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from onediff.infer_compiler import register
from onediff.infer_compiler.transform import register
from onediff.infer_compiler.utils import is_community_version
from nodes import * # must imported before import comfy
from pathlib import Path
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from onediff.infer_compiler import register
from onediff.infer_compiler.transform import register

import oneflow as flow
import onediff_quant
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import torch
from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler import compile
from onediff.infer_compiler.deployable_module import DeployableModule
from onediff.infer_compiler.utils.log_utils import logger

Expand Down Expand Up @@ -34,6 +34,7 @@ def _recursive_setattr(obj, attr, value):
"vqgan.up_blocks", # for StableCascadeDecoderPipeline
"vae.decoder",
"vae.encoder",
"transformer", # for Transformer-based DiffusionPipeline such as DiTPipeline and PixArtAlphaPipeline
]


Expand All @@ -52,7 +53,7 @@ def _filter_parts(ignores=()):


def compile_pipe(
pipe, *, ignores=(),
pipe, *, backend="oneflow", options=None, ignores=(),
):
# To fix the bug of graph load of vae. Please refer to: https://github.com/siliconflow/onediff/issues/452
if (
Expand All @@ -67,7 +68,9 @@ def compile_pipe(
obj = _recursive_getattr(pipe, part, None)
if obj is not None:
logger.info(f"Compiling {part}")
_recursive_setattr(pipe, part, oneflow_compile(obj))
_recursive_setattr(
pipe, part, compile(obj, backend=backend, options=options)
)

if hasattr(pipe, "image_processor") and "image_processor" not in ignores:
logger.info("Patching image_processor")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from onediff.infer_compiler import register
from onediff.infer_compiler.transform import register

import oneflow as flow
import diffusers_enterprise_lite
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from onediff.infer_compiler import register
from onediff.infer_compiler.transform import register

import oneflow as flow
import onediff_quant
Expand Down
10 changes: 5 additions & 5 deletions src/onediff/infer_compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import os
import torch
import oneflow as flow

from .utils.patch_for_compiler import * # TODO:
from .utils.options import *
from .transform.custom_transform import register
from .deployable_module import DeployableModule
from .options import *
from .with_onediff_compile import compile, oneflow_compile
from oneflow.framework.args_tree import ArgsTree

from .with_fx_interpreter import OneFlowInterpreter
from .with_fx_graph import fx_node_tranform


def oneflow_backend(gm, example_inputs, *args, **kwargs):
import oneflow as flow
from oneflow.framework.args_tree import ArgsTree

with_interp = os.getenv(
"ONEDIFF_INFER_COMPILER_USE_INTERPRETER", "False"
).lower() in ("true", "1", "t",)
Expand Down
32 changes: 32 additions & 0 deletions src/onediff/infer_compiler/backends/nexfort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import dataclasses
import torch
from .registry import register_backend
from ..options import CompileOptions


def make_inductor_options(options):
inductor_options = {}
if options is None:
return inductor_options
for filed in dataclasses.fields(options):
filed_name = filed.name
inductor_options[f"inductor.{filed_name}"] = getattr(options, filed_name)
return inductor_options


@register_backend("nexfort")
def compile(torch_module: torch.nn.Module, *, options=None):
from nexfort.utils.memory_format import apply_memory_format
from nexfort.compilers import nexfort_compile
from ..nexfort.deployable_module import NexfortDeployableModule

options = options if options is not None else CompileOptions()
nexfort_options = options.nexfort
if nexfort_options.memory_format != torch.preserve_format:
model = apply_memory_format(
torch_module, memory_format=nexfort_options.memory_format
)
model = nexfort_compile(
model, options=make_inductor_options(nexfort_options.inductor)
)
return NexfortDeployableModule(model)
ccssu marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/backends/oneflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from .registry import register_backend
from ..utils.options import CompileOptions
from ..options import CompileOptions


@register_backend("oneflow")
Expand Down
Empty file.
16 changes: 16 additions & 0 deletions src/onediff/infer_compiler/nexfort/deployable_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
from ..deployable_module import DeployableModule


class NexfortDeployableModule(DeployableModule):
def __init__(self, torch_module):
torch.nn.Module.__init__(self)
object.__setattr__(self, "_deployable_module_model", torch_module)
object.__setattr__(self, "_modules", torch_module._modules)
object.__setattr__(self, "_torch_module", torch_module)

def __call__(self, *args, **kwargs):
return self._deployable_module_model(*args, **kwargs)

def __getattr__(self, name):
return getattr(self._deployable_module_model, name)
8 changes: 5 additions & 3 deletions src/onediff/infer_compiler/oneflow/deployable_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..utils.param_utils import parse_device, check_device
from ..utils.graph_management_utils import graph_file_management
from ..utils.online_quantization_utils import quantize_and_deploy_wrapper
from ..utils.options import OneflowCompileOptions
from ..options import OneflowCompileOptions
from ..deployable_module import DeployableModule

from .utils import handle_deployable_exception, get_mixed_dual_module, get_oneflow_graph
Expand Down Expand Up @@ -50,7 +50,9 @@ def from_existing(cls, existing_module, dynamic=True, options=None):
instance._deployable_module_input_count = (
existing_module._deployable_module_input_count
)
instance._deployable_module_quant_config = existing_module._deployable_module_quant_config
instance._deployable_module_quant_config = (
existing_module._deployable_module_quant_config
)

return instance

Expand Down Expand Up @@ -85,7 +87,7 @@ def apply_model(self, *args, **kwargs):
*args, **kwargs
)
return output

@quantize_and_deploy_wrapper
@input_output_processor
@handle_deployable_exception
Expand Down
1 change: 1 addition & 0 deletions src/onediff/infer_compiler/oneflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..transform.builtin_transform import torch2oflow
from ..transform.manager import transform_mgr
from ..utils.log_utils import logger
from ..utils.patch_for_compiler import *
from .dual_module import DualModule


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,23 @@ class NexfortInductorCompileOptions:

@dataclasses.dataclass
class NexfortCompileOptions:
memory_format: torch.memory_format
fuse_qkv_projections: bool
inductor: NexfortInductorCompileOptions

def __init__(self):
self.inductor = NexfortInductorCompileOptions()
def __init__(
self,
memory_format=torch.channels_last,
fuse_qkv_projections=True,
inductor=None,
):
if isinstance(memory_format, str):
memory_format = getattr(torch, memory_format)
self.memory_format = memory_format
self.fuse_qkv_projections = fuse_qkv_projections
self.inductor = (
inductor if inductor is not None else NexfortInductorCompileOptions()
)


@dataclasses.dataclass
Expand All @@ -38,7 +51,7 @@ class CompileOptions:
# nexfort specific options
nexfort: NexfortCompileOptions

def __init__(self, dynamic=True):
def __init__(self, dynamic=True, oneflow=None, nexfort=None):
self.dynamic = dynamic
self.oneflow = OneflowCompileOptions()
self.nexfort = NexfortCompileOptions()
self.oneflow = oneflow if oneflow is not None else OneflowCompileOptions()
self.nexfort = nexfort if nexfort is not None else NexfortCompileOptions()
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/utils/graph_management_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from pathlib import Path
from functools import wraps
from oneflow.framework.args_tree import ArgsTree
from ..options import OneflowCompileOptions
from ..transform.builtin_transform import torch2oflow
from ..transform.manager import transform_mgr
from .log_utils import logger
from .cost_util import cost_time
from .options import OneflowCompileOptions


def calculate_model_hash(model):
Expand Down
6 changes: 4 additions & 2 deletions src/onediff/infer_compiler/utils/oneflow_exec_mode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import oneflow as flow

_ONEFLOW_EXEC_MODE = False


Expand All @@ -11,13 +9,17 @@ def __init__(self, enabled=None):
self.enabled = True

def __enter__(self):
import oneflow as flow

global _ONEFLOW_EXEC_MODE
self.prev_mode = _ONEFLOW_EXEC_MODE
_ONEFLOW_EXEC_MODE = self.enabled
self.prev_grad_mode = flow.is_grad_enabled()
_ = flow.set_grad_enabled(False)

def __exit__(self, exc_type, exc_val, exc_tb):
import oneflow as flow

global _ONEFLOW_EXEC_MODE
_ONEFLOW_EXEC_MODE = self.prev_mode
_ = flow.set_grad_enabled(self.prev_grad_mode)
Expand Down
10 changes: 7 additions & 3 deletions src/onediff/infer_compiler/with_fx_graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import torch
import torch.fx as fx
import oneflow as flow
from torch.fx.node import map_aggregate
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from .transform import get_attr, torch2oflow


def fx_node_tranform(gm):
import oneflow as flow

of_gm = to_of_transform(gm)

enable_graph = os.getenv("ONEDIFF_INFER_COMPILER_USE_GRAPH", "True").lower() in (
Expand Down Expand Up @@ -40,6 +39,9 @@ def build(self, *args, **kwargs):
def to_of_transform(
gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer
) -> torch.fx.GraphModule:
import oneflow as flow
from .transform import get_attr, torch2oflow

name2node = {}
name2obj = {}
torch2flow = {}
Expand Down Expand Up @@ -94,6 +96,8 @@ def to_of_transform(


def replace_node(node, name2node):
from .transform import torch2oflow

if isinstance(node, torch.fx.Node):
return name2node[node.name]
else:
Expand Down
7 changes: 6 additions & 1 deletion src/onediff/infer_compiler/with_fx_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import torch
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from .transform import map_args, ProxySubmodule


class OneFlowInterpreter(torch.fx.Interpreter):
from torch.fx.node import Argument, Target

def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
from .transform import map_args

args, kwargs = map_args(args, kwargs)
target = torch2oflow(target)
return super().call_function(target, args, kwargs)

def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
from .transform import map_args

args, kwargs = map_args(args, kwargs)
return super().call_method(target, args, kwargs)

def call_module(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
from .transform import ProxySubmodule

submod = self.fetch_attr(target)
submod = ProxySubmodule(submod)
return submod(*args, **kwargs)
Loading