diff --git a/_static/img/DDIM/flowers.gif b/_static/img/DDIM/flowers.gif
new file mode 100644
index 00000000000..571cbc9b3f0
Binary files /dev/null and b/_static/img/DDIM/flowers.gif differ
diff --git a/advanced_source/DDIM_tutorial.py b/advanced_source/DDIM_tutorial.py
new file mode 100644
index 00000000000..00b12219dfa
--- /dev/null
+++ b/advanced_source/DDIM_tutorial.py
@@ -0,0 +1,843 @@
+
+"""
+Unconditional Image Generation With DDIM Model Tutorial
+========================================================
+**Author:** `Muhammed Ayman `_
+"""
+
+"""
+This is an Implementation for `Denoising Diffusion Implicit Models (DDIM) `__
+
+===============================================================================================================
+
+
+DDIM is one of the denoising diffusion probabilistic models family but
+
+the key difference here it doesn�t require **a large reverse diffusion
+
+time steps** to produce samples or images.
+
+
+"""
+
+
+######################################################################
+# .. figure:: /_static/img/DDIM/flowers.gif
+# :align: center
+# :alt: DDIM
+
+
+######################################################################
+# Setup
+
+# =====
+
+#
+
+import torch
+from torch import nn
+import numpy as np
+import matplotlib.pyplot as plt
+import torchvision.transforms.functional as F
+from torchmetrics.image.kid import KernelInceptionDistance
+from PIL import Image
+
+
+
+######################################################################
+# Downloading Data and Preparing the pipeline
+
+# ===========================================
+
+#
+
+
+######################################################################
+# Here We download the **Oxford Flowers Dataset** for generating images of
+
+# flowers, which is a diverse natural flowers dataset containing around 8,000
+
+# images with 102 category.
+
+#
+
+import torchvision.datasets as data
+data.OxfordIIITPet('./data',download=True)
+data.Flowers102("./data",download=True)
+
+
+######################################################################
+# Here we prepare the data pipline using Dataset and Dataloader classes
+
+# from **torch.utils.data** instance
+
+#
+
+from torch.utils.data import Dataset,DataLoader
+from PIL import Image
+from torchvision import transforms
+class PetDataset(Dataset):
+ def __init__(self,pathes,img_size=(64,64),train=True):
+ self.pathes = pathes
+ self.img_size = img_size
+ self.aug = transforms.Compose([
+ transforms.RandomHorizontalFlip(),
+ # transforms.RandomAdjustSharpness(2)
+ ])
+ self.processor = transforms.Compose(
+ [
+ transforms.Resize(self.img_size, antialias=True),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]), # to normalize the images from [0,255] to [-1,1]
+ ]
+ )
+ self.train = train
+ def _center_crop(self,img):
+ h,w = img.size
+ crop_size = min(h,w)
+ img = img.crop(((h-crop_size)//2,
+ (w-crop_size)//2,crop_size,crop_size))
+ return img
+ def __len__(self):
+ return len(self.pathes)
+ def __getitem__(self,idx):
+ img = Image.open(self.pathes[idx]).convert("RGB")
+ img = self._center_crop(img)
+ img = self.processor(img)
+ if self.train:
+ img = self.aug(img)
+ return img
+
+
+
+import os
+import random
+all_flowers_pathes = [os.path.join('/content/data/flowers-102/jpg',x)for x in os.listdir('/content/data/flowers-102/jpg')
+ if x.endswith('.jpg')] # to gather all image pathes
+
+random.shuffle(all_flowers_pathes)
+train_pathes = all_flowers_pathes[:-500]
+val_pathes = all_flowers_pathes[-500:]
+train_ds = PetDataset(train_pathes) # training dataset
+val_ds = PetDataset(val_pathes,train=False) # validation dataset
+
+# helper function to display the image after generation
+def display_img(img):
+ img = (img+1)*0.5
+ img= img.permute(1,2,0)
+ plt.imshow(img)
+ plt.axis('off')
+
+test= val_ds[101] # grap a sample
+display_img(test)
+
+train_iter = DataLoader(train_ds,150,shuffle=True,num_workers=2,pin_memory=True)
+val_iter = DataLoader(val_ds,20,num_workers=2,pin_memory=True)
+
+
+######################################################################
+# Model Architecture and Modules
+
+# ==============================
+
+#
+
+import math
+MAX_FREQ = 1000
+def get_timestep_embedding(timesteps, embedding_dim): # sinusoidal embedding like in Transformers
+ assert len(timesteps.shape) == 1
+ half_dim = embedding_dim // 2
+ emb = math.log(MAX_FREQ) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32,
+ num_channels=in_channels,
+ eps=1e-6, affine=True)
+
+def nonlinearity(name):
+ return getattr(nn,name)()
+
+
+# Resnet Block
+class ResidualBlock(nn.Module):
+ def __init__(self,in_chs,
+ out_chs,
+ temb_dim,
+ act='SiLU',dropout=0.2):
+ super().__init__()
+ self.time_proj = nn.Sequential(nonlinearity(act),
+ nn.Linear(temb_dim,out_chs))
+
+ dims = [in_chs]+2*[out_chs]
+ blocks =[]
+ for i in range(1,3):
+ blc = nn.Sequential(Normalize(dims[i-1]),
+ nonlinearity(act),
+ nn.Conv2d(dims[i-1],dims[i],3,padding=1),)
+ if i>1:
+ blc.insert(2,nn.Dropout(dropout))
+ blocks.append(blc)
+ self.blocks= nn.ModuleList(blocks)
+ self.short_cut = False
+ if in_chs!= out_chs:
+ self.short_cut = True
+ self.conv_short = nn.Conv2d(in_chs,out_chs,1)
+ def forward(self,x,temb):
+ h =x
+ for i,blc in enumerate(self.blocks):
+ h = blc(h)
+ if i==0:
+ h = h+self.time_proj(temb)[:,:,None,None]
+ if self.short_cut:
+ x = self.conv_short(x)
+ return x+h
+
+
+# Attention Module
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h*w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (c**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+# Downsize Block
+class DownBlock(nn.Module):
+ def __init__(self,
+ out_chs,
+ with_conv=True):
+ super().__init__()
+ self.with_conv = with_conv
+ if with_conv:
+ self.down_conv = nn.Conv2d(out_chs,out_chs,3,stride=2)
+ else:
+ self.down_conv = nn.AvgPool2d(2,2)
+
+
+ def _down(self,x):
+ if self.with_conv:
+ pad = (0,1,0,1) # to make the input shape equals to the output shape after convulotion op
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.down_conv(x)
+ else:
+ x = self.down_conv(x)
+ return x
+
+
+ def forward(self,x):
+
+ return self._down(x)
+
+# Upsample BLock
+class UpBlock(nn.Module):
+ def __init__(self,out_chs,
+ with_conv=True,
+ mode='nearest',):
+ super().__init__()
+ self.with_conv = with_conv
+ self.mode = mode
+ if with_conv:
+ self.up_conv = nn.Conv2d(out_chs,out_chs,3,padding=1)
+
+ def _up(self,x):
+ x = torch.nn.functional.interpolate(
+ x, scale_factor=2.0, mode=self.mode)
+ if self.with_conv:
+ x = self.up_conv(x)
+ return x
+ def forward(self,x):
+ return self._up(x)
+
+
+#Unet Model
+class DiffUnet(nn.Module):
+ def __init__(self,chs=32,
+ chs_mult=[2,2,4,4,8],
+ attn_res=[16,8,4],
+ block_depth=2,
+ act='SiLU',
+ temb_dim=256,
+ with_conv=True,
+ res=64,dropout=0.3):
+ super().__init__()
+ self.chs = chs
+ self.conv_in = nn.Conv2d(3,chs,3,padding=1)
+ self.time_proj = nn.Sequential(nn.Linear(chs,temb_dim),
+ nonlinearity(act),
+ nn.Linear(temb_dim,temb_dim))
+ chs_mult = [1]+chs_mult
+ #down block
+ down_dims = [] # to store the down features
+ downs = [] # to store the down blocks of the unet model
+ for i in range(1,len(chs_mult)-1):
+ in_ch = chs*chs_mult[i-1]
+ out_ch = chs*chs_mult[i]
+ down = nn.Module()
+ down.res = nn.ModuleList([ResidualBlock(in_ch,out_ch,temb_dim,act,dropout)]+
+ [ResidualBlock(out_ch,out_ch,temb_dim,act,dropout) for _ in range(1,block_depth)])
+ attn = AttnBlock(out_ch) if res in attn_res else nn.Identity()
+ down.attn = attn
+ down.down_blc = DownBlock(out_ch,with_conv)
+ downs.append(down)
+ down_dims.append(out_ch)
+ res = res//2
+
+ self.downs = nn.ModuleList(downs)
+
+ #mid block
+ last_ch_dim= chs*chs_mult[-1]
+ self.mid_res1 = ResidualBlock(out_ch,
+ last_ch_dim,
+ temb_dim,act,dropout)
+ self.mid_attn = AttnBlock(last_ch_dim)
+ self.mid_res2 = ResidualBlock(last_ch_dim,
+ last_ch_dim,
+ temb_dim,act,dropout)
+
+ #up block
+ down_dims = down_dims[1:]+[last_ch_dim]
+ ups = []
+ for i,skip_ch in zip(reversed(range(1,len(chs_mult)-1)),reversed(down_dims)):
+ out_ch = chs*chs_mult[i]
+ in_ch = out_ch+skip_ch
+ up = nn.Module()
+
+ up.res = nn.ModuleList([ResidualBlock(in_ch,out_ch,temb_dim,act,dropout)]+
+ [ResidualBlock(out_ch*2,out_ch,temb_dim,act,dropout) for _ in range(1,block_depth)])
+ attn = AttnBlock(out_ch) if res in attn_res else nn.Identity()
+ up.attn = attn
+ up.up_blc = UpBlock(skip_ch,with_conv) if i!=0 else nn.Identity()
+ ups.append(up)
+ res = int(res*2)
+ self.ups = nn.ModuleList(ups)
+ self.out = nn.Sequential(Normalize(out_ch),
+ nonlinearity(act),
+ nn.Conv2d(out_ch,3,3,padding=1))
+ self.res = res
+ def forward(self,x,timestep):
+ t = get_timestep_embedding(timestep,self.chs)
+ t = self.time_proj(t)
+ h = self.conv_in(x)
+ hs =[]
+ #Down
+ for blc in self.downs:
+ for res_block in blc.res:
+ h = res_block(h,t)
+ h = blc.attn(h)
+ hs.append(h)
+ h = blc.down_blc(h)
+ #Mid
+ h = self.mid_res1(h,t)
+ h = self.mid_attn(h)
+ h = self.mid_res2(h,t)
+ #Up
+ for blc in self.ups:
+ h = blc.up_blc(h)
+ for res_block in blc.res:
+ h = torch.cat([h,hs.pop()],axis=1)
+ h = res_block(h,t)
+ h = blc.attn(h)
+ return self.out(h)
+
+
+
+######################################################################
+# Diffusion Model and noise scheduler
+
+# ===================================
+
+#
+
+
+import math
+import numpy as np
+from typing import Optional, Tuple, List,Union
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ def sigmoid(x):
+ return 1 / (np.exp(-x) + 1)
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float32,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif beta_schedule == "sigmoid":
+ betas = np.linspace(-6, 6, num_diffusion_timesteps)
+ betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
+ elif beta_schedule == "cosv2":
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), beta_end))
+ betas = np.array(betas)
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+
+
+class DDIMSampler:
+
+ def __init__(self,
+ schedule_name: str,
+ diff_train_steps: int,
+ beta_start: float = 0.001,
+ beta_end: float = 0.2):
+ betas = get_beta_schedule(schedule_name,
+ beta_start=beta_start,
+ beta_end=beta_end,
+ num_diffusion_timesteps=diff_train_steps)
+ self.betas = torch.tensor(betas).to(torch.float32)
+ self.alpha = 1 - self.betas
+ self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
+
+ self.timesteps = np.arange(0, diff_train_steps)[::-1]
+ self.num_train_steps = diff_train_steps
+ self._num_inference_steps = 20
+ self.eta = 0
+
+ def _get_variance(self,
+ timestep: Union[torch.Tensor, int],
+ prev_timestep: Union[torch.Tensor, int] ):
+ alpha_t = self.alpha_cumprod[timestep]
+ alpha_prev = self.alpha_cumprod[prev_timestep] if prev_timestep >= 0 else torch.tensor(1.0)
+ beta_t = (1 - alpha_t)
+ beta_prev = (1 - alpha_prev)
+ return (beta_prev / beta_t) / (1 - alpha_t / alpha_prev)
+
+ @staticmethod
+ def treshold_sample(sample: torch.Tensor,
+ threshold: float = 0.9956,
+ max_clip: float = 1):
+ batch_size, channels, height, width = sample.shape
+ dtype = sample.dtype
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, threshold, dim=1)
+ s = torch.clamp(
+ s, min=1, max=max_clip
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_infer_steps(self,
+ num_steps: int,
+ device: torch.DeviceObjType):
+ self._num_inference_steps = num_steps
+ step_ratio = self.num_train_steps // self._num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ @torch.no_grad()
+ def p_sample(self,
+ x_t: torch.Tensor,
+ t_now: Union[torch.Tensor, int],
+ pred_net):
+ prev_timestep = t_now - self.num_train_steps // self._num_inference_steps
+ alpha_t = self.alpha_cumprod[t_now]
+ alpha_prev = self.alpha_cumprod[prev_timestep] if prev_timestep >= 0 else torch.tensor(1.0)
+ var = self._get_variance(t_now, prev_timestep)
+ eps = torch.randn_like(x_t).to(x_t.device)
+ t_now = (torch.ones((x_t.shape[0],),
+ device=x_t.device,
+ dtype=torch.int32) * t_now).to(x_t.device)
+ eta_t = pred_net(x_t, t_now)
+
+ x0_t = (x_t - eta_t * (1 - alpha_t).sqrt()) / alpha_t.sqrt()
+
+ c1 = self.eta * var.sqrt()
+ c2 = ((1 - alpha_prev) - c1 ** 2).sqrt()
+ x_tminus = alpha_prev.sqrt() * x0_t + c2 * eta_t + c1 * eps
+ return x_tminus, x0_t
+
+ def q_sample(self,
+ x_t: torch.Tensor,
+ timesteps: Union[torch.Tensor, int]):
+
+ alpha_t = self.alpha_cumprod[timesteps].to(timesteps.device)
+ alpha_t = alpha_t.flatten().to(x_t.device)[:, None, None, None]
+ eps = torch.randn(*list(x_t.shape)).to(x_t.device)
+ x_t = alpha_t.sqrt() * x_t + (1 - alpha_t).sqrt() * eps
+ return x_t, eps
+
+
+import copy
+
+class DiffusionModel:
+ def __init__(self,
+ main_net: DiffUnet,
+ ema_net: Optional[DiffUnet] = None,
+ num_steps: int = 100,
+ input_res: Union[Tuple[int, int], List[int]] = (32, 32),
+ emma: float = 0.999,
+ noise_sch_name: str = 'cosv2',
+ **noise_sch_kwargs):
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.eps_net = main_net.to(self.device)
+ self.ema_net = ema_net if ema_net is not None else copy.deepcopy(main_net)
+ self.ema_net = self.ema_net.to(self.device)
+ self.ema_net.eval()
+ self.steps = num_steps
+ self.res = (3,) + input_res if isinstance(input_res, tuple) else [3] + input_res
+ self.num_steps = num_steps
+ self.scheduler = DDIMSampler(noise_sch_name,
+ diff_train_steps=num_steps,
+ **noise_sch_kwargs)
+ self.emma = emma
+
+ @torch.no_grad()
+ def generate(self,
+ num_samples: int = 1,
+ num_infer_steps: int = 25,
+ pred_net: Optional[str] = 'ema',
+ return_list: bool = False,
+ x_t: Optional[torch.Tensor] = None):
+ shape = (num_samples,) + self.res if isinstance(self.res, tuple) else [num_samples] + self.res
+ x_t = torch.randn(*shape).to(self.device) if x_t is None else x_t
+ self.scheduler.set_infer_steps(num_infer_steps, x_t.device)
+ pred_net = getattr(self, pred_net + "_net")
+ xs = [x_t.cpu()]
+ for step in range(num_infer_steps):
+ t = self.scheduler.timesteps[step]
+ x_t, _ = self.scheduler.p_sample(x_t, t, pred_net)
+ xs.append(x_t.cpu())
+ return xs[-1] if not return_list else xs
+
+ @staticmethod
+ def inverse_transform(img):
+ """ Inverse transform the images after generation"""
+ img = (img + 1) / 2
+ img = np.clip(img, 0.0, 1.0)
+ img = np.transpose(img, (1, 2, 0)) if len(img.shape) == 3 else np.transpose(img, (0, 2, 3, 1))
+ return img
+
+ @staticmethod
+ def transform(img):
+ """Transform the image before training converting the pixels values from [0, 255] to [-1, 1]"""
+ img = img.to(torch.float32) / 127.5
+ img = img - 1
+ if len(img.shape) == 3: # one sample
+ img = torch.permute(img, (2, 0, 1))
+ else: # batch of samples
+ img = torch.permute(img, (0, 3, 1, 2))
+ return img
+
+ def train_loss(self,
+ input_batch: torch.Tensor,
+ loss_type: Optional[str] = 'l1_loss',
+ **losskwargs):
+ """Training loss"""
+ bs, _, _, _ = input_batch.shape
+ t = torch.randint(0, self.num_steps, size=(bs,))
+ x_t, eps = self.scheduler.q_sample(input_batch, t)
+ t = t.int().to(input_batch.device)
+ eps_pred = self.eps_net(x_t, t)
+ loss = getattr(torch.nn.functional, loss_type)(eps_pred, eps, **losskwargs)
+ return loss
+
+ def update_emma(self):
+ for p_ema, p in zip(self.ema_net.parameters(), self.eps_net.parameters()):
+ p_ema.data = (1 - self.emma) * p.data + p_ema.data * self.emma
+
+ def train(self):
+ self.eps_net.train()
+
+ def eval(self):
+ self.eps_net.eval()
+
+ def parameters(self):
+ return self.eps_net.parameters()
+
+ def save(self,
+ file_name: str):
+ if not os.path.exists(file_name):
+ os.makedirs(file_name)
+ ema_path = file_name + '/ema.pt'
+ net_path = file_name + "/eps.pt"
+ torch.save(self.ema_net.state_dict(), ema_path)
+ torch.save(self.eps_net.state_dict(), net_path)
+
+ def load(self,
+ path_nets: str):
+ pathes = [os.path.join(path_nets, p) for p in os.listdir(path_nets) if ("ema" in p or "eps" in p)]
+ for index in range(len(pathes)):
+ if "ema" in pathes[index]:
+ break
+ ema_p = pathes[index]
+ eps_p = pathes[int(not index)]
+ map_loc = 'cpu' if not torch.cuda.is_available() else 'cuda'
+ self.eps_net.load_state_dict(torch.load(eps_p, map_location=map_loc))
+ self.ema_net.load_state_dict(torch.load(ema_p, map_location=map_loc))
+
+
+DEVICE = torch.device('cuda' if torch.cuda.is_available()else 'cpu')
+Unet = DiffUnet(block_depth=2)
+Unet = Unet.to(DEVICE)
+Model = DiffusionModel(Unet,num_steps=1000,input_res=(64,64))
+optim = torch.optim.AdamW(Unet.parameters(),5e-4,weight_decay=0.01)
+
+print("number of parameters:{}".format(sum([p.numel()for p in Unet.parameters()])))
+
+
+kid = KernelInceptionDistance(subset_size=100,normalize=True).to(DEVICE)
+
+# preprocess the data before KID
+def prepare_kid(real,pred):
+
+ real = F.resize(((real+1)*0.5).clamp(0,1),(299,299))
+ pred = F.resize(((pred+1)*0.5).clamp(0,1),(299,299))
+ return real,pred
+
+
+
+
+######################################################################
+# #Trainig utilities
+
+#
+
+mean = lambda x:sum(x)/len(x)
+
+def mul(args):
+ res = 1
+ for i in args:
+ res*=i
+ return res
+
+def train_epoch(model,train_ds,opt,loss_type='mse_loss',num=5,max_norm=None,**kwargs):
+ model.train()
+ losses = []
+ for i,inputs in enumerate(train_ds):
+ inputs = inputs.to(DEVICE)
+ opt.zero_grad()
+
+ loss = model.train_loss(inputs,loss_type,**kwargs)
+ loss.backward()
+ if max_norm is not None:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
+ opt.step()
+ loss = loss.item()
+ losses.append(loss)
+ model.update_emma()
+ if (i+1)%(len(train_ds)//num)==0:
+ print(f"Finished training on {100*(i+1)/len(train_ds):.1f}% of the dataset and loss:{loss:.3f}")
+ return mean(losses)
+
+def plot_grid_images(imgs,grid_shape):
+ n_rows,n_cols = grid_shape
+ plt.figure(figsize=(n_cols * 2, n_rows * 2))
+ plt.title('Generated Images')
+ for row in range(n_rows):
+ for col in range(n_cols):
+ index = row * n_cols + col
+ plt.subplot(n_rows, n_cols, index + 1)
+ img = imgs[index]
+ img = (img+1)/2
+ img= img.permute(1,2,0)
+ img = torch.clamp(img,0.0,1.0)
+ plt.imshow(img)
+ plt.axis('off')
+
+ plt.tight_layout()
+ plt.show()
+ plt.close()
+
+
+def val_epoch(model,val_ds,loss_type='mse_loss',grid_shape=[2,3],infer_steps=20,**kwargs):
+ model.eval()
+ with torch.no_grad():
+ losses = []
+ for i,inputs in enumerate(val_ds):
+ inputs = inputs.to(DEVICE)
+ loss = model.train_loss(inputs,loss_type,**kwargs)
+ loss = loss.item()
+ losses.append(loss)
+ samples = model.generate(int(inputs.shape[0]),infer_steps).cuda()
+ inputs ,samples = prepare_kid(inputs,samples)
+ kid.update(inputs,real=True)
+ kid.update(samples,real=False)
+ mean_kid,std_kid = kid.compute()
+ kid.reset()
+ return mean(losses),mean_kid,std_kid
+
+
+######################################################################
+# Training Loop
+
+# =============
+
+#
+
+import time
+EPOCHS = 10
+loss_type='l1_loss'
+max_kid= 0.5
+path_to_save = '/content/Pretrained/DDIM'
+
+for e in range(EPOCHS):
+ st = time.time()
+ print(f"Started Training on:{e+1}/{EPOCHS}")
+ train_loss = train_epoch(Model,train_iter,optim,loss_type,max_norm=1.0)
+ val_loss,mean_kid,std_kid = val_epoch(Model,val_iter,loss_type)
+ if mean_kid
+# and retrain the model.
+#
+# Check out the other tutorials for more cool deep learning applications
+# in PyTorch!
+#
diff --git a/index.rst b/index.rst
index 776f7ac912e..1d3f46a33bc 100644
--- a/index.rst
+++ b/index.rst
@@ -134,6 +134,13 @@ What's new in PyTorch tutorials?
:link: beginner/dcgan_faces_tutorial.html
:tags: Image/Video
+.. customcarditem::
+ :header: Unconditional Image Generation With DDIM Model
+ :card_description: A step-by-step guide to building a complete Diffusion model for image generation.
+ :image: _static/img/thumbnails/cropped/60-min-blitz.png # Image that appears with the card
+ :link: advanced_source/DDIM_tutorial.html
+ :tags: Image/Video
+
.. customcarditem::
:header: Spatial Transformer Networks Tutorial
:card_description: Learn how to augment your network using a visual attention mechanism.
@@ -854,6 +861,7 @@ Additional Resources
beginner/dcgan_faces_tutorial
intermediate/spatial_transformer_tutorial
beginner/vt_tutorial
+ advanced/DDIM_tutorial
.. toctree::
:maxdepth: 2
diff --git a/requirements.txt b/requirements.txt
index 0811ded54c6..6ed86a673af 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,6 +22,8 @@ awscliv2==2.1.1
flask
spacy==3.4.1
ray[tune]==2.4.0
+torchmetrics
+torch-fidelity
tensorboard
jinja2==3.0.3
pytorch-lightning