-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CodeCamp2023-648]MMagic 新 config 体验与适配 GuidedDiffusion (#2005)
* fix * fix code * fix datasets * fix style
- Loading branch information
1 parent
7320e5d
commit db3a630
Showing
8 changed files
with
402 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.dataset.sampler import DefaultSampler | ||
|
||
from mmagic.datasets.imagenet_dataset import ImageNet | ||
from mmagic.datasets.transforms.aug_shape import Flip, Resize | ||
from mmagic.datasets.transforms.crop import (CenterCropLongEdge, | ||
RandomCropLongEdge) | ||
from mmagic.datasets.transforms.formatting import PackInputs | ||
from mmagic.datasets.transforms.loading import LoadImageFromFile | ||
|
||
# dataset settings | ||
dataset_type = ImageNet | ||
|
||
# different from mmcls, we adopt the setting used in BigGAN. | ||
# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing. | ||
train_pipeline = [ | ||
dict(type=LoadImageFromFile, key='gt'), | ||
dict(type=RandomCropLongEdge, keys='gt'), | ||
dict(type=Resize, scale=(512, 512), keys='gt', backend='pillow'), | ||
dict(type=Flip, keys='gt', flip_ratio=0.5, direction='horizontal'), | ||
dict(type=PackInputs) | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type=LoadImageFromFile, key='gt'), | ||
dict(type=CenterCropLongEdge, keys='gt'), | ||
dict(type=Resize, scale=(512, 512), keys='gt', backend='pillow'), | ||
dict(type=PackInputs) | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=None, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='./data/imagenet/', | ||
ann_file='meta/train.txt', | ||
data_prefix='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=True), | ||
persistent_workers=True) | ||
|
||
val_dataloader = dict( | ||
batch_size=None, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='./data/imagenet/', | ||
ann_file='meta/train.txt', | ||
data_prefix='train', | ||
pipeline=test_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
persistent_workers=True) | ||
|
||
test_dataloader = val_dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.dataset.sampler import DefaultSampler | ||
|
||
from mmagic.datasets.imagenet_dataset import ImageNet | ||
from mmagic.datasets.transforms.aug_shape import Flip, Resize | ||
from mmagic.datasets.transforms.crop import (CenterCropLongEdge, | ||
RandomCropLongEdge) | ||
from mmagic.datasets.transforms.formatting import PackInputs | ||
from mmagic.datasets.transforms.loading import LoadImageFromFile | ||
|
||
# dataset settings | ||
dataset_type = ImageNet | ||
|
||
# different from mmcls, we adopt the setting used in BigGAN. | ||
# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing. | ||
train_pipeline = [ | ||
dict(type=LoadImageFromFile, key='gt'), | ||
dict(type=RandomCropLongEdge, keys='gt'), | ||
dict(type=Resize, scale=(64, 64), keys='gt', backend='pillow'), | ||
dict(type=Flip, keys='gt', flip_ratio=0.5, direction='horizontal'), | ||
dict(type=PackInputs) | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type=LoadImageFromFile, key='gt'), | ||
dict(type=CenterCropLongEdge, keys='gt'), | ||
dict(type=Resize, scale=(64, 64), keys='gt', backend='pillow'), | ||
dict(type=PackInputs) | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=None, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='./data/imagenet/', | ||
ann_file='meta/train.txt', | ||
data_prefix='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=True), | ||
persistent_workers=True) | ||
|
||
val_dataloader = dict( | ||
batch_size=64, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='./data/imagenet/', | ||
ann_file='meta/train.txt', | ||
data_prefix='train', | ||
pipeline=test_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
persistent_workers=True) | ||
|
||
test_dataloader = val_dataloader |
39 changes: 39 additions & 0 deletions
39
mmagic/configs/guided_diffusion/adm-g_ddim25_8xb32_imagenet_256x256.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .adm_ddim250_8xb32_imagenet_256x256 import * | ||
|
||
from mmagic.evaluation.metrics import FrechetInceptionDistance | ||
from mmagic.models.editors.guided_diffusion.classifier import EncoderUNetModel | ||
|
||
model.update( | ||
dict( | ||
classifier=dict( | ||
type=EncoderUNetModel, | ||
image_size=256, | ||
in_channels=3, | ||
model_channels=128, | ||
out_channels=1000, | ||
num_res_blocks=2, | ||
attention_resolutions=(8, 16, 32), | ||
channel_mult=(1, 1, 2, 2, 4, 4), | ||
use_fp16=False, | ||
num_head_channels=64, | ||
use_scale_shift_norm=True, | ||
resblock_updown=True, | ||
pool='attention'))) | ||
|
||
metrics = [ | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Full-50k', | ||
fake_nums=50000, | ||
inception_style='StyleGAN', | ||
sample_model='orig', | ||
sample_kwargs=dict( | ||
num_inference_steps=250, show_progress=True, classifier_scale=1.)) | ||
] | ||
|
||
val_evaluator = dict(metrics=metrics) | ||
test_evaluator = dict(metrics=metrics) |
39 changes: 39 additions & 0 deletions
39
mmagic/configs/guided_diffusion/adm-g_ddim25_8xb32_imagenet_512x512.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .adm_ddim250_8xb32_imagenet_512x512 import * | ||
|
||
from mmagic.evaluation.metrics import FrechetInceptionDistance | ||
from mmagic.models.editors.guided_diffusion.classifier import EncoderUNetModel | ||
|
||
model.update( | ||
dict( | ||
classifier=dict( | ||
type=EncoderUNetModel, | ||
image_size=512, | ||
in_channels=3, | ||
model_channels=128, | ||
out_channels=1000, | ||
num_res_blocks=2, | ||
attention_resolutions=(16, 32, 64), | ||
channel_mult=(0.5, 1, 1, 2, 2, 4, 4), | ||
use_fp16=False, | ||
num_head_channels=64, | ||
use_scale_shift_norm=True, | ||
resblock_updown=True, | ||
pool='attention'))) | ||
|
||
metrics = [ | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Full-50k', | ||
fake_nums=50000, | ||
inception_style='StyleGAN', | ||
sample_model='orig', | ||
sample_kwargs=dict( | ||
num_inference_steps=250, show_progress=True, classifier_scale=1.)) | ||
] | ||
|
||
val_evaluator = dict(metrics=metrics) | ||
test_evaluator = dict(metrics=metrics) |
39 changes: 39 additions & 0 deletions
39
mmagic/configs/guided_diffusion/adm-g_ddim25_8xb32_imagenet_64x64.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .adm_ddim250_8xb32_imagenet_64x64 import * | ||
|
||
from mmagic.evaluation.metrics import FrechetInceptionDistance | ||
from mmagic.models.editors.guided_diffusion.classifier import EncoderUNetModel | ||
|
||
model.update( | ||
dict( | ||
classifier=dict( | ||
type=EncoderUNetModel, | ||
image_size=64, | ||
in_channels=3, | ||
model_channels=128, | ||
out_channels=1000, | ||
num_res_blocks=4, | ||
attention_resolutions=(2, 4, 8), | ||
channel_mult=(1, 2, 3, 4), | ||
use_fp16=False, | ||
num_head_channels=64, | ||
use_scale_shift_norm=True, | ||
resblock_updown=True, | ||
pool='attention'))) | ||
|
||
metrics = [ | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Full-50k', | ||
fake_nums=50000, | ||
inception_style='StyleGAN', | ||
sample_model='orig', | ||
sample_kwargs=dict( | ||
num_inference_steps=250, show_progress=True, classifier_scale=1.)) | ||
] | ||
|
||
val_evaluator = dict(metrics=metrics) | ||
test_evaluator = dict(metrics=metrics) |
61 changes: 61 additions & 0 deletions
61
mmagic/configs/guided_diffusion/adm_ddim250_8xb32_imagenet_256x256.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .._base_.datasets.imagenet_64 import * | ||
from .._base_.gen_default_runtime import * | ||
|
||
from mmagic.engine.hooks.visualization_hook import VisualizationHook | ||
from mmagic.evaluation.metrics import FrechetInceptionDistance | ||
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor | ||
from mmagic.models.diffusion_schedulers.ddim_scheduler import EditDDIMScheduler | ||
from mmagic.models.editors.ddpm.denoising_unet import (DenoisingUnet, | ||
MultiHeadAttentionBlock) | ||
from mmagic.models.editors.guided_diffusion.adm import AblatedDiffusionModel | ||
|
||
model = dict( | ||
type=AblatedDiffusionModel, | ||
data_preprocessor=dict(type=DataPreprocessor), | ||
unet=dict( | ||
type=DenoisingUnet, | ||
image_size=256, | ||
in_channels=3, | ||
base_channels=256, | ||
resblocks_per_downsample=2, | ||
attention_res=(32, 16, 8), | ||
norm_cfg=dict(type='GN32', num_groups=32), | ||
dropout=0.1, | ||
num_classes=1000, | ||
use_fp16=False, | ||
resblock_updown=True, | ||
attention_cfg=dict( | ||
type=MultiHeadAttentionBlock, | ||
num_heads=4, | ||
num_head_channels=64, | ||
use_new_attention_order=False), | ||
use_scale_shift_norm=True), | ||
diffusion_scheduler=dict( | ||
type=EditDDIMScheduler, | ||
variance_type='learned_range', | ||
beta_schedule='linear'), | ||
rgb2bgr=True, | ||
use_fp16=False) | ||
|
||
test_dataloader.update(dict(batch_size=32, num_workers=8)) | ||
train_cfg = dict(max_iters=100000) | ||
metrics = [ | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Full-50k', | ||
fake_nums=50000, | ||
inception_style='StyleGAN', | ||
sample_model='orig', | ||
sample_kwargs=dict( | ||
num_inference_steps=250, show_progress=True, classifier_scale=1.)) | ||
] | ||
|
||
val_evaluator = dict(metrics=metrics) | ||
test_evaluator = dict(metrics=metrics) | ||
|
||
# VIS_HOOK | ||
custom_hooks = [dict(type=VisualizationHook, interval=5000, fixed_input=True)] |
57 changes: 57 additions & 0 deletions
57
mmagic/configs/guided_diffusion/adm_ddim250_8xb32_imagenet_512x512.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .._base_.datasets.imagenet_512 import * | ||
from .._base_.gen_default_runtime import * | ||
|
||
from mmagic.evaluation.metrics import FrechetInceptionDistance | ||
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor | ||
from mmagic.models.diffusion_schedulers.ddim_scheduler import EditDDIMScheduler | ||
from mmagic.models.editors.ddpm.denoising_unet import (DenoisingUnet, | ||
MultiHeadAttentionBlock) | ||
from mmagic.models.editors.guided_diffusion import AblatedDiffusionModel | ||
|
||
model = dict( | ||
type=AblatedDiffusionModel, | ||
data_preprocessor=dict(type=DataPreprocessor), | ||
unet=dict( | ||
type=DenoisingUnet, | ||
image_size=512, | ||
in_channels=3, | ||
base_channels=256, | ||
resblocks_per_downsample=2, | ||
attention_res=(32, 16, 8), | ||
norm_cfg=dict(type='GN32', num_groups=32), | ||
dropout=0.1, | ||
num_classes=1000, | ||
use_fp16=False, | ||
resblock_updown=True, | ||
attention_cfg=dict( | ||
type=MultiHeadAttentionBlock, | ||
num_heads=4, | ||
num_head_channels=64, | ||
use_new_attention_order=False), | ||
use_scale_shift_norm=True), | ||
diffusion_scheduler=dict( | ||
type=EditDDIMScheduler, | ||
variance_type='learned_range', | ||
beta_schedule='linear'), | ||
rgb2bgr=True, | ||
use_fp16=False) | ||
|
||
test_dataloader.update(dict(batch_size=32, num_workers=8)) | ||
train_cfg = dict(max_iters=100000) | ||
metrics = [ | ||
dict( | ||
type=FrechetInceptionDistance, | ||
prefix='FID-Full-50k', | ||
fake_nums=50000, | ||
inception_style='StyleGAN', | ||
sample_model='orig', | ||
sample_kwargs=dict( | ||
num_inference_steps=250, show_progress=True, classifier_scale=1.)) | ||
] | ||
|
||
val_evaluator = dict(metrics=metrics) | ||
test_evaluator = dict(metrics=metrics) |
Oops, something went wrong.