Skip to content

Commit

Permalink
[CodeCamp2023-648]MMagic 新 config 体验与适配 GuidedDiffusion (#2005)
Browse files Browse the repository at this point in the history
* fix

* fix code

* fix datasets

* fix style
  • Loading branch information
ooooo-create authored Aug 29, 2023
1 parent 7320e5d commit db3a630
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 0 deletions.
55 changes: 55 additions & 0 deletions mmagic/configs/_base_/datasets/imagenet_512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset.sampler import DefaultSampler

from mmagic.datasets.imagenet_dataset import ImageNet
from mmagic.datasets.transforms.aug_shape import Flip, Resize
from mmagic.datasets.transforms.crop import (CenterCropLongEdge,
RandomCropLongEdge)
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile

# dataset settings
dataset_type = ImageNet

# different from mmcls, we adopt the setting used in BigGAN.
# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing.
train_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=RandomCropLongEdge, keys='gt'),
dict(type=Resize, scale=(512, 512), keys='gt', backend='pillow'),
dict(type=Flip, keys='gt', flip_ratio=0.5, direction='horizontal'),
dict(type=PackInputs)
]

test_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=CenterCropLongEdge, keys='gt'),
dict(type=Resize, scale=(512, 512), keys='gt', backend='pillow'),
dict(type=PackInputs)
]

train_dataloader = dict(
batch_size=None,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='./data/imagenet/',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
persistent_workers=True)

val_dataloader = dict(
batch_size=None,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='./data/imagenet/',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = val_dataloader
55 changes: 55 additions & 0 deletions mmagic/configs/_base_/datasets/imagenet_64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset.sampler import DefaultSampler

from mmagic.datasets.imagenet_dataset import ImageNet
from mmagic.datasets.transforms.aug_shape import Flip, Resize
from mmagic.datasets.transforms.crop import (CenterCropLongEdge,
RandomCropLongEdge)
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile

# dataset settings
dataset_type = ImageNet

# different from mmcls, we adopt the setting used in BigGAN.
# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing.
train_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=RandomCropLongEdge, keys='gt'),
dict(type=Resize, scale=(64, 64), keys='gt', backend='pillow'),
dict(type=Flip, keys='gt', flip_ratio=0.5, direction='horizontal'),
dict(type=PackInputs)
]

test_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=CenterCropLongEdge, keys='gt'),
dict(type=Resize, scale=(64, 64), keys='gt', backend='pillow'),
dict(type=PackInputs)
]

train_dataloader = dict(
batch_size=None,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='./data/imagenet/',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
persistent_workers=True)

val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='./data/imagenet/',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = val_dataloader
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .adm_ddim250_8xb32_imagenet_256x256 import *

from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.editors.guided_diffusion.classifier import EncoderUNetModel

model.update(
dict(
classifier=dict(
type=EncoderUNetModel,
image_size=256,
in_channels=3,
model_channels=128,
out_channels=1000,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
channel_mult=(1, 1, 2, 2, 4, 4),
use_fp16=False,
num_head_channels=64,
use_scale_shift_norm=True,
resblock_updown=True,
pool='attention')))

metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .adm_ddim250_8xb32_imagenet_512x512 import *

from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.editors.guided_diffusion.classifier import EncoderUNetModel

model.update(
dict(
classifier=dict(
type=EncoderUNetModel,
image_size=512,
in_channels=3,
model_channels=128,
out_channels=1000,
num_res_blocks=2,
attention_resolutions=(16, 32, 64),
channel_mult=(0.5, 1, 1, 2, 2, 4, 4),
use_fp16=False,
num_head_channels=64,
use_scale_shift_norm=True,
resblock_updown=True,
pool='attention')))

metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .adm_ddim250_8xb32_imagenet_64x64 import *

from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.editors.guided_diffusion.classifier import EncoderUNetModel

model.update(
dict(
classifier=dict(
type=EncoderUNetModel,
image_size=64,
in_channels=3,
model_channels=128,
out_channels=1000,
num_res_blocks=4,
attention_resolutions=(2, 4, 8),
channel_mult=(1, 2, 3, 4),
use_fp16=False,
num_head_channels=64,
use_scale_shift_norm=True,
resblock_updown=True,
pool='attention')))

metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .._base_.datasets.imagenet_64 import *
from .._base_.gen_default_runtime import *

from mmagic.engine.hooks.visualization_hook import VisualizationHook
from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor
from mmagic.models.diffusion_schedulers.ddim_scheduler import EditDDIMScheduler
from mmagic.models.editors.ddpm.denoising_unet import (DenoisingUnet,
MultiHeadAttentionBlock)
from mmagic.models.editors.guided_diffusion.adm import AblatedDiffusionModel

model = dict(
type=AblatedDiffusionModel,
data_preprocessor=dict(type=DataPreprocessor),
unet=dict(
type=DenoisingUnet,
image_size=256,
in_channels=3,
base_channels=256,
resblocks_per_downsample=2,
attention_res=(32, 16, 8),
norm_cfg=dict(type='GN32', num_groups=32),
dropout=0.1,
num_classes=1000,
use_fp16=False,
resblock_updown=True,
attention_cfg=dict(
type=MultiHeadAttentionBlock,
num_heads=4,
num_head_channels=64,
use_new_attention_order=False),
use_scale_shift_norm=True),
diffusion_scheduler=dict(
type=EditDDIMScheduler,
variance_type='learned_range',
beta_schedule='linear'),
rgb2bgr=True,
use_fp16=False)

test_dataloader.update(dict(batch_size=32, num_workers=8))
train_cfg = dict(max_iters=100000)
metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)

# VIS_HOOK
custom_hooks = [dict(type=VisualizationHook, interval=5000, fixed_input=True)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .._base_.datasets.imagenet_512 import *
from .._base_.gen_default_runtime import *

from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor
from mmagic.models.diffusion_schedulers.ddim_scheduler import EditDDIMScheduler
from mmagic.models.editors.ddpm.denoising_unet import (DenoisingUnet,
MultiHeadAttentionBlock)
from mmagic.models.editors.guided_diffusion import AblatedDiffusionModel

model = dict(
type=AblatedDiffusionModel,
data_preprocessor=dict(type=DataPreprocessor),
unet=dict(
type=DenoisingUnet,
image_size=512,
in_channels=3,
base_channels=256,
resblocks_per_downsample=2,
attention_res=(32, 16, 8),
norm_cfg=dict(type='GN32', num_groups=32),
dropout=0.1,
num_classes=1000,
use_fp16=False,
resblock_updown=True,
attention_cfg=dict(
type=MultiHeadAttentionBlock,
num_heads=4,
num_head_channels=64,
use_new_attention_order=False),
use_scale_shift_norm=True),
diffusion_scheduler=dict(
type=EditDDIMScheduler,
variance_type='learned_range',
beta_schedule='linear'),
rgb2bgr=True,
use_fp16=False)

test_dataloader.update(dict(batch_size=32, num_workers=8))
train_cfg = dict(max_iters=100000)
metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Loading

0 comments on commit db3a630

Please sign in to comment.