Skip to content

Commit a45e710

Browse files
committed
initial commit
1 parent 8aa59b5 commit a45e710

19 files changed

+3089
-0
lines changed

01-Diffusion-Sandbox.ipynb

Lines changed: 554 additions & 0 deletions
Large diffs are not rendered by default.

02-Pixel-Diffusion.ipynb

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "55e50b25",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"%load_ext autoreload\n",
11+
"%autoreload 2\n",
12+
"%matplotlib inline"
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"id": "ce528bb2",
18+
"metadata": {},
19+
"source": [
20+
"# Unconditional Pixel Diffusion Training\n",
21+
"\n",
22+
"In this notebook, we will train a simple `PixelDiffusion` model in low resolution (64 by 64).\n",
23+
"\n",
24+
"The training should take about 9 hours.\n",
25+
"\n",
26+
"---\n",
27+
"\n",
28+
"Maps dataset from the pix2pix paper:\n",
29+
"```bash\n",
30+
"wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz\n",
31+
"tar -xvf maps.tar.gz\n",
32+
"```"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": null,
38+
"id": "ec7d6dc4",
39+
"metadata": {},
40+
"outputs": [],
41+
"source": [
42+
"import torch\n",
43+
"import torch.nn.functional as F\n",
44+
"import torchvision\n",
45+
"import torchvision.transforms as T\n",
46+
"from torchvision.transforms import ToTensor\n",
47+
"from torch.utils.data import Dataset\n",
48+
"import pytorch_lightning as pl\n",
49+
"\n",
50+
"import numpy as np\n",
51+
"import matplotlib as mpl\n",
52+
"import matplotlib.pyplot as plt\n",
53+
"import imageio\n",
54+
"from skimage import io\n",
55+
"import os\n",
56+
"\n",
57+
"from src import *\n",
58+
"\n",
59+
"mpl.rcParams['figure.figsize'] = (8, 8)"
60+
]
61+
},
62+
{
63+
"cell_type": "code",
64+
"execution_count": null,
65+
"id": "effe3e3a",
66+
"metadata": {},
67+
"outputs": [],
68+
"source": [
69+
"import kornia\n",
70+
"from kornia.utils import image_to_tensor\n",
71+
"import kornia.augmentation as KA\n",
72+
"\n",
73+
"class SimpleImageDataset(Dataset):\n",
74+
" \"\"\"Dataset returning images in a folder.\"\"\"\n",
75+
"\n",
76+
" def __init__(self,\n",
77+
" root_dir,\n",
78+
" transforms=None,\n",
79+
" paired=True,\n",
80+
" return_pair=False):\n",
81+
" self.root_dir = root_dir\n",
82+
" self.transforms = transforms\n",
83+
" self.paired=paired\n",
84+
" self.return_pair=return_pair\n",
85+
" \n",
86+
" # set up transforms\n",
87+
" if self.transforms is not None:\n",
88+
" if self.paired:\n",
89+
" data_keys=2*['input']\n",
90+
" else:\n",
91+
" data_keys=['input']\n",
92+
"\n",
93+
" self.input_T=KA.container.AugmentationSequential(\n",
94+
" *self.transforms,\n",
95+
" data_keys=data_keys,\n",
96+
" same_on_batch=False\n",
97+
" ) \n",
98+
" \n",
99+
" # check files\n",
100+
" supported_formats=['webp','jpg'] \n",
101+
" self.files=[el for el in os.listdir(self.root_dir) if el.split('.')[-1] in supported_formats]\n",
102+
"\n",
103+
" def __len__(self):\n",
104+
" return len(self.files)\n",
105+
"\n",
106+
" def __getitem__(self, idx):\n",
107+
" if torch.is_tensor(idx):\n",
108+
" idx = idx.tolist() \n",
109+
"\n",
110+
" img_name = os.path.join(self.root_dir,\n",
111+
" self.files[idx])\n",
112+
" image = image_to_tensor(io.imread(img_name))/255\n",
113+
"\n",
114+
" if self.paired:\n",
115+
" c,h,w=image.shape\n",
116+
" slice=int(w/2)\n",
117+
" image2=image[:,:,slice:]\n",
118+
" image=image[:,:,:slice]\n",
119+
" if self.transforms is not None:\n",
120+
" out = self.input_T(image,image2)\n",
121+
" image=out[0][0]\n",
122+
" image2=out[1][0]\n",
123+
" elif self.transforms is not None:\n",
124+
" image = self.input_T(image)[0]\n",
125+
"\n",
126+
" if self.return_pair:\n",
127+
" return image2,image\n",
128+
" else:\n",
129+
" return image"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": null,
135+
"id": "528e3f70",
136+
"metadata": {},
137+
"outputs": [],
138+
"source": [
139+
"CROP_SIZE=64\n",
140+
"\n",
141+
"inp_T=[ \n",
142+
" KA.RandomCrop((CROP_SIZE,CROP_SIZE)),\n",
143+
" ]\n",
144+
"\n",
145+
"train_ds=SimpleImageDataset('./data/maps/train',\n",
146+
" transforms=inp_T\n",
147+
" )\n",
148+
"\n",
149+
"test_ds=SimpleImageDataset('./data/maps/val',\n",
150+
" transforms=inp_T\n",
151+
" )\n",
152+
"\n",
153+
"for idx in range(16):\n",
154+
" plt.subplot(4,4,1+idx)\n",
155+
" plt.imshow(train_ds[idx].permute(1,2,0))\n",
156+
" plt.axis('off')\n",
157+
"plt.tight_layout()"
158+
]
159+
},
160+
{
161+
"cell_type": "markdown",
162+
"id": "82e372cc",
163+
"metadata": {},
164+
"source": [
165+
"### Model Training"
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": null,
171+
"id": "04156f52",
172+
"metadata": {},
173+
"outputs": [],
174+
"source": [
175+
"model=PixelDiffusion(train_ds,\n",
176+
" lr=1e-4,\n",
177+
" batch_size=16)"
178+
]
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": null,
183+
"id": "1a83cab8",
184+
"metadata": {},
185+
"outputs": [],
186+
"source": [
187+
"trainer = pl.Trainer(\n",
188+
" max_steps=2e5,\n",
189+
" callbacks=[EMA(0.9999)],\n",
190+
" gpus = [0]\n",
191+
")"
192+
]
193+
},
194+
{
195+
"cell_type": "code",
196+
"execution_count": null,
197+
"id": "deafb040",
198+
"metadata": {},
199+
"outputs": [],
200+
"source": [
201+
"trainer.fit(model)"
202+
]
203+
},
204+
{
205+
"cell_type": "code",
206+
"execution_count": null,
207+
"id": "c4faf7bd",
208+
"metadata": {},
209+
"outputs": [],
210+
"source": [
211+
"B=8 # number of samples\n",
212+
"\n",
213+
"model.cuda()\n",
214+
"out=model(batch_size=B,shape=(64,64),verbose=True)"
215+
]
216+
},
217+
{
218+
"cell_type": "code",
219+
"execution_count": null,
220+
"id": "aeddad22",
221+
"metadata": {},
222+
"outputs": [],
223+
"source": [
224+
"for idx in range(out.shape[0]):\n",
225+
" plt.subplot(1,len(out),idx+1)\n",
226+
" plt.imshow(out[idx].detach().cpu().permute(1,2,0))\n",
227+
" plt.axis('off')"
228+
]
229+
},
230+
{
231+
"cell_type": "markdown",
232+
"id": "65df248d",
233+
"metadata": {},
234+
"source": [
235+
"By default, the `DDPM` sampler contained in the model is used, as above.\n",
236+
"\n",
237+
"However, you can use a `DDIM` sampler just as well to reduce the number of inference steps:"
238+
]
239+
},
240+
{
241+
"cell_type": "code",
242+
"execution_count": null,
243+
"id": "fac0abca",
244+
"metadata": {},
245+
"outputs": [],
246+
"source": [
247+
"B=8 # number of samples\n",
248+
"STEPS=200 # ddim steps\n",
249+
"\n",
250+
"ddim_sampler=DDIM_Sampler(STEPS,model.model.num_timesteps)\n",
251+
"\n",
252+
"model.cuda()\n",
253+
"out=model(batch_size=B,sampler=ddim_sampler,shape=(64,64),verbose=True)"
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": null,
259+
"id": "22c834ed",
260+
"metadata": {},
261+
"outputs": [],
262+
"source": [
263+
"for idx in range(out.shape[0]):\n",
264+
" plt.subplot(1,len(out),idx+1)\n",
265+
" plt.imshow(out[idx].detach().cpu().permute(1,2,0))\n",
266+
" plt.axis('off')"
267+
]
268+
},
269+
{
270+
"cell_type": "code",
271+
"execution_count": null,
272+
"id": "91e24258",
273+
"metadata": {},
274+
"outputs": [],
275+
"source": []
276+
}
277+
],
278+
"metadata": {
279+
"language_info": {
280+
"name": "python",
281+
"pygments_lexer": "ipython3"
282+
}
283+
},
284+
"nbformat": 4,
285+
"nbformat_minor": 5
286+
}

0 commit comments

Comments
 (0)