From f91dae4044133589f300e816553428f089245750 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 9 Apr 2024 12:16:53 +0800 Subject: [PATCH 01/13] integrate nexfort --- benchmarks/text_to_image.py | 16 ++++++---- .../compilers/diffusion_pipeline_compiler.py | 9 ++++-- .../infer_compiler/backends/nexfort.py | 32 +++++++++++++++++++ src/onediff/infer_compiler/utils/options.py | 23 ++++++++++--- 4 files changed, 66 insertions(+), 14 deletions(-) create mode 100644 src/onediff/infer_compiler/backends/nexfort.py diff --git a/benchmarks/text_to_image.py b/benchmarks/text_to_image.py index 03414a8b8..a8c97c510 100644 --- a/benchmarks/text_to_image.py +++ b/benchmarks/text_to_image.py @@ -30,7 +30,6 @@ from PIL import Image, ImageDraw from diffusers.utils import load_image -import oneflow as flow from onediffx import compile_pipe @@ -62,7 +61,7 @@ def parse_args(): "--compiler", type=str, default="oneflow", - choices=["none", "oneflow", "compile", "compile-max-autotune"], + choices=["none", "oneflow", "nexfort", "compile", "compile-max-autotune"], ) return parser.parse_args() @@ -162,6 +161,8 @@ def main(): pass elif args.compiler == "oneflow": pipe = compile_pipe(pipe) + elif args.compiler == "nexfort": + pipe = compile_pipe(pipe, backend="nexfort") elif args.compiler in ("compile", "compile-max-autotune"): mode = "max-autotune" if args.compiler == "compile-max-autotune" else None pipe.unet = torch.compile(pipe.unet, mode=mode) @@ -248,10 +249,13 @@ def get_kwarg_inputs(): iter_per_sec = iter_profiler.get_iter_per_sec() if iter_per_sec is not None: print(f"Iterations per second: {iter_per_sec:.3f}") - cuda_mem_after_used = flow._oneflow_internal.GetCUDAMemoryUsed() - host_mem_after_used = flow._oneflow_internal.GetCPUMemoryUsed() - print(f"CUDA Mem after: {cuda_mem_after_used / 1024:.3f}GiB") - print(f"Host Mem after: {host_mem_after_used / 1024:.3f}GiB") + if args.compiler == "oneflow": + import oneflow as flow + + cuda_mem_after_used = flow._oneflow_internal.GetCUDAMemoryUsed() / 1024 + else: + cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024 ** 3) + print(f"CUDA Mem after: {cuda_mem_after_used:.3f}GiB") print("=======================================") if args.output_image is not None: diff --git a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py index 3040f9970..2986403f4 100644 --- a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py +++ b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py @@ -1,6 +1,6 @@ import os import torch -from onediff.infer_compiler import oneflow_compile +from onediff.infer_compiler import compile from onediff.infer_compiler.deployable_module import DeployableModule from onediff.infer_compiler.utils.log_utils import logger @@ -34,6 +34,7 @@ def _recursive_setattr(obj, attr, value): "vqgan.up_blocks", # for StableCascadeDecoderPipeline "vae.decoder", "vae.encoder", + "transformer", # for Transformer-based DiffusionPipeline such as DiTPipeline and PixArtAlphaPipeline ] @@ -52,7 +53,7 @@ def _filter_parts(ignores=()): def compile_pipe( - pipe, *, ignores=(), + pipe, *, backend="oneflow", options=None, ignores=(), ): # To fix the bug of graph load of vae. Please refer to: https://github.com/siliconflow/onediff/issues/452 if ( @@ -67,7 +68,9 @@ def compile_pipe( obj = _recursive_getattr(pipe, part, None) if obj is not None: logger.info(f"Compiling {part}") - _recursive_setattr(pipe, part, oneflow_compile(obj)) + _recursive_setattr( + pipe, part, compile(obj, backend=backend, options=options) + ) if hasattr(pipe, "image_processor") and "image_processor" not in ignores: logger.info("Patching image_processor") diff --git a/src/onediff/infer_compiler/backends/nexfort.py b/src/onediff/infer_compiler/backends/nexfort.py new file mode 100644 index 000000000..f858791f8 --- /dev/null +++ b/src/onediff/infer_compiler/backends/nexfort.py @@ -0,0 +1,32 @@ +import dataclasses +import torch +from .registry import register_backend +from ..utils.options import CompileOptions + + +def make_inductor_options(options): + inductor_options = {} + if options is None: + return inductor_options + for filed in dataclasses.fields(options): + filed_name = filed.name + inductor_options[f"inductor.{filed_name}"] = getattr(options, filed_name) + return inductor_options + + +@register_backend("nexfort") +def compile(torch_module: torch.nn.Module, *, options=None): + from nexfort.utils.memory_format import apply_memory_format + from nexfort.compilers import nexfort_compile + from ..nexfort.deployable_module import NexfortDeployableModule + + options = options if options is not None else CompileOptions() + nexfort_options = options.nexfort + if nexfort_options.memory_format != torch.preserve_format: + model = apply_memory_format( + torch_module, memory_format=nexfort_options.memory_format + ) + model = nexfort_compile( + model, options=make_inductor_options(nexfort_options.inductor) + ) + return NexfortDeployableModule(model) diff --git a/src/onediff/infer_compiler/utils/options.py b/src/onediff/infer_compiler/utils/options.py index 1e7ba46df..1061d4f7c 100644 --- a/src/onediff/infer_compiler/utils/options.py +++ b/src/onediff/infer_compiler/utils/options.py @@ -21,10 +21,23 @@ class NexfortInductorCompileOptions: @dataclasses.dataclass class NexfortCompileOptions: + memory_format: torch.memory_format + fuse_qkv_projections: bool inductor: NexfortInductorCompileOptions - def __init__(self): - self.inductor = NexfortInductorCompileOptions() + def __init__( + self, + memory_format=torch.channels_last, + fuse_qkv_projections=True, + inductor=None, + ): + if isinstance(memory_format, str): + memory_format = getattr(torch, memory_format) + self.memory_format = memory_format + self.fuse_qkv_projections = fuse_qkv_projections + self.inductor = ( + inductor if inductor is not None else NexfortInductorCompileOptions() + ) @dataclasses.dataclass @@ -38,7 +51,7 @@ class CompileOptions: # nexfort specific options nexfort: NexfortCompileOptions - def __init__(self, dynamic=True): + def __init__(self, dynamic=True, oneflow=None, nexfort=None): self.dynamic = dynamic - self.oneflow = OneflowCompileOptions() - self.nexfort = NexfortCompileOptions() + self.oneflow = oneflow if oneflow is not None else OneflowCompileOptions() + self.nexfort = nexfort if nexfort is not None else NexfortCompileOptions() From 058811b90465435ddc4e3f94cc61b759381c7173 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 9 Apr 2024 12:29:54 +0800 Subject: [PATCH 02/13] upload nexfort --- src/onediff/infer_compiler/nexfort/__init__.py | 0 .../infer_compiler/nexfort/deployable_module.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 src/onediff/infer_compiler/nexfort/__init__.py create mode 100644 src/onediff/infer_compiler/nexfort/deployable_module.py diff --git a/src/onediff/infer_compiler/nexfort/__init__.py b/src/onediff/infer_compiler/nexfort/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/onediff/infer_compiler/nexfort/deployable_module.py b/src/onediff/infer_compiler/nexfort/deployable_module.py new file mode 100644 index 000000000..421b2d8bc --- /dev/null +++ b/src/onediff/infer_compiler/nexfort/deployable_module.py @@ -0,0 +1,16 @@ +import torch +from ..deployable_module import DeployableModule + + +class NexfortDeployableModule(DeployableModule): + def __init__(self, torch_module): + torch.nn.Module.__init__(self) + object.__setattr__(self, "_deployable_module_model", torch_module) + object.__setattr__(self, "_modules", torch_module._modules) + object.__setattr__(self, "_torch_module", torch_module) + + def __call__(self, *args, **kwargs): + return self._deployable_module_model(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._deployable_module_model, name) From ef0773edd2bf2b29da11c88639b096be7fb5ec90 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 9 Apr 2024 13:04:59 +0800 Subject: [PATCH 03/13] move options out of utils --- src/onediff/infer_compiler/__init__.py | 2 +- src/onediff/infer_compiler/backends/nexfort.py | 2 +- src/onediff/infer_compiler/backends/oneflow.py | 2 +- src/onediff/infer_compiler/oneflow/deployable_module.py | 8 +++++--- src/onediff/infer_compiler/{utils => }/options.py | 0 5 files changed, 8 insertions(+), 6 deletions(-) rename src/onediff/infer_compiler/{utils => }/options.py (100%) diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index 686149e3d..ed15f119c 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -3,7 +3,7 @@ import oneflow as flow from .utils.patch_for_compiler import * # TODO: -from .utils.options import * +from .options import * from .transform.custom_transform import register from .with_onediff_compile import compile, oneflow_compile from oneflow.framework.args_tree import ArgsTree diff --git a/src/onediff/infer_compiler/backends/nexfort.py b/src/onediff/infer_compiler/backends/nexfort.py index f858791f8..9d702aaf1 100644 --- a/src/onediff/infer_compiler/backends/nexfort.py +++ b/src/onediff/infer_compiler/backends/nexfort.py @@ -1,7 +1,7 @@ import dataclasses import torch from .registry import register_backend -from ..utils.options import CompileOptions +from ..options import CompileOptions def make_inductor_options(options): diff --git a/src/onediff/infer_compiler/backends/oneflow.py b/src/onediff/infer_compiler/backends/oneflow.py index 4a66f539b..58a24567c 100644 --- a/src/onediff/infer_compiler/backends/oneflow.py +++ b/src/onediff/infer_compiler/backends/oneflow.py @@ -1,6 +1,6 @@ import torch from .registry import register_backend -from ..utils.options import CompileOptions +from ..options import CompileOptions @register_backend("oneflow") diff --git a/src/onediff/infer_compiler/oneflow/deployable_module.py b/src/onediff/infer_compiler/oneflow/deployable_module.py index 01ab9550f..45bbe00e9 100644 --- a/src/onediff/infer_compiler/oneflow/deployable_module.py +++ b/src/onediff/infer_compiler/oneflow/deployable_module.py @@ -9,7 +9,7 @@ from ..utils.param_utils import parse_device, check_device from ..utils.graph_management_utils import graph_file_management from ..utils.online_quantization_utils import quantize_and_deploy_wrapper -from ..utils.options import OneflowCompileOptions +from ..options import OneflowCompileOptions from ..deployable_module import DeployableModule from .utils import handle_deployable_exception, get_mixed_dual_module, get_oneflow_graph @@ -50,7 +50,9 @@ def from_existing(cls, existing_module, dynamic=True, options=None): instance._deployable_module_input_count = ( existing_module._deployable_module_input_count ) - instance._deployable_module_quant_config = existing_module._deployable_module_quant_config + instance._deployable_module_quant_config = ( + existing_module._deployable_module_quant_config + ) return instance @@ -85,7 +87,7 @@ def apply_model(self, *args, **kwargs): *args, **kwargs ) return output - + @quantize_and_deploy_wrapper @input_output_processor @handle_deployable_exception diff --git a/src/onediff/infer_compiler/utils/options.py b/src/onediff/infer_compiler/options.py similarity index 100% rename from src/onediff/infer_compiler/utils/options.py rename to src/onediff/infer_compiler/options.py From 7fc99ff61fdc71701f2b699a1389c2f435a0b296 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 9 Apr 2024 13:08:00 +0800 Subject: [PATCH 04/13] fix --- src/onediff/infer_compiler/utils/graph_management_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onediff/infer_compiler/utils/graph_management_utils.py b/src/onediff/infer_compiler/utils/graph_management_utils.py index 14515f137..607709eeb 100644 --- a/src/onediff/infer_compiler/utils/graph_management_utils.py +++ b/src/onediff/infer_compiler/utils/graph_management_utils.py @@ -5,11 +5,11 @@ from pathlib import Path from functools import wraps from oneflow.framework.args_tree import ArgsTree +from ..options import OneflowCompileOptions from ..transform.builtin_transform import torch2oflow from ..transform.manager import transform_mgr from .log_utils import logger from .cost_util import cost_time -from .options import OneflowCompileOptions def calculate_model_hash(model): From b3137f0b39773c04a6f9e9fc5f5d8c748552a4f5 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 9 Apr 2024 13:49:28 +0800 Subject: [PATCH 05/13] fix oneflow import --- .../infer_compiler_registry/register_comfy/__init__.py | 2 +- .../infer_compiler_registry/register_onediff_quant.py | 2 +- .../register_diffusers_enterprise_lite/__init__.py | 2 +- .../register_onediff_quant/__init__.py | 2 +- src/onediff/infer_compiler/__init__.py | 8 ++++---- src/onediff/infer_compiler/oneflow/utils.py | 1 + src/onediff/infer_compiler/utils/oneflow_exec_mode.py | 6 ++++-- src/onediff/infer_compiler/with_fx_graph.py | 10 +++++++--- src/onediff/infer_compiler/with_fx_interpreter.py | 7 ++++++- 9 files changed, 26 insertions(+), 14 deletions(-) diff --git a/onediff_comfy_nodes/infer_compiler_registry/register_comfy/__init__.py b/onediff_comfy_nodes/infer_compiler_registry/register_comfy/__init__.py index 748a3bf6e..b789c7498 100644 --- a/onediff_comfy_nodes/infer_compiler_registry/register_comfy/__init__.py +++ b/onediff_comfy_nodes/infer_compiler_registry/register_comfy/__init__.py @@ -1,4 +1,4 @@ -from onediff.infer_compiler import register +from onediff.infer_compiler.transform import register from onediff.infer_compiler.utils import is_community_version from nodes import * # must imported before import comfy from pathlib import Path diff --git a/onediff_comfy_nodes/infer_compiler_registry/register_onediff_quant.py b/onediff_comfy_nodes/infer_compiler_registry/register_onediff_quant.py index e8324a89e..e9ab3afd8 100644 --- a/onediff_comfy_nodes/infer_compiler_registry/register_onediff_quant.py +++ b/onediff_comfy_nodes/infer_compiler_registry/register_onediff_quant.py @@ -1,4 +1,4 @@ -from onediff.infer_compiler import register +from onediff.infer_compiler.transform import register import oneflow as flow import onediff_quant diff --git a/src/infer_compiler_registry/register_diffusers_enterprise_lite/__init__.py b/src/infer_compiler_registry/register_diffusers_enterprise_lite/__init__.py index f101ce692..d8bf735f9 100644 --- a/src/infer_compiler_registry/register_diffusers_enterprise_lite/__init__.py +++ b/src/infer_compiler_registry/register_diffusers_enterprise_lite/__init__.py @@ -1,4 +1,4 @@ -from onediff.infer_compiler import register +from onediff.infer_compiler.transform import register import oneflow as flow import diffusers_enterprise_lite diff --git a/src/infer_compiler_registry/register_onediff_quant/__init__.py b/src/infer_compiler_registry/register_onediff_quant/__init__.py index e8324a89e..e9ab3afd8 100644 --- a/src/infer_compiler_registry/register_onediff_quant/__init__.py +++ b/src/infer_compiler_registry/register_onediff_quant/__init__.py @@ -1,4 +1,4 @@ -from onediff.infer_compiler import register +from onediff.infer_compiler.transform import register import oneflow as flow import onediff_quant diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index ed15f119c..ab6fce98d 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -1,18 +1,18 @@ import os import torch -import oneflow as flow -from .utils.patch_for_compiler import * # TODO: +from .deployable_module import DeployableModule from .options import * -from .transform.custom_transform import register from .with_onediff_compile import compile, oneflow_compile -from oneflow.framework.args_tree import ArgsTree from .with_fx_interpreter import OneFlowInterpreter from .with_fx_graph import fx_node_tranform def oneflow_backend(gm, example_inputs, *args, **kwargs): + import oneflow as flow + from oneflow.framework.args_tree import ArgsTree + with_interp = os.getenv( "ONEDIFF_INFER_COMPILER_USE_INTERPRETER", "False" ).lower() in ("true", "1", "t",) diff --git a/src/onediff/infer_compiler/oneflow/utils.py b/src/onediff/infer_compiler/oneflow/utils.py index 4a5e899aa..dfdf97a6e 100644 --- a/src/onediff/infer_compiler/oneflow/utils.py +++ b/src/onediff/infer_compiler/oneflow/utils.py @@ -3,6 +3,7 @@ from ..transform.builtin_transform import torch2oflow from ..transform.manager import transform_mgr from ..utils.log_utils import logger +from ..utils.patch_for_compiler import * from .dual_module import DualModule diff --git a/src/onediff/infer_compiler/utils/oneflow_exec_mode.py b/src/onediff/infer_compiler/utils/oneflow_exec_mode.py index cab31d0a6..13635c05f 100644 --- a/src/onediff/infer_compiler/utils/oneflow_exec_mode.py +++ b/src/onediff/infer_compiler/utils/oneflow_exec_mode.py @@ -1,5 +1,3 @@ -import oneflow as flow - _ONEFLOW_EXEC_MODE = False @@ -11,6 +9,8 @@ def __init__(self, enabled=None): self.enabled = True def __enter__(self): + import oneflow as flow + global _ONEFLOW_EXEC_MODE self.prev_mode = _ONEFLOW_EXEC_MODE _ONEFLOW_EXEC_MODE = self.enabled @@ -18,6 +18,8 @@ def __enter__(self): _ = flow.set_grad_enabled(False) def __exit__(self, exc_type, exc_val, exc_tb): + import oneflow as flow + global _ONEFLOW_EXEC_MODE _ONEFLOW_EXEC_MODE = self.prev_mode _ = flow.set_grad_enabled(self.prev_grad_mode) diff --git a/src/onediff/infer_compiler/with_fx_graph.py b/src/onediff/infer_compiler/with_fx_graph.py index 1c9cd6a63..9f17772c8 100644 --- a/src/onediff/infer_compiler/with_fx_graph.py +++ b/src/onediff/infer_compiler/with_fx_graph.py @@ -1,14 +1,13 @@ import os import torch import torch.fx as fx -import oneflow as flow from torch.fx.node import map_aggregate from typing import Any, Dict, Iterator, List, Optional, Tuple, Union -from .transform import get_attr, torch2oflow - def fx_node_tranform(gm): + import oneflow as flow + of_gm = to_of_transform(gm) enable_graph = os.getenv("ONEDIFF_INFER_COMPILER_USE_GRAPH", "True").lower() in ( @@ -40,6 +39,9 @@ def build(self, *args, **kwargs): def to_of_transform( gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer ) -> torch.fx.GraphModule: + import oneflow as flow + from .transform import get_attr, torch2oflow + name2node = {} name2obj = {} torch2flow = {} @@ -94,6 +96,8 @@ def to_of_transform( def replace_node(node, name2node): + from .transform import torch2oflow + if isinstance(node, torch.fx.Node): return name2node[node.name] else: diff --git a/src/onediff/infer_compiler/with_fx_interpreter.py b/src/onediff/infer_compiler/with_fx_interpreter.py index 4f85956c7..884a201d7 100644 --- a/src/onediff/infer_compiler/with_fx_interpreter.py +++ b/src/onediff/infer_compiler/with_fx_interpreter.py @@ -1,23 +1,28 @@ import torch from typing import Any, Dict, Iterator, List, Optional, Tuple, Union -from .transform import map_args, ProxySubmodule class OneFlowInterpreter(torch.fx.Interpreter): from torch.fx.node import Argument, Target def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: + from .transform import map_args + args, kwargs = map_args(args, kwargs) target = torch2oflow(target) return super().call_function(target, args, kwargs) def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: + from .transform import map_args + args, kwargs = map_args(args, kwargs) return super().call_method(target, args, kwargs) def call_module( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: + from .transform import ProxySubmodule + submod = self.fetch_attr(target) submod = ProxySubmodule(submod) return submod(*args, **kwargs) From 09882e88077e736550fb51bdb2593dd34ec05fdd Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 9 Apr 2024 15:44:55 +0800 Subject: [PATCH 06/13] remove oneflow config --- benchmarks/image_to_video.py | 6 +- onediff_comfy_nodes/_nodes.py | 2 +- .../utils/deep_cache_speedup.py | 2 +- onediff_diffusers_extensions/README.md | 14 +- .../text_to_image_sdxl_torch_compile.py | 3 +- ..._to_image_deep_cache_sd_sdxl_enterprise.py | 23 +-- .../examples/text_to_image_deep_cache_sdxl.py | 11 +- .../examples/text_to_image_sd_enterprise.py | 11 +- .../examples/text_to_image_sdxl_enterprise.py | 15 +- .../examples/text_to_image_sdxl_light.py | 2 +- .../examples/text_to_image_sdxl_reuse_pipe.py | 20 +-- .../examples/text_to_image_sdxl_save_load.py | 3 +- .../onediffx/__init__.py | 5 +- .../scripts/onediff.py | 2 +- .../attention_processor_oflow.py | 6 +- src/onediff/infer_compiler/__init__.py | 5 + .../infer_compiler/backends/oneflow.py | 2 + .../{utils/env_var.py => env/utils.py} | 0 .../infer_compiler/oneflow/__init__.py | 1 - src/onediff/infer_compiler/oneflow/config.py | 146 ------------------ src/onediff/infer_compiler/options.py | 34 ++++ src/onediff/infer_compiler/utils/__init__.py | 6 - .../optimization/attention_processor.py | 2 +- 23 files changed, 98 insertions(+), 223 deletions(-) rename src/onediff/infer_compiler/{utils/env_var.py => env/utils.py} (100%) delete mode 100644 src/onediff/infer_compiler/oneflow/config.py diff --git a/benchmarks/image_to_video.py b/benchmarks/image_to_video.py index a1c707d69..b257f1f64 100644 --- a/benchmarks/image_to_video.py +++ b/benchmarks/image_to_video.py @@ -41,7 +41,7 @@ import oneflow as flow import torch -from onediffx import compile_pipe, compiler_config +from onediffx import compile_pipe, compile_options from diffusers.utils import load_image, export_to_video @@ -187,10 +187,10 @@ def main(): # especially for 40xx series cards. # So here by partially disabling the half accumulation in MHA partially, # we can get a good balance. - compiler_config.attention_allow_half_precision_score_accumulation_max_m = ( + compile_options.oneflow.attention_allow_half_precision_score_accumulation_max_m = ( args.attention_fp16_score_accum_max_m ) - pipe = compile_pipe(pipe,) + pipe = compile_pipe(pipe, options=compile_options) elif args.compiler == "compile": pipe.unet = torch.compile(pipe.unet) if hasattr(pipe, "controlnet"): diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index 9ceb9c029..47aa3edda 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -1,7 +1,7 @@ from functools import partial from onediff.infer_compiler.transform import torch2oflow from ._config import _USE_UNET_INT8, ONEDIFF_QUANTIZED_OPTIMIZED_MODELS -from onediff.infer_compiler.utils import set_boolean_env_var +from onediff.infer_compiler.env import set_boolean_env_var from onediff.optimization.quant_optimizer import quantize_model from onediff.infer_compiler import oneflow_compile, CompileOptions from onediff.infer_compiler.deployable_module import DeployableModule diff --git a/onediff_comfy_nodes/utils/deep_cache_speedup.py b/onediff_comfy_nodes/utils/deep_cache_speedup.py index 2580da9b0..fd0c53d25 100644 --- a/onediff_comfy_nodes/utils/deep_cache_speedup.py +++ b/onediff_comfy_nodes/utils/deep_cache_speedup.py @@ -2,7 +2,7 @@ from comfy import model_management from comfy.model_base import SVD_img2vid -from onediff.infer_compiler.utils import set_boolean_env_var +from onediff.infer_compiler.env import set_boolean_env_var from .model_patcher import OneFlowDeepCacheSpeedUpModelPatcher diff --git a/onediff_diffusers_extensions/README.md b/onediff_diffusers_extensions/README.md index fd2d308d1..86b9f279b 100644 --- a/onediff_diffusers_extensions/README.md +++ b/onediff_diffusers_extensions/README.md @@ -197,7 +197,7 @@ deepcache_output = pipe( import torch from diffusers.utils import load_image, export_to_video -from onediffx import compile_pipe, compiler_config +from onediffx import compile_pipe, compile_options from onediffx.deep_cache import StableVideoDiffusionPipeline pipe = StableVideoDiffusionPipeline.from_pretrained( @@ -208,8 +208,8 @@ pipe = StableVideoDiffusionPipeline.from_pretrained( ) pipe.to("cuda") -compiler_config.attention_allow_half_precision_score_accumulation_max_m = 0 -pipe = compile_pipe(pipe) +compile_options.oneflow.attention_allow_half_precision_score_accumulation_max_m = 0 +pipe = compile_pipe(pipe, options=compile_options) input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true") input_image = input_image.resize((1024, 576)) @@ -415,12 +415,12 @@ We tested the performance of `set_adapters`, still using the five LoRA models me 2. If your LoRA model only contains the weights of the Linear module, you can directly use OneDiffX without any modifications. But if your LoRA model includes the weights of the Conv module (such as LyCORIS), you need to disable constant folding optimization by above methods (which may cause a performance drop of around 4.4%), otherwise the weights of the Conv module may not be loaded into the model. - Set the env var `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to 0 - - Set compiler_config.mlir_enable_inference_optimization to 0 before invoking `oneflow_compile` as the code below + - Set mlir_enable_inference_optimization to 0 when invoking `oneflow_compile` as the code below ``` - from onediffx import compiler_config - compiler_config.mlir_enable_inference_optimization = 0 + from onediffx import compile_options + compile_options.oneflow.mlir_enable_inference_optimization = 0 ... - pipe.unet = oneflow_compile(pipe.unet) + pipe.unet = oneflow_compile(pipe.unet, options=compile_options) ... ``` ### Optimization diff --git a/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_torch_compile.py b/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_torch_compile.py index 8e38be060..85212ed1a 100644 --- a/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_torch_compile.py +++ b/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_torch_compile.py @@ -10,7 +10,7 @@ import oneflow as flow from diffusers import DiffusionPipeline -from onediff.infer_compiler import oneflow_compile, CompileOptions +from onediff.infer_compiler import oneflow_compile, compile_options parser = argparse.ArgumentParser() parser.add_argument( @@ -53,7 +53,6 @@ # Compile unet with oneflow if cmd_args.compile: print("unet is compiled to oneflow.") - compile_options = CompileOptions() compile_options.oneflow.max_cached_graph_size = cmd_args.num_dynamic_input_size base.unet = oneflow_compile(base.unet, options=compile_options) diff --git a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd_sdxl_enterprise.py b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd_sdxl_enterprise.py index 13999ed5e..7b414f0ae 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd_sdxl_enterprise.py +++ b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd_sdxl_enterprise.py @@ -6,7 +6,7 @@ import torch.nn as nn # oneflow_compile should be imported before importing any diffusers -from onediff.infer_compiler import oneflow_compile, CompileOptions +from onediff.infer_compiler import oneflow_compile, compile_options def parse_args(): @@ -110,26 +110,29 @@ def parse_args(): pipe.unet, sub_module_name, sub_calibrate_info, False, False, args.bits, ) -options = CompileOptions() -options.oneflow.use_graph = args.graph +compile_options.oneflow.use_graph = args.graph if args.compile_text_encoder: if pipe.text_encoder is not None: - pipe.text_encoder = oneflow_compile(pipe.text_encoder, options=options) + pipe.text_encoder = oneflow_compile(pipe.text_encoder, options=compile_options) if hasattr(pipe, "text_encoder_2"): - pipe.text_encoder_2 = oneflow_compile(pipe.text_encoder_2, options=options) + pipe.text_encoder_2 = oneflow_compile( + pipe.text_encoder_2, options=compile_options + ) if args.compile: if pipe.text_encoder is not None: - pipe.text_encoder = oneflow_compile(pipe.text_encoder, options=options) + pipe.text_encoder = oneflow_compile(pipe.text_encoder, options=compile_options) if hasattr(pipe, "text_encoder_2"): - pipe.text_encoder_2 = oneflow_compile(pipe.text_encoder_2, options=options) - pipe.unet = oneflow_compile(pipe.unet, options=options) - pipe.fast_unet = oneflow_compile(pipe.fast_unet, options=options) + pipe.text_encoder_2 = oneflow_compile( + pipe.text_encoder_2, options=compile_options + ) + pipe.unet = oneflow_compile(pipe.unet, options=compile_options) + pipe.fast_unet = oneflow_compile(pipe.fast_unet, options=compile_options) if hasattr(pipe, "text_encoder_2") and pipe.needs_upcasting: # To avoid mis-match of loaded graph and loaded model pipe.upcast_vae() - pipe.vae.decoder = oneflow_compile(pipe.vae.decoder, options=options) + pipe.vae.decoder = oneflow_compile(pipe.vae.decoder, options=compile_options) if args.load_graph: print("Loading graphs to avoid compilation...") diff --git a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sdxl.py b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sdxl.py index 741a40d31..7839be91f 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sdxl.py +++ b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sdxl.py @@ -7,9 +7,7 @@ import torch -from onediffx import compile_pipe, compiler_config -from onediff.schedulers import EulerDiscreteScheduler - +from onediffx import compile_pipe from onediffx.deep_cache import StableDiffusionXLPipeline parser = argparse.ArgumentParser() @@ -42,13 +40,8 @@ OUTPUT_TYPE = "pil" # SDXL base: StableDiffusionXLPipeline -scheduler = EulerDiscreteScheduler.from_pretrained(args.base, subfolder="scheduler") base = StableDiffusionXLPipeline.from_pretrained( - args.base, - scheduler=scheduler, - torch_dtype=torch.float16, - variant=args.variant, - use_safetensors=True, + args.base, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True, ) base.to("cuda") diff --git a/onediff_diffusers_extensions/examples/text_to_image_sd_enterprise.py b/onediff_diffusers_extensions/examples/text_to_image_sd_enterprise.py index 668b44282..6d398aa77 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sd_enterprise.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sd_enterprise.py @@ -2,7 +2,7 @@ import time import argparse -from onediff.infer_compiler import oneflow_compile, CompileOptions +from onediff.infer_compiler import oneflow_compile, compile_options import torch import torch.nn as nn @@ -92,16 +92,15 @@ def parse_args(): pipe.unet, sub_module_name, sub_calibrate_info, False, False, args.bits, ) -options = CompileOptions() -options.oneflow.use_graph = args.graph +compile_options.oneflow.use_graph = args.graph if args.compile_text_encoder: if pipe.text_encoder is not None: - pipe.text_encoder = oneflow_compile(pipe.text_encoder, options=options) + pipe.text_encoder = oneflow_compile(pipe.text_encoder, options=compile_options) if args.compile: - pipe.unet = oneflow_compile(pipe.unet, options=options) - pipe.vae.decoder = oneflow_compile(pipe.vae.decoder, options=options) + pipe.unet = oneflow_compile(pipe.unet, options=compile_options) + pipe.vae.decoder = oneflow_compile(pipe.vae.decoder, options=compile_options) if args.load_graph: print("Loading graphs to avoid compilation...") diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_enterprise.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_enterprise.py index 408d136d0..24162195d 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_enterprise.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_enterprise.py @@ -6,7 +6,7 @@ import torch.nn as nn # oneflow_compile should be imported before importing any diffusers -from onediff.infer_compiler import oneflow_compile, CompileOptions +from onediff.infer_compiler import oneflow_compile, compile_options def parse_args(): @@ -90,18 +90,19 @@ def parse_args(): pipe.unet, sub_module_name, sub_calibrate_info, False, False, args.bits, ) -options = CompileOptions() -options.oneflow.use_graph = args.graph +compile_options.oneflow.use_graph = args.graph if args.compile_text_encoder: if pipe.text_encoder is not None: - pipe.text_encoder = oneflow_compile(pipe.text_encoder, options=options) + pipe.text_encoder = oneflow_compile(pipe.text_encoder, options=compile_options) if pipe.text_encoder_2 is not None: - pipe.text_encoder_2 = oneflow_compile(pipe.text_encoder_2, options=options) + pipe.text_encoder_2 = oneflow_compile( + pipe.text_encoder_2, options=compile_options + ) if args.compile: - pipe.unet = oneflow_compile(pipe.unet, options=options) - pipe.vae.decoder = oneflow_compile(pipe.vae.decoder, options=options) + pipe.unet = oneflow_compile(pipe.unet, options=compile_options) + pipe.vae.decoder = oneflow_compile(pipe.vae.decoder, options=compile_options) if args.load_graph: print("Loading graphs to avoid compilation...") diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py index d3e368411..5f2ffa313 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py @@ -5,7 +5,7 @@ import torch from safetensors.torch import load_file from diffusers import StableDiffusionXLPipeline -from onediffx import compile_pipe, compiler_config, save_pipe, load_pipe +from onediffx import compile_pipe, save_pipe, load_pipe from huggingface_hub import hf_hub_download try: diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index 79d356029..6d675f27c 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -3,9 +3,7 @@ import torch -from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.oneflow import oneflow_compiler_config -from onediff.schedulers import EulerDiscreteScheduler +from onediff.infer_compiler import oneflow_compile, compile_options from diffusers import StableDiffusionXLPipeline # import diffusers @@ -51,28 +49,22 @@ OUTPUT_TYPE = "pil" # SDXL base: StableDiffusionXLPipeline -scheduler = EulerDiscreteScheduler.from_pretrained(args.base, subfolder="scheduler") base = StableDiffusionXLPipeline.from_pretrained( - args.base, - scheduler=scheduler, - torch_dtype=torch.float16, - variant=args.variant, - use_safetensors=True, + args.base, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True, ) base.to("cuda") - -oneflow_compiler_config.mlir_enable_inference_optimization = False +compile_options.oneflow.mlir_enable_inference_optimization = False # Compile unet with oneflow if args.compile_unet: print("Compiling unet with oneflow.") - compiled_unet = oneflow_compile(base.unet) + compiled_unet = oneflow_compile(base.unet, options=compile_options) base.unet = compiled_unet # Compile vae with oneflow if args.compile_vae: print("Compiling vae with oneflow.") - compiled_decoder = oneflow_compile(base.vae.decoder) + compiled_decoder = oneflow_compile(base.vae.decoder, options=compile_options) base.vae.decoder = compiled_decoder # Warmup with run @@ -96,7 +88,6 @@ if str(args.new_base).endswith(".safetensors"): new_base = StableDiffusionXLPipeline.from_single_file( args.new_base, - scheduler=scheduler, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True, @@ -104,7 +95,6 @@ else: new_base = StableDiffusionXLPipeline.from_pretrained( args.new_base, - scheduler=scheduler, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True, diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_save_load.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_save_load.py index 8d07742bc..ae9488221 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_save_load.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_save_load.py @@ -7,7 +7,7 @@ import torch import oneflow as flow -from onediff.infer_compiler import oneflow_compile, CompileOptions +from onediff.infer_compiler import oneflow_compile, compile_options from diffusers import DiffusionPipeline parser = argparse.ArgumentParser() @@ -47,7 +47,6 @@ # Compile unet and vae print("unet and vae is compiled to oneflow.") -compile_options = CompileOptions() compile_options.oneflow.max_cached_graph_size = cmd_args.num_dynamic_input_size base.unet = oneflow_compile(base.unet, options=compile_options) diff --git a/onediff_diffusers_extensions/onediffx/__init__.py b/onediff_diffusers_extensions/onediffx/__init__.py index 28a5efcef..1ada09ca6 100644 --- a/onediff_diffusers_extensions/onediffx/__init__.py +++ b/onediff_diffusers_extensions/onediffx/__init__.py @@ -1,5 +1,6 @@ __version__ = "0.13.0.dev" -from onediff.infer_compiler.oneflow import oneflow_compiler_config as compiler_config + from .compilers.diffusion_pipeline_compiler import compile_pipe, save_pipe, load_pipe +from onediff.infer_compiler import compile_options -__all__ = ["compile_pipe", "compiler_config", "save_pipe", "load_pipe"] +__all__ = ["compile_pipe", "compile_options", "save_pipe", "load_pipe"] diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index a11a622f0..a9c518a0c 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -19,7 +19,7 @@ from onediff_hijack import do_hijack as onediff_do_hijack from onediff.infer_compiler.utils.log_utils import logger -from onediff.infer_compiler.utils.env_var import parse_boolean_from_env +from onediff.infer_compiler.env import parse_boolean_from_env from onediff.optimization.quant_optimizer import ( quantize_model, varify_can_use_quantization, diff --git a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py index da3170341..220349d5d 100644 --- a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py @@ -21,7 +21,7 @@ import diffusers from diffusers.utils import deprecate, logging -from onediff.infer_compiler.utils import parse_boolean_from_env, set_boolean_env_var +from onediff.infer_compiler.env import parse_boolean_from_env, set_boolean_env_var def is_xformers_available(): @@ -392,7 +392,9 @@ def get_attention_scores(self, query, key, attention_mask=None): if self.upcast_attention and parse_boolean_from_env( "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", True ): - set_boolean_env_var("ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", False) + set_boolean_env_var( + "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", False + ) dtype = query.dtype # if self.upcast_attention: # query = query.float() diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index ab6fce98d..50c431400 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -2,13 +2,18 @@ import torch from .deployable_module import DeployableModule +from .env import populate_default_env_var from .options import * +from .options import _GLOBAL_compile_options as compile_options from .with_onediff_compile import compile, oneflow_compile from .with_fx_interpreter import OneFlowInterpreter from .with_fx_graph import fx_node_tranform +populate_default_env_var() + + def oneflow_backend(gm, example_inputs, *args, **kwargs): import oneflow as flow from oneflow.framework.args_tree import ArgsTree diff --git a/src/onediff/infer_compiler/backends/oneflow.py b/src/onediff/infer_compiler/backends/oneflow.py index 58a24567c..d30893439 100644 --- a/src/onediff/infer_compiler/backends/oneflow.py +++ b/src/onediff/infer_compiler/backends/oneflow.py @@ -20,6 +20,7 @@ def compile(torch_module: torch.nn.Module, *, options=None): - 'graph_file' (None) generates a compilation cache file. If the file exists, loading occurs; if not, the compilation result is saved after the first run. - 'graph_file_device' (None) sets the device for the graph file, default None. If set, the compilation result will be converted to the specified device. """ + from ..env import populate_oneflow_env_var from ..transform.custom_transform import set_default_registry from ..oneflow.deployable_module import OneflowDeployableModule from ..oneflow.utils import get_mixed_deployable_module @@ -27,6 +28,7 @@ def compile(torch_module: torch.nn.Module, *, options=None): set_default_registry() options = options if options is not None else CompileOptions() + populate_oneflow_env_var(options.oneflow) def wrap_module(module): if isinstance(module, OneflowDeployableModule): diff --git a/src/onediff/infer_compiler/utils/env_var.py b/src/onediff/infer_compiler/env/utils.py similarity index 100% rename from src/onediff/infer_compiler/utils/env_var.py rename to src/onediff/infer_compiler/env/utils.py diff --git a/src/onediff/infer_compiler/oneflow/__init__.py b/src/onediff/infer_compiler/oneflow/__init__.py index f6f3d9100..e69de29bb 100644 --- a/src/onediff/infer_compiler/oneflow/__init__.py +++ b/src/onediff/infer_compiler/oneflow/__init__.py @@ -1 +0,0 @@ -from .config import OneFlowCompilerConfig, oneflow_compiler_config diff --git a/src/onediff/infer_compiler/oneflow/config.py b/src/onediff/infer_compiler/oneflow/config.py deleted file mode 100644 index 48b3c5afa..000000000 --- a/src/onediff/infer_compiler/oneflow/config.py +++ /dev/null @@ -1,146 +0,0 @@ -import os -from typing import Optional -import dataclasses -from ..utils import ( - parse_boolean_from_env, - set_boolean_env_var, - parse_integer_from_env, - set_integer_env_var, -) - - -def init_default_env(): - # ONEFLOW_RUN_GRAPH_BY_VM must set here to enable nn.Graph init with vm run - os.environ.setdefault("ONEFLOW_RUN_GRAPH_BY_VM", "1") - os.environ.setdefault("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", "1") - - os.environ.setdefault("ONEFLOW_MLIR_CSE", "1") - os.environ.setdefault("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1") - os.environ.setdefault("ONEFLOW_MLIR_ENABLE_ROUND_TRIP", "1") - os.environ.setdefault("ONEFLOW_MLIR_FUSE_FORWARD_OPS", "1") - os.environ.setdefault("ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL", "1") - os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL", "1") - os.environ.setdefault("ONEFLOW_MLIR_PREFER_NHWC", "1") - - os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS", "1") - os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR", "1") - os.environ.setdefault("ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", "1") - os.environ.setdefault("ONEFLOW_KERNEL_GEMM_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", "1") - os.environ.setdefault("ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL", "1") - os.environ.setdefault("ONEFLOW_KERNEL_GEMM_ENABLE_CUTLASS_IMPL", "1") - os.environ.setdefault("ONEFLOW_CONVOLUTION_BIAS_ADD_ACT_FUSION", "1") - # os.environ.setdefault("ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL", "0") - # os.environ.setdefault("ONEFLOW_KERNEL_GLU_ENABLE_Y_GEMM_IMPL", "0") - # os.environ.setdefault("ONEFLOW_KERNEL_GLU_QUANT_ENABLE_DUAL_GEMM_IMPL", "0") - - os.environ.setdefault("ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION", "1") - os.environ.setdefault("ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", "1") - os.environ.setdefault("ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT", "1") - # os.environ.setdefault("ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", "1") - # os.environ.setdefault("ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_SCORE_ACCUMULATION_MAX_M", "-1") - # os.environ.setdefault("ONEFLOW_ATTENTION_ALLOW_QUANTIZATION", "1") - - os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL_QUANT", "1") - - # TODO: enable this will cause the failure of multi resolution warmup - # os.environ.setdefault("ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH", "1") - # os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", "1") - - -@dataclasses.dataclass -class OneFlowCompilerConfig: - run_graph_by_vm: Optional[bool] = None - graph_delay_variable_op_execution: Optional[bool] = None - - mlir_cse: Optional[bool] = None - mlir_enable_inference_optimization: Optional[bool] = None - mlir_enable_round_trip: Optional[bool] = None - mlir_fuse_forward_ops: Optional[bool] = None - mlir_fuse_ops_with_backward_impl: Optional[bool] = None - mlir_group_matmul: Optional[bool] = None - mlir_prefer_nhwc: Optional[bool] = None - mlir_fuse_kernel_launch: Optional[bool] = None - - kernel_enable_cuda_graph: Optional[bool] = None - kernel_enable_fused_conv_bias: Optional[bool] = None - kernel_enable_fused_linear: Optional[bool] = None - kernel_conv_cutlass_impl_enable_tuning_warmup: Optional[bool] = None - kernel_gemm_cutlass_impl_enable_tuning_warmup: Optional[bool] = None - kernel_conv_enable_cutlass_impl: Optional[bool] = None - kernel_gemm_enable_cutlass_impl: Optional[bool] = None - kernel_glu_enable_dual_gemm_impl: Optional[bool] = None - kernel_glu_enable_y_gemm_impl: Optional[bool] = None - kernel_glu_quant_enable_dual_gemm_impl: Optional[bool] = None - - conv_allow_half_precision_accumulation: Optional[bool] = None - matmul_allow_half_precision_accumulation: Optional[bool] = None - linear_embedding_skip_init: Optional[bool] = None - attention_allow_half_precision_accumulation: Optional[bool] = None - attention_allow_half_precision_score_accumulation_max_m: Optional[int] = None - attention_allow_quantization: Optional[bool] = None - - attr2env_var = { - "run_graph_by_vm": "ONEFLOW_RUN_GRAPH_BY_VM", - "graph_delay_variable_op_execution": "ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", - "mlir_cse": "ONEFLOW_MLIR_CSE", - "mlir_enable_inference_optimization": "ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", - "mlir_enable_round_trip": "ONEFLOW_MLIR_ENABLE_ROUND_TRIP", - "mlir_fuse_forward_ops": "ONEFLOW_MLIR_FUSE_FORWARD_OPS", - "mlir_fuse_ops_with_backward_impl": "ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL", - "mlir_group_matmul": "ONEFLOW_MLIR_GROUP_MATMUL", - "mlir_prefer_nhwc": "ONEFLOW_MLIR_PREFER_NHWC", - "mlir_fuse_kernel_launch": "ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH", - "kernel_enable_cuda_graph": "ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", - "kernel_enable_fused_conv_bias": "ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS", - "kernel_enable_fused_linear": "ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR", - "kernel_conv_cutlass_impl_enable_tuning_warmup": "ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", - "kernel_gemm_cutlass_impl_enable_tuning_warmup": "ONEFLOW_KERNEL_GEMM_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", - "kernel_conv_enable_cutlass_impl": "ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL", - "kernel_gemm_enable_cutlass_impl": "ONEFLOW_KERNEL_GEMM_ENABLE_CUTLASS_IMPL", - "kernel_glu_enable_dual_gemm_impl": "ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL", - "kernel_glu_enable_y_gemm_impl": "ONEFLOW_KERNEL_GLU_ENABLE_Y_GEMM_IMPL", - "kernel_glu_quant_enable_dual_gemm_impl": "ONEFLOW_KERNEL_GLU_QUANT_ENABLE_DUAL_GEMM_IMPL", - "conv_allow_half_precision_accumulation": "ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION", - "matmul_allow_half_precision_accumulation": "ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", - "linear_embedding_skip_init": "ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT", - "attention_allow_half_precision_accumulation": "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", - "attention_allow_half_precision_score_accumulation_max_m": "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_SCORE_ACCUMULATION_MAX_M", - } - - def __post_init__(self): - fields = dataclasses.fields(self) - fields = {field.name: field for field in fields} - for name in self.attr2env_var: - if fields[name].type in (bool, Optional[bool]): - super().__setattr__( - name, parse_boolean_from_env(self.attr2env_var[name]) - ) - elif fields[name].type in (int, Optional[int]): - super().__setattr__( - name, parse_integer_from_env(self.attr2env_var[name]) - ) - else: - raise ValueError( - f"Unsupported type {dataclasses.fields(self)[name].type}" - ) - - super().__setattr__("_initialized", True) - - def __setattr__(self, name, value): - super().__setattr__(name, value) - if getattr(self, "_initialized", False) and name in self.attr2env_var: - fields = dataclasses.fields(self) - fields = dataclasses.fields(self) - fields = {field.name: field for field in fields} - if fields[name].type in (bool, Optional[bool]): - set_boolean_env_var(self.attr2env_var[name], value) - elif fields[name].type in (int, Optional[int]): - set_integer_env_var(self.attr2env_var[name], value) - else: - raise ValueError( - f"Unsupported type {dataclasses.fields(self)[name].type}" - ) - - -init_default_env() -oneflow_compiler_config = OneFlowCompilerConfig() diff --git a/src/onediff/infer_compiler/options.py b/src/onediff/infer_compiler/options.py index 1061d4f7c..a164958c1 100644 --- a/src/onediff/infer_compiler/options.py +++ b/src/onediff/infer_compiler/options.py @@ -11,6 +11,36 @@ class OneflowCompileOptions: graph_file: str = None graph_file_device: torch.device = None + # Optimization related environment variables + run_graph_by_vm: bool = None + graph_delay_variable_op_execution: bool = None + + conv_allow_half_precision_accumulation: bool = None + matmul_allow_half_precision_accumulation: bool = None + attention_allow_half_precision_accumulation: bool = None + attention_allow_half_precision_score_accumulation_max_m: int = None + attention_allow_quantization: bool = None + + mlir_cse: bool = None + mlir_enable_inference_optimization: bool = None + mlir_enable_round_trip: bool = None + mlir_fuse_forward_ops: bool = None + mlir_fuse_ops_with_backward_impl: bool = None + mlir_group_matmul: bool = None + mlir_prefer_nhwc: bool = None + mlir_fuse_kernel_launch: bool = None + + kernel_enable_cuda_graph: bool = None + kernel_enable_fused_conv_bias: bool = None + kernel_enable_fused_linear: bool = None + kernel_conv_cutlass_impl_enable_tuning_warmup: bool = None + kernel_gemm_cutlass_impl_enable_tuning_warmup: bool = None + kernel_conv_enable_cutlass_impl: bool = None + kernel_gemm_enable_cutlass_impl: bool = None + kernel_glu_enable_dual_gemm_impl: bool = None + kernel_glu_enable_y_gemm_impl: bool = None + kernel_glu_quant_enable_dual_gemm_impl: bool = None + @dataclasses.dataclass class NexfortInductorCompileOptions: @@ -55,3 +85,7 @@ def __init__(self, dynamic=True, oneflow=None, nexfort=None): self.dynamic = dynamic self.oneflow = oneflow if oneflow is not None else OneflowCompileOptions() self.nexfort = nexfort if nexfort is not None else NexfortCompileOptions() + + +# a global default compile options +_GLOBAL_compile_options = CompileOptions() diff --git a/src/onediff/infer_compiler/utils/__init__.py b/src/onediff/infer_compiler/utils/__init__.py index 1f1c4855c..21eb48d4e 100644 --- a/src/onediff/infer_compiler/utils/__init__.py +++ b/src/onediff/infer_compiler/utils/__init__.py @@ -1,10 +1,4 @@ from .oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled -from .env_var import ( - parse_boolean_from_env, - set_boolean_env_var, - parse_integer_from_env, - set_integer_env_var, -) from .model_inplace_assign import TensorInplaceAssign from .version_util import ( get_support_message, diff --git a/src/onediff/optimization/attention_processor.py b/src/onediff/optimization/attention_processor.py index 22650ab62..188e66435 100644 --- a/src/onediff/optimization/attention_processor.py +++ b/src/onediff/optimization/attention_processor.py @@ -84,7 +84,7 @@ def __call__( hidden_states = flow.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) else: - from ..infer_compiler.utils import ( + from ..infer_compiler.env import ( parse_boolean_from_env, set_boolean_env_var, ) From faa4ed65cd87916d6a248324a1c52a2bea95a313 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 9 Apr 2024 17:03:01 +0800 Subject: [PATCH 07/13] fix --- onediff_comfy_nodes/_nodes.py | 2 +- .../utils/deep_cache_speedup.py | 2 +- .../scripts/onediff.py | 2 +- .../attention_processor_oflow.py | 2 +- src/onediff/infer_compiler/__init__.py | 2 +- .../infer_compiler/backends/oneflow.py | 2 +- .../infer_compiler/env_var/__init__.py | 2 + .../env_var/populate_env_var.py | 109 ++++++++++++++++++ .../infer_compiler/{env => env_var}/utils.py | 0 .../optimization/attention_processor.py | 2 +- 10 files changed, 118 insertions(+), 7 deletions(-) create mode 100644 src/onediff/infer_compiler/env_var/__init__.py create mode 100644 src/onediff/infer_compiler/env_var/populate_env_var.py rename src/onediff/infer_compiler/{env => env_var}/utils.py (100%) diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index 47aa3edda..5d97bda6c 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -1,7 +1,7 @@ from functools import partial from onediff.infer_compiler.transform import torch2oflow from ._config import _USE_UNET_INT8, ONEDIFF_QUANTIZED_OPTIMIZED_MODELS -from onediff.infer_compiler.env import set_boolean_env_var +from onediff.infer_compiler.env_var import set_boolean_env_var from onediff.optimization.quant_optimizer import quantize_model from onediff.infer_compiler import oneflow_compile, CompileOptions from onediff.infer_compiler.deployable_module import DeployableModule diff --git a/onediff_comfy_nodes/utils/deep_cache_speedup.py b/onediff_comfy_nodes/utils/deep_cache_speedup.py index fd0c53d25..98ad25c47 100644 --- a/onediff_comfy_nodes/utils/deep_cache_speedup.py +++ b/onediff_comfy_nodes/utils/deep_cache_speedup.py @@ -2,7 +2,7 @@ from comfy import model_management from comfy.model_base import SVD_img2vid -from onediff.infer_compiler.env import set_boolean_env_var +from onediff.infer_compiler.env_var import set_boolean_env_var from .model_patcher import OneFlowDeepCacheSpeedUpModelPatcher diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index a9c518a0c..9ba436a22 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -19,7 +19,7 @@ from onediff_hijack import do_hijack as onediff_do_hijack from onediff.infer_compiler.utils.log_utils import logger -from onediff.infer_compiler.env import parse_boolean_from_env +from onediff.infer_compiler.env_var import parse_boolean_from_env from onediff.optimization.quant_optimizer import ( quantize_model, varify_can_use_quantization, diff --git a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py index 220349d5d..1322a9812 100644 --- a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py @@ -21,7 +21,7 @@ import diffusers from diffusers.utils import deprecate, logging -from onediff.infer_compiler.env import parse_boolean_from_env, set_boolean_env_var +from onediff.infer_compiler.env_var import parse_boolean_from_env, set_boolean_env_var def is_xformers_available(): diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index 50c431400..e65c6c7bc 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -2,7 +2,7 @@ import torch from .deployable_module import DeployableModule -from .env import populate_default_env_var +from .env_var import populate_default_env_var from .options import * from .options import _GLOBAL_compile_options as compile_options from .with_onediff_compile import compile, oneflow_compile diff --git a/src/onediff/infer_compiler/backends/oneflow.py b/src/onediff/infer_compiler/backends/oneflow.py index d30893439..ac4607dc0 100644 --- a/src/onediff/infer_compiler/backends/oneflow.py +++ b/src/onediff/infer_compiler/backends/oneflow.py @@ -20,7 +20,7 @@ def compile(torch_module: torch.nn.Module, *, options=None): - 'graph_file' (None) generates a compilation cache file. If the file exists, loading occurs; if not, the compilation result is saved after the first run. - 'graph_file_device' (None) sets the device for the graph file, default None. If set, the compilation result will be converted to the specified device. """ - from ..env import populate_oneflow_env_var + from ..env_var import populate_oneflow_env_var from ..transform.custom_transform import set_default_registry from ..oneflow.deployable_module import OneflowDeployableModule from ..oneflow.utils import get_mixed_deployable_module diff --git a/src/onediff/infer_compiler/env_var/__init__.py b/src/onediff/infer_compiler/env_var/__init__.py new file mode 100644 index 000000000..6207844f8 --- /dev/null +++ b/src/onediff/infer_compiler/env_var/__init__.py @@ -0,0 +1,2 @@ +from .utils import * +from .populate_env_var import * diff --git a/src/onediff/infer_compiler/env_var/populate_env_var.py b/src/onediff/infer_compiler/env_var/populate_env_var.py new file mode 100644 index 000000000..2e674c886 --- /dev/null +++ b/src/onediff/infer_compiler/env_var/populate_env_var.py @@ -0,0 +1,109 @@ +import dataclasses +import os +from typing import Optional + + +def _populate_env_var(field2env_var, options): + from .utils import ( + parse_boolean_from_env, + set_boolean_env_var, + parse_integer_from_env, + set_integer_env_var, + ) + + for field in dataclasses.fields(options): + field_name = field.name + field_value = getattr(options, field_name) + if field_value is None or field_name not in field2env_var: + continue + env_var = field2env_var[field_name] + set_env_var = None + if field.type in (bool, Optional[bool]): + set_env_var = set_boolean_env_var + elif field.type in (int, Optional[int]): + set_env_var = set_integer_env_var + else: + raise ValueError(f"Unsupported type {field.type}") + set_env_var(env_var, field_value) + + +def populate_oneflow_env_var(options): + field2env_var = { + "run_graph_by_vm": "ONEFLOW_RUN_GRAPH_BY_VM", + "graph_delay_variable_op_execution": "ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", + "mlir_cse": "ONEFLOW_MLIR_CSE", + "mlir_enable_inference_optimization": "ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", + "mlir_enable_round_trip": "ONEFLOW_MLIR_ENABLE_ROUND_TRIP", + "mlir_fuse_forward_ops": "ONEFLOW_MLIR_FUSE_FORWARD_OPS", + "mlir_fuse_ops_with_backward_impl": "ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL", + "mlir_group_matmul": "ONEFLOW_MLIR_GROUP_MATMUL", + "mlir_prefer_nhwc": "ONEFLOW_MLIR_PREFER_NHWC", + "mlir_fuse_kernel_launch": "ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH", + "kernel_enable_cuda_graph": "ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", + "kernel_enable_fused_conv_bias": "ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS", + "kernel_enable_fused_linear": "ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR", + "kernel_conv_cutlass_impl_enable_tuning_warmup": "ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", + "kernel_gemm_cutlass_impl_enable_tuning_warmup": "ONEFLOW_KERNEL_GEMM_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", + "kernel_conv_enable_cutlass_impl": "ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL", + "kernel_gemm_enable_cutlass_impl": "ONEFLOW_KERNEL_GEMM_ENABLE_CUTLASS_IMPL", + "kernel_glu_enable_dual_gemm_impl": "ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL", + "kernel_glu_enable_y_gemm_impl": "ONEFLOW_KERNEL_GLU_ENABLE_Y_GEMM_IMPL", + "kernel_glu_quant_enable_dual_gemm_impl": "ONEFLOW_KERNEL_GLU_QUANT_ENABLE_DUAL_GEMM_IMPL", + "conv_allow_half_precision_accumulation": "ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION", + "matmul_allow_half_precision_accumulation": "ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", + "attention_allow_half_precision_accumulation": "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", + "attention_allow_half_precision_score_accumulation_max_m": "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_SCORE_ACCUMULATION_MAX_M", + } + _populate_env_var(field2env_var, options) + + +def populate_oneflow_default_env_var(): + # ONEFLOW_RUN_GRAPH_BY_VM must set here to enable nn.Graph init with vm run + os.environ.setdefault("ONEFLOW_RUN_GRAPH_BY_VM", "1") + os.environ.setdefault("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", "1") + + os.environ.setdefault("ONEFLOW_MLIR_CSE", "1") + os.environ.setdefault("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1") + os.environ.setdefault("ONEFLOW_MLIR_ENABLE_ROUND_TRIP", "1") + os.environ.setdefault("ONEFLOW_MLIR_FUSE_FORWARD_OPS", "1") + os.environ.setdefault("ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL", "1") + os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL", "1") + os.environ.setdefault("ONEFLOW_MLIR_PREFER_NHWC", "1") + + os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS", "1") + os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR", "1") + os.environ.setdefault("ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", "1") + os.environ.setdefault("ONEFLOW_KERNEL_GEMM_CUTLASS_IMPL_ENABLE_TUNING_WARMUP", "1") + os.environ.setdefault("ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL", "1") + os.environ.setdefault("ONEFLOW_KERNEL_GEMM_ENABLE_CUTLASS_IMPL", "1") + os.environ.setdefault("ONEFLOW_CONVOLUTION_BIAS_ADD_ACT_FUSION", "1") + # os.environ.setdefault("ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL", "0") + # os.environ.setdefault("ONEFLOW_KERNEL_GLU_ENABLE_Y_GEMM_IMPL", "0") + # os.environ.setdefault("ONEFLOW_KERNEL_GLU_QUANT_ENABLE_DUAL_GEMM_IMPL", "0") + + os.environ.setdefault("ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION", "1") + os.environ.setdefault("ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", "1") + os.environ.setdefault("ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT", "1") + # os.environ.setdefault("ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", "1") + # os.environ.setdefault("ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_SCORE_ACCUMULATION_MAX_M", "-1") + # os.environ.setdefault("ONEFLOW_ATTENTION_ALLOW_QUANTIZATION", "1") + + os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL_QUANT", "1") + + # TODO: enable this will cause the failure of multi resolution warmup + # os.environ.setdefault("ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH", "1") + # os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", "1") + + +def populate_nexfort_env_var(options): + field2env_var = {} + _populate_env_var(field2env_var, options) + + +def populate_nexfort_default_env_var(): + pass + + +def populate_default_env_var(): + populate_oneflow_default_env_var() + populate_nexfort_default_env_var() diff --git a/src/onediff/infer_compiler/env/utils.py b/src/onediff/infer_compiler/env_var/utils.py similarity index 100% rename from src/onediff/infer_compiler/env/utils.py rename to src/onediff/infer_compiler/env_var/utils.py diff --git a/src/onediff/optimization/attention_processor.py b/src/onediff/optimization/attention_processor.py index 188e66435..0c3db7d9d 100644 --- a/src/onediff/optimization/attention_processor.py +++ b/src/onediff/optimization/attention_processor.py @@ -84,7 +84,7 @@ def __call__( hidden_states = flow.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) else: - from ..infer_compiler.env import ( + from ..infer_compiler.env_var import ( parse_boolean_from_env, set_boolean_env_var, ) From 9f018c7b551f80ba82e01753596d007662663ac5 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 9 Apr 2024 18:37:08 +0800 Subject: [PATCH 08/13] early apply patch --- src/onediff/infer_compiler/import_tools/dyn_mock_mod.py | 1 + src/onediff/infer_compiler/oneflow/utils.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onediff/infer_compiler/import_tools/dyn_mock_mod.py b/src/onediff/infer_compiler/import_tools/dyn_mock_mod.py index 1a0ed4716..e7723f78d 100644 --- a/src/onediff/infer_compiler/import_tools/dyn_mock_mod.py +++ b/src/onediff/infer_compiler/import_tools/dyn_mock_mod.py @@ -11,6 +11,7 @@ from oneflow.mock_torch.mock_importer import _importer from .import_module_utils import import_module_from_path from ..utils.log_utils import logger +from ..utils.patch_for_compiler import * __all__ = ["DynamicMockModule"] diff --git a/src/onediff/infer_compiler/oneflow/utils.py b/src/onediff/infer_compiler/oneflow/utils.py index dfdf97a6e..4a5e899aa 100644 --- a/src/onediff/infer_compiler/oneflow/utils.py +++ b/src/onediff/infer_compiler/oneflow/utils.py @@ -3,7 +3,6 @@ from ..transform.builtin_transform import torch2oflow from ..transform.manager import transform_mgr from ..utils.log_utils import logger -from ..utils.patch_for_compiler import * from .dual_module import DualModule From c974d3a74c8676eca12f3271f6bc6f87e8e1f4b7 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 15 Apr 2024 10:51:18 +0800 Subject: [PATCH 09/13] refine --- .../oneflow/hijack_animatediff/sampling.py | 5 +- .../hijack_ipadapter_plus/IPAdapterPlus.py | 5 +- .../modules/oneflow/utils/__init__.py | 14 ++++-- .../oneflow/utils/deep_cache_speedup.py | 7 +-- .../modules/oneflow/utils/model_patcher.py | 14 ++++-- .../compilers/diffusion_pipeline_compiler.py | 3 +- .../onediffx/lora/unet.py | 2 +- onediff_sd_webui_extensions/onediff_lora.py | 2 +- .../attention_processor_oflow.py | 2 +- src/onediff/infer_compiler/__init__.py | 13 ++--- .../infer_compiler/backends/oneflow.py | 15 +++--- src/onediff/infer_compiler/core/__init__.py | 4 ++ .../{ => core}/deployable_module.py | 0 .../{ => core}/with_fx_graph.py | 0 .../{ => core}/with_fx_interpreter.py | 0 .../{ => core}/with_onediff_compile.py | 0 .../infer_compiler/env_var/__init__.py | 2 - src/onediff/infer_compiler/env_var/utils.py | 31 ------------ .../oneflow/deployable_module.py | 4 +- src/onediff/infer_compiler/utils/__init__.py | 12 +++++ .../populate_env_var.py => utils/env_var.py} | 49 +++++++++++++++---- .../utils/graph_management_utils.py | 2 +- .../utils/model_inplace_assign.py | 2 +- .../infer_compiler/{ => utils}/options.py | 0 .../infer_compiler/utils/param_utils.py | 10 ++-- .../optimization/attention_processor.py | 2 +- 26 files changed, 107 insertions(+), 93 deletions(-) create mode 100644 src/onediff/infer_compiler/core/__init__.py rename src/onediff/infer_compiler/{ => core}/deployable_module.py (100%) rename src/onediff/infer_compiler/{ => core}/with_fx_graph.py (100%) rename src/onediff/infer_compiler/{ => core}/with_fx_interpreter.py (100%) rename src/onediff/infer_compiler/{ => core}/with_onediff_compile.py (100%) delete mode 100644 src/onediff/infer_compiler/env_var/__init__.py delete mode 100644 src/onediff/infer_compiler/env_var/utils.py rename src/onediff/infer_compiler/{env_var/populate_env_var.py => utils/env_var.py} (82%) rename src/onediff/infer_compiler/{ => utils}/options.py (100%) diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/sampling.py b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/sampling.py index bc9c14f9d..720c5ab2a 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/sampling.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/sampling.py @@ -1,12 +1,11 @@ # /ComfyUI/custom_nodes/ComfyUI-AnimateDiff-Evolved/animatediff/sampling.py import oneflow as flow from einops import rearrange -from onediff.infer_compiler.deployable_module import DeployableModule +from onediff.infer_compiler import DeployableModule from onediff.infer_compiler.transform import register from oneflow.nn.functional import group_norm -from ._config import (animatediff_hijacker, animatediff_of, animatediff_pt, - comfy_of) +from ._config import animatediff_hijacker, animatediff_of, animatediff_pt, comfy_of FunctionInjectionHolder = animatediff_pt.animatediff.sampling.FunctionInjectionHolder diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/IPAdapterPlus.py b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/IPAdapterPlus.py index 387e258a2..1c2f3e50c 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/IPAdapterPlus.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/IPAdapterPlus.py @@ -2,11 +2,10 @@ import os from pathlib import Path -from onediff.infer_compiler.deployable_module import DeployableModule +from onediff.infer_compiler import DeployableModule from onediff.infer_compiler.transform import torch2oflow -from ._config import (ipadapter_plus_hijacker, ipadapter_plus_of, - ipadapter_plus_pt) +from ._config import ipadapter_plus_hijacker, ipadapter_plus_of, ipadapter_plus_pt from .CrossAttentionPatch import CrossAttentionPatch as CrossAttentionPatch_OF set_model_patch_replace_fn_pt = ipadapter_plus_pt.IPAdapterPlus.set_model_patch_replace diff --git a/onediff_comfy_nodes/modules/oneflow/utils/__init__.py b/onediff_comfy_nodes/modules/oneflow/utils/__init__.py index 9efc8fb0b..528694ab9 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/__init__.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/__init__.py @@ -2,13 +2,17 @@ import re import time -from onediff.infer_compiler.deployable_module import DeployableModule +from onediff.infer_compiler import DeployableModule -from .model_patcher import (OneFlowDeepCacheSpeedUpModelPatcher, - OneFlowSpeedUpModelPatcher) +from .model_patcher import ( + OneFlowDeepCacheSpeedUpModelPatcher, + OneFlowSpeedUpModelPatcher, +) from .onediff_load_utils import onediff_load_quant_checkpoint_advanced -from .onediff_quant_utils import (quantize_and_save_model, - replace_module_with_quantizable_module) +from .onediff_quant_utils import ( + quantize_and_save_model, + replace_module_with_quantizable_module, +) OUTPUT_FOLDER = os.path.join( os.path.dirname(os.path.realpath(__file__)), "..", "graphs" diff --git a/onediff_comfy_nodes/modules/oneflow/utils/deep_cache_speedup.py b/onediff_comfy_nodes/modules/oneflow/utils/deep_cache_speedup.py index 3de4c5d66..a69b7c57a 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/deep_cache_speedup.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/deep_cache_speedup.py @@ -2,7 +2,6 @@ from comfy import model_management from comfy.model_base import SVD_img2vid from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.env_var import set_boolean_env_var from register_comfy import DeepCacheUNet, FastDeepCacheUNet from .model_patcher import OneFlowDeepCacheSpeedUpModelPatcher @@ -21,7 +20,7 @@ def deep_cache_speedup( gen_compile_options=None, use_oneflow_deepcache_speedup_modelpatcher=True, ): - + offload_device = model_management.unet_offload_device() if use_oneflow_deepcache_speedup_modelpatcher: model_patcher = OneFlowDeepCacheSpeedUpModelPatcher( @@ -41,9 +40,7 @@ def deep_cache_speedup( model_patcher.fast_deep_cache_unet = FastDeepCacheUNet( model_patcher.model.diffusion_model, cache_layer_id, cache_block_id ) - model_patcher.deep_cache_unet = oneflow_compile( - model_patcher.deep_cache_unet - ) + model_patcher.deep_cache_unet = oneflow_compile(model_patcher.deep_cache_unet) model_patcher.fast_deep_cache_unet = oneflow_compile( model_patcher.fast_deep_cache_unet ) diff --git a/onediff_comfy_nodes/modules/oneflow/utils/model_patcher.py b/onediff_comfy_nodes/modules/oneflow/utils/model_patcher.py index 79e2af709..be22c7e64 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/model_patcher.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/model_patcher.py @@ -32,8 +32,11 @@ def __init__( graph_path=None, graph_device=None, ): - from onediff.infer_compiler import CompileOptions, oneflow_compile - from onediff.infer_compiler.deployable_module import DeployableModule + from onediff.infer_compiler import ( + CompileOptions, + oneflow_compile, + DeployableModule, + ) self.weight_inplace_update = weight_inplace_update self.object_patches = {} @@ -502,8 +505,11 @@ def __init__( use_graph=None, gen_compile_options=None, ): - from onediff.infer_compiler import CompileOptions, oneflow_compile - from onediff.infer_compiler.deployable_module import DeployableModule + from onediff.infer_compiler import ( + CompileOptions, + oneflow_compile, + DeployableModule, + ) self.weight_inplace_update = weight_inplace_update self.object_patches = {} diff --git a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py index 2986403f4..3307991e3 100644 --- a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py +++ b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py @@ -1,7 +1,6 @@ import os import torch -from onediff.infer_compiler import compile -from onediff.infer_compiler.deployable_module import DeployableModule +from onediff.infer_compiler import compile, DeployableModule from onediff.infer_compiler.utils.log_utils import logger diff --git a/onediff_diffusers_extensions/onediffx/lora/unet.py b/onediff_diffusers_extensions/onediffx/lora/unet.py index 59d308b00..cca033aa1 100644 --- a/onediff_diffusers_extensions/onediffx/lora/unet.py +++ b/onediff_diffusers_extensions/onediffx/lora/unet.py @@ -3,7 +3,7 @@ from collections import defaultdict import torch -from onediff.infer_compiler.deployable_module import DeployableModule +from onediff.infer_compiler import DeployableModule from onediff.infer_compiler.utils.log_utils import logger from diffusers.models.lora import ( LoRACompatibleConv, diff --git a/onediff_sd_webui_extensions/onediff_lora.py b/onediff_sd_webui_extensions/onediff_lora.py index dd233660c..77066873f 100644 --- a/onediff_sd_webui_extensions/onediff_lora.py +++ b/onediff_sd_webui_extensions/onediff_lora.py @@ -1,5 +1,5 @@ import torch -from onediff.infer_compiler.deployable_module import DeployableModule +from onediff.infer_compiler import DeployableModule from onediff.infer_compiler.utils.param_utils import update_graph_related_tensor diff --git a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py index 1322a9812..6406e01d5 100644 --- a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py @@ -21,7 +21,7 @@ import diffusers from diffusers.utils import deprecate, logging -from onediff.infer_compiler.env_var import parse_boolean_from_env, set_boolean_env_var +from onediff.infer_compiler.utils import parse_boolean_from_env, set_boolean_env_var def is_xformers_available(): diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index e65c6c7bc..7c310292a 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -1,17 +1,12 @@ import os import torch -from .deployable_module import DeployableModule -from .env_var import populate_default_env_var -from .options import * -from .options import _GLOBAL_compile_options as compile_options -from .with_onediff_compile import compile, oneflow_compile +from .core import * +from .utils import set_default_env_vars +from .utils.options import _GLOBAL_compile_options as compile_options -from .with_fx_interpreter import OneFlowInterpreter -from .with_fx_graph import fx_node_tranform - -populate_default_env_var() +set_default_env_vars() def oneflow_backend(gm, example_inputs, *args, **kwargs): diff --git a/src/onediff/infer_compiler/backends/oneflow.py b/src/onediff/infer_compiler/backends/oneflow.py index 735725009..71b010950 100644 --- a/src/onediff/infer_compiler/backends/oneflow.py +++ b/src/onediff/infer_compiler/backends/oneflow.py @@ -19,18 +19,21 @@ def compile(torch_module: torch.nn.Module, *, options=None): - 'graph_file' (None) generates a compilation cache file. If the file exists, loading occurs; if not, the compilation result is saved after the first run. - 'graph_file_device' (None) sets the device for the graph file, default None. If set, the compilation result will be converted to the specified device. """ - from ..env_var import populate_oneflow_env_var - from ..transform.custom_transform import set_default_registry from ..oneflow.deployable_module import OneflowDeployableModule from ..oneflow.utils import get_mixed_deployable_module - from ..options import CompileOptions - from ..utils.param_utils import state_update_hook, init_state_update_attr, forward_pre_check_and_update_state_hook, forward_generate_constant_folding_info_hook - + from ..transform.custom_transform import set_default_registry + from ..utils import CompileOptions, set_oneflow_env_vars + from ..utils.param_utils import ( + state_update_hook, + init_state_update_attr, + forward_pre_check_and_update_state_hook, + forward_generate_constant_folding_info_hook, + ) set_default_registry() options = options if options is not None else CompileOptions() - populate_oneflow_env_var(options.oneflow) + set_oneflow_env_vars(options.oneflow) def wrap_module(module): if isinstance(module, OneflowDeployableModule): diff --git a/src/onediff/infer_compiler/core/__init__.py b/src/onediff/infer_compiler/core/__init__.py new file mode 100644 index 000000000..734fe9adb --- /dev/null +++ b/src/onediff/infer_compiler/core/__init__.py @@ -0,0 +1,4 @@ +from .deployable_module import DeployableModule +from .with_onediff_compile import compile, oneflow_compile +from .with_fx_interpreter import OneFlowInterpreter +from .with_fx_graph import fx_node_tranform diff --git a/src/onediff/infer_compiler/deployable_module.py b/src/onediff/infer_compiler/core/deployable_module.py similarity index 100% rename from src/onediff/infer_compiler/deployable_module.py rename to src/onediff/infer_compiler/core/deployable_module.py diff --git a/src/onediff/infer_compiler/with_fx_graph.py b/src/onediff/infer_compiler/core/with_fx_graph.py similarity index 100% rename from src/onediff/infer_compiler/with_fx_graph.py rename to src/onediff/infer_compiler/core/with_fx_graph.py diff --git a/src/onediff/infer_compiler/with_fx_interpreter.py b/src/onediff/infer_compiler/core/with_fx_interpreter.py similarity index 100% rename from src/onediff/infer_compiler/with_fx_interpreter.py rename to src/onediff/infer_compiler/core/with_fx_interpreter.py diff --git a/src/onediff/infer_compiler/with_onediff_compile.py b/src/onediff/infer_compiler/core/with_onediff_compile.py similarity index 100% rename from src/onediff/infer_compiler/with_onediff_compile.py rename to src/onediff/infer_compiler/core/with_onediff_compile.py diff --git a/src/onediff/infer_compiler/env_var/__init__.py b/src/onediff/infer_compiler/env_var/__init__.py deleted file mode 100644 index 6207844f8..000000000 --- a/src/onediff/infer_compiler/env_var/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .utils import * -from .populate_env_var import * diff --git a/src/onediff/infer_compiler/env_var/utils.py b/src/onediff/infer_compiler/env_var/utils.py deleted file mode 100644 index 23b6e749b..000000000 --- a/src/onediff/infer_compiler/env_var/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -from typing import Optional - - -def parse_boolean_from_env(env_var, default_value=None): - env_var = os.getenv(env_var) - if env_var is None: - return default_value - env_var = env_var.lower() - return env_var in ("1", "true", "yes", "on", "y") - - -def set_boolean_env_var(env_var: str, val: Optional[bool]): - if val is None: - os.environ.pop(env_var, None) - else: - os.environ[env_var] = "1" if val else "0" - - -def parse_integer_from_env(env_var, default_value=None): - env_var = os.getenv(env_var) - if env_var is None: - return default_value - return int(env_var) - - -def set_integer_env_var(env_var: str, val: Optional[int]): - if val is None: - os.environ.pop(env_var, None) - else: - os.environ[env_var] = str(int(val)) diff --git a/src/onediff/infer_compiler/oneflow/deployable_module.py b/src/onediff/infer_compiler/oneflow/deployable_module.py index 7c0c28bd8..49ee5629d 100644 --- a/src/onediff/infer_compiler/oneflow/deployable_module.py +++ b/src/onediff/infer_compiler/oneflow/deployable_module.py @@ -2,6 +2,7 @@ import torch import oneflow as flow +from ..core.deployable_module import DeployableModule from ..transform.manager import transform_mgr from ..utils.oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled from ..utils.args_tree_util import input_output_processor @@ -9,8 +10,7 @@ from ..utils.param_utils import parse_device, check_device from ..utils.graph_management_utils import graph_file_management from ..utils.online_quantization_utils import quantize_and_deploy_wrapper -from ..options import OneflowCompileOptions -from ..deployable_module import DeployableModule +from ..utils.options import OneflowCompileOptions from .utils import handle_deployable_exception, get_mixed_dual_module, get_oneflow_graph diff --git a/src/onediff/infer_compiler/utils/__init__.py b/src/onediff/infer_compiler/utils/__init__.py index 21eb48d4e..076b41bcd 100644 --- a/src/onediff/infer_compiler/utils/__init__.py +++ b/src/onediff/infer_compiler/utils/__init__.py @@ -1,7 +1,19 @@ from .oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled +from .env_var import ( + parse_boolean_from_env, + set_boolean_env_var, + parse_integer_from_env, + set_integer_env_var, + set_oneflow_env_vars, + set_oneflow_default_env_vars, + set_nexfort_env_vars, + set_nexfort_default_env_vars, + set_default_env_vars, +) from .model_inplace_assign import TensorInplaceAssign from .version_util import ( get_support_message, is_quantization_enabled, is_community_version, ) +from .options import * diff --git a/src/onediff/infer_compiler/env_var/populate_env_var.py b/src/onediff/infer_compiler/utils/env_var.py similarity index 82% rename from src/onediff/infer_compiler/env_var/populate_env_var.py rename to src/onediff/infer_compiler/utils/env_var.py index 2e674c886..b385f5181 100644 --- a/src/onediff/infer_compiler/env_var/populate_env_var.py +++ b/src/onediff/infer_compiler/utils/env_var.py @@ -3,7 +3,36 @@ from typing import Optional -def _populate_env_var(field2env_var, options): +def parse_boolean_from_env(env_var, default_value=None): + env_var = os.getenv(env_var) + if env_var is None: + return default_value + env_var = env_var.lower() + return env_var in ("1", "true", "yes", "on", "y") + + +def set_boolean_env_var(env_var: str, val: Optional[bool]): + if val is None: + os.environ.pop(env_var, None) + else: + os.environ[env_var] = "1" if val else "0" + + +def parse_integer_from_env(env_var, default_value=None): + env_var = os.getenv(env_var) + if env_var is None: + return default_value + return int(env_var) + + +def set_integer_env_var(env_var: str, val: Optional[int]): + if val is None: + os.environ.pop(env_var, None) + else: + os.environ[env_var] = str(int(val)) + + +def _set_env_vars(field2env_var, options): from .utils import ( parse_boolean_from_env, set_boolean_env_var, @@ -27,7 +56,7 @@ def _populate_env_var(field2env_var, options): set_env_var(env_var, field_value) -def populate_oneflow_env_var(options): +def set_oneflow_env_vars(options): field2env_var = { "run_graph_by_vm": "ONEFLOW_RUN_GRAPH_BY_VM", "graph_delay_variable_op_execution": "ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", @@ -54,10 +83,10 @@ def populate_oneflow_env_var(options): "attention_allow_half_precision_accumulation": "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", "attention_allow_half_precision_score_accumulation_max_m": "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_SCORE_ACCUMULATION_MAX_M", } - _populate_env_var(field2env_var, options) + _set_env_vars(field2env_var, options) -def populate_oneflow_default_env_var(): +def set_oneflow_default_env_vars(): # ONEFLOW_RUN_GRAPH_BY_VM must set here to enable nn.Graph init with vm run os.environ.setdefault("ONEFLOW_RUN_GRAPH_BY_VM", "1") os.environ.setdefault("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", "1") @@ -95,15 +124,15 @@ def populate_oneflow_default_env_var(): # os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", "1") -def populate_nexfort_env_var(options): +def set_nexfort_env_vars(options): field2env_var = {} - _populate_env_var(field2env_var, options) + _set_env_vars(field2env_var, options) -def populate_nexfort_default_env_var(): +def set_nexfort_default_env_vars(): pass -def populate_default_env_var(): - populate_oneflow_default_env_var() - populate_nexfort_default_env_var() +def set_default_env_vars(): + set_oneflow_default_env_vars() + set_nexfort_default_env_vars() diff --git a/src/onediff/infer_compiler/utils/graph_management_utils.py b/src/onediff/infer_compiler/utils/graph_management_utils.py index 607709eeb..14515f137 100644 --- a/src/onediff/infer_compiler/utils/graph_management_utils.py +++ b/src/onediff/infer_compiler/utils/graph_management_utils.py @@ -5,11 +5,11 @@ from pathlib import Path from functools import wraps from oneflow.framework.args_tree import ArgsTree -from ..options import OneflowCompileOptions from ..transform.builtin_transform import torch2oflow from ..transform.manager import transform_mgr from .log_utils import logger from .cost_util import cost_time +from .options import OneflowCompileOptions def calculate_model_hash(model): diff --git a/src/onediff/infer_compiler/utils/model_inplace_assign.py b/src/onediff/infer_compiler/utils/model_inplace_assign.py index e0ac23822..f61276f5b 100644 --- a/src/onediff/infer_compiler/utils/model_inplace_assign.py +++ b/src/onediff/infer_compiler/utils/model_inplace_assign.py @@ -2,7 +2,7 @@ from typing import Union, List from collections import defaultdict import torch -from onediff.infer_compiler.deployable_module import DeployableModule +from onediff.infer_compiler import DeployableModule _nested_counter = defaultdict(lambda: 0) diff --git a/src/onediff/infer_compiler/options.py b/src/onediff/infer_compiler/utils/options.py similarity index 100% rename from src/onediff/infer_compiler/options.py rename to src/onediff/infer_compiler/utils/options.py diff --git a/src/onediff/infer_compiler/utils/param_utils.py b/src/onediff/infer_compiler/utils/param_utils.py index 080c56a4b..dd51653ab 100644 --- a/src/onediff/infer_compiler/utils/param_utils.py +++ b/src/onediff/infer_compiler/utils/param_utils.py @@ -38,7 +38,7 @@ def _convert(device): def init_state_update_attr(module: torch.nn.Module): - from onediff.infer_compiler.deployable_module import DeployableModule + from onediff.infer_compiler import DeployableModule if isinstance(module, DeployableModule): module = module._torch_module @@ -50,7 +50,7 @@ def init_state_update_attr(module: torch.nn.Module): def set_constant_folded_conv_attr( deployable_module, constant_folding_info: Dict[str, flow.Tensor] = None ) -> None: - from onediff.infer_compiler.deployable_module import DeployableModule + from onediff.infer_compiler import DeployableModule if not isinstance(deployable_module, DeployableModule): raise TypeError( @@ -86,7 +86,7 @@ def convert_var_name(s: str, prefix="variable_transpose_"): s = re.sub(r"_[0-9]+$", "", s.removeprefix(prefix)).removeprefix("model.") return s - from onediff.infer_compiler.deployable_module import DeployableModule + from onediff.infer_compiler import DeployableModule if not isinstance(deployable_module, DeployableModule): raise TypeError( @@ -111,7 +111,7 @@ def convert_var_name(s: str, prefix="variable_transpose_"): def update_graph_with_constant_folding_info( module: torch.nn.Module, info: Dict[str, flow.Tensor] = None ) -> None: - from onediff.infer_compiler.deployable_module import DeployableModule + from onediff.infer_compiler import DeployableModule if isinstance(module, DeployableModule): if info is None: @@ -142,7 +142,7 @@ def update_graph_related_tensor(module: torch.nn.Conv2d) -> None: def get_constant_folding_info(module) -> Union[Dict[str, flow.Tensor], None]: - from onediff.infer_compiler.deployable_module import DeployableModule + from onediff.infer_compiler import DeployableModule if not isinstance(module, DeployableModule): raise TypeError(f"module must be a DeployableModule, got {type(module)}") diff --git a/src/onediff/optimization/attention_processor.py b/src/onediff/optimization/attention_processor.py index 0c3db7d9d..22650ab62 100644 --- a/src/onediff/optimization/attention_processor.py +++ b/src/onediff/optimization/attention_processor.py @@ -84,7 +84,7 @@ def __call__( hidden_states = flow.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) else: - from ..infer_compiler.env_var import ( + from ..infer_compiler.utils import ( parse_boolean_from_env, set_boolean_env_var, ) From e025cb00216536464ffe7f50d4d9692e9f7e2259 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 15 Apr 2024 11:00:25 +0800 Subject: [PATCH 10/13] fix --- src/onediff/infer_compiler/backends/nexfort.py | 2 +- src/onediff/infer_compiler/core/with_onediff_compile.py | 2 +- src/onediff/infer_compiler/nexfort/deployable_module.py | 2 +- src/onediff/infer_compiler/utils/env_var.py | 7 ------- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/onediff/infer_compiler/backends/nexfort.py b/src/onediff/infer_compiler/backends/nexfort.py index 9d702aaf1..67cca8cbc 100644 --- a/src/onediff/infer_compiler/backends/nexfort.py +++ b/src/onediff/infer_compiler/backends/nexfort.py @@ -1,7 +1,6 @@ import dataclasses import torch from .registry import register_backend -from ..options import CompileOptions def make_inductor_options(options): @@ -19,6 +18,7 @@ def compile(torch_module: torch.nn.Module, *, options=None): from nexfort.utils.memory_format import apply_memory_format from nexfort.compilers import nexfort_compile from ..nexfort.deployable_module import NexfortDeployableModule + from ..utils import CompileOptions options = options if options is not None else CompileOptions() nexfort_options = options.nexfort diff --git a/src/onediff/infer_compiler/core/with_onediff_compile.py b/src/onediff/infer_compiler/core/with_onediff_compile.py index 6ebc080e6..c2ec0568b 100644 --- a/src/onediff/infer_compiler/core/with_onediff_compile.py +++ b/src/onediff/infer_compiler/core/with_onediff_compile.py @@ -5,7 +5,7 @@ def compile( torch_module: torch.nn.Module, *, backend="nexfort", options=None ) -> DeployableModule: - from .backends.registry import lookup_backend + from ..backends.registry import lookup_backend backend = lookup_backend(backend) model = backend(torch_module, options=options) diff --git a/src/onediff/infer_compiler/nexfort/deployable_module.py b/src/onediff/infer_compiler/nexfort/deployable_module.py index 421b2d8bc..eb8a91be2 100644 --- a/src/onediff/infer_compiler/nexfort/deployable_module.py +++ b/src/onediff/infer_compiler/nexfort/deployable_module.py @@ -1,5 +1,5 @@ import torch -from ..deployable_module import DeployableModule +from ..core.deployable_module import DeployableModule class NexfortDeployableModule(DeployableModule): diff --git a/src/onediff/infer_compiler/utils/env_var.py b/src/onediff/infer_compiler/utils/env_var.py index b385f5181..a40d3b68d 100644 --- a/src/onediff/infer_compiler/utils/env_var.py +++ b/src/onediff/infer_compiler/utils/env_var.py @@ -33,13 +33,6 @@ def set_integer_env_var(env_var: str, val: Optional[int]): def _set_env_vars(field2env_var, options): - from .utils import ( - parse_boolean_from_env, - set_boolean_env_var, - parse_integer_from_env, - set_integer_env_var, - ) - for field in dataclasses.fields(options): field_name = field.name field_value = getattr(options, field_name) From 2555e0b76244804f787de1954019e5e21939d70d Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 15 Apr 2024 11:26:01 +0800 Subject: [PATCH 11/13] fix --- .../oneflow/infer_compiler_registry/register_comfy/__init__.py | 2 +- .../oneflow/infer_compiler_registry/register_onediff_quant.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/__init__.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/__init__.py index 073197d8e..b719b9ba5 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/__init__.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/__init__.py @@ -3,7 +3,7 @@ import comfy from comfy.ldm.modules.diffusionmodules.model import AttnBlock from nodes import * # must imported before import comfy -from onediff.infer_compiler import register +from onediff.infer_compiler.transform import register from onediff.infer_compiler.utils import is_community_version from .attention import CrossAttention as CrossAttention1f diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_onediff_quant.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_onediff_quant.py index 0539cdb5c..d05e8acb5 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_onediff_quant.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_onediff_quant.py @@ -1,6 +1,6 @@ import onediff_quant import oneflow as flow -from onediff.infer_compiler import register +from onediff.infer_compiler.transform import register torch2oflow_class_map = { onediff_quant.FakeQuantModule: onediff_quant.OneFlowFakeQuantModule, From c4d3370b2036d1384c8182dbe9d46599aeb0af5a Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 15 Apr 2024 12:24:32 +0800 Subject: [PATCH 12/13] refine --- src/onediff/infer_compiler/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index 7c310292a..c1131bd2f 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -3,6 +3,7 @@ from .core import * from .utils import set_default_env_vars +from .utils.options import CompileOptions from .utils.options import _GLOBAL_compile_options as compile_options From 2dc005a52d888f5071bb17f0fd8c9d9b7b3c89df Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 14 May 2024 20:47:54 +0800 Subject: [PATCH 13/13] rm useless --- .../experimental/control_net_canny.py | 54 --------- .../experimental/text_to_image_sdxl_fp16.py | 57 --------- .../text_to_image_sdxl_torch_compile.py | 83 -------------- .../experimental/torch_interpretor.py | 33 ------ onediff_diffusers_extensions/setup.py | 4 +- setup.py | 5 +- src/onediff/infer_compiler/__init__.py | 33 ------ src/onediff/infer_compiler/core/__init__.py | 2 - .../infer_compiler/core/with_fx_graph.py | 108 ------------------ .../core/with_fx_interpreter.py | 28 ----- .../core/with_onediff_compile.py | 3 +- 11 files changed, 7 insertions(+), 403 deletions(-) delete mode 100644 onediff_diffusers_extensions/examples/experimental/control_net_canny.py delete mode 100644 onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_fp16.py delete mode 100644 onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_torch_compile.py delete mode 100644 onediff_diffusers_extensions/examples/experimental/torch_interpretor.py delete mode 100644 src/onediff/infer_compiler/core/with_fx_graph.py delete mode 100644 src/onediff/infer_compiler/core/with_fx_interpreter.py diff --git a/onediff_diffusers_extensions/examples/experimental/control_net_canny.py b/onediff_diffusers_extensions/examples/experimental/control_net_canny.py deleted file mode 100644 index e30bdd393..000000000 --- a/onediff_diffusers_extensions/examples/experimental/control_net_canny.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -performs image generation using a stable diffusion model with a control network. -""" -import cv2 -from onediff.infer_compiler import oneflow_compile -from PIL import Image -import numpy as np - - -import oneflow as flow -from diffusers.utils import load_image -from diffusers import ControlNetModel -from diffusers import StableDiffusionControlNetPipeline -import torch - - -image = load_image( - "http://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" -) - -image = np.array(image) - -LOW_THRESHOLD = 100 -HIGH_THRESHOLD = 200 -PROMPT = "disco dancer with colorful lights, best quality, extremely detailed" -NEGATIVE_PROMPT = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" - -image = cv2.Canny(image, LOW_THRESHOLD, HIGH_THRESHOLD) -image = image[:, :, None] -image = np.concatenate([image, image, image], axis=2) -canny_image = Image.fromarray(image) - -controlnet = ControlNetModel.from_pretrained( - "lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16 -) - -pipe = StableDiffusionControlNetPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 -) - -pipe.to("cuda") -pipe.unet = oneflow_compile(pipe.unet) -generator = torch.manual_seed(0) - - -out_images = pipe( - prompt=PROMPT, - negative_prompt=NEGATIVE_PROMPT, - num_inference_steps=20, - generator=generator, - image=canny_image, -).images -for i, image in enumerate(out_images): - image.save(f"{PROMPT}-of-{i}.png") diff --git a/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_fp16.py b/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_fp16.py deleted file mode 100644 index 716ed75f6..000000000 --- a/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_fp16.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Compile to oneflow graph with : -oneflow_compile example: python examples/text_to_image_sdxl_fp16.py --compile -torch.compile example: python examples/text_to_image_sdxl_fp16.py -""" -import os -import argparse -from diffusers import StableDiffusionXLPipeline -import torch - -from onediff.infer_compiler import torchbackend - -parser = argparse.ArgumentParser() -parser.add_argument( - "--model", type=str, default="/share_nfs/hf_models/stable-diffusion-xl-base-1.0" -) -parser.add_argument("--variant", type=str, default="fp16") -parser.add_argument( - "--prompt", - type=str, - default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", -) -parser.add_argument( - "--saved_image", type=str, required=False, default="xl-base-out.png" -) -parser.add_argument("--seed", type=int, default=1) -parser.add_argument( - "--compile", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=True -) -parser.add_argument("--graph", action=argparse.BooleanOptionalAction) -args = parser.parse_args() - -if args.compile: - print("unet is compiled to oneflow.") - if args.graph: - print("unet is compiled to oneflow graph.") - -torch.manual_seed(args.seed) - -pipe = StableDiffusionXLPipeline.from_pretrained( - args.model, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True -) - -if args.compile: - if args.graph: - os.environ["with_graph"] = "1" - pipe.unet = torch.compile( - pipe.unet, fullgraph=True, mode="reduce-overhead", backend=torchbackend - ) - -pipe.to("cuda") - -for i in range(3): - image = pipe( - prompt=args.prompt, height=768, width=768, num_inference_steps=50 - ).images[0] - image.save(f"{i}-{args.saved_image}") diff --git a/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_torch_compile.py b/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_torch_compile.py deleted file mode 100644 index 85212ed1a..000000000 --- a/onediff_diffusers_extensions/examples/experimental/text_to_image_sdxl_torch_compile.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -Compile to oneflow graph with : -oneflow_compile example: python examples/text_to_image_sdxl.py --compile -torch.compile example: python examples/text_to_image_sdxl.py --compile_with_dynamo -""" -import os -import argparse - -import torch -import oneflow as flow - -from diffusers import DiffusionPipeline -from onediff.infer_compiler import oneflow_compile, compile_options - -parser = argparse.ArgumentParser() -parser.add_argument( - "--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" -) -parser.add_argument( - "--refiner", type=str, default="stabilityai/stable-diffusion-xl-refiner-1.0" -) -parser.add_argument("--variant", type=str, default="fp16") -parser.add_argument( - "--prompt", - type=str, - default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", -) -parser.add_argument("--n_steps", type=int, default=30) -parser.add_argument("--saved_image", type=str, required=False, default="sdxl-out.png") -parser.add_argument("--seed", type=int, default=1) -parser.add_argument( - "--compile", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=True -) -parser.add_argument("--compile_with_dynamo", action=argparse.BooleanOptionalAction) -parser.add_argument("--num_dynamic_input_size", type=int, default=9) -cmd_args = parser.parse_args() - -if cmd_args.compile and cmd_args.compile_with_dynamo: - parser.error("--compile and --compile_with_dynamo cannot be used together.") - -# Normal SDXL pipeline init. -SEED = torch.Generator("cuda").manual_seed(cmd_args.seed) -OUTPUT_TYPE = "pil" -# SDXL base: StableDiffusionXLPipeline -base = DiffusionPipeline.from_pretrained( - cmd_args.base, - torch_dtype=torch.float16, - variant=cmd_args.variant, - use_safetensors=True, -) -base.to("cuda") - -# Compile unet with oneflow -if cmd_args.compile: - print("unet is compiled to oneflow.") - compile_options.oneflow.max_cached_graph_size = cmd_args.num_dynamic_input_size - base.unet = oneflow_compile(base.unet, options=compile_options) - -# Compile unet with torch.compile to oneflow. -# Note this is at alpha stage(experimental) and may be changed later. -if cmd_args.compile_with_dynamo: - print("unet is compiled to oneflow with torch.compile.") - from onediff.infer_compiler import oneflow_backend - - base.unet = torch.compile( - base.unet, fullgraph=True, mode="reduce-overhead", backend=oneflow_backend - ) - -# Normal SDXL run -# sizes = [1024, 896, 768] -sizes = [1024] -for h in sizes: - for w in sizes: - for i in range(3): - image = base( - prompt=cmd_args.prompt, - height=h, - width=w, - generator=SEED, - num_inference_steps=cmd_args.n_steps, - output_type=OUTPUT_TYPE, - ).images - image[0].save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}") diff --git a/onediff_diffusers_extensions/examples/experimental/torch_interpretor.py b/onediff_diffusers_extensions/examples/experimental/torch_interpretor.py deleted file mode 100644 index 9c3421010..000000000 --- a/onediff_diffusers_extensions/examples/experimental/torch_interpretor.py +++ /dev/null @@ -1,33 +0,0 @@ -# HF_HUB_OFFLINE=1 python3 examples/torch_interpretor.py -import os -import torch -from diffusers import StableDiffusionPipeline -from onediff.infer_compiler import oneflow_backend - -pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - use_auth_token=True, - revision="fp16", - torch_dtype=torch.float16, -) - -# run with interpreter mode to oneflow -# ONEDIFF_INFER_COMPILER_USE_INTERPRETER's default value is 0 -os.environ["ONEDIFF_INFER_COMPILER_USE_INTERPRETER"] = "0" - -# optimize with oneflow graph -# ONEDIFF_INFER_COMPILER_USE_GRAPH's default value is 0 -os.environ["ONEDIFF_INFER_COMPILER_USE_GRAPH"] = "1" - - -pipe.unet = torch.compile( - pipe.unet, fullgraph=True, mode="reduce-overhead", backend=oneflow_backend -) -pipe = pipe.to("cuda") - -PROMPT = "a photo of an astronaut riding a horse on mars" -with torch.autocast("cuda"): - for i in range(3): - images = pipe(PROMPT).images - for j, image in enumerate(images): - image.save(f"{PROMPT}-of-{i}-{j}.png") diff --git a/onediff_diffusers_extensions/setup.py b/onediff_diffusers_extensions/setup.py index a14e275f0..eea4fcf59 100644 --- a/onediff_diffusers_extensions/setup.py +++ b/onediff_diffusers_extensions/setup.py @@ -16,8 +16,8 @@ def get_version(): description="onediff extensions for diffusers", url="https://github.com/siliconflow/onediff", author="OneDiff contributors", - license="Apache", - author_email="caishenghang@oneflow.org", + license="Apache-2.0", + author_email="contact@siliconflow.com", packages=find_packages(), python_requires=">=3.7.0", install_requires=[ diff --git a/setup.py b/setup.py index fcfb6b9aa..cf3ab2356 100644 --- a/setup.py +++ b/setup.py @@ -16,8 +16,9 @@ def get_version(): description="an out-of-the-box acceleration library for diffusion models", url="https://github.com/siliconflow/onediff", author="OneDiff contributors", - license="Apache", - author_email="caishenghang@oneflow.org", + license="Apache-2.0", + license_files=('LICENSE',), + author_email="contact@siliconflow.com", package_dir={"": "src"}, packages=find_packages("src"), python_requires=">=3.8.0", diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index c1131bd2f..bff98d894 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -8,36 +8,3 @@ set_default_env_vars() - - -def oneflow_backend(gm, example_inputs, *args, **kwargs): - import oneflow as flow - from oneflow.framework.args_tree import ArgsTree - - with_interp = os.getenv( - "ONEDIFF_INFER_COMPILER_USE_INTERPRETER", "False" - ).lower() in ("true", "1", "t",) - if not with_interp: - transformed_fn = fx_node_tranform(gm) - - def wrapped_forward(*args, **kwargs): - def input_fn(value): - if isinstance(value, torch.Tensor): - return flow.utils.tensor.from_torch(value.contiguous()) - else: - return value - - args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) - out = args_tree.map_leaf(input_fn) - args = out[0] - if with_interp: - output = OneFlowInterpreter(gm, garbage_collect_values=False).run( - *args, **kwargs - ) - else: - output = transformed_fn(*args, **kwargs) - if isinstance(output, tuple): - return tuple(flow.utils.tensor.to_torch(i) for i in output) - return flow.utils.tensor.to_torch(output) - - return wrapped_forward diff --git a/src/onediff/infer_compiler/core/__init__.py b/src/onediff/infer_compiler/core/__init__.py index 734fe9adb..2c2324087 100644 --- a/src/onediff/infer_compiler/core/__init__.py +++ b/src/onediff/infer_compiler/core/__init__.py @@ -1,4 +1,2 @@ from .deployable_module import DeployableModule from .with_onediff_compile import compile, oneflow_compile -from .with_fx_interpreter import OneFlowInterpreter -from .with_fx_graph import fx_node_tranform diff --git a/src/onediff/infer_compiler/core/with_fx_graph.py b/src/onediff/infer_compiler/core/with_fx_graph.py deleted file mode 100644 index 9f17772c8..000000000 --- a/src/onediff/infer_compiler/core/with_fx_graph.py +++ /dev/null @@ -1,108 +0,0 @@ -import os -import torch -import torch.fx as fx -from torch.fx.node import map_aggregate -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union - - -def fx_node_tranform(gm): - import oneflow as flow - - of_gm = to_of_transform(gm) - - enable_graph = os.getenv("ONEDIFF_INFER_COMPILER_USE_GRAPH", "True").lower() in ( - "true", - "1", - "t", - ) - - if not enable_graph: - oneflow_fn = of_gm.forward - else: - - class OfGraph(flow.nn.Graph): - def __init__(self): - super().__init__() - self.fx_md = of_gm - # self.config.enable_cudnn_conv_heuristic_search_algo(False) - self.config.allow_fuse_add_to_output(True) - - def build(self, *args, **kwargs): - return self.fx_md(*args, **kwargs) - - of_g = OfGraph() - oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs) - - return oneflow_fn - - -def to_of_transform( - gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer -) -> torch.fx.GraphModule: - import oneflow as flow - from .transform import get_attr, torch2oflow - - name2node = {} - name2obj = {} - torch2flow = {} - of_g = flow.fx.Graph() - modules = dict(gm.named_modules()) - for node in gm.graph.nodes: - if node.op == "placeholder": - of_node = of_g.create_node("placeholder", node.target) - name2node[node.name] = of_node - elif node.op == "output": - of_node = of_g.output(node_replace_args(node.args, name2node)[0]) - name2node[node.name] = of_node - elif node.op == "call_function": - of_node = of_g.create_node( - "call_function", - torch2oflow(node.target), - args=node_replace_args(node.args, name2node), - kwargs=node_replace_args(node.kwargs, name2node), - ) - name2node[node.name] = of_node - elif node.op == "call_method": - of_node = of_g.create_node( - "call_method", - node.target, - args=node_replace_args(node.args, name2node), - kwargs=node_replace_args(node.kwargs, name2node), - ) - name2node[node.name] = of_node - elif node.op == "call_module": - torch_md = modules[node.target] - name2obj[node.target] = torch2oflow(torch_md) - - of_node = of_g.create_node( - "call_module", - node.target, - args=node_replace_args(node.args, name2node), - kwargs=node_replace_args(node.kwargs, name2node), - ) - name2node[node.name] = of_node - elif node.op == "get_attr": - of_node = of_g.create_node("get_attr", node.target) - name2node[node.name] = of_node - name2obj[node.target] = get_attr(gm, node, torch2flow) - else: - raise ValueError(f"not valid node type{node.foramt_node()}") - - of_gm = flow.fx.GraphModule(name2obj, of_g) - of_gm.training = False - of_gm.graph.lint() - of_gm.recompile() - return of_gm - - -def replace_node(node, name2node): - from .transform import torch2oflow - - if isinstance(node, torch.fx.Node): - return name2node[node.name] - else: - return torch2oflow(node) - - -def node_replace_args(args, name2node): - return map_aggregate(args, lambda node: replace_node(node, name2node)) diff --git a/src/onediff/infer_compiler/core/with_fx_interpreter.py b/src/onediff/infer_compiler/core/with_fx_interpreter.py deleted file mode 100644 index 884a201d7..000000000 --- a/src/onediff/infer_compiler/core/with_fx_interpreter.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union - - -class OneFlowInterpreter(torch.fx.Interpreter): - from torch.fx.node import Argument, Target - - def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: - from .transform import map_args - - args, kwargs = map_args(args, kwargs) - target = torch2oflow(target) - return super().call_function(target, args, kwargs) - - def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: - from .transform import map_args - - args, kwargs = map_args(args, kwargs) - return super().call_method(target, args, kwargs) - - def call_module( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] - ) -> Any: - from .transform import ProxySubmodule - - submod = self.fetch_attr(target) - submod = ProxySubmodule(submod) - return submod(*args, **kwargs) diff --git a/src/onediff/infer_compiler/core/with_onediff_compile.py b/src/onediff/infer_compiler/core/with_onediff_compile.py index c2ec0568b..3ab038162 100644 --- a/src/onediff/infer_compiler/core/with_onediff_compile.py +++ b/src/onediff/infer_compiler/core/with_onediff_compile.py @@ -1,9 +1,10 @@ import torch from .deployable_module import DeployableModule +_DEFAULT_BACKEND = "oneflow" def compile( - torch_module: torch.nn.Module, *, backend="nexfort", options=None + torch_module: torch.nn.Module, *, backend=_DEFAULT_BACKEND, options=None ) -> DeployableModule: from ..backends.registry import lookup_backend