Skip to content

Commit

Permalink
📝 Add VAE tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Nov 29, 2023
1 parent a7de3eb commit 8812e04
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Tutorials
tutorials/basics
tutorials/forward_kl
tutorials/reverse_kl
tutorials/vae
308 changes: 308 additions & 0 deletions docs/tutorials/vae.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Variational autoencoders\n",
"\n",
"This notebook walks you through implementing a variational autoencoder (VAE) for the MNIST dataset with a normalizing flow as prior."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.utils.data as data\n",
"import zuko\n",
"\n",
"from torch import Tensor\n",
"from torch.distributions import Distribution, Normal, Bernoulli, Independent\n",
"from torchvision.datasets import MNIST\n",
"from torchvision.transforms.functional import to_tensor, to_pil_image\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data\n",
"\n",
"The [MNIST](https://wikipedia.org/wiki/MNIST_database) dataset consists of 28 x 28 grayscale images representing handwritten digits (0 to 9)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"trainset = MNIST(root='', download=True, train=True, transform=to_tensor)\n",
"trainloader = data.DataLoader(trainset, batch_size=256, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcAAAAAcCAAAAADTxTBPAAAKc0lEQVR4nO1aa1iVVRZejiJ5Q0IUqBAbDc1LYWZGmeRlaszSzPCS2sPjNGqlZY6XYNI0zVs0ojGUUNKTkWJ2AU2rCVJT8zKEt1S8oCKJoHFTRDnvWmd+HBDO+dY+xDNdZp54f328a6+9F9/7fd9ea+1DVI961KMev0sEZ5/+rUOox3+BNwrw6a+6YI9ETrzjV13RgGX2A0G/3Oxp6S5E5xnHPpgxo/HPvIzfdsbeVj/vnA19fHxmL/rkhg/s5S9brSGFAH5059//XEeD5SW2h9U5nBYBT0V5Kny7HxkPGnyCuzxtZ2bmj5Ub7hG2vfZVl5avcCYmlIqISL/aXSvR/C+j/5FweWNCQkLCvDtNg4JTwdP7qqYGa0pv+smLXUPbDk/Gr2Vm5tPruPSb+y0D7jrDKMpHaI0b02eo85AXNxgEjLhk4z51DOjmN/cCWK5Ymn1iErDLa6dywAAArPSymH3lrH9tyy4qLx3uzPicExGRwgd+UthEtESqgf1RN6uDQgEepfs3zZWnfupa19C9kCthG/vYY3dbhGja+xQYux8HR1WTs1c5jfnDiv2d9NlftqkC9ordz/zCiNheVlOnty6BTx3AOW3GGJOAqQCqBMS9FrOvSIgeYDU241+u1MRLckpEXnfrF9RpbkFBIhERHReR82lpaWlxaRkiMkgbHpzNPMQ0V7pEmUxERH+LfI/5exfS5xgzM+/YWF6iOq1y3JmINF5dTR53FvBGeU9fccAF24Gg6yz0iHPgtP0Ar3G1tHyrCMCRoFuA3tbpvL82CTgZODv/1fnz00wC2kP0CIn6fOFDRDTqQtbtFlum7BeRP5pciWhAXJEw82EiImr/UPv2AQ6+xSmRFZrDPKy/0TjbMNONJKKwZ9faAKDikIvh0befZc5oRl3iNbcehczpUzn39iH2Gnc721nATTJLXbN3rs32pIVtFFqK9L4ezTeCp7naIgAgK5B0AW84wXhJzWIaBQb6ExF55QDrrNunr0ioGiERHeHeREQH7UOttse/ExG51eRKb+8SkZK4cdZn9AmR8p6Kx47Lx28xTkeBciVA4wM25+SUgHcDACwViFeDeH7CMGNIIbC++aDI1kR88VoieluZs4A75G7VO4E5zcpGAJu8iMYAp1u72j4Djq9uS/SIKiDNYmCSIVQiIgq/CMRYaV8Ro9t3GEBEIaXqxuS/T0Q+NHi2ipcLe4Z1bGu1NH7rskh3xWUIY6GbPCXQLhMUesBJAEDHVh37ngI2WQe8xul/UCcMTuL8vY87rhlJVfSL4iSgX54Eat6+bDtvTeHmM5Z7EdFhwLoX3DDnnjZERE/pAlItAo5MA2DNYci7SJYaXObZDrYmarYa2z2sxtGLWUSmGFxjeFlz1dBvpcjVCdbXkryjGTOIiJ6PjtYcA0UmK/SXAMom9SSiOOC4r3VAs3RWMy3PVBQ/2KryiWF8U8Unyos1h62SI96Kd7sMts22sLO5/NMmRNcNLuO52poOvGMQ0M5sFnD0wXIA/26imFJNAgaeKw8johXIsdo6HaoQMe2BTedmDx6iaEREd9lE5MpDygPR4ms730009YVsZruyFeoCPlAKZDs29lRATanal5x+d1IDCx0KVBdxTgL2v0Z7DU+9IqO1SSfa+POWrqT3OXxKRB12AcnNNC+i5yKjIndja0PNZnoD283asmULAygcr36hTAJ2O4alRDTtKpRP19ByR0WgVTS0gFfr8hFFO9x2z+rmankYfPIWCvkYKD3E31p3c13AL4Ct/YmIrn+iuPLKGmsx80zL/rmDa/Qm7Lyt6jJRwono9u7TYuJKLhasL4GW8j9abNviZ2HbAG3bzNxewrA9ogXStOd6Zjtzbns1ToOA3bKryogU1Y1S1fSuUQTbeWeUZ8CuipWa13OXRUx7oF0G60sR3fNZQWUtGN3GydBiMnJfoeAkzn+/exgfUgS0awIOy0h3VLGRwD5TPdvtS+Y4l5f64cuYUv0XI7bqMo5/zMzMZKko3rF09E0e+RXKhO2YOdFKe+eBAeScQZ7i5HHXGVzMXVsK5E1XO1gmAU8yM9uZmR/S3ChVihV2DMBZwE49GCIaOGrU2GJdwF2c8yfdiYja3jEwgUVEvnbKLQYCs8kvFcWxnl0PF8da/Qx7YCUeKceVZ4xW77Fgl0o2HGevvZSeC/nL6h17ZkpKSkrKOEfuOV6OK/O9abPZtPZMr/OctaSz/2Yo37TGg4FZ95LPXgAYoTXT7MxrteiD/n5n165duy4FdAFf0AQcYSvP6xuSBjBsZ/Q3nqjBHDlmeVV6NSafOVxiri+IaPROEZEZNamZANF2IIxCAS2LCRRx03hkYLy7Fa/y1fudiHCcrLr0nIfTpi5ksiy2kiEnbLZ15rX62Nn6qHksBDZ4U+s9XD73I+Dzft0tmTgD6GyetqVJwGFSZv1gpZ8YR0Sdt4EBYwHtKXLIZVsNyDg/hsiX+R5zJETUaLOIOBXXC/ljCvmBp1BwNk/RXAJFzEXGAjuzqeVMdNsrm5gznYuJcCyrvApJwkdG12StSC6w2bbpOTYRET3IsNSADReh5JnrqedOHOlLXn9eVYLqB6gK/4Ra51VhuEnAIXI52EI+H0hE1KcIwzt3VmoPB5aIuLYb8somEdF8/sLo5MDr4uK7EB9RyBmszCnYdaN6c9wJ2HgT41m92iPqGPsDM1dsdGaHc2XRP7WQzS0eXUC22Qwd20q7VcCnUTrSZ+DaS5jtqCpHbdhgaVpM1gT0GFRZOYwrNQlIhyRON7SM5aMK3SrF0dsIKLaWEZFlIpIlJ/UTtoDZlT3xhl+JVNxX0xQKhE4sBjjfEGWgiOlT3nQ8eJW3bvOfeoKZeZdrWhWOq8tDAsNTT/PJ1XqrhYiIku3WdlmindndyZ32BuahLOMIgJfUAqISR5ntLv/jfZsQSETkM6YIuKgf1FBMiSHpj0Se9tC/L4fDOlCPUd+JvGbxnJaUn1+woaMap/8+cdROfotFZJ+TrUepo+Gu5S9E5C6JabEGeE5///z6fc/MvGOoxRwO4IfDALa9YliQiIiSJcKVCsnl8mhTmURE9IwiYCYApE7r0MjdYp8A7CLgXuCNBQsWLNjDwFfDDH4xxfp/H3TCpvYTQreLZG8oEeHvDeWqAWtEQpoQNXm5RMRe6pKSDEpj4J0pxkSl8UGTgLcCWarB58NjzMzfPKr0L276FmAgf5nVVBPJ1rb7/TbWMtNqdLOzRcAWY5dG+tV2/D1QFdABPrvC+NDEyGMqfxTv6g7RTztquQu1BOSKv4pIRnp6hohIqV50u8EeSVX5TglQykaiXutymJkvvao/ZQFzwHjdTfeciIiS7XUXkI7CzUfZDYIOWATs/g4AICtzuaXzUY2z5fqJZRR0YYk8p09PEimq609Jbv6g6ki3Yoly3FkLEsT1Zx0OJAFqBbiImQ8unO9d54VqIsL6BvpvqU3ACKS5KQfqBs8J57Fugvsj9zX7fsEf0zhHMzJ+2tat8fEjtdOI2tDu24ka3WU94swVxG8Cr8+xtm6by+8Zi3Hif0w/Iq833FXk9XBCf+XkrR7/5/gP6yvFvKvDVPUAAAAASUVORK5CYII=",
"text/plain": [
"<PIL.Image.Image image mode=L size=448x28>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = [trainset[i][0] for i in range(16)]\n",
"x = torch.cat(x, dim=-1)\n",
"\n",
"to_pil_image(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evidence lower bound\n",
"\n",
"As usual with variational inference, we wish to find the parameters $\\phi$ for which a model $p_\\phi(x)$ is most similar to a target distribution $p(x)$, which leads to the objective\n",
"\n",
"$$ \\arg \\max_\\phi \\mathbb{E}_{p(x)} \\big[ \\log p_\\phi(x) \\big] $$\n",
"\n",
"However, variational autoencoders have latent random variables $z$ and model the joint distribution of $z$ and $x$ as a factorization\n",
"\n",
"$$ p_\\phi(x, z) = p_\\phi(x | z) \\, p_\\phi(z) $$\n",
"\n",
"where $p_\\phi(x | z)$ is the decoder (sometimes called likelihood) and $p_\\phi(z)$ the prior. In this case, maximizing the log-evidence $\\log p_\\phi(x)$ becomes an issue as the integral\n",
"\n",
"$$ p_\\phi(x) = \\int p_\\phi(z, x) \\, \\mathrm{d}z $$\n",
"\n",
"is often intractable, not to mention its gradients. To solve this issue, VAEs introduce an encoder $q_\\psi(z | x)$ (sometimes called proposal or guide) to define a lower bound for the evidence (ELBO) for which unbiased Monte Carlo estimates of the gradients are available.\n",
"\n",
"$$\n",
"\\begin{align}\n",
" \\log p_\\phi(x) & \\geq \\log p_\\phi(x) - \\mathrm{KL} \\big( q_\\psi(z | x) \\, || \\, p_\\phi(z | x) \\big) \\\\\n",
" & \\geq \\log p_\\phi(x) + \\mathbb{E}_{q_\\psi(z | x)} \\left[ \\log \\frac{p_\\phi(z | x)}{q_\\psi(z | x)} \\right] \\\\\n",
" & \\geq \\mathbb{E}_{q_\\psi(z | x)} \\left[ \\log \\frac{p_\\phi(z, x)}{q_\\psi(z | x)} \\right] = \\mathrm{ELBO}(x, \\phi, \\psi)\n",
"\\end{align}\n",
"$$\n",
"\n",
"Importantly, if $p_\\phi(x, z)$ and $q_\\psi(z | x)$ are expressive enough, the bound can become tight and maximizing the ELBO for $\\phi$ and $\\psi$ will lead to the same model as maximizing the log-evidence."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class ELBO(nn.Module):\n",
" def __init__(\n",
" self,\n",
" encoder: zuko.flows.LazyDistribution,\n",
" decoder: zuko.flows.LazyDistribution,\n",
" prior: zuko.flows.LazyDistribution,\n",
" ):\n",
" super().__init__()\n",
"\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.prior = prior\n",
"\n",
" def forward(self, x: Tensor) -> Tensor:\n",
" q = self.encoder(x)\n",
" z = q.rsample()\n",
"\n",
" return self.decoder(z).log_prob(x) + self.prior().log_prob(z) - q.log_prob(z)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model\n",
"\n",
"We choose a (diagonal) Gaussian model as encoder, a Bernoulli model as decoder, and a [Masked Autoregressive Flow](https://arxiv.org/abs/1705.07057) (MAF) as prior. We use 16 features for the latent space."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class GaussianModel(zuko.flows.LazyDistribution):\n",
" def __init__(self, features: int, context: int):\n",
" super().__init__()\n",
"\n",
" self.hyper = nn.Sequential(\n",
" nn.Linear(context, 1024),\n",
" nn.ReLU(),\n",
" nn.Linear(1024, 1024),\n",
" nn.ReLU(),\n",
" nn.Linear(1024, 2 * features),\n",
" )\n",
"\n",
" def forward(self, c: Tensor) -> Distribution:\n",
" phi = self.hyper(c)\n",
" mu, log_sigma = phi.chunk(2, dim=-1)\n",
"\n",
" return Independent(Normal(mu, log_sigma.exp()), 1)\n",
"\n",
"\n",
"class BernoulliModel(zuko.flows.LazyDistribution):\n",
" def __init__(self, features: int, context: int):\n",
" super().__init__()\n",
"\n",
" self.hyper = nn.Sequential(\n",
" nn.Linear(context, 1024),\n",
" nn.ReLU(),\n",
" nn.Linear(1024, 1024),\n",
" nn.ReLU(),\n",
" nn.Linear(1024, features),\n",
" )\n",
"\n",
" def forward(self, c: Tensor) -> Distribution:\n",
" phi = self.hyper(c)\n",
" rho = torch.sigmoid(phi)\n",
"\n",
" return Independent(Bernoulli(rho), 1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"encoder = GaussianModel(16, 784)\n",
"decoder = BernoulliModel(784, 16)\n",
"\n",
"prior = zuko.flows.MAF(\n",
" features=16,\n",
" transforms=3,\n",
" hidden_features=(256, 256),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that because the decoder is a Bernoulli model, the data $x$ should be binary."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"\n",
"As explained earlier, our objective is to maximize the ELBO for all $x$.\n",
"\n",
"$$\n",
"\\arg \\max_{\\phi, \\, \\psi} \\mathbb{E}_{p(x)} \\big[ \\text{ELBO}(x, \\phi, \\psi) \\big]\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 64/64 [07:09<00:00, 6.71s/it, loss=65.8]\n"
]
}
],
"source": [
"elbo = ELBO(encoder, decoder, prior).cuda()\n",
"optimizer = torch.optim.Adam(elbo.parameters(), lr=1e-3)\n",
"\n",
"for epoch in (bar := tqdm(range(64))):\n",
" losses = []\n",
"\n",
" for x, _ in trainloader:\n",
" x = x.round().flatten(-3).cuda()\n",
"\n",
" loss = -elbo(x).mean()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" losses.append(loss.detach())\n",
"\n",
" losses = torch.stack(losses)\n",
"\n",
" bar.set_postfix(loss=losses.mean().item())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After training, we can generate MNIST images by sampling latent variables from the prior and decode them."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcAAAAAcCAAAAADTxTBPAAAMh0lEQVR4nO1aa5AU1RU+t3tmdvbBLrsLuyywCqgIKOALFSoiovhCVFRMNETRUomJJmJhULR8kQTRFCsSBINiiWLElOITjLFU5CEvHwgIuMs+gWVhH7Mz0zM93fd8lR+zj+np26tsTKxU8f2aOafP49577j3nnm6iYziGYziGo4Y2YfZxP7UPnchY+9j5P7UP3cLmUt+MjP9AXohuCvaptuvHdVf4x8en5rbzfmofiIiEf97WSPTAZO2HChyABVjL9W7au+rg18pFOPWdcQNKc7sQPC7C4RtUXoqCRTV7l5ye2U2HuoDoNT7oNc7RbF7RRTTpGdk/vj8qiOFf2ABg3pjmTc75+e2+DAum0Is2tJg7W0272d8te70t2aSSzJxp2jYiG4q8BAfGuVXJFINeDMUi4YYVCq4miEjkDejZlUs+XdeVgZE7t85uHqyO7AXA/i5CpvSs/r6ubP54GFkPAAA3Od3pV/vxkKTrGZeuu8gVhoH54B7eWgvfaDXrz1RxpgFmsYLum7whwQDM3mqNQRtmoZKT/3HTkW8rQrF49blpHD3vxExNaMHryy7w9HTQ5DcfXvHKPFVO6P1WxJIc/pNPsdHuBlqyPJXSBQs2zO9yAbVcVzYPTJ3ZN++oM8SIKABmBvgeB+NN6+BZgohIzDF2FrgFRRTSQ6eoYQCA/LPCnVqAlZGr9Xs5ZgHYrVa6E9aNSoavrGHP6keffPOA5MRoJytjyhOZGulFv3mxl4ertHbv5cH8P8bWKw6Fnq+1Rg1DyoOXuE/Rt4Fa75yTNfm6g7VzVb5mnzzg3Fv/+t6uz9a+9Ys0XmkN2D78Xf23xcpDW+vdTzVtwyMAOGYmGLAcHAOHAkREVBKVf1NqBPKU7lcBAIeNmNWomLcYYHukFdH/KgOAkncnc3+10KOh0KeTThs/vyrCCPd18LIrV2YS+S985yuv4/6QtZGIbqj6XBFqq0y7acOcT6q29Hct1Sagynv9vuLbJmyqzk+jiuDxf9lwxGYpE3ak7sYlxzkVXNECgG3JMh5wafRvlgzZeoKLcY4JAC0VlWEGOHUQQ4FkjZzfwHF1zbkGryqoQgIoL9L14KTdLa+5llgDYHuP/fdpbrSjt8RVaomBUbvp+l4Dx855dFkE2OrgjYuv0Ym0iRv/7mGO8R0R0cyvFMfdhRasBUOyT5m4cUd6HMaBw0G3RBuAfb0eiG9wErXci3furD9cv2PxXVNGZ+lC9HWmitNNKW2zNV59cPWQjgnwEZEomRpCW5Y7mD5xN8UAcPTBCx/eLAGkjuIbyAARkbiWeabSTy2MRxVkC7B0IiIt6/lwy8R09kAA0vsS0gBEVfSDOKgW0Mtl/Mm8QF5B7+LhSwDDMcRKOUMQiaHhWrUw4xkioh4fTVMETRj4oCgj2GNZ7OPjnZwa4H3PGPQDFeSP8ggnue+C2vKb+ubndFIcJk8+IK2a1+ed0bOoo7AIjq3cvzJstS0erPoKlmn74fYoAHPdmTl9T34oBiD1jE20zZh/P6KFRAp/V0P2c1PfAba3jaS4iu0N6TPzWwAJzwX0A5imoJ8D2XHBcGq8w5LrMpI3S63URMKxlZp5PRHpE7lBZSwTuIGIKOepN1S5QMIe16PwV2+YiQ+cl5tWYKug3PQjMonLgSqi6TjsdFO7Z/9bA5QCRERUGAFMMxKaUZAy0a0xm9tWL3I7Ef2B4dz2AQvg2H3ZmvAVPhgFOKUMDyAZnDSNsUaQ5r58n2cnytzHSCYj1Pbzwq0ScmH6Aw0ADMXCJ/GqOgVmSUwlIiL/zC3bVl6Z4oswEe/Ie1mHEHek1zDuISIapdzV2UAJEVGgfN85Kl8Y4XGv1EWkXFHoCN91QGJwwduWDCmK4neBUqKA5HFOul4eXqoykkSf5D5jaZUP76T+c+PiqgYprU1t61IOOCJfNwDYl/sECa3o0iiAlOTZQ2I5EdF5NnCJIN1dMX9pt77vLk5PAvYQEZHvkRYAliu0YwDbUz0G4gPwcwV9B0xBRPq9FfGG3cbGlInrD6zvWE9xAM2OQDOssUREPg4pbEk+gYgofw1vUp2HYyTijQnI8O5BDnoWwP+oAwA0p1/VSxOQGpFYycYI581SfBWZ7a5N2lEJADK037K4NmW3+/x6TlZnRJqIO0b3GQB+KZNIaP4esxKATInePAvLBGmPGAy+WHk3mRS1Wxe4fMph8Mt5J90SZTAAVwrUbACY43HbOaLegIUmX0NEfertA2fkDN4VStnAW8GdCUqrR8whuMt6nogoKONupTuxligwfdsB2KVurjbeACC54f6Tnasr6oEd6wEZlY4ji4i0W2zIW3y+wDyLd92b6ZDTV9R/PsWrp5O1l60Fw3QSxdu4VeFMEmcDjlrsGgD2wgARCS27dBUDZqrJSjYaakxORNqKGRfylr26LrbvlHTygfakK5FmkIiIshkAPlLfcU8DMFxBf07O14iOqzGWZhDlVVen9AEaEemcln4x7HUIjrdrNCIq4RbX3PWRsW82m+AKC/NdBvVfxhgMq+kZVyPNxzCjQGJ88Uqbn3WOrV7WHT7SGjMZiQ/7pZ1OOWNqw197XO2Dpw5NTr221LZGq58hHwOTUwmNAJ7ViIhEoO9dRwC5PXUBi2IMcOPvNiPisV2Enj2j0W5Iu2FpuwFANu4GsMMt1MMGgDplm1AD8JjKUEW8D1Hhd4evJSLf0qqyFJaN2k737rRt511D1Bs6EfUz9rqmrk8cYK4feRnCLoOZrQDb0eiWQYo9Y8I20VRM5KuQZQ5O/vRpS2ptKQGsHpqXfizrD9Q1fDRN2bHr+HW9xWaJ4gkiokVA2BEUAP6VVCcGfp4AED3PsVLBR9bOOMuvVWCXh0Yiyr7A4i9c1ECenzJCSDvO2rAfABpHqlgWUKMKlp7GezqJu54bTiROj0ZXpEyCJtHc8TfjE46m5eoJxt0a9fpE1rmaGNrVS5ffpxNVw9XXK44C8Zcer2kerZruCNhuLCWiwHY5K40nSOs59OJGu663YhzaqC9j1jvpVOHLDbQ/XAbwHIVJomRuui9VawCQA4mIyDe2GQDsF5XdigRcNlOxjDlHRV/k1VGZFJXA4asUAxQAOsvBlAdyIvs0ypx4VkAffojl6w4hG9wesuLXlvwgXe9C27KiTbDP9vB/NGyXH9vY2nd2YDYsZVd2PrO9ewQJ/+aENUj1wAuJiMclX99jVP8sjTa1bum0TCIK5F4WBfCkh59UB+zqT0Rae0IrAXh6pk/oI2uTJezT6hTLuMhLJxGJutQiohPnAPDoXj0VY2k8rLhIBAAMa//jT/GmT3NkVHDE0kllIUYkrZtkAE8kHxU3mgi5Zi4nZO0qnskJj/fAeYBrl11nmQv8IqsZIeVl3dcCcKJ6hwUYqpw2AVB0jpMYsMY4lHZ9LImbdqNptt0jpnhJ5jLgbJ5mAOA9DU12suQwTlUL6syXqsilBRlCC46usmVcMYogAA+F1G+PBbNpjJtRAOBdnURGwbC5T1zRu3MFM9cb8bqEKaWM3Jwu9C4gt5yiB4IjN0vwKLdWTRAVyJg6On0G3G3e2ZGmoVrOQpZ3qAfgj7WXaaqCqwR8v1qOiHy99oQ/cZJyHwrJNn2JMzwFaQgQTksPzehEYpZXb0ST/I2COqFx35ZthgS4SnXam8B0T18uMwGscosFAHCLYQEcL3+6rHOP6jc3GtK2rK1DFaYsAJywLBvgMjefiIhy4qaSLjZ2dB1SsDwRXvXEulC82fN10CNfA7AfV7Vv9Chv9xIjX+6wt1vmpQn0H7940xfRLdd2/XZ1KfNTaWtU2B5IkMtKPDt7IoxIkYurfSMBwJb7xqqEqoAmb1+0kLpApcltjSOu2XlvRqpNkTPkmlH56vnMMNuHEe7jZbHAbFUOcCxYQc/+zmZmNm7wHgL5cnuoa/Mr48pzNYmsUeMq4y9065uQQ5CuqRZF+yXbxmsDuvyWYjHLWaUun3yDx1w9O+jh6iVeBUwbskIcH6JiHF+eqK788uvbvVsWKty2n5nN5SO8h1HQrHyfokXwuoJMY2rjbB26vVsT/XjTh95M/YzVlduf7I5eEbWrlBXT96M43Dzv+i5evbsxDIC60dvhjva//TIpY/tqVfd1JXtdcXWtmx/8aBcv7CL6fLPK6z7y7LR0hVLbeE/9OvT7MWDGiT/4cyciEv5VwLJu2vovQR9/uoIqpm4+moH9MOR2tfJZfR+8Nb9bsZtXH150dOdSCtyvgbtG1pnpn6YcQxJi7pnd2n9E2ssv9fr//Ab1GJJwHur/BorGqSV9w5JjAAAAAElFTkSuQmCC",
"text/plain": [
"<PIL.Image.Image image mode=L size=448x28>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z = prior().sample((16,))\n",
"x = decoder(z).mean.reshape(-1, 28, 28)\n",
"\n",
"to_pil_image(x.movedim(0, 1).reshape(28, -1))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "zuko",
"language": "python",
"name": "zuko"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 8812e04

Please sign in to comment.