Skip to content

Commit

Permalink
make forward logic more clear for GAN models
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Feb 28, 2023
1 parent 799e1f8 commit ba85fbd
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 173 deletions.
87 changes: 56 additions & 31 deletions mmedit/models/base_models/base_conditional_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,46 +185,71 @@ def forward(self,
labels = self.label_fn(num_batches=num_batches)

sample_model = self._get_valid_model(inputs)
if sample_model in ['ema', 'ema/orig']:
generator = self.generator_ema
else: # sample model is `orig`
generator = self.generator
outputs = generator(noise, label=labels, return_noise=False)
outputs = self.data_preprocessor.destruct(outputs, data_samples)

if sample_model == 'ema/orig':
generator = self.generator
outputs_orig = generator(noise, label=labels, return_noise=False)
batch_sample_list = []
if sample_model in ['ema', 'orig']:
if sample_model == 'ema':
generator = self.generator_ema
else:
generator = self.generator
outputs = generator(noise, label=labels, return_noise=False)
outputs = self.data_preprocessor.destruct(outputs, data_samples)

if data_samples:
data_samples = data_samples.split()
# save to data sample
for idx in range(num_batches):
gen_sample = EditDataSample()
# save inputs to data sample
if data_samples:
gen_sample.update(data_samples[idx])
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img'][idx]
# save outputs to data sample
gen_sample.fake_img = outputs[idx]
gen_sample.sample_model = sample_model

# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.set_gt_label(labels[idx])
gen_sample.sample_kwargs = deepcopy(sample_kwargs)

batch_sample_list.append(gen_sample)
else: # sample model in 'ema/orig'
outputs_orig = self.generator(
noise, label=labels, return_noise=False, **sample_kwargs)
outputs_ema = self.generator_ema(
noise, label=labels, return_noise=False, **sample_kwargs)
outputs_orig = self.data_preprocessor.destruct(
outputs_orig, data_samples)
outputs = dict(ema=outputs, orig=outputs_orig)
outputs_ema = self.data_preprocessor.destruct(
outputs_ema, data_samples)

batch_sample_list = []
if data_samples:
data_samples = data_samples.split()
for idx in range(num_batches):
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples[idx])
if sample_model == 'ema/orig':
data_samples = data_samples.split()
# save to data sample
for idx in range(num_batches):
gen_sample = EditDataSample()
# save inputs to data sample
if data_samples:
gen_sample.update(data_samples[idx])
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img'][idx]
# save outputs to data sample
gen_sample.ema = EditDataSample(
fake_img=outputs['ema'][idx], sample_model='ema')
fake_img=outputs_ema[idx], sample_model='ema')
gen_sample.orig = EditDataSample(
fake_img=outputs['orig'][idx], sample_model='orig')
fake_img=outputs_orig[idx], sample_model='orig')
gen_sample.sample_model = 'ema/orig'

# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.set_gt_label(labels[idx])
gen_sample.ema.set_gt_label(labels[idx])
gen_sample.orig.set_gt_label(labels[idx])
else:
gen_sample.fake_img = outputs[idx]
gen_sample.sample_model = sample_model
gen_sample.set_gt_label(labels[idx])
gen_sample.sample_kwargs = deepcopy(sample_kwargs)

batch_sample_list.append(gen_sample)

# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.sample_kwargs = deepcopy(sample_kwargs)
batch_sample_list.append(gen_sample)
return batch_sample_list

def train_generator(self, inputs: dict, data_samples: List[EditDataSample],
Expand Down
81 changes: 51 additions & 30 deletions mmedit/models/base_models/base_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,48 +343,69 @@ def forward(self,
num_batches = noise.shape[0]

sample_model = self._get_valid_model(inputs)
if sample_model in ['ema', 'ema/orig']:
generator = self.generator_ema
else: # sample model is 'orig'
generator = self.generator
batch_sample_list = []
if sample_model in ['ema', 'orig']:
if sample_model == 'ema':
generator = self.generator_ema
else:
generator = self.generator
outputs = generator(noise, return_noise=False, **sample_kwargs)
outputs = self.data_preprocessor.destruct(outputs, data_samples)

num_batches = noise.shape[0]
outputs = generator(noise, return_noise=False, **sample_kwargs)
outputs = self.data_preprocessor.destruct(outputs, data_samples)
if data_samples:
data_samples = data_samples.split()
# save to data sample
for idx in range(num_batches):
gen_sample = EditDataSample()
# save inputs to data sample
if data_samples:
gen_sample.update(data_samples[idx])
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img'][idx]
# save outputs to data sample
gen_sample.fake_img = outputs[idx]
gen_sample.sample_model = sample_model

if sample_model == 'ema/orig':
generator = self.generator
outputs_orig = generator(
# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.sample_kwargs = deepcopy(sample_kwargs)

batch_sample_list.append(gen_sample)

else: # sample model is 'ema/orig
outputs_orig = self.generator(
noise, return_noise=False, **sample_kwargs)
outputs_ema = self.generator_ema(
noise, return_noise=False, **sample_kwargs)
outputs_orig = self.data_preprocessor.destruct(
outputs_orig, data_samples)
outputs = dict(ema=outputs, orig=outputs_orig)
outputs_ema = self.data_preprocessor.destruct(
outputs_ema, data_samples)

if data_samples:
data_samples = data_samples.split()
batch_sample_list = []
for idx in range(num_batches):
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples[idx])
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img'][idx]
if isinstance(outputs, dict):
data_samples = data_samples.split()
# save to data sample
for idx in range(num_batches):
gen_sample = EditDataSample()
# save inputs to data sample
if data_samples:
gen_sample.update(data_samples[idx])
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img'][idx]
# save outputs to data sample
gen_sample.ema = EditDataSample(
fake_img=outputs['ema'][idx], sample_model='ema')
fake_img=outputs_ema[idx], sample_model='ema')
gen_sample.orig = EditDataSample(
fake_img=outputs['orig'][idx], sample_model='orig')
fake_img=outputs_orig[idx], sample_model='orig')
gen_sample.sample_model = 'ema/orig'
else:
gen_sample.fake_img = outputs[idx]
gen_sample.sample_model = sample_model

# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.sample_kwargs = deepcopy(sample_kwargs)
# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.sample_kwargs = deepcopy(sample_kwargs)

batch_sample_list.append(gen_sample)
batch_sample_list.append(gen_sample)

return batch_sample_list

Expand Down
88 changes: 57 additions & 31 deletions mmedit/models/editors/eg3d/eg3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,43 +184,69 @@ def forward(self,
labels = self.label_fn(num_batches=num_batches)

sample_model = self._get_valid_model(inputs)
if sample_model in ['ema', 'ema/orig']:
generator = self.generator_ema
else: # sample model is `orig`
generator = self.generator
outputs = generator(noise, label=labels)

if sample_model == 'ema/orig':
generator = self.generator
outputs_orig = generator(noise, label=labels)

outputs = dict(ema=outputs, orig=outputs_orig)

if data_samples is not None:
data_samples = data_samples.split()
batch_sample_list = []
for idx in range(num_batches):
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples[idx])
if sample_model == 'ema/orig':
gen_sample.ema = self.pack_to_data_sample(outputs['ema'], idx)
gen_sample.orig = self.pack_to_data_sample(
outputs['orig'], idx)
gen_sample.sample_model = 'ema/orig'
gen_sample.set_gt_label(labels[idx])
gen_sample.ema.set_gt_label(labels[idx])
gen_sample.orig.set_gt_label(labels[idx])
if sample_model in ['ema', 'orig']:
if sample_model == 'ema':
generator = self.generator_ema
else:
generator = self.generator
outputs = generator(noise, label=labels)
outputs['fake_img'] = self.data_preprocessor.destruct(
outputs['fake_img'], data_samples)

if data_samples is not None:
data_samples = data_samples.split()
# save to data sample
for idx in range(num_batches):
gen_sample = EditDataSample()
# save inputs to data sample
if data_samples:
gen_sample.update(data_samples[idx])
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img'][idx]
# save outputs to data sample
gen_sample = self.pack_to_data_sample(outputs, idx, gen_sample)
gen_sample.sample_model = sample_model
gen_sample.set_gt_label(labels[idx])

# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.sample_kwargs = deepcopy(sample_kwargs)
batch_sample_list.append(gen_sample)
# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.sample_kwargs = deepcopy(sample_kwargs)

batch_sample_list.append(gen_sample)

else:
outputs_orig = self.generator(noise, label=labels)
outputs_ema = self.generator_ema(noise, label=labels)
outputs_orig['fake_img'] = self.data_preprocessor.destruct(
outputs_orig['fake_img'], data_samples)
outputs_ema['fake_img'] = self.data_preprocessor.destruct(
outputs_ema['fake_img'], data_samples)

if data_samples is not None:
data_samples = data_samples.split()
# save to data sample
for idx in range(num_batches):
gen_sample = EditDataSample()
# save inputs to data sample
if data_samples:
gen_sample.update(data_samples[idx])
if isinstance(inputs, dict) and 'img' in inputs:
gen_sample.gt_img = inputs['img'][idx]
# save outputs to data sample
gen_sample.ema = self.pack_to_data_sample(outputs_ema, idx)
gen_sample.orig = self.pack_to_data_sample(outputs_orig, idx)
gen_sample.sample_model = sample_model
gen_sample.set_gt_label(labels[idx])

# Append input condition (noise and sample_kwargs) to
# batch_sample_list
gen_sample.noise = noise[idx]
gen_sample.sample_kwargs = deepcopy(sample_kwargs)

batch_sample_list.append(gen_sample)

return batch_sample_list

@torch.no_grad()
Expand Down
Loading

0 comments on commit ba85fbd

Please sign in to comment.