Skip to content

Commit

Permalink
Training to predict x0 in training example (open-mmlab#1031)
Browse files Browse the repository at this point in the history
* changed training example to add option to train model that predicts x0 (instead of eps), changed DDPM pipeline accordingly

* Revert "changed training example to add option to train model that predicts x0 (instead of eps), changed DDPM pipeline accordingly"

This reverts commit c5efb525648885f2e7df71f4483a9f248515ad61.

* changed training example to add option to train model that predicts x0 (instead of eps), changed DDPM pipeline accordingly

* fixed code style

Co-authored-by: lukovnikov <lukovnikov@users.noreply.github.com>
  • Loading branch information
lukovnikov and lukovnikov authored Nov 2, 2022
1 parent 0b61cea commit cbcd051
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
55 changes: 51 additions & 4 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@
logger = get_logger(__name__)


def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
if not isinstance(arr, torch.Tensor):
arr = torch.from_numpy(arr)
res = arr[timesteps].float().to(timesteps.device)
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)


def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
Expand Down Expand Up @@ -171,6 +189,16 @@ def parse_args():
),
)

parser.add_argument(
"--predict_mode",
type=str,
default="eps",
help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction",
)

parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
Expand Down Expand Up @@ -224,7 +252,7 @@ def main(args):
"UpBlock2D",
),
)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
Expand Down Expand Up @@ -257,6 +285,8 @@ def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}

logger.info(f"Dataset size: {len(dataset)}")

dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
Expand Down Expand Up @@ -319,8 +349,20 @@ def transforms(examples):

with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps).sample
loss = F.mse_loss(noise_pred, noise)
model_output = model(noisy_images, timesteps).sample

if args.predict_mode == "eps":
loss = F.mse_loss(model_output, noise) # this could have different weights!
elif args.predict_mode == "x0":
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
snr_weights = alpha_t / (1 - alpha_t)
loss = snr_weights * F.mse_loss(
model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper
loss = loss.mean()

accelerator.backward(loss)

if accelerator.sync_gradients:
Expand Down Expand Up @@ -355,7 +397,12 @@ def transforms(examples):

generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
images = pipeline(
generator=generator,
batch_size=args.eval_batch_size,
output_type="numpy",
predict_epsilon=args.predict_mode == "eps",
).images

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __call__(
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
predict_epsilon: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Expand Down Expand Up @@ -84,7 +85,9 @@ def __call__(
model_output = self.unet(image, t).sample

# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
image = self.scheduler.step(
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
).prev_sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down

0 comments on commit cbcd051

Please sign in to comment.