-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a7de3eb
commit 8812e04
Showing
2 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ Tutorials | |
tutorials/basics | ||
tutorials/forward_kl | ||
tutorials/reverse_kl | ||
tutorials/vae |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,308 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Variational autoencoders\n", | ||
"\n", | ||
"This notebook walks you through implementing a variational autoencoder (VAE) for the MNIST dataset with a normalizing flow as prior." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.utils.data as data\n", | ||
"import zuko\n", | ||
"\n", | ||
"from torch import Tensor\n", | ||
"from torch.distributions import Distribution, Normal, Bernoulli, Independent\n", | ||
"from torchvision.datasets import MNIST\n", | ||
"from torchvision.transforms.functional import to_tensor, to_pil_image\n", | ||
"from tqdm import tqdm" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Data\n", | ||
"\n", | ||
"The [MNIST](https://wikipedia.org/wiki/MNIST_database) dataset consists of 28 x 28 grayscale images representing handwritten digits (0 to 9)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainset = MNIST(root='', download=True, train=True, transform=to_tensor)\n", | ||
"trainloader = data.DataLoader(trainset, batch_size=256, shuffle=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcAAAAAcCAAAAADTxTBPAAAKc0lEQVR4nO1aa1iVVRZejiJ5Q0IUqBAbDc1LYWZGmeRlaszSzPCS2sPjNGqlZY6XYNI0zVs0ojGUUNKTkWJ2AU2rCVJT8zKEt1S8oCKJoHFTRDnvWmd+HBDO+dY+xDNdZp54f328a6+9F9/7fd9ea+1DVI961KMev0sEZ5/+rUOox3+BNwrw6a+6YI9ETrzjV13RgGX2A0G/3Oxp6S5E5xnHPpgxo/HPvIzfdsbeVj/vnA19fHxmL/rkhg/s5S9brSGFAH5059//XEeD5SW2h9U5nBYBT0V5Kny7HxkPGnyCuzxtZ2bmj5Ub7hG2vfZVl5avcCYmlIqISL/aXSvR/C+j/5FweWNCQkLCvDtNg4JTwdP7qqYGa0pv+smLXUPbDk/Gr2Vm5tPruPSb+y0D7jrDKMpHaI0b02eo85AXNxgEjLhk4z51DOjmN/cCWK5Ymn1iErDLa6dywAAArPSymH3lrH9tyy4qLx3uzPicExGRwgd+UthEtESqgf1RN6uDQgEepfs3zZWnfupa19C9kCthG/vYY3dbhGja+xQYux8HR1WTs1c5jfnDiv2d9NlftqkC9ordz/zCiNheVlOnty6BTx3AOW3GGJOAqQCqBMS9FrOvSIgeYDU241+u1MRLckpEXnfrF9RpbkFBIhERHReR82lpaWlxaRkiMkgbHpzNPMQ0V7pEmUxERH+LfI/5exfS5xgzM+/YWF6iOq1y3JmINF5dTR53FvBGeU9fccAF24Gg6yz0iHPgtP0Ar3G1tHyrCMCRoFuA3tbpvL82CTgZODv/1fnz00wC2kP0CIn6fOFDRDTqQtbtFlum7BeRP5pciWhAXJEw82EiImr/UPv2AQ6+xSmRFZrDPKy/0TjbMNONJKKwZ9faAKDikIvh0befZc5oRl3iNbcehczpUzn39iH2Gnc721nATTJLXbN3rs32pIVtFFqK9L4ezTeCp7naIgAgK5B0AW84wXhJzWIaBQb6ExF55QDrrNunr0ioGiERHeHeREQH7UOttse/ExG51eRKb+8SkZK4cdZn9AmR8p6Kx47Lx28xTkeBciVA4wM25+SUgHcDACwViFeDeH7CMGNIIbC++aDI1kR88VoieluZs4A75G7VO4E5zcpGAJu8iMYAp1u72j4Djq9uS/SIKiDNYmCSIVQiIgq/CMRYaV8Ro9t3GEBEIaXqxuS/T0Q+NHi2ipcLe4Z1bGu1NH7rskh3xWUIY6GbPCXQLhMUesBJAEDHVh37ngI2WQe8xul/UCcMTuL8vY87rhlJVfSL4iSgX54Eat6+bDtvTeHmM5Z7EdFhwLoX3DDnnjZERE/pAlItAo5MA2DNYci7SJYaXObZDrYmarYa2z2sxtGLWUSmGFxjeFlz1dBvpcjVCdbXkryjGTOIiJ6PjtYcA0UmK/SXAMom9SSiOOC4r3VAs3RWMy3PVBQ/2KryiWF8U8Unyos1h62SI96Kd7sMts22sLO5/NMmRNcNLuO52poOvGMQ0M5sFnD0wXIA/26imFJNAgaeKw8johXIsdo6HaoQMe2BTedmDx6iaEREd9lE5MpDygPR4ms730009YVsZruyFeoCPlAKZDs29lRATanal5x+d1IDCx0KVBdxTgL2v0Z7DU+9IqO1SSfa+POWrqT3OXxKRB12AcnNNC+i5yKjIndja0PNZnoD283asmULAygcr36hTAJ2O4alRDTtKpRP19ByR0WgVTS0gFfr8hFFO9x2z+rmankYfPIWCvkYKD3E31p3c13AL4Ct/YmIrn+iuPLKGmsx80zL/rmDa/Qm7Lyt6jJRwono9u7TYuJKLhasL4GW8j9abNviZ2HbAG3bzNxewrA9ogXStOd6Zjtzbns1ToOA3bKryogU1Y1S1fSuUQTbeWeUZ8CuipWa13OXRUx7oF0G60sR3fNZQWUtGN3GydBiMnJfoeAkzn+/exgfUgS0awIOy0h3VLGRwD5TPdvtS+Y4l5f64cuYUv0XI7bqMo5/zMzMZKko3rF09E0e+RXKhO2YOdFKe+eBAeScQZ7i5HHXGVzMXVsK5E1XO1gmAU8yM9uZmR/S3ChVihV2DMBZwE49GCIaOGrU2GJdwF2c8yfdiYja3jEwgUVEvnbKLQYCs8kvFcWxnl0PF8da/Qx7YCUeKceVZ4xW77Fgl0o2HGevvZSeC/nL6h17ZkpKSkrKOEfuOV6OK/O9abPZtPZMr/OctaSz/2Yo37TGg4FZ95LPXgAYoTXT7MxrteiD/n5n165duy4FdAFf0AQcYSvP6xuSBjBsZ/Q3nqjBHDlmeVV6NSafOVxiri+IaPROEZEZNamZANF2IIxCAS2LCRRx03hkYLy7Fa/y1fudiHCcrLr0nIfTpi5ksiy2kiEnbLZ15rX62Nn6qHksBDZ4U+s9XD73I+Dzft0tmTgD6GyetqVJwGFSZv1gpZ8YR0Sdt4EBYwHtKXLIZVsNyDg/hsiX+R5zJETUaLOIOBXXC/ljCvmBp1BwNk/RXAJFzEXGAjuzqeVMdNsrm5gznYuJcCyrvApJwkdG12StSC6w2bbpOTYRET3IsNSADReh5JnrqedOHOlLXn9eVYLqB6gK/4Ra51VhuEnAIXI52EI+H0hE1KcIwzt3VmoPB5aIuLYb8somEdF8/sLo5MDr4uK7EB9RyBmszCnYdaN6c9wJ2HgT41m92iPqGPsDM1dsdGaHc2XRP7WQzS0eXUC22Qwd20q7VcCnUTrSZ+DaS5jtqCpHbdhgaVpM1gT0GFRZOYwrNQlIhyRON7SM5aMK3SrF0dsIKLaWEZFlIpIlJ/UTtoDZlT3xhl+JVNxX0xQKhE4sBjjfEGWgiOlT3nQ8eJW3bvOfeoKZeZdrWhWOq8tDAsNTT/PJ1XqrhYiIku3WdlmindndyZ32BuahLOMIgJfUAqISR5ntLv/jfZsQSETkM6YIuKgf1FBMiSHpj0Se9tC/L4fDOlCPUd+JvGbxnJaUn1+woaMap/8+cdROfotFZJ+TrUepo+Gu5S9E5C6JabEGeE5///z6fc/MvGOoxRwO4IfDALa9YliQiIiSJcKVCsnl8mhTmURE9IwiYCYApE7r0MjdYp8A7CLgXuCNBQsWLNjDwFfDDH4xxfp/H3TCpvYTQreLZG8oEeHvDeWqAWtEQpoQNXm5RMRe6pKSDEpj4J0pxkSl8UGTgLcCWarB58NjzMzfPKr0L276FmAgf5nVVBPJ1rb7/TbWMtNqdLOzRcAWY5dG+tV2/D1QFdABPrvC+NDEyGMqfxTv6g7RTztquQu1BOSKv4pIRnp6hohIqV50u8EeSVX5TglQykaiXutymJkvvao/ZQFzwHjdTfeciIiS7XUXkI7CzUfZDYIOWATs/g4AICtzuaXzUY2z5fqJZRR0YYk8p09PEimq609Jbv6g6ki3Yoly3FkLEsT1Zx0OJAFqBbiImQ8unO9d54VqIsL6BvpvqU3ACKS5KQfqBs8J57Fugvsj9zX7fsEf0zhHMzJ+2tat8fEjtdOI2tDu24ka3WU94swVxG8Cr8+xtm6by+8Zi3Hif0w/Iq833FXk9XBCf+XkrR7/5/gP6yvFvKvDVPUAAAAASUVORK5CYII=", | ||
"text/plain": [ | ||
"<PIL.Image.Image image mode=L size=448x28>" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"x = [trainset[i][0] for i in range(16)]\n", | ||
"x = torch.cat(x, dim=-1)\n", | ||
"\n", | ||
"to_pil_image(x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Evidence lower bound\n", | ||
"\n", | ||
"As usual with variational inference, we wish to find the parameters $\\phi$ for which a model $p_\\phi(x)$ is most similar to a target distribution $p(x)$, which leads to the objective\n", | ||
"\n", | ||
"$$ \\arg \\max_\\phi \\mathbb{E}_{p(x)} \\big[ \\log p_\\phi(x) \\big] $$\n", | ||
"\n", | ||
"However, variational autoencoders have latent random variables $z$ and model the joint distribution of $z$ and $x$ as a factorization\n", | ||
"\n", | ||
"$$ p_\\phi(x, z) = p_\\phi(x | z) \\, p_\\phi(z) $$\n", | ||
"\n", | ||
"where $p_\\phi(x | z)$ is the decoder (sometimes called likelihood) and $p_\\phi(z)$ the prior. In this case, maximizing the log-evidence $\\log p_\\phi(x)$ becomes an issue as the integral\n", | ||
"\n", | ||
"$$ p_\\phi(x) = \\int p_\\phi(z, x) \\, \\mathrm{d}z $$\n", | ||
"\n", | ||
"is often intractable, not to mention its gradients. To solve this issue, VAEs introduce an encoder $q_\\psi(z | x)$ (sometimes called proposal or guide) to define a lower bound for the evidence (ELBO) for which unbiased Monte Carlo estimates of the gradients are available.\n", | ||
"\n", | ||
"$$\n", | ||
"\\begin{align}\n", | ||
" \\log p_\\phi(x) & \\geq \\log p_\\phi(x) - \\mathrm{KL} \\big( q_\\psi(z | x) \\, || \\, p_\\phi(z | x) \\big) \\\\\n", | ||
" & \\geq \\log p_\\phi(x) + \\mathbb{E}_{q_\\psi(z | x)} \\left[ \\log \\frac{p_\\phi(z | x)}{q_\\psi(z | x)} \\right] \\\\\n", | ||
" & \\geq \\mathbb{E}_{q_\\psi(z | x)} \\left[ \\log \\frac{p_\\phi(z, x)}{q_\\psi(z | x)} \\right] = \\mathrm{ELBO}(x, \\phi, \\psi)\n", | ||
"\\end{align}\n", | ||
"$$\n", | ||
"\n", | ||
"Importantly, if $p_\\phi(x, z)$ and $q_\\psi(z | x)$ are expressive enough, the bound can become tight and maximizing the ELBO for $\\phi$ and $\\psi$ will lead to the same model as maximizing the log-evidence." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class ELBO(nn.Module):\n", | ||
" def __init__(\n", | ||
" self,\n", | ||
" encoder: zuko.flows.LazyDistribution,\n", | ||
" decoder: zuko.flows.LazyDistribution,\n", | ||
" prior: zuko.flows.LazyDistribution,\n", | ||
" ):\n", | ||
" super().__init__()\n", | ||
"\n", | ||
" self.encoder = encoder\n", | ||
" self.decoder = decoder\n", | ||
" self.prior = prior\n", | ||
"\n", | ||
" def forward(self, x: Tensor) -> Tensor:\n", | ||
" q = self.encoder(x)\n", | ||
" z = q.rsample()\n", | ||
"\n", | ||
" return self.decoder(z).log_prob(x) + self.prior().log_prob(z) - q.log_prob(z)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Model\n", | ||
"\n", | ||
"We choose a (diagonal) Gaussian model as encoder, a Bernoulli model as decoder, and a [Masked Autoregressive Flow](https://arxiv.org/abs/1705.07057) (MAF) as prior. We use 16 features for the latent space." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class GaussianModel(zuko.flows.LazyDistribution):\n", | ||
" def __init__(self, features: int, context: int):\n", | ||
" super().__init__()\n", | ||
"\n", | ||
" self.hyper = nn.Sequential(\n", | ||
" nn.Linear(context, 1024),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(1024, 1024),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(1024, 2 * features),\n", | ||
" )\n", | ||
"\n", | ||
" def forward(self, c: Tensor) -> Distribution:\n", | ||
" phi = self.hyper(c)\n", | ||
" mu, log_sigma = phi.chunk(2, dim=-1)\n", | ||
"\n", | ||
" return Independent(Normal(mu, log_sigma.exp()), 1)\n", | ||
"\n", | ||
"\n", | ||
"class BernoulliModel(zuko.flows.LazyDistribution):\n", | ||
" def __init__(self, features: int, context: int):\n", | ||
" super().__init__()\n", | ||
"\n", | ||
" self.hyper = nn.Sequential(\n", | ||
" nn.Linear(context, 1024),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(1024, 1024),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(1024, features),\n", | ||
" )\n", | ||
"\n", | ||
" def forward(self, c: Tensor) -> Distribution:\n", | ||
" phi = self.hyper(c)\n", | ||
" rho = torch.sigmoid(phi)\n", | ||
"\n", | ||
" return Independent(Bernoulli(rho), 1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"encoder = GaussianModel(16, 784)\n", | ||
"decoder = BernoulliModel(784, 16)\n", | ||
"\n", | ||
"prior = zuko.flows.MAF(\n", | ||
" features=16,\n", | ||
" transforms=3,\n", | ||
" hidden_features=(256, 256),\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Note that because the decoder is a Bernoulli model, the data $x$ should be binary." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Training\n", | ||
"\n", | ||
"As explained earlier, our objective is to maximize the ELBO for all $x$.\n", | ||
"\n", | ||
"$$\n", | ||
"\\arg \\max_{\\phi, \\, \\psi} \\mathbb{E}_{p(x)} \\big[ \\text{ELBO}(x, \\phi, \\psi) \\big]\n", | ||
"$$" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|██████████| 64/64 [07:09<00:00, 6.71s/it, loss=65.8]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"elbo = ELBO(encoder, decoder, prior).cuda()\n", | ||
"optimizer = torch.optim.Adam(elbo.parameters(), lr=1e-3)\n", | ||
"\n", | ||
"for epoch in (bar := tqdm(range(64))):\n", | ||
" losses = []\n", | ||
"\n", | ||
" for x, _ in trainloader:\n", | ||
" x = x.round().flatten(-3).cuda()\n", | ||
"\n", | ||
" loss = -elbo(x).mean()\n", | ||
" loss.backward()\n", | ||
"\n", | ||
" optimizer.step()\n", | ||
" optimizer.zero_grad()\n", | ||
"\n", | ||
" losses.append(loss.detach())\n", | ||
"\n", | ||
" losses = torch.stack(losses)\n", | ||
"\n", | ||
" bar.set_postfix(loss=losses.mean().item())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"After training, we can generate MNIST images by sampling latent variables from the prior and decode them." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcAAAAAcCAAAAADTxTBPAAAMh0lEQVR4nO1aa5AU1RU+t3tmdvbBLrsLuyywCqgIKOALFSoiovhCVFRMNETRUomJJmJhULR8kQTRFCsSBINiiWLElOITjLFU5CEvHwgIuMs+gWVhH7Mz0zM93fd8lR+zj+np26tsTKxU8f2aOafP49577j3nnm6iYziGYziGo4Y2YfZxP7UPnchY+9j5P7UP3cLmUt+MjP9AXohuCvaptuvHdVf4x8en5rbzfmofiIiEf97WSPTAZO2HChyABVjL9W7au+rg18pFOPWdcQNKc7sQPC7C4RtUXoqCRTV7l5ye2U2HuoDoNT7oNc7RbF7RRTTpGdk/vj8qiOFf2ABg3pjmTc75+e2+DAum0Is2tJg7W0272d8te70t2aSSzJxp2jYiG4q8BAfGuVXJFINeDMUi4YYVCq4miEjkDejZlUs+XdeVgZE7t85uHqyO7AXA/i5CpvSs/r6ubP54GFkPAAA3Od3pV/vxkKTrGZeuu8gVhoH54B7eWgvfaDXrz1RxpgFmsYLum7whwQDM3mqNQRtmoZKT/3HTkW8rQrF49blpHD3vxExNaMHryy7w9HTQ5DcfXvHKPFVO6P1WxJIc/pNPsdHuBlqyPJXSBQs2zO9yAbVcVzYPTJ3ZN++oM8SIKABmBvgeB+NN6+BZgohIzDF2FrgFRRTSQ6eoYQCA/LPCnVqAlZGr9Xs5ZgHYrVa6E9aNSoavrGHP6keffPOA5MRoJytjyhOZGulFv3mxl4ertHbv5cH8P8bWKw6Fnq+1Rg1DyoOXuE/Rt4Fa75yTNfm6g7VzVb5mnzzg3Fv/+t6uz9a+9Ys0XmkN2D78Xf23xcpDW+vdTzVtwyMAOGYmGLAcHAOHAkREVBKVf1NqBPKU7lcBAIeNmNWomLcYYHukFdH/KgOAkncnc3+10KOh0KeTThs/vyrCCPd18LIrV2YS+S985yuv4/6QtZGIbqj6XBFqq0y7acOcT6q29Hct1Sagynv9vuLbJmyqzk+jiuDxf9lwxGYpE3ak7sYlxzkVXNECgG3JMh5wafRvlgzZeoKLcY4JAC0VlWEGOHUQQ4FkjZzfwHF1zbkGryqoQgIoL9L14KTdLa+5llgDYHuP/fdpbrSjt8RVaomBUbvp+l4Dx855dFkE2OrgjYuv0Ym0iRv/7mGO8R0R0cyvFMfdhRasBUOyT5m4cUd6HMaBw0G3RBuAfb0eiG9wErXci3furD9cv2PxXVNGZ+lC9HWmitNNKW2zNV59cPWQjgnwEZEomRpCW5Y7mD5xN8UAcPTBCx/eLAGkjuIbyAARkbiWeabSTy2MRxVkC7B0IiIt6/lwy8R09kAA0vsS0gBEVfSDOKgW0Mtl/Mm8QF5B7+LhSwDDMcRKOUMQiaHhWrUw4xkioh4fTVMETRj4oCgj2GNZ7OPjnZwa4H3PGPQDFeSP8ggnue+C2vKb+ubndFIcJk8+IK2a1+ed0bOoo7AIjq3cvzJstS0erPoKlmn74fYoAHPdmTl9T34oBiD1jE20zZh/P6KFRAp/V0P2c1PfAba3jaS4iu0N6TPzWwAJzwX0A5imoJ8D2XHBcGq8w5LrMpI3S63URMKxlZp5PRHpE7lBZSwTuIGIKOepN1S5QMIe16PwV2+YiQ+cl5tWYKug3PQjMonLgSqi6TjsdFO7Z/9bA5QCRERUGAFMMxKaUZAy0a0xm9tWL3I7Ef2B4dz2AQvg2H3ZmvAVPhgFOKUMDyAZnDSNsUaQ5r58n2cnytzHSCYj1Pbzwq0ScmH6Aw0ADMXCJ/GqOgVmSUwlIiL/zC3bVl6Z4oswEe/Ie1mHEHek1zDuISIapdzV2UAJEVGgfN85Kl8Y4XGv1EWkXFHoCN91QGJwwduWDCmK4neBUqKA5HFOul4eXqoykkSf5D5jaZUP76T+c+PiqgYprU1t61IOOCJfNwDYl/sECa3o0iiAlOTZQ2I5EdF5NnCJIN1dMX9pt77vLk5PAvYQEZHvkRYAliu0YwDbUz0G4gPwcwV9B0xBRPq9FfGG3cbGlInrD6zvWE9xAM2OQDOssUREPg4pbEk+gYgofw1vUp2HYyTijQnI8O5BDnoWwP+oAwA0p1/VSxOQGpFYycYI581SfBWZ7a5N2lEJADK037K4NmW3+/x6TlZnRJqIO0b3GQB+KZNIaP4esxKATInePAvLBGmPGAy+WHk3mRS1Wxe4fMph8Mt5J90SZTAAVwrUbACY43HbOaLegIUmX0NEfertA2fkDN4VStnAW8GdCUqrR8whuMt6nogoKONupTuxligwfdsB2KVurjbeACC54f6Tnasr6oEd6wEZlY4ji4i0W2zIW3y+wDyLd92b6ZDTV9R/PsWrp5O1l60Fw3QSxdu4VeFMEmcDjlrsGgD2wgARCS27dBUDZqrJSjYaakxORNqKGRfylr26LrbvlHTygfakK5FmkIiIshkAPlLfcU8DMFxBf07O14iOqzGWZhDlVVen9AEaEemcln4x7HUIjrdrNCIq4RbX3PWRsW82m+AKC/NdBvVfxhgMq+kZVyPNxzCjQGJ88Uqbn3WOrV7WHT7SGjMZiQ/7pZ1OOWNqw197XO2Dpw5NTr221LZGq58hHwOTUwmNAJ7ViIhEoO9dRwC5PXUBi2IMcOPvNiPisV2Enj2j0W5Iu2FpuwFANu4GsMMt1MMGgDplm1AD8JjKUEW8D1Hhd4evJSLf0qqyFJaN2k737rRt511D1Bs6EfUz9rqmrk8cYK4feRnCLoOZrQDb0eiWQYo9Y8I20VRM5KuQZQ5O/vRpS2ptKQGsHpqXfizrD9Q1fDRN2bHr+HW9xWaJ4gkiokVA2BEUAP6VVCcGfp4AED3PsVLBR9bOOMuvVWCXh0Yiyr7A4i9c1ECenzJCSDvO2rAfABpHqlgWUKMKlp7GezqJu54bTiROj0ZXpEyCJtHc8TfjE46m5eoJxt0a9fpE1rmaGNrVS5ffpxNVw9XXK44C8Zcer2kerZruCNhuLCWiwHY5K40nSOs59OJGu663YhzaqC9j1jvpVOHLDbQ/XAbwHIVJomRuui9VawCQA4mIyDe2GQDsF5XdigRcNlOxjDlHRV/k1VGZFJXA4asUAxQAOsvBlAdyIvs0ypx4VkAffojl6w4hG9wesuLXlvwgXe9C27KiTbDP9vB/NGyXH9vY2nd2YDYsZVd2PrO9ewQJ/+aENUj1wAuJiMclX99jVP8sjTa1bum0TCIK5F4WBfCkh59UB+zqT0Rae0IrAXh6pk/oI2uTJezT6hTLuMhLJxGJutQiohPnAPDoXj0VY2k8rLhIBAAMa//jT/GmT3NkVHDE0kllIUYkrZtkAE8kHxU3mgi5Zi4nZO0qnskJj/fAeYBrl11nmQv8IqsZIeVl3dcCcKJ6hwUYqpw2AVB0jpMYsMY4lHZ9LImbdqNptt0jpnhJ5jLgbJ5mAOA9DU12suQwTlUL6syXqsilBRlCC46usmVcMYogAA+F1G+PBbNpjJtRAOBdnURGwbC5T1zRu3MFM9cb8bqEKaWM3Jwu9C4gt5yiB4IjN0vwKLdWTRAVyJg6On0G3G3e2ZGmoVrOQpZ3qAfgj7WXaaqCqwR8v1qOiHy99oQ/cZJyHwrJNn2JMzwFaQgQTksPzehEYpZXb0ST/I2COqFx35ZthgS4SnXam8B0T18uMwGscosFAHCLYQEcL3+6rHOP6jc3GtK2rK1DFaYsAJywLBvgMjefiIhy4qaSLjZ2dB1SsDwRXvXEulC82fN10CNfA7AfV7Vv9Chv9xIjX+6wt1vmpQn0H7940xfRLdd2/XZ1KfNTaWtU2B5IkMtKPDt7IoxIkYurfSMBwJb7xqqEqoAmb1+0kLpApcltjSOu2XlvRqpNkTPkmlH56vnMMNuHEe7jZbHAbFUOcCxYQc/+zmZmNm7wHgL5cnuoa/Mr48pzNYmsUeMq4y9065uQQ5CuqRZF+yXbxmsDuvyWYjHLWaUun3yDx1w9O+jh6iVeBUwbskIcH6JiHF+eqK788uvbvVsWKty2n5nN5SO8h1HQrHyfokXwuoJMY2rjbB26vVsT/XjTh95M/YzVlduf7I5eEbWrlBXT96M43Dzv+i5evbsxDIC60dvhjva//TIpY/tqVfd1JXtdcXWtmx/8aBcv7CL6fLPK6z7y7LR0hVLbeE/9OvT7MWDGiT/4cyciEv5VwLJu2vovQR9/uoIqpm4+moH9MOR2tfJZfR+8Nb9bsZtXH150dOdSCtyvgbtG1pnpn6YcQxJi7pnd2n9E2ssv9fr//Ab1GJJwHur/BorGqSV9w5JjAAAAAElFTkSuQmCC", | ||
"text/plain": [ | ||
"<PIL.Image.Image image mode=L size=448x28>" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"z = prior().sample((16,))\n", | ||
"x = decoder(z).mean.reshape(-1, 28, 28)\n", | ||
"\n", | ||
"to_pil_image(x.movedim(0, 1).reshape(28, -1))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "zuko", | ||
"language": "python", | ||
"name": "zuko" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |