Skip to content

Commit baa64bf

Browse files
authored
Merge pull request #4 from kcg-ml/dev-matias
fix: refactor https://github.com/kcg-ml/kcg-ml-diffae/blob/main/diffa…
2 parents 2e7bac0 + aa3ba4a commit baa64bf

File tree

5 files changed

+1213
-1151
lines changed

5 files changed

+1213
-1151
lines changed

model/unet_model.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import copy
2+
import torch
3+
from torch import nn
4+
from torch.cuda import amp
5+
6+
class UNetModel:
7+
"""Core model architecture implementation for diffusion models."""
8+
def __init__(self, conf):
9+
"""
10+
Initialize the UNet model.
11+
12+
Args:
13+
conf: Configuration object containing model parameters
14+
"""
15+
self.conf = conf
16+
self.model = conf.make_model_conf().make_model()
17+
self.ema_model = copy.deepcopy(self.model)
18+
self.ema_model.requires_grad_(False)
19+
self.ema_model.eval()
20+
21+
# Calculate model size
22+
model_size = 0
23+
for param in self.model.parameters():
24+
model_size += param.data.nelement()
25+
print('Model params: %.2f M' % (model_size / 1024 / 1024))
26+
27+
# Initialize samplers
28+
self.sampler = conf.make_diffusion_conf().make_sampler()
29+
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
30+
self.T_sampler = conf.make_T_sampler()
31+
32+
# Initialize latent samplers if needed
33+
if conf.train_mode.use_latent_net():
34+
self.latent_sampler = conf.make_latent_diffusion_conf().make_sampler()
35+
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf().make_sampler()
36+
else:
37+
self.latent_sampler = None
38+
self.eval_latent_sampler = None
39+
40+
def update_ema(self, decay):
41+
"""
42+
Update the exponential moving average model.
43+
44+
Args:
45+
decay: EMA decay rate
46+
"""
47+
self._ema(self.model, self.ema_model, decay)
48+
49+
def _ema(self, source, target, decay):
50+
"""
51+
Apply exponential moving average update.
52+
53+
Args:
54+
source: Source model
55+
target: Target model (EMA)
56+
decay: EMA decay rate
57+
"""
58+
source_dict = source.state_dict()
59+
target_dict = target.state_dict()
60+
for key in source_dict.keys():
61+
target_dict[key].data.copy_(target_dict[key].data * decay +
62+
source_dict[key].data * (1 - decay))
63+
64+
def encode(self, x):
65+
"""
66+
Encode input using the model's encoder.
67+
68+
Args:
69+
x: Input tensor
70+
71+
Returns:
72+
Encoded representation
73+
"""
74+
assert self.conf.model_type.has_autoenc()
75+
cond = self.ema_model.encoder.forward(x)
76+
return cond
77+
78+
def encode_stochastic(self, x, cond, T=None):
79+
"""
80+
Stochastically encode input.
81+
82+
Args:
83+
x: Input tensor
84+
cond: Conditioning tensor
85+
T: Number of diffusion steps
86+
87+
Returns:
88+
Stochastically encoded sample
89+
"""
90+
if T is None:
91+
sampler = self.eval_sampler
92+
else:
93+
sampler = self.conf._make_diffusion_conf(T).make_sampler()
94+
out = sampler.ddim_reverse_sample_loop(self.ema_model,
95+
x,
96+
model_kwargs={'cond': cond})
97+
return out['sample']
98+
99+
def forward(self, noise=None, x_start=None, use_ema=False):
100+
"""
101+
Forward pass through the model.
102+
103+
Args:
104+
noise: Input noise
105+
x_start: Starting point for diffusion
106+
use_ema: Whether to use EMA model
107+
108+
Returns:
109+
Generated sample
110+
"""
111+
with amp.autocast(False):
112+
model = self.ema_model if use_ema else self.model
113+
gen = self.eval_sampler.sample(model=model,
114+
noise=noise,
115+
x_start=x_start)
116+
return gen
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import torch
2+
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
3+
from dataset import *
4+
from dist_utils import get_world_size, get_rank
5+
import numpy as np
6+
7+
class UNetPreprocessor:
8+
"""Handles data preprocessing and dataset creation for UNet models."""
9+
def __init__(self, conf):
10+
"""
11+
Initialize the preprocessor.
12+
13+
Args:
14+
conf: Configuration object
15+
"""
16+
self.conf = conf
17+
self.train_data = None
18+
self.val_data = None
19+
20+
def setup(self, seed=None, global_rank=0):
21+
"""
22+
Set up datasets with proper seeding.
23+
24+
Args:
25+
seed: Random seed
26+
global_rank: Current process rank
27+
"""
28+
# Set seed for each worker separately
29+
if seed is not None:
30+
seed_worker = seed * get_world_size() + global_rank
31+
np.random.seed(seed_worker)
32+
torch.manual_seed(seed_worker)
33+
torch.cuda.manual_seed(seed_worker)
34+
print('local seed:', seed_worker)
35+
36+
# Create datasets
37+
self.train_data = self.conf.make_dataset()
38+
print('train data:', len(self.train_data))
39+
self.val_data = self.train_data
40+
print('val data:', len(self.val_data))
41+
42+
def create_train_dataloader(self, batch_size, drop_last=True, shuffle=True):
43+
"""
44+
Create training dataloader.
45+
46+
Args:
47+
batch_size: Batch size
48+
drop_last: Whether to drop the last incomplete batch
49+
shuffle: Whether to shuffle the data
50+
51+
Returns:
52+
DataLoader for training
53+
"""
54+
if not hasattr(self, "train_data") or self.train_data is None:
55+
self.setup()
56+
57+
# Create a DataLoader directly
58+
dataloader = torch.utils.data.DataLoader(
59+
self.train_data,
60+
batch_size=batch_size,
61+
shuffle=shuffle,
62+
drop_last=drop_last,
63+
num_workers=0, # Use 0 to avoid pickling issues
64+
persistent_workers=False
65+
)
66+
return SizedIterableWrapper(dataloader, len(self.train_data))
67+
68+
def create_val_dataloader(self, batch_size, drop_last=False):
69+
"""
70+
Create validation dataloader.
71+
72+
Args:
73+
batch_size: Batch size
74+
drop_last: Whether to drop the last incomplete batch
75+
76+
Returns:
77+
DataLoader for validation
78+
"""
79+
if not hasattr(self, "val_data") or self.val_data is None:
80+
self.setup()
81+
82+
dataloader = torch.utils.data.DataLoader(
83+
self.val_data,
84+
batch_size=batch_size,
85+
shuffle=False,
86+
drop_last=drop_last,
87+
num_workers=0,
88+
persistent_workers=False
89+
)
90+
return dataloader
91+
92+
def create_latent_dataset(self, conds):
93+
"""
94+
Create a dataset from latent conditions.
95+
96+
Args:
97+
conds: Latent conditions tensor
98+
99+
Returns:
100+
TensorDataset containing the conditions
101+
"""
102+
return TensorDataset(conds)
103+
104+
105+
class SizedIterableWrapper:
106+
"""Wrapper for iterables that provides a __len__ method."""
107+
def __init__(self, dataloader, length):
108+
self.dataloader = dataloader
109+
self._length = length
110+
111+
def __iter__(self):
112+
return iter(self.dataloader)
113+
114+
def __len__(self):
115+
return self._length

0 commit comments

Comments
 (0)