From d4cd7b15104c16838e3c562cf2d33337e3d38897 Mon Sep 17 00:00:00 2001 From: Xuye Qin Date: Fri, 25 Oct 2024 12:51:06 +0800 Subject: [PATCH] FEAT: added MLX support for Flux.1 (#2459) --- setup.cfg | 2 +- xinference/core/chat_interface.py | 6 +- xinference/core/image_interface.py | 6 +- xinference/core/scheduler.py | 2 +- xinference/deploy/docker/requirements.txt | 2 +- xinference/deploy/docker/requirements_cpu.txt | 2 +- xinference/model/image/core.py | 28 +- xinference/model/image/model_spec.json | 18 +- .../model/image/model_spec_modelscope.json | 18 +- xinference/model/image/scheduler/flux.py | 2 +- .../model/image/stable_diffusion/core.py | 5 +- .../model/image/stable_diffusion/mlx.py | 221 +++++++++++ xinference/thirdparty/mlx/__init__.py | 13 + xinference/thirdparty/mlx/flux/__init__.py | 15 + xinference/thirdparty/mlx/flux/autoencoder.py | 357 ++++++++++++++++++ xinference/thirdparty/mlx/flux/clip.py | 154 ++++++++ xinference/thirdparty/mlx/flux/datasets.py | 75 ++++ xinference/thirdparty/mlx/flux/flux.py | 247 ++++++++++++ xinference/thirdparty/mlx/flux/layers.py | 302 +++++++++++++++ xinference/thirdparty/mlx/flux/lora.py | 76 ++++ xinference/thirdparty/mlx/flux/model.py | 134 +++++++ xinference/thirdparty/mlx/flux/sampler.py | 56 +++ xinference/thirdparty/mlx/flux/t5.py | 244 ++++++++++++ xinference/thirdparty/mlx/flux/tokenizers.py | 185 +++++++++ xinference/thirdparty/mlx/flux/trainer.py | 98 +++++ xinference/thirdparty/mlx/flux/utils.py | 179 +++++++++ 26 files changed, 2428 insertions(+), 19 deletions(-) create mode 100644 xinference/model/image/stable_diffusion/mlx.py create mode 100644 xinference/thirdparty/mlx/__init__.py create mode 100644 xinference/thirdparty/mlx/flux/__init__.py create mode 100644 xinference/thirdparty/mlx/flux/autoencoder.py create mode 100644 xinference/thirdparty/mlx/flux/clip.py create mode 100644 xinference/thirdparty/mlx/flux/datasets.py create mode 100644 xinference/thirdparty/mlx/flux/flux.py create mode 100644 xinference/thirdparty/mlx/flux/layers.py create mode 100644 xinference/thirdparty/mlx/flux/lora.py create mode 100644 xinference/thirdparty/mlx/flux/model.py create mode 100644 xinference/thirdparty/mlx/flux/sampler.py create mode 100644 xinference/thirdparty/mlx/flux/t5.py create mode 100644 xinference/thirdparty/mlx/flux/tokenizers.py create mode 100644 xinference/thirdparty/mlx/flux/trainer.py create mode 100644 xinference/thirdparty/mlx/flux/utils.py diff --git a/setup.cfg b/setup.cfg index ccde851616..3c08363e59 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ install_requires = tabulate requests pydantic - fastapi==0.110.3 + fastapi>=0.110.3 uvicorn huggingface-hub>=0.19.4 typing_extensions diff --git a/xinference/core/chat_interface.py b/xinference/core/chat_interface.py index 9de2dab252..08b30ab054 100644 --- a/xinference/core/chat_interface.py +++ b/xinference/core/chat_interface.py @@ -74,7 +74,11 @@ def build(self) -> "gr.Blocks": # Gradio initiates the queue during a startup event, but since the app has already been # started, that event will not run, so manually invoke the startup events. # See: https://github.com/gradio-app/gradio/issues/5228 - interface.startup_events() + try: + interface.run_startup_events() + except AttributeError: + # compatibility + interface.startup_events() favicon_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), os.path.pardir, diff --git a/xinference/core/image_interface.py b/xinference/core/image_interface.py index 56761e4101..b48636bfd5 100644 --- a/xinference/core/image_interface.py +++ b/xinference/core/image_interface.py @@ -63,7 +63,11 @@ def build(self) -> gr.Blocks: # Gradio initiates the queue during a startup event, but since the app has already been # started, that event will not run, so manually invoke the startup events. # See: https://github.com/gradio-app/gradio/issues/5228 - interface.startup_events() + try: + interface.run_startup_events() + except AttributeError: + # compatibility + interface.startup_events() favicon_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), os.path.pardir, diff --git a/xinference/core/scheduler.py b/xinference/core/scheduler.py index 1b91d62e27..8b91855daa 100644 --- a/xinference/core/scheduler.py +++ b/xinference/core/scheduler.py @@ -79,7 +79,7 @@ def __init__( # For tool call self.tools = None # Currently, for storing tool call streaming results. - self.outputs: List[str] = [] + self.outputs: List[str] = [] # type: ignore # inference results, # it is a list type because when stream=True, # self.completion contains all the results in a decode round. diff --git a/xinference/deploy/docker/requirements.txt b/xinference/deploy/docker/requirements.txt index 6fc624298b..a3aa0a5e93 100644 --- a/xinference/deploy/docker/requirements.txt +++ b/xinference/deploy/docker/requirements.txt @@ -8,7 +8,7 @@ tqdm>=4.27 tabulate requests pydantic -fastapi==0.110.3 +fastapi>=0.110.3 uvicorn huggingface-hub>=0.19.4 typing_extensions diff --git a/xinference/deploy/docker/requirements_cpu.txt b/xinference/deploy/docker/requirements_cpu.txt index 7c15d3a9df..9eb9409b4f 100644 --- a/xinference/deploy/docker/requirements_cpu.txt +++ b/xinference/deploy/docker/requirements_cpu.txt @@ -7,7 +7,7 @@ tqdm>=4.27 tabulate requests pydantic -fastapi==0.110.3 +fastapi>=0.110.3 uvicorn huggingface-hub>=0.19.4 typing_extensions diff --git a/xinference/model/image/core.py b/xinference/model/image/core.py index 098cdaa10b..581358b789 100644 --- a/xinference/model/image/core.py +++ b/xinference/model/image/core.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import collections.abc import logging import os +import platform from collections import defaultdict from typing import Dict, List, Literal, Optional, Tuple, Union @@ -23,6 +25,7 @@ from ..utils import valid_model_revision from .ocr.got_ocr2 import GotOCR2Model from .stable_diffusion.core import DiffusionModel +from .stable_diffusion.mlx import MLXDiffusionModel logger = logging.getLogger(__name__) @@ -46,6 +49,7 @@ class ImageModelFamilyV1(CacheableModelSpec): model_hub: str = "huggingface" model_ability: Optional[List[str]] controlnet: Optional[List["ImageModelFamilyV1"]] + default_model_config: Optional[dict] = {} default_generate_config: Optional[dict] = {} @@ -212,7 +216,9 @@ def create_image_model_instance( download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None, model_path: Optional[str] = None, **kwargs, -) -> Tuple[Union[DiffusionModel, GotOCR2Model], ImageModelDescription]: +) -> Tuple[ + Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription +]: model_spec = match_diffusion(model_name, download_hub) if model_spec.model_ability and "ocr" in model_spec.model_ability: return create_ocr_model_instance( @@ -224,6 +230,12 @@ def create_image_model_instance( model_path=model_path, **kwargs, ) + + # use default model config + model_default_config = (model_spec.default_model_config or {}).copy() + model_default_config.update(kwargs) + kwargs = model_default_config + controlnet = kwargs.get("controlnet") # Handle controlnet if controlnet is not None: @@ -265,10 +277,20 @@ def create_image_model_instance( lora_load_kwargs = None lora_fuse_kwargs = None - model = DiffusionModel( + if ( + platform.system() == "Darwin" + and "arm" in platform.machine().lower() + and model_name in MLXDiffusionModel.supported_models + ): + # Mac with M series silicon chips + model_cls = MLXDiffusionModel + else: + model_cls = DiffusionModel # type: ignore + + model = model_cls( model_uid, model_path, - lora_model_paths=lora_model, + lora_model=lora_model, lora_load_kwargs=lora_load_kwargs, lora_fuse_kwargs=lora_fuse_kwargs, model_spec=model_spec, diff --git a/xinference/model/image/model_spec.json b/xinference/model/image/model_spec.json index 43c77e0e8e..24933cb99e 100644 --- a/xinference/model/image/model_spec.json +++ b/xinference/model/image/model_spec.json @@ -8,7 +8,11 @@ "text2image", "image2image", "inpainting" - ] + ], + "default_model_config": { + "quantize": true, + "quantize_text_encoder": "text_encoder_2" + } }, { "model_name": "FLUX.1-dev", @@ -19,7 +23,11 @@ "text2image", "image2image", "inpainting" - ] + ], + "default_model_config": { + "quantize": true, + "quantize_text_encoder": "text_encoder_2" + } }, { "model_name": "sd3-medium", @@ -30,7 +38,11 @@ "text2image", "image2image", "inpainting" - ] + ], + "default_model_config": { + "quantize": true, + "quantize_text_encoder": "text_encoder_3" + } }, { "model_name": "sd-turbo", diff --git a/xinference/model/image/model_spec_modelscope.json b/xinference/model/image/model_spec_modelscope.json index 709de622b9..ad8af7a26f 100644 --- a/xinference/model/image/model_spec_modelscope.json +++ b/xinference/model/image/model_spec_modelscope.json @@ -9,7 +9,11 @@ "text2image", "image2image", "inpainting" - ] + ], + "default_model_config": { + "quantize": true, + "quantize_text_encoder": "text_encoder_2" + } }, { "model_name": "FLUX.1-dev", @@ -21,7 +25,11 @@ "text2image", "image2image", "inpainting" - ] + ], + "default_model_config": { + "quantize": true, + "quantize_text_encoder": "text_encoder_2" + } }, { "model_name": "sd3-medium", @@ -33,7 +41,11 @@ "text2image", "image2image", "inpainting" - ] + ], + "default_model_config": { + "quantize": true, + "quantize_text_encoder": "text_encoder_3" + } }, { "model_name": "sd-turbo", diff --git a/xinference/model/image/scheduler/flux.py b/xinference/model/image/scheduler/flux.py index 174acb82e3..b681e59fa7 100644 --- a/xinference/model/image/scheduler/flux.py +++ b/xinference/model/image/scheduler/flux.py @@ -124,7 +124,7 @@ def __init__(self): self._running_queue: deque[Text2ImageRequest] = deque() # type: ignore self._model = None self._available_device = get_available_device() - self._id_to_req: Dict[str, Text2ImageRequest] = {} + self._id_to_req: Dict[str, Text2ImageRequest] = {} # type: ignore def set_model(self, model): """ diff --git a/xinference/model/image/stable_diffusion/core.py b/xinference/model/image/stable_diffusion/core.py index ae9b6e4bd4..c5a9b33f86 100644 --- a/xinference/model/image/stable_diffusion/core.py +++ b/xinference/model/image/stable_diffusion/core.py @@ -283,9 +283,8 @@ def _load_to_device(self, model): model.enable_sequential_cpu_offload() elif not self._kwargs.get("device_map"): logger.debug("Loading model to available device") - model = move_model_to_available_device(self._model) - # Recommended if your computer has < 64 GB of RAM - if self._kwargs.get("attention_slicing", True): + model = move_model_to_available_device(model) + if self._kwargs.get("attention_slicing", False): model.enable_attention_slicing() if self._kwargs.get("vae_tiling", False): model.enable_vae_tiling() diff --git a/xinference/model/image/stable_diffusion/mlx.py b/xinference/model/image/stable_diffusion/mlx.py new file mode 100644 index 0000000000..849ff62aab --- /dev/null +++ b/xinference/model/image/stable_diffusion/mlx.py @@ -0,0 +1,221 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import gc +import logging +import re +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import numpy as np +from PIL import Image +from xoscar.utils import classproperty + +from ....types import LoRA +from ..sdapi import SDAPIDiffusionModelMixin +from ..utils import handle_image_result + +if TYPE_CHECKING: + from ....core.progress_tracker import Progressor + from ..core import ImageModelFamilyV1 + + +logger = logging.getLogger(__name__) + + +def quantization_predicate(name: str, m) -> bool: + return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 + + +def to_latent_size(image_size: Tuple[int, int]): + h, w = image_size + h = ((h + 15) // 16) * 16 + w = ((w + 15) // 16) * 16 + + if (h, w) != image_size: + print( + "Warning: The image dimensions need to be divisible by 16px. " + f"Changing size to {h}x{w}." + ) + + return (h // 8, w // 8) + + +class MLXDiffusionModel(SDAPIDiffusionModelMixin): + def __init__( + self, + model_uid: str, + model_path: Optional[str] = None, + device: Optional[str] = None, + lora_model: Optional[List[LoRA]] = None, + lora_load_kwargs: Optional[Dict] = None, + lora_fuse_kwargs: Optional[Dict] = None, + model_spec: Optional["ImageModelFamilyV1"] = None, + **kwargs, + ): + self._model_uid = model_uid + self._model_path = model_path + self._device = device + # model info when loading + self._model = None + self._lora_model = lora_model + self._lora_load_kwargs = lora_load_kwargs or {} + self._lora_fuse_kwargs = lora_fuse_kwargs or {} + # info + self._model_spec = model_spec + self._abilities = model_spec.model_ability or [] # type: ignore + self._kwargs = kwargs + + @property + def model_ability(self): + return self._abilities + + @classproperty + def supported_models(self): + return ["FLUX.1-schnell", "FLUX.1-dev"] + + def load(self): + try: + import mlx.nn as nn + except ImportError: + error_message = "Failed to import module 'mlx'" + installation_guide = [ + "Please make sure 'mlx' is installed. ", + "You can install it by `pip install mlx`\n", + ] + + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + from ....thirdparty.mlx.flux import FluxPipeline + + logger.debug( + "Loading model from %s, kwargs: %s", self._model_path, self._kwargs + ) + flux = self._model = FluxPipeline( + "flux-" + self._model_spec.model_name.split("-")[1], + model_path=self._model_path, + t5_padding=self._kwargs.get("t5_padding", True), + ) + self._apply_lora() + + quantize = self._kwargs.get("quantize", True) + if quantize: + nn.quantize(flux.flow, class_predicate=quantization_predicate) + nn.quantize(flux.t5, class_predicate=quantization_predicate) + nn.quantize(flux.clip, class_predicate=quantization_predicate) + + def _apply_lora(self): + if self._lora_model is not None: + import mlx.core as mx + + for lora_model in self._lora_model: + weights, lora_config = mx.load( + lora_model.local_path, return_metadata=True + ) + rank = int(lora_config.get("lora_rank", 8)) + num_blocks = int(lora_config.get("lora_blocks", -1)) + flux = self._model + flux.linear_to_lora_layers(rank, num_blocks) + flux.flow.load_weights(list(weights.items()), strict=False) + flux.fuse_lora_layers() + logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.") + + @staticmethod + @contextlib.contextmanager + def _release_after(): + import mlx.core as mx + + try: + yield + finally: + gc.collect() + mx.metal.clear_cache() + + def text_to_image( + self, + prompt: str, + n: int = 1, + size: str = "1024*1024", + response_format: str = "url", + **kwargs, + ): + import mlx.core as mx + + flux = self._model + width, height = map(int, re.split(r"[^\d]+", size)) + + # Make the generator + latent_size = to_latent_size((height, width)) + gen_latent_kwargs = {} + if (num_steps := kwargs.get("num_inference_steps")) is None: + num_steps = 50 if "dev" in self._model_spec.model_name else 2 # type: ignore + gen_latent_kwargs["num_steps"] = num_steps + if guidance := kwargs.get("guidance_scale"): + gen_latent_kwargs["guidance"] = guidance + if seed := kwargs.get("seed"): + gen_latent_kwargs["seed"] = seed + + with self._release_after(): + latents = flux.generate_latents( # type: ignore + prompt, n_images=n, latent_size=latent_size, **gen_latent_kwargs + ) + + # First we get and eval the conditioning + conditioning = next(latents) + mx.eval(conditioning) + peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 + mx.metal.reset_peak_memory() + + progressor: Progressor = kwargs.pop("progressor", None) + # Actual denoising loop + for i, x_t in enumerate(latents): + mx.eval(x_t) + progressor.set_progress((i + 1) / num_steps) + + peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 + mx.metal.reset_peak_memory() + + # Decode them into images + decoded = [] + for i in range(n): + decoded.append(flux.decode(x_t[i : i + 1], latent_size)) # type: ignore + mx.eval(decoded[-1]) + peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 + peak_mem_overall = max( + peak_mem_conditioning, peak_mem_generation, peak_mem_decoding + ) + + images = [] + x = mx.concatenate(decoded, axis=0) + x = (x * 255).astype(mx.uint8) + for i in range(len(x)): + im = Image.fromarray(np.array(x[i])) + images.append(im) + + logger.debug( + f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB" + ) + logger.debug( + f"Peak memory used for the generation: {peak_mem_generation:.3f}GB" + ) + logger.debug(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB") + logger.debug(f"Peak memory used overall: {peak_mem_overall:.3f}GB") + + return handle_image_result(response_format, images) + + def image_to_image(self, **kwargs): + raise NotImplementedError + + def inpainting(self, **kwargs): + raise NotImplementedError diff --git a/xinference/thirdparty/mlx/__init__.py b/xinference/thirdparty/mlx/__init__.py new file mode 100644 index 0000000000..37f6558d95 --- /dev/null +++ b/xinference/thirdparty/mlx/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/xinference/thirdparty/mlx/flux/__init__.py b/xinference/thirdparty/mlx/flux/__init__.py new file mode 100644 index 0000000000..b1122d75d6 --- /dev/null +++ b/xinference/thirdparty/mlx/flux/__init__.py @@ -0,0 +1,15 @@ +# Copyright © 2024 Apple Inc. + +from .datasets import Dataset, load_dataset +from .flux import FluxPipeline +from .lora import LoRALinear +from .sampler import FluxSampler +from .trainer import Trainer +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) diff --git a/xinference/thirdparty/mlx/flux/autoencoder.py b/xinference/thirdparty/mlx/flux/autoencoder.py new file mode 100644 index 0000000000..6332bb570b --- /dev/null +++ b/xinference/thirdparty/mlx/flux/autoencoder.py @@ -0,0 +1,357 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import List + +import mlx.core as mx +import mlx.nn as nn +from mlx.nn.layers.upsample import upsample_nearest + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: List[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, + dims=in_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.q = nn.Linear(in_channels, in_channels) + self.k = nn.Linear(in_channels, in_channels) + self.v = nn.Linear(in_channels, in_channels) + self.proj_out = nn.Linear(in_channels, in_channels) + + def __call__(self, x: mx.array) -> mx.array: + B, H, W, C = x.shape + + y = x.reshape(B, 1, -1, C) + y = self.norm(y) + q = self.q(y) + k = self.k(y) + v = self.v(y) + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5)) + y = self.proj_out(y) + + return x + y.reshape(B, H, W, C) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, + dims=in_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, + dims=out_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Linear(in_channels, out_channels) + + def __call__(self, x): + h = x + h = self.norm1(h) + h = nn.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nn.silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def __call__(self, x: mx.array): + x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__(self, x: mx.array): + x = upsample_nearest(x, (2, 2)) + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = [] + block_in = self.ch + for i_level in range(self.num_resolutions): + block = [] + attn = [] # TODO: Remove the attn, nobody appends anything to it + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = {} + down["block"] = block + down["attn"] = attn + if i_level != self.num_resolutions - 1: + down["downsample"] = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid["attn_1"] = AttnBlock(block_in) + self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__(self, x: mx.array): + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level]["block"][i_block](hs[-1]) + + # TODO: Remove the attn + if len(self.down[i_level]["attn"]) > 0: + h = self.down[i_level]["attn"][i_block](h) + + hs.append(h) + + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level]["downsample"](hs[-1])) + + # middle + h = hs[-1] + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # end + h = self.norm_out(h) + h = nn.silu(h) + h = self.conv_out(h) + + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid["attn_1"] = AttnBlock(block_in) + self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = [] + for i_level in reversed(range(self.num_resolutions)): + block = [] + attn = [] # TODO: Remove the attn, nobody appends anything to it + + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = {} + up["block"] = block + up["attn"] = attn + if i_level != 0: + up["upsample"] = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def __call__(self, z: mx.array): + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level]["block"][i_block](h) + + # TODO: Remove the attn + if len(self.up[i_level]["attn"]) > 0: + h = self.up[i_level]["attn"][i_block](h) + + if i_level != 0: + h = self.up[i_level]["upsample"](h) + + # end + h = self.norm_out(h) + h = nn.silu(h) + h = self.conv_out(h) + + return h + + +class DiagonalGaussian(nn.Module): + def __call__(self, z: mx.array): + mean, logvar = mx.split(z, 2, axis=-1) + if self.training: + std = mx.exp(0.5 * logvar) + eps = mx.random.normal(shape=z.shape, dtype=z.dtype) + return mean + std * eps + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + if w.ndim == 4: + w = w.transpose(0, 2, 3, 1) + w = w.reshape(-1).reshape(w.shape) + if w.shape[1:3] == (1, 1): + w = w.squeeze((1, 2)) + new_weights[k] = w + return new_weights + + def encode(self, x: mx.array): + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: mx.array): + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def __call__(self, x: mx.array): + return self.decode(self.encode(x)) diff --git a/xinference/thirdparty/mlx/flux/clip.py b/xinference/thirdparty/mlx/flux/clip.py new file mode 100644 index 0000000000..d5a30dbf34 --- /dev/null +++ b/xinference/thirdparty/mlx/flux/clip.py @@ -0,0 +1,154 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import List, Optional + +import mlx.core as mx +import mlx.nn as nn + +_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu} + + +@dataclass +class CLIPTextModelConfig: + num_layers: int = 23 + model_dims: int = 1024 + num_heads: int = 16 + max_length: int = 77 + vocab_size: int = 49408 + hidden_act: str = "quick_gelu" + + @classmethod + def from_dict(cls, config): + return cls( + num_layers=config["num_hidden_layers"], + model_dims=config["hidden_size"], + num_heads=config["num_attention_heads"], + max_length=config["max_position_embeddings"], + vocab_size=config["vocab_size"], + hidden_act=config["hidden_act"], + ) + + +@dataclass +class CLIPOutput: + # The last_hidden_state indexed at the EOS token and possibly projected if + # the model has a projection layer + pooled_output: Optional[mx.array] = None + + # The full sequence output of the transformer after the final layernorm + last_hidden_state: Optional[mx.array] = None + + # A list of hidden states corresponding to the outputs of the transformer layers + hidden_states: Optional[List[mx.array]] = None + + +class CLIPEncoderLayer(nn.Module): + """The transformer encoder layer from CLIP.""" + + def __init__(self, model_dims: int, num_heads: int, activation: str): + super().__init__() + + self.layer_norm1 = nn.LayerNorm(model_dims) + self.layer_norm2 = nn.LayerNorm(model_dims) + + self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True) + + self.linear1 = nn.Linear(model_dims, 4 * model_dims) + self.linear2 = nn.Linear(4 * model_dims, model_dims) + + self.act = _ACTIVATIONS[activation] + + def __call__(self, x, attn_mask=None): + y = self.layer_norm1(x) + y = self.attention(y, y, y, attn_mask) + x = y + x + + y = self.layer_norm2(x) + y = self.linear1(y) + y = self.act(y) + y = self.linear2(y) + x = y + x + + return x + + +class CLIPTextModel(nn.Module): + """Implements the text encoder transformer from CLIP.""" + + def __init__(self, config: CLIPTextModelConfig): + super().__init__() + + self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims) + self.position_embedding = nn.Embedding(config.max_length, config.model_dims) + self.layers = [ + CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act) + for i in range(config.num_layers) + ] + self.final_layer_norm = nn.LayerNorm(config.model_dims) + + def _get_mask(self, N, dtype): + indices = mx.arange(N) + mask = indices[:, None] < indices[None] + mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9) + return mask + + def sanitize(self, weights): + new_weights = {} + for key, w in weights.items(): + # Remove prefixes + if key.startswith("text_model."): + key = key[11:] + if key.startswith("embeddings."): + key = key[11:] + if key.startswith("encoder."): + key = key[8:] + + # Map attention layers + if "self_attn." in key: + key = key.replace("self_attn.", "attention.") + if "q_proj." in key: + key = key.replace("q_proj.", "query_proj.") + if "k_proj." in key: + key = key.replace("k_proj.", "key_proj.") + if "v_proj." in key: + key = key.replace("v_proj.", "value_proj.") + + # Map ffn layers + if "mlp.fc1" in key: + key = key.replace("mlp.fc1", "linear1") + if "mlp.fc2" in key: + key = key.replace("mlp.fc2", "linear2") + + new_weights[key] = w + + return new_weights + + def __call__(self, x): + # Extract some shapes + B, N = x.shape + eos_tokens = x.argmax(-1) + + # Compute the embeddings + x = self.token_embedding(x) + x = x + self.position_embedding.weight[:N] + + # Compute the features from the transformer + mask = self._get_mask(N, x.dtype) + hidden_states = [] + for l in self.layers: + x = l(x, mask) + hidden_states.append(x) + + # Apply the final layernorm and return + x = self.final_layer_norm(x) + last_hidden_state = x + + # Select the EOS token + pooled_output = x[mx.arange(len(x)), eos_tokens] + + return CLIPOutput( + pooled_output=pooled_output, + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + ) diff --git a/xinference/thirdparty/mlx/flux/datasets.py b/xinference/thirdparty/mlx/flux/datasets.py new file mode 100644 index 0000000000..d31a09f179 --- /dev/null +++ b/xinference/thirdparty/mlx/flux/datasets.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path + +from PIL import Image + + +class Dataset: + def __getitem__(self, index: int): + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + +class LocalDataset(Dataset): + prompt_key = "prompt" + + def __init__(self, dataset: str, data_file): + self.dataset_base = Path(dataset) + with open(data_file, "r") as fid: + self._data = [json.loads(l) for l in fid] + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int): + item = self._data[index] + image = Image.open(self.dataset_base / item["image"]) + return image, item[self.prompt_key] + + +class LegacyDataset(LocalDataset): + prompt_key = "text" + + def __init__(self, dataset: str): + self.dataset_base = Path(dataset) + with open(self.dataset_base / "index.json") as f: + self._data = json.load(f)["data"] + + +class HuggingFaceDataset(Dataset): + + def __init__(self, dataset: str): + from datasets import load_dataset as hf_load_dataset + + self._df = hf_load_dataset(dataset)["train"] + + def __len__(self): + return len(self._df) + + def __getitem__(self, index: int): + item = self._df[index] + return item["image"], item["prompt"] + + +def load_dataset(dataset: str): + dataset_base = Path(dataset) + data_file = dataset_base / "train.jsonl" + legacy_file = dataset_base / "index.json" + + if data_file.exists(): + print(f"Load the local dataset {data_file} .", flush=True) + dataset = LocalDataset(dataset, data_file) + elif legacy_file.exists(): + print(f"Load the local dataset {legacy_file} .") + print() + print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.") + print(" See the README for details.") + print(flush=True) + dataset = LegacyDataset(dataset) + else: + print(f"Load the Hugging Face dataset {dataset} .", flush=True) + dataset = HuggingFaceDataset(dataset) + + return dataset diff --git a/xinference/thirdparty/mlx/flux/flux.py b/xinference/thirdparty/mlx/flux/flux.py new file mode 100644 index 0000000000..425cb4b9ea --- /dev/null +++ b/xinference/thirdparty/mlx/flux/flux.py @@ -0,0 +1,247 @@ +# Copyright © 2024 Apple Inc. + +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from tqdm import tqdm + +from .lora import LoRALinear +from .sampler import FluxSampler +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) + + +class FluxPipeline: + def __init__(self, name: str, model_path: str, t5_padding: bool = True): + self.dtype = mx.bfloat16 + self.name = name + self.t5_padding = t5_padding + + self.model_path = model_path + self.ae = load_ae(name, model_path) + self.flow = load_flow_model(name, model_path) + self.clip = load_clip(name, model_path) + self.clip_tokenizer = load_clip_tokenizer(name, model_path) + self.t5 = load_t5(name, model_path) + self.t5_tokenizer = load_t5_tokenizer(name, model_path) + self.sampler = FluxSampler(name) + + def ensure_models_are_loaded(self): + mx.eval( + self.ae.parameters(), + self.flow.parameters(), + self.clip.parameters(), + self.t5.parameters(), + ) + + def reload_text_encoders(self): + self.t5 = load_t5(self.name, self.model_path) + self.clip = load_clip(self.name, self.model_path) + + def tokenize(self, text): + t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) + clip_tokens = self.clip_tokenizer.encode(text) + return t5_tokens, clip_tokens + + def _prepare_latent_images(self, x): + b, h, w, c = x.shape + + # Pack the latent image to 2x2 patches + x = x.reshape(b, h // 2, 2, w // 2, 2, c) + x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) + + # Create positions ids used to positionally encode each patch. Due to + # the way RoPE works, this results in an interesting positional + # encoding where parts of the feature are holding different positional + # information. Namely, the first part holds information independent of + # the spatial position (hence 0s), the 2nd part holds vertical spatial + # information and the last one horizontal. + i = mx.zeros((h // 2, w // 2), dtype=mx.int32) + j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") + x_ids = mx.stack([i, j, k], axis=-1) + x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) + + return x, x_ids + + def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): + # Prepare the text features + txt = self.t5(t5_tokens) + if len(txt) == 1 and n_images > 1: + txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) + txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) + + # Prepare the clip text features + vec = self.clip(clip_tokens).pooled_output + if len(vec) == 1 and n_images > 1: + vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) + + return txt, txt_ids, vec + + def _denoising_loop( + self, + x_t, + x_ids, + txt, + txt_ids, + vec, + num_steps: int = 35, + guidance: float = 4.0, + start: float = 1, + stop: float = 0, + ): + B = len(x_t) + + def scalar(x): + return mx.full((B,), x, dtype=self.dtype) + + guidance = scalar(guidance) + timesteps = self.sampler.timesteps( + num_steps, + x_t.shape[1], + start=start, + stop=stop, + ) + for i in range(num_steps): + t = timesteps[i] + t_prev = timesteps[i + 1] + + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=scalar(t), + guidance=guidance, + ) + x_t = self.sampler.step(pred, x_t, t, t_prev) + + yield x_t + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + ): + # Set the PRNG state + if seed is not None: + mx.random.seed(seed) + + # Create the latent variables + x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) + x_T, x_ids = self._prepare_latent_images(x_T) + + # Get the conditioning + t5_tokens, clip_tokens = self.tokenize(text) + txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) + + # Yield the conditioning for controlled evaluation by the caller + yield (x_T, x_ids, txt, txt_ids, vec) + + # Yield the latent sequences from the denoising loop + yield from self._denoising_loop( + x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance + ) + + def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): + h, w = latent_size + x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) + x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) + x = self.ae.decode(x) + return mx.clip(x + 1, 0, 2) * 0.5 + + def generate_images( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + reload_text_encoders: bool = True, + progress: bool = True, + ): + latents = self.generate_latents( + text, n_images, num_steps, guidance, latent_size, seed + ) + mx.eval(next(latents)) + + if reload_text_encoders: + self.reload_text_encoders() + + for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): + mx.eval(x_t) + + images = [] + for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"): + images.append(self.decode(x_t[i : i + 1])) + mx.eval(images[-1]) + images = mx.concatenate(images, axis=0) + mx.eval(images) + + return images + + def training_loss( + self, + x_0: mx.array, + t5_features: mx.array, + clip_features: mx.array, + guidance: mx.array, + ): + # Get the text conditioning + txt = t5_features + txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) + vec = clip_features + + # Prepare the latent input + x_0, x_ids = self._prepare_latent_images(x_0) + + # Forward process + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) + x_t = mx.stop_gradient(x_t) + + # Do the denoising + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t, + guidance=guidance, + ) + + return (pred + x_0 - eps).square().mean() + + def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): + """Swap the linear layers in the transformer blocks with LoRA layers.""" + all_blocks = self.flow.double_blocks + self.flow.single_blocks + all_blocks.reverse() + num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) + for i, block in zip(range(num_blocks), all_blocks): + loras = [] + for name, module in block.named_modules(): + if isinstance(module, nn.Linear): + loras.append((name, LoRALinear.from_base(module, r=rank))) + block.update_modules(tree_unflatten(loras)) + + def fuse_lora_layers(self): + fused_layers = [] + for name, module in self.flow.named_modules(): + if isinstance(module, LoRALinear): + fused_layers.append((name, module.fuse())) + self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/xinference/thirdparty/mlx/flux/layers.py b/xinference/thirdparty/mlx/flux/layers.py new file mode 100644 index 0000000000..12397904e8 --- /dev/null +++ b/xinference/thirdparty/mlx/flux/layers.py @@ -0,0 +1,302 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +def _rope(pos: mx.array, dim: int, theta: float): + scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim + omega = 1.0 / (theta**scale) + x = pos[..., None] * omega + cosx = mx.cos(x) + sinx = mx.sin(x) + pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1) + pe = pe.reshape(*pe.shape[:-1], 2, 2) + + return pe + + +@partial(mx.compile, shapeless=True) +def _ab_plus_cd(a, b, c, d): + return a * b + c * d + + +def _apply_rope(x, pe): + s = x.shape + x = x.reshape(*s[:-1], -1, 1, 2) + x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1]) + return x.reshape(s) + + +def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array): + B, H, L, D = q.shape + + q = _apply_rope(q, pe) + k = _apply_rope(k, pe) + x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5)) + + return x.transpose(0, 2, 1, 3).reshape(B, L, -1) + + +def timestep_embedding( + t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0 +): + half = dim // 2 + freqs = mx.arange(0, half, dtype=mx.float32) / half + freqs = freqs * (-math.log(max_period)) + freqs = mx.exp(freqs) + + x = (time_factor * t)[:, None] * freqs[None] + x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1) + + return x.astype(t.dtype) + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: List[int]): + super().__init__() + + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def __call__(self, ids: mx.array): + n_axes = ids.shape[-1] + pe = mx.concatenate( + [_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + axis=-3, + ) + + return pe[:, None] + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def __call__(self, x: mx.array) -> mx.array: + return self.out_layer(nn.silu(self.in_layer(x))) + + +class QKNorm(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = nn.RMSNorm(dim) + self.key_norm = nn.RMSNorm(dim) + + def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]: + return self.query_norm(q), self.key_norm(k) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def __call__(self, x: mx.array, pe: mx.array) -> mx.array: + H = self.num_heads + B, L, _ = x.shape + qkv = self.qkv(x) + q, k, v = mx.split(qkv, 3, axis=-1) + q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + q, k = self.norm(q, k) + x = _attention(q, k, v, pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: mx.array + scale: mx.array + gate: mx.array + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]: + x = self.lin(nn.silu(x)) + xs = mx.split(x[:, None, :], self.multiplier, axis=-1) + + mod1 = ModulationOut(*xs[:3]) + mod2 = ModulationOut(*xs[3:]) if self.is_double else None + + return mod1, mod2 + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approx="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approx="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def __call__( + self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array + ) -> Tuple[mx.array, mx.array]: + B, L, _ = img.shape + _, S, _ = txt.shape + H = self.num_heads + + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1) + img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_q, img_k = self.img_attn.norm(img_q, img_k) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1) + txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k) + + # run actual attention + q = mx.concatenate([txt_q, img_q], axis=2) + k = mx.concatenate([txt_k, img_k], axis=2) + v = mx.concatenate([txt_v, img_v], axis=2) + + attn = _attention(q, k, v, pe) + txt_attn, img_attn = mx.split(attn, [S], axis=1) + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + + return img, txt + + +class SingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: Optional[float] = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approx="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def __call__(self, x: mx.array, vec: mx.array, pe: mx.array): + B, L, _ = x.shape + H = self.num_heads + + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + + q, k, v, mlp = mx.split( + self.linear1(x_mod), + [self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size], + axis=-1, + ) + q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + q, k = self.norm(q, k) + + # compute attention + y = _attention(q, k, v, pe) + + # compute activation in mlp stream, cat again and run second linear layer + y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2)) + return x + mod.gate * y + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def __call__(self, x: mx.array, vec: mx.array): + shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/xinference/thirdparty/mlx/flux/lora.py b/xinference/thirdparty/mlx/flux/lora.py new file mode 100644 index 0000000000..b0c8ae5605 --- /dev/null +++ b/xinference/thirdparty/mlx/flux/lora.py @@ -0,0 +1,76 @@ +# Copyright © 2024 Apple Inc. + +import math + +import mlx.core as mx +import mlx.nn as nn + + +class LoRALinear(nn.Module): + @staticmethod + def from_base( + linear: nn.Linear, + r: int = 8, + dropout: float = 0.0, + scale: float = 1.0, + ): + output_dims, input_dims = linear.weight.shape + lora_lin = LoRALinear( + input_dims=input_dims, + output_dims=output_dims, + r=r, + dropout=dropout, + scale=scale, + ) + lora_lin.linear = linear + return lora_lin + + def fuse(self): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + dtype = weight.dtype + + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = self.scale * self.lora_b.T + lora_a = self.lora_a.T + fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype) + if bias: + fused_linear.bias = linear.bias + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 1.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, r), + ) + self.lora_b = mx.zeros(shape=(r, output_dims)) + + def __call__(self, x): + y = self.linear(x) + z = (self.dropout(x) @ self.lora_a) @ self.lora_b + return y + (self.scale * z).astype(x.dtype) diff --git a/xinference/thirdparty/mlx/flux/model.py b/xinference/thirdparty/mlx/flux/model.py new file mode 100644 index 0000000000..18ea70b08a --- /dev/null +++ b/xinference/thirdparty/mlx/flux/model.py @@ -0,0 +1,134 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + if params.guidance_embed + else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + + self.single_blocks = [ + SingleStreamBlock( + self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio + ) + for _ in range(params.depth_single_blocks) + ] + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + if k.endswith(".scale"): + k = k[:-6] + ".weight" + for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: + if f".{seq}." in k: + k = k.replace(f".{seq}.", f".{seq}.layers.") + break + new_weights[k] = w + return new_weights + + def __call__( + self, + img: mx.array, + img_ids: mx.array, + txt: mx.array, + txt_ids: mx.array, + timesteps: mx.array, + y: mx.array, + guidance: Optional[mx.array] = None, + ) -> mx.array: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = mx.concatenate([txt_ids, img_ids], axis=1) + pe = self.pe_embedder(ids).astype(img.dtype) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = mx.concatenate([txt, img], axis=1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) + + return img diff --git a/xinference/thirdparty/mlx/flux/sampler.py b/xinference/thirdparty/mlx/flux/sampler.py new file mode 100644 index 0000000000..3bff1ca275 --- /dev/null +++ b/xinference/thirdparty/mlx/flux/sampler.py @@ -0,0 +1,56 @@ +# Copyright © 2024 Apple Inc. + +import math +from functools import lru_cache + +import mlx.core as mx + + +class FluxSampler: + def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5): + self._base_shift = base_shift + self._max_shift = max_shift + self._schnell = "schnell" in name + + def _time_shift(self, x, t): + x1, x2 = 256, 4096 + t1, t2 = self._base_shift, self._max_shift + exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1) + t = exp_mu / (exp_mu + (1 / t - 1)) + return t + + @lru_cache + def timesteps( + self, num_steps, image_sequence_length, start: float = 1, stop: float = 0 + ): + t = mx.linspace(start, stop, num_steps + 1) + + if self._schnell: + t = self._time_shift(image_sequence_length, t) + + return t.tolist() + + def random_timesteps(self, B, L, dtype=mx.float32, key=None): + if self._schnell: + # TODO: Should we upweigh 1 and 0.75? + t = mx.random.randint(1, 5, shape=(B,), key=key) + t = t.astype(dtype) / 4 + else: + t = mx.random.uniform(shape=(B,), dtype=dtype, key=key) + t = self._time_shift(L, t) + + return t + + def sample_prior(self, shape, dtype=mx.float32, key=None): + return mx.random.normal(shape, dtype=dtype, key=key) + + def add_noise(self, x, t, noise=None, key=None): + noise = ( + noise + if noise is not None + else mx.random.normal(x.shape, dtype=x.dtype, key=key) + ) + return x * (1 - t) + t * noise + + def step(self, pred, x_t, t, t_prev): + return x_t + (t_prev - t) * pred diff --git a/xinference/thirdparty/mlx/flux/t5.py b/xinference/thirdparty/mlx/flux/t5.py new file mode 100644 index 0000000000..cf0515cd5e --- /dev/null +++ b/xinference/thirdparty/mlx/flux/t5.py @@ -0,0 +1,244 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +_SHARED_REPLACEMENT_PATTERNS = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), +] + +_ENCODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), +] + + +@dataclass +class T5Config: + vocab_size: int + num_layers: int + num_heads: int + relative_attention_num_buckets: int + d_kv: int + d_model: int + feed_forward_proj: str + tie_word_embeddings: bool + + d_ff: Optional[int] = None + num_decoder_layers: Optional[int] = None + relative_attention_max_distance: int = 128 + layer_norm_epsilon: float = 1e-6 + + @classmethod + def from_dict(cls, config): + return cls( + vocab_size=config["vocab_size"], + num_layers=config["num_layers"], + num_heads=config["num_heads"], + relative_attention_num_buckets=config["relative_attention_num_buckets"], + d_kv=config["d_kv"], + d_model=config["d_model"], + feed_forward_proj=config["feed_forward_proj"], + tie_word_embeddings=config["tie_word_embeddings"], + d_ff=config.get("d_ff", 4 * config["d_model"]), + num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]), + relative_attention_max_distance=config.get( + "relative_attention_max_distance", 128 + ), + layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6), + ) + + +class RelativePositionBias(nn.Module): + def __init__(self, config: T5Config, bidirectional: bool): + self.bidirectional = bidirectional + self.num_buckets = config.relative_attention_num_buckets + self.max_distance = config.relative_attention_max_distance + self.n_heads = config.num_heads + self.embeddings = nn.Embedding(self.num_buckets, self.n_heads) + + @staticmethod + def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance): + num_buckets = num_buckets // 2 if bidirectional else num_buckets + max_exact = num_buckets // 2 + + abspos = rpos.abs() + is_small = abspos < max_exact + + scale = (num_buckets - max_exact) / math.log(max_distance / max_exact) + buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16) + buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1) + + buckets = mx.where(is_small, abspos, buckets_large) + if bidirectional: + buckets = buckets + (rpos > 0) * num_buckets + else: + buckets = buckets * (rpos < 0) + + return buckets + + def __call__(self, query_length: int, key_length: int, offset: int = 0): + """Compute binned relative position bias""" + context_position = mx.arange(offset, query_length)[:, None] + memory_position = mx.arange(key_length)[None, :] + + # shape (query_length, key_length) + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + + # shape (query_length, key_length, num_heads) + values = self.embeddings(relative_position_bucket) + + # shape (num_heads, query_length, key_length) + return values.transpose(2, 0, 1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + inner_dim = config.d_kv * config.num_heads + self.num_heads = config.num_heads + self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) + + def __call__( + self, + queries: mx.array, + keys: mx.array, + values: mx.array, + mask: Optional[mx.array], + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> [mx.array, Tuple[mx.array, mx.array]]: + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, _ = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + keys = mx.concatenate([key_cache, keys], axis=3) + values = mx.concatenate([value_cache, values], axis=2) + + values_hat = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype) + ) + values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat), (keys, values) + + +class DenseActivation(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + mlp_dims = config.d_ff or config.d_model * 4 + self.gated = config.feed_forward_proj.startswith("gated") + if self.gated: + self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) + else: + self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) + activation = config.feed_forward_proj.removeprefix("gated-") + if activation == "relu": + self.act = nn.relu + elif activation == "gelu": + self.act = nn.gelu + elif activation == "silu": + self.act = nn.silu + else: + raise ValueError(f"Unknown activation: {activation}") + + def __call__(self, x): + if self.gated: + hidden_act = self.act(self.wi_0(x)) + hidden_linear = self.wi_1(x) + x = hidden_act * hidden_linear + else: + x = self.act(self.wi(x)) + return self.wo(x) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.attention = MultiHeadAttention(config) + self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) + + def __call__(self, x, mask): + y = self.ln1(x) + y, _ = self.attention(y, y, y, mask=mask) + x = x + y + + y = self.ln2(x) + y = self.dense(y) + return x + y + + +class TransformerEncoder(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.layers = [ + TransformerEncoderLayer(config) for i in range(config.num_layers) + ] + self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) + + def __call__(self, x: mx.array): + pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) + pos_bias = pos_bias.astype(x.dtype) + for layer in self.layers: + x = layer(x, mask=pos_bias) + return self.ln(x) + + +class T5Encoder(nn.Module): + def __init__(self, config: T5Config): + self.wte = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = TransformerEncoder(config) + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + for old, new in _SHARED_REPLACEMENT_PATTERNS: + k = k.replace(old, new) + if k.startswith("encoder."): + for old, new in _ENCODER_REPLACEMENT_PATTERNS: + k = k.replace(old, new) + new_weights[k] = w + return new_weights + + def __call__(self, inputs: mx.array): + return self.encoder(self.wte(inputs)) diff --git a/xinference/thirdparty/mlx/flux/tokenizers.py b/xinference/thirdparty/mlx/flux/tokenizers.py new file mode 100644 index 0000000000..796ef3896f --- /dev/null +++ b/xinference/thirdparty/mlx/flux/tokenizers.py @@ -0,0 +1,185 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import regex +from sentencepiece import SentencePieceProcessor + + +class CLIPTokenizer: + """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" + + def __init__(self, bpe_ranks, vocab, max_length=77): + self.max_length = max_length + self.bpe_ranks = bpe_ranks + self.vocab = vocab + self.pat = regex.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + regex.IGNORECASE, + ) + + self._cache = {self.bos: self.bos, self.eos: self.eos} + + @property + def bos(self): + return "<|startoftext|>" + + @property + def bos_token(self): + return self.vocab[self.bos] + + @property + def eos(self): + return "<|endoftext|>" + + @property + def eos_token(self): + return self.vocab[self.eos] + + def bpe(self, text): + if text in self._cache: + return self._cache[text] + + unigrams = list(text[:-1]) + [text[-1] + ""] + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + if not unique_bigrams: + return unigrams + + # In every iteration try to merge the two most likely bigrams. If none + # was merged we are done. + # + # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py + while unique_bigrams: + bigram = min( + unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) + ) + if bigram not in self.bpe_ranks: + break + + new_unigrams = [] + skip = False + for a, b in zip(unigrams, unigrams[1:]): + if skip: + skip = False + continue + + if (a, b) == bigram: + new_unigrams.append(a + b) + skip = True + + else: + new_unigrams.append(a) + + if not skip: + new_unigrams.append(b) + + unigrams = new_unigrams + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + self._cache[text] = unigrams + + return unigrams + + def tokenize(self, text, prepend_bos=True, append_eos=True): + if isinstance(text, list): + return [self.tokenize(t, prepend_bos, append_eos) for t in text] + + # Lower case cleanup and split according to self.pat. Hugging Face does + # a much more thorough job here but this should suffice for 95% of + # cases. + clean_text = regex.sub(r"\s+", " ", text.lower()) + tokens = regex.findall(self.pat, clean_text) + + # Split the tokens according to the byte-pair merge file + bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] + + # Map to token ids and return + tokens = [self.vocab[t] for t in bpe_tokens] + if prepend_bos: + tokens = [self.bos_token] + tokens + if append_eos: + tokens.append(self.eos_token) + + if len(tokens) > self.max_length: + tokens = tokens[: self.max_length] + if append_eos: + tokens[-1] = self.eos_token + + return tokens + + def encode(self, text): + if not isinstance(text, list): + return self.encode([text]) + + tokens = self.tokenize(text) + length = max(len(t) for t in tokens) + for t in tokens: + t.extend([self.eos_token] * (length - len(t))) + + return mx.array(tokens) + + +class T5Tokenizer: + def __init__(self, model_file, max_length=512): + self._tokenizer = SentencePieceProcessor(model_file) + self.max_length = max_length + + @property + def pad(self): + try: + return self._tokenizer.id_to_piece(self.pad_token) + except IndexError: + return None + + @property + def pad_token(self): + return self._tokenizer.pad_id() + + @property + def bos(self): + try: + return self._tokenizer.id_to_piece(self.bos_token) + except IndexError: + return None + + @property + def bos_token(self): + return self._tokenizer.bos_id() + + @property + def eos(self): + try: + return self._tokenizer.id_to_piece(self.eos_token) + except IndexError: + return None + + @property + def eos_token(self): + return self._tokenizer.eos_id() + + def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True): + if isinstance(text, list): + return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text] + + tokens = self._tokenizer.encode(text) + + if prepend_bos and self.bos_token >= 0: + tokens = [self.bos_token] + tokens + if append_eos and self.eos_token >= 0: + tokens.append(self.eos_token) + if pad and len(tokens) < self.max_length and self.pad_token >= 0: + tokens += [self.pad_token] * (self.max_length - len(tokens)) + + return tokens + + def encode(self, text, pad=True): + if not isinstance(text, list): + return self.encode([text], pad=pad) + + pad_token = self.pad_token if self.pad_token >= 0 else 0 + tokens = self.tokenize(text, pad=pad) + length = max(len(t) for t in tokens) + for t in tokens: + t.extend([pad_token] * (length - len(t))) + + return mx.array(tokens) diff --git a/xinference/thirdparty/mlx/flux/trainer.py b/xinference/thirdparty/mlx/flux/trainer.py new file mode 100644 index 0000000000..40a126e886 --- /dev/null +++ b/xinference/thirdparty/mlx/flux/trainer.py @@ -0,0 +1,98 @@ +import mlx.core as mx +import numpy as np +from PIL import Image, ImageFile +from tqdm import tqdm + +from .datasets import Dataset +from .flux import FluxPipeline + + +class Trainer: + + def __init__(self, flux: FluxPipeline, dataset: Dataset, args): + self.flux = flux + self.dataset = dataset + self.args = args + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def _random_crop_resize(self, img): + resolution = self.args.resolution + width, height = img.size + + a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() + + # Random crop the input image between 0.8 to 1.0 of its original dimensions + crop_size = ( + max((0.8 + 0.2 * a) * width, resolution[0]), + max((0.8 + 0.2 * b) * height, resolution[1]), + ) + pan = (width - crop_size[0], height - crop_size[1]) + img = img.crop( + ( + pan[0] * c, + pan[1] * d, + crop_size[0] + pan[0] * c, + crop_size[1] + pan[1] * d, + ) + ) + + # Fit the largest rectangle with the ratio of resolution in the image + # rectangle. + width, height = crop_size + ratio = resolution[0] / resolution[1] + r1 = (height * ratio, height) + r2 = (width, width / ratio) + r = r1 if r1[0] <= width else r2 + img = img.crop( + ( + (width - r[0]) / 2, + (height - r[1]) / 2, + (width + r[0]) / 2, + (height + r[1]) / 2, + ) + ) + + # Finally resize the image to resolution + img = img.resize(resolution, Image.LANCZOS) + + return mx.array(np.array(img)) + + def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int): + for i in range(num_augmentations): + img = self._random_crop_resize(input_img) + img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(self.flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def _encode_prompt(self, prompt): + t5_tok, clip_tok = self.flux.tokenize([prompt]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def encode_dataset(self): + """Encode the images & prompt in the latent space to prepare for training.""" + self.flux.ae.eval() + for image, prompt in tqdm(self.dataset, desc="encode dataset"): + self._encode_image(image, self.args.num_augmentations) + self._encode_prompt(prompt) + + def iterate(self, batch_size): + xs = mx.concatenate(self.latents) + t5 = mx.concatenate(self.t5_features) + clip = mx.concatenate(self.clip_features) + mx.eval(xs, t5, clip) + n_aug = self.args.num_augmentations + while True: + x_indices = mx.random.permutation(len(self.latents)) + c_indices = x_indices // n_aug + for i in range(0, len(self.latents), batch_size): + x_i = x_indices[i : i + batch_size] + c_i = c_indices[i : i + batch_size] + yield xs[x_i], t5[c_i], clip[c_i] diff --git a/xinference/thirdparty/mlx/flux/utils.py b/xinference/thirdparty/mlx/flux/utils.py new file mode 100644 index 0000000000..47e7fe9e33 --- /dev/null +++ b/xinference/thirdparty/mlx/flux/utils.py @@ -0,0 +1,179 @@ +# Copyright © 2024 Apple Inc. + +import json +import os +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx + +from .autoencoder import AutoEncoder, AutoEncoderParams +from .clip import CLIPTextModel, CLIPTextModelConfig +from .model import Flux, FluxParams +from .t5 import T5Config, T5Encoder +from .tokenizers import CLIPTokenizer, T5Tokenizer + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: Optional[str] + ae_path: Optional[str] + repo_id: Optional[str] + repo_flow: Optional[str] + repo_ae: Optional[str] + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def load_flow_model(name: str, ckpt_path: str): + # Make the model + model = Flux(configs[name].params) + + # Load the checkpoint if needed + if os.path.isdir(ckpt_path): + ckpt_path = os.path.join(ckpt_path, configs[name].repo_flow) + weights = mx.load(ckpt_path) + weights = model.sanitize(weights) + model.load_weights(list(weights.items())) + + return model + + +def load_ae(name: str, ckpt_path: str): + # Make the autoencoder + ae = AutoEncoder(configs[name].ae_params) + + # Load the checkpoint if needed + ckpt_path = os.path.join(ckpt_path, "ae.safetensors") + weights = mx.load(ckpt_path) + weights = ae.sanitize(weights) + ae.load_weights(list(weights.items())) + + return ae + + +def load_clip(name: str, ckpt_path: str): + config_path = os.path.join(ckpt_path, "text_encoder/config.json") + with open(config_path) as f: + config = CLIPTextModelConfig.from_dict(json.load(f)) + + # Make the clip text encoder + clip = CLIPTextModel(config) + + ckpt_path = os.path.join(ckpt_path, "text_encoder/model.safetensors") + weights = mx.load(ckpt_path) + weights = clip.sanitize(weights) + clip.load_weights(list(weights.items())) + + return clip + + +def load_t5(name: str, ckpt_path: str): + config_path = os.path.join(ckpt_path, "text_encoder_2/config.json") + with open(config_path) as f: + config = T5Config.from_dict(json.load(f)) + + # Make the T5 model + t5 = T5Encoder(config) + + model_index = os.path.join(ckpt_path, "text_encoder_2/model.safetensors.index.json") + weight_files = set() + with open(model_index) as f: + for _, w in json.load(f)["weight_map"].items(): + weight_files.add(w) + weights = {} + for w in weight_files: + w = f"text_encoder_2/{w}" + w = os.path.join(ckpt_path, w) + weights.update(mx.load(w)) + weights = t5.sanitize(weights) + t5.load_weights(list(weights.items())) + + return t5 + + +def load_clip_tokenizer(name: str, ckpt_path: str): + vocab_file = os.path.join(ckpt_path, "tokenizer/vocab.json") + with open(vocab_file, encoding="utf-8") as f: + vocab = json.load(f) + + merges_file = os.path.join(ckpt_path, "tokenizer/merges.txt") + with open(merges_file, encoding="utf-8") as f: + bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] + bpe_merges = [tuple(m.split()) for m in bpe_merges] + bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) + + return CLIPTokenizer(bpe_ranks, vocab, max_length=77) + + +def load_t5_tokenizer(name: str, ckpt_path: str, pad: bool = True): + model_file = os.path.join(ckpt_path, "tokenizer_2/spiece.model") + return T5Tokenizer(model_file, 256 if "schnell" in name else 512)