Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train controlxs 21 and xl #13

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# try train ctxs

1。 prepare configs files configs/sdxl_encD_depth_48m.yaml configs/sd21_encD_depth_14m.yaml
2、prepare dataset ,for test ,i use pokenom
3、main.py to train sd21 running ok , some config write in __main__
mainxl.py to train sdxl still get some error

code modify:
use 3 channel to train, pix to pix, can seen the dataset and config.yaml

help some to check ~
....



# ControlNet-XS

![](./ControlNet-XS_files/teaser_small.gif)
Expand Down
7 changes: 5 additions & 2 deletions configs/inference/sd/sd21_encD_depth_14m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ model:
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: false
sync_path: /PATH/TO/STABLEDIFFUSION/v2-1_512-ema-pruned.ckpt # path to the StableDiffusion v 2.1 weights
ckpt_path_ctr: ./checkpoints/sd21_encD_depth_14m.ckpt # path to the ControlNet-XS weights
# sync_path: /PATH/TO/STABLEDIFFUSION/v2-1_512-ema-pruned.ckpt # path to the StableDiffusion v 2.1 weights
# sync_path: /home/dell/workspace/models/v2-1_512-ema-pruned.ckpt # path to the StableDiffusion v 2.1 weights
sync_path: ./logs/2023-10-10T09-58-01_ctxs/checkpoints/last-v2.ckpt
# ckpt_path_ctr: /home/dell/workspace/models/sd21_encD_depth_14m.ckpt # path to the ControlNet-XS weights
# ckpt_path_ctr: ./logs/2023-10-10T09-58-01_ctxs/checkpoints/last.ckpt # path to the ControlNet-XS weights
synch_control: false
control_mode: midas
control_stage_config:
Expand Down
5 changes: 3 additions & 2 deletions configs/inference/sdxl/sdxl_encD_depth_48m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ model:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
sd_locked: true
skip_wrapper: true
ckpt_path: /PATH/TO/STABLEDIFFUSION/sd_xl_base_1.0_0.9vae.safetensors # path to the StableDiffusion-XL weights
ckpt_path_control: ./checkpoints/sdxl_encD_depth_48m.safetensors # path to the ControlNet-XS weights
ckpt_path: /home/dell/workspace/models/sd_xl_base_1.0_0.9vae.safetensors # path to the StableDiffusion-XL weights
# ckpt_path_control: ./checkpoints/sdxl_encD_depth_48m.safetensors # path to the ControlNet-XS weights
ckpt_path_control: /home/dell/workspace/models/sdxl_encD_depth_48m.safetensors # path to the ControlNet-XS weights
network_config:
target: sgm.modules.diffusionmodules.twoStreamControl.TwoStreamControlNet
params:
Expand Down
248 changes: 248 additions & 0 deletions configs/sd21_encD_depth_14m.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
model:
base_learning_rate: 1.0e-05
target: ldm.models.diffusion.ddpm.TwoStreamControlLDM
params:
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
# cond_stage_key: captions
cond_stage_key: txt
# control_key: hint
control_key: image
image_size: 64
channels: 4
cond_stage_trainable: True
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: false
# sync_path: /PATH/TO/STABLEDIFFUSION/v2-1_512-ema-pruned.ckpt # path to the StableDiffusion v 2.1 weights
sync_path: /home/dell/workspace/models/v2-1_512-ema-pruned.ckpt # path to the StableDiffusion v 2.1 weights
ckpt_path_ctr: /home/dell/workspace/models/sd21_encD_depth_14m.ckpt # path to the ControlNet-XS weights
synch_control: false
control_mode: midas
control_stage_config:
target: ldm.modules.diffusionmodules.twoStreamControl.TwoStreamControlNet
params:
# num_idx: 1000

# weighting_config:
# target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
# scaling_config:
# target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
# discretization_config:
# target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
use_checkpoint: true
image_size: 32
in_channels: 4
out_channels: 4
hint_channels: 3
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
num_head_channels: 8
use_spatial_transformer: true
use_linear_in_transformer: true
transformer_depth: 1
context_dim: 1024
legacy: false
infusion2control: cat
infusion2base: add
guiding: encoder_double
two_stream_mode: cross
control_model_ratio: 0.0125

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
# target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: true
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
num_head_channels: 64
use_spatial_transformer: true
use_linear_in_transformer: true
transformer_depth: 1
context_dim: 1024
legacy: false
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: true
layer: penultimate
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
num_idx: 1000

discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50

discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 7.5

# data:
# target: sgm.data.dataset.StableDataModuleFromConfig
# params:
# train:
# datapipeline:
# urls:
# # USER: adapt this path the root of your custom dataset
# # - "DATA_PATH"
# # dataset_name='/home/dell/workspace/dataset/laion_data/',
# # dataset_name='/home/dell/workspace/controlnet/sampleset/',
# # dataset_name='/home/dell/workspace/T2IAdapter-SDXL/face3w/face3w.py',
# /home/dell/workspace/T2IAdapter-SDXL/face3w/face1set.py
# pipeline_config:
# shardshuffle: 10000
# sample_shuffle: 10000


# decoders:
# - "pil"

# postprocessors:
# - target: sdata.mappers.TorchVisionImageTransforms
# params:
# key: 'jpg' # USER: you might wanna adapt this for your custom dataset
# transforms:
# - target: torchvision.transforms.Resize
# params:
# size: 256
# interpolation: 3
# - target: torchvision.transforms.ToTensor
# - target: sdata.mappers.Rescaler
# # USER: you might wanna use non-default parameters due to your custom dataset
# - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
# # USER: you might wanna use non-default parameters due to your custom dataset

# loader:
# batch_size: 2
# num_workers: 1
data:
target: data.datam.DataModuleFromConfig
params:
batch_size: 2
num_workers: 1
num_val_workers: 0 # Avoid a weird val dataloader issue
train:
# target: ldm.data.simple.hf_dataset
target: data.simple.hf_dataset
params:
# name: lambdalabs/pokemon-blip-captions
name: /home/dell/workspace/dataset/lambdalabs___pokemon-blip-captions/
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
- target: torchvision.transforms.RandomHorizontalFlip
validation:
target: data.simple.TextOnly
params:
captions:
- "A pokemon with green eyes, large wings, and a hat"
- "A cute bunny rabbit"
- "Yoda"
- "An epic landscape photo of a mountain"
output_size: 512
n_gpus: 1 # small hack to make sure we see all our samples

# loader:
# batch_size: 2
# num_workers: 1

lightning:
modelcheckpoint:
params:
# every_n_train_steps: 30
every_n_train_steps: 300

callbacks:
metrics_over_trainsteps_checkpoint:
params:
# every_n_train_steps: 25
every_n_train_steps: 250

image_logger:
target: main.ImageLogger
params:
disabled: False
enable_autocast: False
batch_frequency: 200
max_images: 2
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 8
n_rows: 2

trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 3

Loading