Skip to content

Commit

Permalink
Ci with pre commit (#1040)
Browse files Browse the repository at this point in the history
Co-authored-by: lixiang007666 <88304454@qq.com>
  • Loading branch information
strint and lixiang007666 authored Jul 25, 2024
1 parent 38abb3a commit 94e8e4a
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 35 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ concurrency:
cancel-in-progress: true

jobs:
pre-commit:
runs-on: [ubuntu-latest]
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
- uses: pre-commit/action@v3.0.1
upload_src:
if: github.repository == 'siliconflow/onediff'
runs-on: [ubuntu-latest]
Expand Down
1 change: 1 addition & 0 deletions onediff_comfy_nodes/modules/booster_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from onediff.torch_utils.module_operations import get_sub_module
from onediff.utils.import_utils import is_oneflow_available


@singledispatch
def switch_to_cached_model(new_model, cached_model):
raise NotImplementedError(type(new_model))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def tensor_to_size(source, dest_size):
return source


def get_weight_subidxs(weight,ad_params,sub_idxs):
def get_weight_subidxs(weight, ad_params, sub_idxs):
return weight[ad_params[sub_idxs]]


Expand Down Expand Up @@ -167,7 +167,7 @@ def ipadapter_attention(
if ad_params is not None and ad_params["sub_idxs"] is not None:
if isinstance(weight, torch.Tensor) and weight.dim() != 0:
weight = tensor_to_size(weight, ad_params["full_length"])
weight = get_weight_subidxs(weight,ad_params,"sub_idxs")
weight = get_weight_subidxs(weight, ad_params, "sub_idxs")
# if torch.all(weight == 0):
# return 0
weight = weight.repeat(
Expand All @@ -178,8 +178,8 @@ def ipadapter_attention(

# if image length matches or exceeds full_length get sub_idx images
if cond.shape[0] >= ad_params["full_length"]:
cond = get_weight_subidxs(cond,ad_params,"sub_idxs")
uncond = get_weight_subidxs(uncond,ad_params,"sub_idxs")
cond = get_weight_subidxs(cond, ad_params, "sub_idxs")
uncond = get_weight_subidxs(uncond, ad_params, "sub_idxs")
# otherwise get sub_idxs images
else:
cond = tensor_to_size(cond, ad_params["full_length"])
Expand Down
1 change: 0 additions & 1 deletion onediff_diffusers_extensions/examples/kolors/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,3 @@ python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \

The quality report for accelerating the kolors model with onediff is located at:
https://github.com/siliconflow/odeval/tree/main/models/kolors

Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import json
import time

import torch

from diffusers import DPMSolverMultistepScheduler, KolorsPipeline
from onediffx import compile_pipe, quantize_pipe, load_pipe, save_pipe
from onediff.infer_compiler import oneflow_compile
import torch
from onediffx import compile_pipe, load_pipe, quantize_pipe, save_pipe


def parse_args():
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(
elif compiler == "oneflow":
print("oneflow backend compile...")
# self.pipe.unet = self.oneflow_compile(self.pipe.unet)
self.pipe = compile_pipe(self.pipe, ignores=['text_encoder', 'vae'])
self.pipe = compile_pipe(self.pipe, ignores=["text_encoder", "vae"])

def warmup(self, gen_args, warmup_iterations):
warmup_args = gen_args.copy()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Dict, Optional

import torch
from typing import Optional, Dict, Any
from onediff.infer_compiler.backends.oneflow.transform import transform_mgr

transformed_diffusers = transform_mgr.transform_package("diffusers")
ConfigMixin = transformed_diffusers.configuration_utils.ConfigMixin
register_to_config = transformed_diffusers.configuration_utils.register_to_config
Expand All @@ -20,11 +22,14 @@
LoRACompatibleLinear = transformed_diffusers.models.lora.LoRACompatibleLinear
ModelMixin = transformed_diffusers.models.modeling_utils.ModelMixin
AdaLayerNormSingle = transformed_diffusers.models.normalization.AdaLayerNormSingle
Transformer2DModelOutput = transformed_diffusers.models.transformers.transformer_2d.Transformer2DModelOutput
Transformer2DModelOutput = (
transformed_diffusers.models.transformers.transformer_2d.Transformer2DModelOutput
)
proxy_Transformer2DModel = (
transformed_diffusers.models.transformers.transformer_2d.Transformer2DModel
)


class Transformer2DModel(proxy_Transformer2DModel):
def forward(
self,
Expand Down Expand Up @@ -79,7 +84,9 @@ def forward(
"""
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
logger.warning(
"Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
Expand All @@ -100,7 +107,9 @@ def forward(

# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = (
1 - encoder_attention_mask.to(hidden_states.dtype)
) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

hidden_states_in = hidden_states
Expand All @@ -113,8 +122,16 @@ def forward(
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
height, width = (
hidden_states.shape[-2] // self.patch_size,
hidden_states.shape[-1] // self.patch_size,
)
(
hidden_states,
encoder_hidden_states,
timestep,
embedded_timestep,
) = self._operate_on_patched_inputs(
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
)

Expand All @@ -131,7 +148,9 @@ def custom_forward(*inputs):

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
Expand Down Expand Up @@ -203,7 +222,9 @@ def custom_forward(*inputs):

return Transformer2DModelOutput(sample=output)

def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
def _get_output_for_continuous_inputs(
self, hidden_states, residual, batch_size, height, width, inner_dim
):
# # hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
# hidden_states = (
# hidden_states.permute(0, 2, 1)
Expand All @@ -215,13 +236,17 @@ def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size,
return
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states.reshape(batch_size, height, width, inner_dim)
.permute(0, 3, 1, 2)
.contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states.reshape(batch_size, height, width, inner_dim)
.permute(0, 3, 1, 2)
.contiguous()
)

output = hidden_states + residual
Expand All @@ -238,7 +263,13 @@ def _get_output_for_vectorized_inputs(self, hidden_states):
return output

def _get_output_for_patched_inputs(
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
self,
hidden_states,
timestep,
class_labels,
embedded_timestep,
height=None,
width=None,
):
raise
# import ipdb; ipdb.set_trace()
Expand All @@ -247,10 +278,14 @@ def _get_output_for_patched_inputs(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = (
self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
)
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None]
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
Expand All @@ -261,11 +296,23 @@ def _get_output_for_patched_inputs(
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
shape=(
-1,
height,
width,
self.patch_size,
self.patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
shape=(
-1,
self.out_channels,
height * self.patch_size,
width * self.patch_size,
)
)
return output

Expand All @@ -284,4 +331,4 @@ def _operate_on_continuous_inputs(self, hidden_states):
hidden_states = hidden_states.permute(0, 2, 3, 1).flatten(1, 2)
hidden_states = self.proj_in(hidden_states)

return hidden_states, inner_dim
return hidden_states, inner_dim
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def forward(
)
proxy_Transformer2DModel = (
transformed_diffusers.models.transformer_2d.Transformer2DModel
)
)

class Transformer2DModel(proxy_Transformer2DModel):
def forward(
Expand Down Expand Up @@ -1213,5 +1213,6 @@ def custom_forward(*inputs):
return (output,)

return Transformer2DModelOutput(sample=output)

else:
from .transformer_2d.v_0_28 import Transformer2DModel, Transformer2DModelOutput
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def forward(
else:
hidden_states = upsampler(hidden_states)

class CrossAttnUpBlock2D(
diffusers_unet_2d_blocks.CrossAttnUpBlock2D
):
class CrossAttnUpBlock2D(diffusers_unet_2d_blocks.CrossAttnUpBlock2D):
def forward(
self,
hidden_states: torch.FloatTensor,
Expand Down Expand Up @@ -185,9 +183,7 @@ def forward(

return hidden_states

class CrossAttnUpBlock2D(
diffusers_unet_2d_blocks.CrossAttnUpBlock2D
):
class CrossAttnUpBlock2D(diffusers_unet_2d_blocks.CrossAttnUpBlock2D):
def forward(
self,
hidden_states: torch.FloatTensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
UNet2DConditionOutput = (
transformed_diffusers.models.unet_2d_condition.UNet2DConditionOutput
)
proxy_UNet2DConditionModel = transformed_diffusers.models.unet_2d_condition.UNet2DConditionModel
proxy_UNet2DConditionModel = (
transformed_diffusers.models.unet_2d_condition.UNet2DConditionModel
)
else:
UNet2DConditionOutput = (
transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput
)
proxy_UNet2DConditionModel = transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
proxy_UNet2DConditionModel = (
transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
)

try:
USE_PEFT_BACKEND = transformed_diffusers.utils.USE_PEFT_BACKEND
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import oneflow as flow # usort: skip
import functools

from oneflow.framework.args_tree import ArgsTree

from onediff.utils import logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ def mem_get_info(dev):

@staticmethod
def _scaled_dot_product_attention_math(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale

Expand Down Expand Up @@ -51,7 +57,13 @@ def _scaled_dot_product_attention_math(

@staticmethod
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
"""Scaled Dot-Product Attention
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools


def patch_input_adapter(in_args, in_kwargs):
return in_args, in_kwargs

Expand Down

0 comments on commit 94e8e4a

Please sign in to comment.