Skip to content

Commit

Permalink
Merge pull request #112 from okotaku/feat/support_random_mask_choice
Browse files Browse the repository at this point in the history
[Feature] Support RandomChoice for random mask
  • Loading branch information
okotaku authored Dec 13, 2023
2 parents 62913d9 + 7427522 commit 3b62bbe
Show file tree
Hide file tree
Showing 14 changed files with 321 additions and 4 deletions.
86 changes: 86 additions & 0 deletions configs/_base_/datasets/dog_inpaint_multiple_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
train_pipeline = [
dict(type="torchvision/Resize", size=512, interpolation="bilinear"),
dict(type="RandomCrop", size=512),
dict(type="RandomHorizontalFlip", p=0.5),
dict(type="RandomChoice",
transforms=[
[dict(
type="LoadMask",
mask_mode="irregular",
mask_config=dict(
num_vertices=(4, 10),
max_angle=6.0,
length_range=(20, 200),
brush_width=(10, 100),
area_ratio_range=(0.15, 0.65)))],
[dict(
type="LoadMask",
mask_mode="irregular",
mask_config=dict(
num_vertices=(1, 5),
max_angle=6.0,
length_range=(40, 450),
brush_width=(20, 250),
area_ratio_range=(0.15, 0.65)))],
[dict(
type="LoadMask",
mask_mode="irregular",
mask_config=dict(
num_vertices=(4, 70),
max_angle=6.0,
length_range=(15, 100),
brush_width=(5, 20),
area_ratio_range=(0.15, 0.65)))],
[dict(
type="LoadMask",
mask_mode="bbox",
mask_config=dict(
max_bbox_shape=(150, 150),
max_bbox_delta=50,
min_margin=0))],
[dict(
type="LoadMask",
mask_mode="bbox",
mask_config=dict(
max_bbox_shape=(300, 300),
max_bbox_delta=100,
min_margin=10))],
]),
dict(type="torchvision/ToTensor"),
dict(type="MaskToTensor"),
dict(type="DumpImage", max_imgs=10, dump_dir="work_dirs/dump"),
dict(type="torchvision/Normalize", mean=[0.5], std=[0.5]),
dict(type="GetMaskedImage"),
dict(type="PackInputs",
input_keys=["img", "mask", "masked_image", "text"]),
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
dataset=dict(
type="HFDreamBoothDataset",
dataset="diffusers/dog-example",
instance_prompt="a photo of sks dog",
pipeline=train_pipeline,
class_prompt=None),
sampler=dict(type="InfiniteSampler", shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["a photo of sks dog"] * 4,
image=["https://github.com/okotaku/diffengine/assets/24734142/8e02bd0e-9dcc-49b6-94b0-86ab3b40bc2b"] * 4, # noqa
mask=["https://github.com/okotaku/diffengine/assets/24734142/d0de4fb9-9183-418a-970d-582e9324f05d"] * 2 + [ # noqa
"https://github.com/okotaku/diffengine/assets/24734142/a40d1a4f-9c47-4fa0-936e-88a49c92c8d7"] * 2, # noqa
by_epoch=False,
width=512,
height=512,
interval=100),
dict(type="SDCheckpointHook"),
]
8 changes: 8 additions & 0 deletions configs/stable_diffusion_inpaint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,11 @@ You can see more details on [`docs/source/run_guides/run_sd.md`](../../docs/sour
![mask](https://github.com/okotaku/diffengine/assets/24734142/d0de4fb9-9183-418a-970d-582e9324f05d)

![example](https://github.com/okotaku/diffengine/assets/24734142/f9ec820b-af75-4c74-8c0b-6558a0a19b95)

#### stable_diffusion_inpaint_dog_multi_mask

![input](https://github.com/okotaku/diffengine/assets/24734142/8e02bd0e-9dcc-49b6-94b0-86ab3b40bc2b)

![mask](https://github.com/okotaku/diffengine/assets/24734142/a40d1a4f-9c47-4fa0-936e-88a49c92c8d7)

![example](https://github.com/okotaku/diffengine/assets/24734142/f9766a71-0845-4dea-a037-f7dabfca200e)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = [
"../_base_/models/stable_diffusion_inpaint.py",
"../_base_/datasets/dog_inpaint_multiple_mask.py",
"../_base_/schedules/stable_diffusion_1k.py",
"../_base_/default_runtime.py",
]
17 changes: 16 additions & 1 deletion diffengine/datasets/hf_dreambooth_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa: S311,RUF012
import copy
import hashlib
import os
import random
import shutil
from collections.abc import Sequence
Expand Down Expand Up @@ -45,6 +46,9 @@ class HFDreamBoothDataset(Dataset):
class_prompt (Optional[str]): The prompt to specify images in the same
class as provided instance images. Defaults to None.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
csv (str, optional): Image path csv file name when loading local
folder. If None, the dataset will be loaded from image folders.
Defaults to None.
cache_dir (str, optional): The directory where the downloaded datasets
will be stored.Defaults to None.
"""
Expand All @@ -65,8 +69,12 @@ def __init__(self,
class_image_config: dict | None = None,
class_prompt: str | None = None,
pipeline: Sequence = (),
csv: str | None = None,
cache_dir: str | None = None) -> None:

self.dataset_name = dataset
self.csv = csv

if class_image_config is None:
class_image_config = {
"model": "runwayml/stable-diffusion-v1-5",
Expand All @@ -77,7 +85,12 @@ def __init__(self,
}
if Path(dataset).exists():
# load local folder
self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"]
if csv is not None:
data_file = os.path.join(dataset, csv)
self.dataset = load_dataset(
"csv", data_files=data_file, cache_dir=cache_dir)["train"]
else:
self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"]
else: # noqa
# load huggingface online
if dataset_sub_dir is not None:
Expand Down Expand Up @@ -172,6 +185,8 @@ def __getitem__(self, idx: int) -> dict:
data_info = self.dataset[idx]
image = data_info[self.image_column]
if isinstance(image, str):
if self.csv is not None:
image = os.path.join(self.dataset_name, image)
image = Image.open(image)
image = image.convert("RGB")
result = {"img": image, "text": self.instance_prompt}
Expand Down
4 changes: 4 additions & 0 deletions diffengine/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .loading import LoadMask
from .processing import (
TRANSFORMS,
AddConstantCaption,
CenterCrop,
CLIPImageProcessor,
ComputePixArtImgInfo,
Expand All @@ -17,6 +18,7 @@
SaveImageShape,
T5TextPreprocess,
)
from .wrappers import RandomChoice

__all__ = [
"BaseTransform",
Expand All @@ -36,4 +38,6 @@
"LoadMask",
"MaskToTensor",
"GetMaskedImage",
"RandomChoice",
"AddConstantCaption",
]
31 changes: 31 additions & 0 deletions diffengine/datasets/transforms/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,34 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
"""
results[self.key] = results["img"] * results["mask"]
return results


@TRANSFORMS.register_module()
class AddConstantCaption(BaseTransform):
"""AddConstantCaption.
Example. "a dog." * constant_caption="in szn style"
-> "a dog. in szn style"
Args:
----
keys (List[str]): `keys` to apply augmentation from results.
"""

def __init__(self, constant_caption: str, keys=None) -> None:
if keys is None:
keys = ["text"]
self.constant_caption: str = constant_caption
self.keys = keys

def transform(self,
results: dict) -> dict | tuple[list, list] | None:
"""Transform.
Args:
----
results (dict): The result dict.
"""
for k in self.keys:
results[k] = results[k] + " " + self.constant_caption
return results
69 changes: 69 additions & 0 deletions diffengine/datasets/transforms/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections.abc import Callable, Iterator

import mmengine
import numpy as np
from mmengine.dataset.base_dataset import Compose

from diffengine.datasets.transforms.base import BaseTransform
from diffengine.registry import TRANSFORMS

Transform = dict | Callable[[dict], dict]


@TRANSFORMS.register_module()
class RandomChoice(BaseTransform):
"""Process data with a randomly chosen transform from given candidates.
Copied from mmcv/transforms/wrappers.py.
Args:
----
transforms (list[list]): A list of transform candidates, each is a
sequence of transforms.
prob (list[float], optional): The probabilities associated
with each pipeline. The length should be equal to the pipeline
number and the sum should be 1. If not given, a uniform
distribution will be assumed.
Examples:
--------
>>> # config
>>> pipeline = [
>>> dict(type='RandomChoice',
>>> transforms=[
>>> [dict(type='RandomHorizontalFlip')], # subpipeline 1
>>> [dict(type='RandomRotate')], # subpipeline 2
>>> ]
>>> )
>>> ]
"""

def __init__(self,
transforms: list[Transform | list[Transform]],
prob: list[float] | None = None) -> None:

super().__init__()

if prob is not None:
assert mmengine.is_seq_of(prob, float)
assert len(transforms) == len(prob),(
"``transforms`` and ``prob`` must have same lengths. "
f"Got {len(transforms)} vs {len(prob)}.")
assert sum(prob) == 1

self.prob = prob
self.transforms = [Compose(transforms) for transforms in transforms]

def __iter__(self) -> Iterator:
"""Iterate over transforms."""
return iter(self.transforms)

def random_pipeline_index(self) -> int:
"""Return a random transform index."""
indices = np.arange(len(self.transforms))
return np.random.choice(indices, p=self.prob) # noqa

def transform(self, results: dict) -> dict | None:
"""Randomly choose a transform to apply."""
idx = self.random_pipeline_index()
return self.transforms[idx](results)
2 changes: 2 additions & 0 deletions diffengine/engine/hooks/peft_save_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
model.unet.save_pretrained(osp.join(ckpt_path, "unet"))
model_keys = ["unet"]
elif hasattr(model, "prior"):
# TODO(takuoko): Delete if bug is fixed in diffusers. # noqa
model.prior._internal_dict["_name_or_path"] = "prior" # noqa
model.prior.save_pretrained(osp.join(ckpt_path, "prior"))
model_keys = ["prior"]
elif hasattr(model, "transformer"):
Expand Down
2 changes: 1 addition & 1 deletion diffengine/models/editors/pixart_alpha/pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def infer(self,
torch_dtype=torch.float32,
)
if self.finetune_text_encoder:
# todo[takuoko]: When parsing text_encoder directly, the # noqa
# TODO(takuoko): When parsing text_encoder directly, the # noqa
# results are different. So we need to parse here.
pipeline.text_encoder = self.text_encoder
pipeline.to(self.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def prepare_model(self) -> None:
self.unet, is_sdxl=True)

if self.gradient_checkpointing:
# todo[takuoko]: Support ControlNetXSModel for gradient # noqa
# TODO(takuoko): Support ControlNetXSModel for gradient # noqa
# checkpointing
# self.controlnet.enable_gradient_checkpointing()
self.unet.enable_gradient_checkpointing()
Expand Down
13 changes: 13 additions & 0 deletions tests/test_datasets/test_hf_dreambooth_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,16 @@ def test_dataset_from_local(self):
assert data["text"] == "a photo of sks dog"
assert isinstance(data["img"], Image.Image)
assert data["img"].width == 400

def test_dataset_from_local_with_csv(self):
dataset = HFDreamBoothDataset(
dataset="tests/testdata/dataset",
csv="metadata.csv",
image_column="file_name",
instance_prompt="a photo of sks dog")
assert len(dataset) == 1

data = dataset[0]
assert data["text"] == "a photo of sks dog"
assert isinstance(data["img"], Image.Image)
assert data["img"].width == 400
17 changes: 17 additions & 0 deletions tests/test_datasets/test_transforms/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,20 @@ def test_transform(self):
assert data["masked_image"].shape == img.shape
assert torch.allclose(data["masked_image"][10:, 10:], img[10:, 10:])
assert data["masked_image"][:10, :10].sum() == 0


class TestAddConstantCaption(TestCase):

def test_register(self):
assert "AddConstantCaption" in TRANSFORMS

def test_transform(self):
data = {
"text": "a dog.",
}

# test transform
trans = TRANSFORMS.build(dict(type="AddConstantCaption",
constant_caption="in szn style"))
data = trans(data)
assert data["text"] == "a dog. in szn style"
Loading

0 comments on commit 3b62bbe

Please sign in to comment.