From d84d07caad8e778dc8e034911f1a47acdf51f4e6 Mon Sep 17 00:00:00 2001 From: Richard Stanton Date: Thu, 14 Dec 2023 18:07:16 +0000 Subject: [PATCH] classifier uncertainty notebook --- .../exploration.ipynb | 2853 +++++++++++++++++ 1 file changed, 2853 insertions(+) create mode 100644 neural_networks/unfinished-dirichlet_distributions/exploration.ipynb diff --git a/neural_networks/unfinished-dirichlet_distributions/exploration.ipynb b/neural_networks/unfinished-dirichlet_distributions/exploration.ipynb new file mode 100644 index 0000000..445181b --- /dev/null +++ b/neural_networks/unfinished-dirichlet_distributions/exploration.ipynb @@ -0,0 +1,2853 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Modelling uncertainty in classification neural networks\n", + "\n", + "Classification neural networks are typically over confident. Particularly when the class predictionsare close to 0 or 1 [ref].\n", + "Measuring confidence is represented by measuring the prediction uncertainty.\n", + "We can build neural networks to model uncertainty as well.\n", + "\n", + "Ref:\n", + "* https://arxiv.org/pdf/2107.03342.pdf\n", + "* https://arxiv.org/pdf/2011.06225.pdf\n", + "* https://en.wikipedia.org/wiki/Dirichlet_distribution\n", + "* Pytorch distributions - https://pytorch.org/docs/stable/distributions.html\n", + "* https://www.youtube.com/watch?v=p1EnIbDItTc" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import scipy.stats\n", + "\n", + "plt.style.use(\"seaborn-v0_8-whitegrid\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Non-contextual version\n", + "We can fit a Beta distribution to random binary data with pytorch.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate data\n", + "\n", + "We need data for a classification problem.\n", + "\n", + "We will first make a non-contextual version.\n", + "We generate data from Bernouli trials with a known success probability.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1, 1, 1, 1, 0, 0, 1, 0, 1, 0])" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average: 0.6999\n" + ] + } + ], + "source": [ + "# true distribution parameters\n", + "p_success = 0.7\n", + "n_samples = 10000\n", + "\n", + "rand_gen = np.random.default_rng(seed=0)\n", + "\n", + "rand_samples = rand_gen.binomial(n=1, p=p_success, size=(n_samples, 1))\n", + "\n", + "display(rand_samples[:10].flatten())\n", + "print(f\"Average: {rand_samples.mean()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fit pytorch model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The uninformed prior distribution for a Beta is commonly parameterised as [0.5, 0.5].\n", + "Within a pytorch model we need to ensure the parameters remain in the correct domain.\n", + "Beta parameters must be positive, therefore we use the soft plus function to convert them.\n", + "\n", + "\n", + "> Notes\n", + "* How do we fit the loglikelihood of the beta? Common for the PDF to be 0 (logloss of -inf) at 0 and 1.\n", + "* Truncate outcomes to 0.01 and 0.99 for a hack?\n", + "* Conceptually it makes sense?\n", + " * data generating only 1s. beta concentrating at 1 will give higher prob of success than one less certain\n", + " * find log loss of mean outcome? should work. we find the average success and find the log loss of that given our beta" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/_v/nlh4h1yx2n1gd6f3szjlgxt40000gr/T/ipykernel_42343/90226784.py:9: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch\n", + "import matplotlib.pyplot as plt\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "x = torch.linspace(-10, 10, 100)\n", + "y = torch.nn.functional.softplus(x)\n", + "ax.plot(x, y)\n", + "ax.set(xlabel=\"x\", ylabel=\"y\", title=\"\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To get the equivalent uninformed prior distributino parameters we can find the inverse soft plus, [ref](https://github.com/pytorch/pytorch/issues/72759)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-0.4328)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def inv_softplus(x):\n", + " return x + torch.log(-torch.expm1(-x))\n", + "\n", + "\n", + "torch.nn.functional.softplus(inv_softplus(torch.tensor(0.5)))\n", + "inv_softplus(torch.tensor(0.5))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now construct our Beta distribution in pytorch." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torch.distributions\n", + "from typing import List\n", + "\n", + "\n", + "class BetaModel(pl.LightningModule):\n", + " def __init__(self, learning_rate=1e-3, prior_parameters: List[float] = None):\n", + " super().__init__()\n", + "\n", + " if prior_parameters is None:\n", + " prior_parameters = [0.5, 0.5]\n", + " else:\n", + " assert len(prior_parameters) == 2\n", + " assert prior_parameters[0] > 0\n", + " assert prior_parameters[1] > 0\n", + "\n", + " prior_parameters_invsp = inv_softplus(torch.tensor(prior_parameters))\n", + " self.beta_params_invsp = torch.nn.Parameter(prior_parameters_invsp)\n", + "\n", + " self.train_log_error = []\n", + " self.val_log_error = []\n", + " self.beta_params_log = []\n", + " self.learning_rate = learning_rate\n", + "\n", + " def forward(self):\n", + " beta_params_pos = torch.nn.functional.softplus(self.beta_params_invsp)\n", + " dist = torch.distributions.beta.Beta(beta_params_pos[0], beta_params_pos[1])\n", + " return dist\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " y = batch[0]\n", + "\n", + " # find loglikelihood of y given our current distribution estimate\n", + " dist = self.forward()\n", + " negloglik = -dist.log_prob(y.mean())\n", + " loss = torch.mean(negloglik)\n", + "\n", + " # log training results\n", + " self.train_log_error.append(loss.detach().numpy())\n", + " self.beta_params_log.append(\n", + " torch.nn.functional.softplus(self.beta_params_invsp).detach().numpy()\n", + " )\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " y = batch[0]\n", + "\n", + " # find loglikelihood of y given our current distribution estimate\n", + " dist = self.forward()\n", + " negloglik = -dist.log_prob(y)\n", + " loss = torch.mean(negloglik)\n", + "\n", + " # log training results\n", + " self.val_log_error.append(loss.detach().numpy())\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(\n", + " self.parameters(),\n", + " lr=self.learning_rate,\n", + " )\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets create the model:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(1)\n", + "model = BetaModel(learning_rate=1e1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can generate distribution objects from the model by calling the forward pass method.\n", + "We see the model does not yet know anything about the true success probability." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/_v/nlh4h1yx2n1gd6f3szjlgxt40000gr/T/ipykernel_42343/4198434410.py:10: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x = np.linspace(0, 1, 1000)\n", + "output_dist = model()\n", + "y = scipy.stats.beta(output_dist.concentration0.detach().numpy(), output_dist.concentration1.detach().numpy()).pdf(x)\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "ax.plot(x, y, label=\"Prior distribution\")\n", + "ax.axvline(x=p_success, linestyle=\"--\", color=\"b\", alpha=0.5, label=\"True P(success)\")\n", + "ax.legend()\n", + "ax.set(xlabel=\"P(success)\", ylabel=\"PDF\", title=\"Prior distribution\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fitting the model distribution\n", + "To train the model with our random sample data created above, we need to setup a dataloader to pass to the trainer." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([10000, 1])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# training on the whole dataset each batch\n", + "from torch.utils.data import TensorDataset, DataLoader\n", + "\n", + "rand_samples_t = torch.Tensor(rand_samples)\n", + "dataset_train = TensorDataset(rand_samples_t)\n", + "dataloader_train = DataLoader(dataset_train, batch_size=len(rand_samples))\n", + "\n", + "# test loading a batch\n", + "rand_samples_batch = next(iter(dataloader_train))\n", + "rand_samples_batch[0].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can train the model via PyTorch Lightning's Trainer object." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/Users/stantoon/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n", + "/Users/stantoon/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "/Users/stantoon/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------\n", + " | other params | n/a | 2 \n", + "--------------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n", + "/Users/stantoon/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n", + "/Users/stantoon/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49: 100%|██████████| 1/1 [00:00<00:00, 57.75it/s, v_num=10]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=50` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 49: 100%|██████████| 1/1 [00:00<00:00, 51.30it/s, v_num=10]\n" + ] + } + ], + "source": [ + "# fit network\n", + "trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\")\n", + "trainer.fit(model, dataloader_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see the training loss has converged ok." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/_v/nlh4h1yx2n1gd6f3szjlgxt40000gr/T/ipykernel_42343/2207435402.py:10: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def moving_average(a, n=10):\n", + " ret = np.cumsum(a, dtype=float)\n", + " ret[n:] = ret[n:] - ret[:-n]\n", + " return ret[n - 1 :] / n\n", + "\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "ax.plot(moving_average(np.array(model.train_log_error)))\n", + "ax.set(ylabel=\"Loss\",xlabel=\"Training iteration\", title='Training loss')\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The beta parameters are moving towards appropriate values." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/_v/nlh4h1yx2n1gd6f3szjlgxt40000gr/T/ipykernel_42343/4205145043.py:8: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "beta_params_log = np.array(model.beta_params_log)\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "ax.plot(beta_params_log[:, 0], label=\"param0\")\n", + "ax.plot(beta_params_log[:, 1], label=\"param1\")\n", + "ax.legend()\n", + "ax.set(xlabel=\"Iteration\", ylabel=\"Coefficient value\", title='Distribution parameters through training')\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plotting distribution against interation" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/_v/nlh4h1yx2n1gd6f3szjlgxt40000gr/T/ipykernel_42343/1794095813.py:13: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "beta_params_log = np.array(model.beta_params_log)\n", + "x = np.linspace(0, 1, 1000)\n", + "iterations = range(0, beta_params_log.shape[0], 10)\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "for idx in iterations:\n", + " y = scipy.stats.beta(*beta_params_log[idx, :]).pdf(x)\n", + " ax.plot(x, y, alpha=0.5, label=idx)\n", + "\n", + "ax.axvline(x=p_success, linestyle=\"--\", color=\"b\", alpha=0.5, label=\"True P(success)\")\n", + "ax.legend()\n", + "ax.set(xlabel=\"P(success)\", ylabel=\"PDF\", title=\"Distributions through training\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comparison with maximum likelihood estimation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The true beta parameters for this data can be calculated as the sum of the successes and the sum of the failures." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MLE parameters\n", + "6999\n", + "3001\n" + ] + } + ], + "source": [ + "print(\"MLE parameters\")\n", + "print((rand_samples == 1).sum())\n", + "print((rand_samples == 0).sum())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The MLE distribution is a very tight fit around the true success probability." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/_v/nlh4h1yx2n1gd6f3szjlgxt40000gr/T/ipykernel_42343/2078173807.py:9: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x = np.linspace(0, 1, 1000)\n", + "y = scipy.stats.beta((rand_samples == 1).sum(), (rand_samples == 0).sum()).pdf(x)\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "ax.plot(x, y, label=\"MLE distribution\")\n", + "ax.axvline(x=p_success, linestyle=\"--\", color=\"b\", alpha=0.5, label=\"True P(success)\")\n", + "ax.set(xlabel=\"P(success)\", ylabel=\"PDF\", title=\"MLE distribution\")\n", + "ax.legend()\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is shown also in the log likelihood of observing the data given the distribution.\n", + "The MLE LL is expectedly better (higher):" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MLE LL: 4.47\n", + "NN LL: 2.28\n" + ] + } + ], + "source": [ + "print(\n", + " f\"MLE LL: {scipy.stats.beta((rand_samples==1).sum(), (rand_samples==0).sum()).logpdf(rand_samples.mean()):.2f}\"\n", + ")\n", + "print(\n", + " f\"NN LL: {scipy.stats.beta(*beta_params_log[-1,:]).logpdf(rand_samples.mean()):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, when we want to make the distribution a function of various features we cannot use the MLE approach in the same way.\n", + "Gradient descent however can still be used." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Appendix\n", + "\n", + "TODO:\n", + "\n", + "* Quantile regression from a multi head NN\n", + "* " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Exploring distributions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[7.9465e-01],\n", + " [9.8314e-02],\n", + " [5.3686e-01],\n", + " [1.2176e-01],\n", + " [4.7895e-01],\n", + " [1.8919e-02],\n", + " [7.2673e-01],\n", + " [6.2900e-01],\n", + " [5.8163e-01],\n", + " [7.8585e-01],\n", + " [1.1850e-01],\n", + " [1.4547e-01],\n", + " [2.6722e-01],\n", + " [9.9544e-01],\n", + " [9.9937e-01],\n", + " [9.6382e-01],\n", + " [9.8782e-01],\n", + " [9.0896e-01],\n", + " [3.3455e-02],\n", + " [9.9977e-01],\n", + " [1.7100e-02],\n", + " [7.8972e-01],\n", + " [5.8074e-02],\n", + " [9.1087e-01],\n", + " [4.7398e-01],\n", + " [7.3772e-01],\n", + " [9.8253e-01],\n", + " [6.9342e-01],\n", + " [9.6179e-01],\n", + " [5.5020e-01],\n", + " [6.0733e-01],\n", + " [3.9311e-01],\n", + " [6.2002e-04],\n", + " [3.0071e-01],\n", + " [5.8508e-01],\n", + " [2.9016e-01],\n", + " [3.4136e-01],\n", + " [5.7220e-01],\n", + " [4.2474e-01],\n", + " [9.9973e-01],\n", + " [9.4924e-01],\n", + " [5.2341e-01],\n", + " [4.3132e-01],\n", + " [9.4580e-01],\n", + " [9.0959e-01],\n", + " [9.9789e-01],\n", + " [9.9586e-01],\n", + " [1.1715e-01],\n", + " [5.0077e-01],\n", + " [1.5136e-01],\n", + " [8.9403e-01],\n", + " [2.2254e-01],\n", + " [5.7264e-01],\n", + " [6.9465e-01],\n", + " [9.4916e-01],\n", + " [9.9788e-01],\n", + " [5.1308e-01],\n", + " [9.2236e-01],\n", + " [9.8785e-01],\n", + " [6.4955e-01],\n", + " [8.3602e-02],\n", + " [3.9582e-01],\n", + " [3.2508e-01],\n", + " [1.7559e-01],\n", + " [9.3873e-01],\n", + " [8.8563e-01],\n", + " [9.6166e-01],\n", + " [6.7665e-02],\n", + " [7.5508e-01],\n", + " [4.6497e-01],\n", + " [9.4712e-01],\n", + " [6.6969e-02],\n", + " [2.7491e-01],\n", + " [5.3678e-01],\n", + " [6.1143e-02],\n", + " [2.5732e-01],\n", + " [8.5025e-01],\n", + " [1.2072e-01],\n", + " [5.2869e-01],\n", + " [2.0114e-01],\n", + " [4.7868e-01],\n", + " [1.2813e-01],\n", + " [9.9808e-01],\n", + " [4.4948e-01],\n", + " [2.7319e-01],\n", + " [9.9187e-01],\n", + " [9.0806e-01],\n", + " [6.3824e-01],\n", + " [7.9136e-01],\n", + " [9.0642e-01],\n", + " [1.7599e-02],\n", + " [1.1454e-01],\n", + " [9.8142e-01],\n", + " [2.5868e-01],\n", + " [9.0499e-01],\n", + " [9.5528e-01],\n", + " [2.2938e-02],\n", + " [4.8639e-01],\n", + " [9.7827e-01],\n", + " [8.2148e-01]])\n", + "tensor([0.5669, 0.3317, 0.1014])\n", + "tensor([0.4172, 0.2689, 0.3139])\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.distributions\n", + "\n", + "dist = torch.distributions.beta.Beta(torch.tensor([0.5]), torch.tensor([0.5]))\n", + "print(dist.sample(sample_shape=torch.Size([100])))\n", + "\n", + "dist = torch.distributions.dirichlet.Dirichlet(torch.tensor([0.5, 0.5, 0.5]))\n", + "print(dist.sample())\n", + "\n", + "dist = torch.distributions.dirichlet.Dirichlet(torch.tensor([15.0, 15.0, 15.0]))\n", + "print(dist.sample())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate random data\n", + "\n", + "I'll make random data from two beta distributions and then select from them randomly to get two peaks.\n", + "I'll deliberately make the distributions fairly distinct to make the initial analysis easier to visualise." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "ename": "OSError", + "evalue": "'seaborn-whitegrid' is not a valid package style, path of style file, URL of style file, or library style name (library styles are listed in `style.available`)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/matplotlib/style/core.py:137\u001b[0m, in \u001b[0;36muse\u001b[0;34m(style)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 137\u001b[0m style \u001b[39m=\u001b[39m _rc_params_in_file(style)\n\u001b[1;32m 138\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mOSError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n", + "File \u001b[0;32m~/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/matplotlib/__init__.py:866\u001b[0m, in \u001b[0;36m_rc_params_in_file\u001b[0;34m(fname, transform, fail_on_error)\u001b[0m\n\u001b[1;32m 865\u001b[0m rc_temp \u001b[39m=\u001b[39m {}\n\u001b[0;32m--> 866\u001b[0m \u001b[39mwith\u001b[39;49;00m _open_file_or_url(fname) \u001b[39mas\u001b[39;49;00m fd:\n\u001b[1;32m 867\u001b[0m \u001b[39mtry\u001b[39;49;00m:\n", + "File \u001b[0;32m~/miniconda3/envs/pytorch_env/lib/python3.12/contextlib.py:137\u001b[0m, in \u001b[0;36m_GeneratorContextManager.__enter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 137\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgen)\n\u001b[1;32m 138\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n", + "File \u001b[0;32m~/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/matplotlib/__init__.py:843\u001b[0m, in \u001b[0;36m_open_file_or_url\u001b[0;34m(fname)\u001b[0m\n\u001b[1;32m 842\u001b[0m fname \u001b[39m=\u001b[39m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mexpanduser(fname)\n\u001b[0;32m--> 843\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39;49m(fname, encoding\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mutf-8\u001b[39;49m\u001b[39m'\u001b[39;49m) \u001b[39mas\u001b[39;00m f:\n\u001b[1;32m 844\u001b[0m \u001b[39myield\u001b[39;00m f\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'seaborn-whitegrid'", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/stantoon/Documents/VariousProjects/github/data-analysis/neural_networks/unfinished-dirichlet_distributions/exploration.ipynb Cell 38\u001b[0m line \u001b[0;36m4\n\u001b[1;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mnumpy\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mnp\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mmatplotlib\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mpyplot\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mplt\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m plt\u001b[39m.\u001b[39;49mstyle\u001b[39m.\u001b[39;49muse(\u001b[39m\"\u001b[39;49m\u001b[39mseaborn-whitegrid\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m 6\u001b[0m \u001b[39m# true distribution parameters\u001b[39;00m\n\u001b[1;32m 7\u001b[0m p_d1 \u001b[39m=\u001b[39m \u001b[39m0.5\u001b[39m\n", + "File \u001b[0;32m~/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/matplotlib/style/core.py:139\u001b[0m, in \u001b[0;36muse\u001b[0;34m(style)\u001b[0m\n\u001b[1;32m 137\u001b[0m style \u001b[39m=\u001b[39m _rc_params_in_file(style)\n\u001b[1;32m 138\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mOSError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n\u001b[0;32m--> 139\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mOSError\u001b[39;00m(\n\u001b[1;32m 140\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mstyle\u001b[39m!r}\u001b[39;00m\u001b[39m is not a valid package style, path of style \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 141\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mfile, URL of style file, or library style name (library \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 142\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mstyles are listed in `style.available`)\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mfrom\u001b[39;00m \u001b[39merr\u001b[39;00m\n\u001b[1;32m 143\u001b[0m filtered \u001b[39m=\u001b[39m {}\n\u001b[1;32m 144\u001b[0m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m style: \u001b[39m# don't trigger RcParams.__getitem__('backend')\u001b[39;00m\n", + "\u001b[0;31mOSError\u001b[0m: 'seaborn-whitegrid' is not a valid package style, path of style file, URL of style file, or library style name (library styles are listed in `style.available`)" + ] + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.style.use(\"seaborn-whitegrid\")\n", + "\n", + "# true distribution parameters\n", + "p_d1 = 0.5\n", + "p_lambda = 5\n", + "\n", + "n = 10000\n", + "rand_gen = np.random.default_rng(seed=0)\n", + "\n", + "poisson_samples = rand_gen.poisson(lam=p_lambda, size=(n, 1))\n", + "mix_samples = rand_gen.binomial(n=1, p=p_d1, size=(n, 1))\n", + "rand_samples = mix_samples * poisson_samples\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 6))\n", + "ax.hist(rand_samples, bins=50)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a PyTorch model\n", + "\n", + "Assuming we know the underlying generating model (...strong assumption...?), we can construct a network that builds the equivalent distribution objects in PyTorch.\n", + "\n", + "We are fitting the distribution parameters.\n", + "As such we have no input features for the forward pass, only the output values.\n", + "We use negative log likelihood as the loss function to optimise." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "\n", + "\n", + "class ZIPModel(pl.LightningModule):\n", + " def __init__(\n", + " self,\n", + " learning_rate=1e-3,\n", + " init_mix_parameter: float = 0.5,\n", + " init_poisson_lambda: float = 1.0,\n", + " ):\n", + " super().__init__()\n", + " self.mixture_prob = torch.nn.Parameter(torch.tensor([init_mix_parameter]))\n", + " self.poisson_lambda = torch.nn.Parameter(torch.tensor([init_poisson_lambda]))\n", + "\n", + " self.train_log_error = []\n", + " self.val_log_error = []\n", + " self.mixture_prob_log = []\n", + " self.poisson_lambda_log = []\n", + " self.learning_rate = learning_rate\n", + "\n", + " def forward(self):\n", + " # ensure correct domain for params\n", + " mixture_prob_norm = torch.sigmoid(self.mixture_prob)\n", + " poisson_lambda_norm = torch.nn.functional.softplus(self.poisson_lambda)\n", + " poisson_lambda_norm = torch.concat(\n", + " (torch.tensor([0.0]), poisson_lambda_norm)\n", + " ) # maintains grad\n", + "\n", + " mix = torch.distributions.Categorical(\n", + " torch.concat((mixture_prob_norm, 1 - mixture_prob_norm))\n", + " )\n", + " poissons = torch.distributions.Poisson(poisson_lambda_norm)\n", + "\n", + " mixture_dist = torch.distributions.MixtureSameFamily(mix, poissons)\n", + " return mixture_dist\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(\n", + " self.parameters(),\n", + " lr=self.learning_rate,\n", + " )\n", + " return optimizer\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " y = batch[0]\n", + "\n", + " mixture_dist = self.forward()\n", + "\n", + " negloglik = -mixture_dist.log_prob(y)\n", + " loss = torch.mean(negloglik)\n", + "\n", + " self.train_log_error.append(loss.detach().numpy())\n", + " self.poisson_lambda_log.append(\n", + " torch.nn.functional.softplus(self.poisson_lambda).detach().numpy()\n", + " )\n", + " self.mixture_prob_log.append(torch.sigmoid(self.mixture_prob).detach().numpy())\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " y = batch[0]\n", + "\n", + " mixture_dist = self.forward()\n", + "\n", + " negloglik = -mixture_dist.log_prob(y)\n", + " loss = torch.mean(negloglik)\n", + "\n", + " self.train_log_error.append(loss.detach().numpy())\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets create the model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create model\n", + "torch.manual_seed(1)\n", + "model = ZIPModel(learning_rate=1e-0, init_mix_parameter=0.5, init_poisson_lambda=4.5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can generate distribution objects from the model by calling the forward pass method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# get some samples\n", + "output_dist = model()\n", + "output_samples = output_dist.sample((n, 1)).numpy().squeeze()\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 6))\n", + "ax.hist(output_samples, bins=50)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fitting the model distribution\n", + "To train the model with our random sample data created above, we need to setup a dataloader to pass to the trainer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([10000, 1])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# training on the whole dataset each batch\n", + "from torch.utils.data import TensorDataset, DataLoader\n", + "\n", + "rand_samples_t = torch.Tensor(rand_samples)\n", + "dataset_train = TensorDataset(rand_samples_t)\n", + "dataloader_train = DataLoader(dataset_train, batch_size=len(rand_samples))\n", + "\n", + "# test loading a batch\n", + "rand_samples_batch = next(iter(dataloader_train))\n", + "rand_samples_batch[0].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can train the model via PyTorch Lightning's Trainer object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/Users/rich/Developer/miniconda3/envs/pytorch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1789: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.\n", + " rank_zero_warn(\n", + "/Users/rich/Developer/miniconda3/envs/pytorch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:107: PossibleUserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + " rank_zero_warn(\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n", + "/Users/rich/Developer/miniconda3/envs/pytorch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n", + "/Users/rich/Developer/miniconda3/envs/pytorch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", + " rank_zero_warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "02d579ea152f40efa80d4e61b7c925ef", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + } + ], + "source": [ + "# fit network\n", + "trainer = pl.Trainer(\n", + " max_epochs=100,\n", + ")\n", + "trainer.fit(model, dataloader_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see the training loss has converged ok." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/ky/4qby95090jbbq38_mh94x72r0000gn/T/ipykernel_82119/2590399248.py:11: UserWarning: Matplotlib is currently using module://matplotlib_inline.backend_inline, which is a non-GUI backend, so cannot show the figure.\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def moving_average(a, n=10):\n", + " ret = np.cumsum(a, dtype=float)\n", + " ret[n:] = ret[n:] - ret[:-n]\n", + " return ret[n - 1 :] / n\n", + "\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 6))\n", + "ax.plot(moving_average(np.array(model.train_log_error)))\n", + "ax.set_ylabel(\"Loss\")\n", + "ax.set_xlabel(\"Training iteration\")\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/ky/4qby95090jbbq38_mh94x72r0000gn/T/ipykernel_82119/2571154232.py:9: UserWarning: Matplotlib is currently using module://matplotlib_inline.backend_inline, which is a non-GUI backend, so cannot show the figure.\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(10, 6))\n", + "ax.plot(model.poisson_lambda_log, color=\"b\", label=\"lambda\")\n", + "ax.plot(model.mixture_prob_log, color=\"orange\", label=\"p\")\n", + "ax.axhline(y=p_lambda, linestyle=\"--\", color=\"b\", alpha=0.5)\n", + "ax.axhline(y=p_d1, linestyle=\"--\", color=\"orange\", alpha=0.5)\n", + "ax.legend()\n", + "ax.set_xlabel(\"Iteration\")\n", + "ax.set_ylabel(\"Coefficient value\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Checking results\n", + "Now we can check the resulting distribution that comes out of our model and compare that directly to the random samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# plot pdf\n", + "output_dist = model()\n", + "x = torch.arange(0, 20)\n", + "y = torch.exp(output_dist.log_prob(x))\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 6))\n", + "ax.hist(\n", + " rand_samples,\n", + " bins=rand_samples.max() - rand_samples.min(),\n", + " density=True,\n", + " label=\"raw data\",\n", + ")\n", + "ax.plot(x.detach().numpy(), y.detach().numpy(), label=\"model pdf\")\n", + "fig.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the trained distribution parameters now are close to the underlying parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final mix: 0.5104645\n", + "Final lambda: 5.013678\n", + "True mix: 0.5\n", + "True lambda: 5\n" + ] + } + ], + "source": [ + "print(\"Final mix:\", model.mixture_prob_log[-1][0])\n", + "print(\"Final lambda:\", model.poisson_lambda_log[-1][0])\n", + "\n", + "print(\"True mix:\", p_d1)\n", + "print(\"True lambda:\", p_lambda)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Random seeding\n", + "We can start the training from different random initialisations to see how smooth the objective function surface is." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "605a1889cc094909a5f7e54c8e482ae1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7d3c515e6aae4ae18a6e9070581b2de6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "84aae68403154d368f0642658647bbd8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "14b3b745289d41abab2188420f79accd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d6c47f4ffd714f1aac9f216cab5c0b99", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d77e5dfd2c5e441588e97d77aa8b0e98", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9f988731de19470daf91d8290febfeb3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "de8f04d053d043109eff2f5520321ab6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6292a6facca84d3ba87b3f1cf08ff16d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9ab3e9840e744a7585c16126e561638d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "seed_model_params = []\n", + "for _idx in range(10):\n", + " print(_idx)\n", + " init_params = [random.random(), random.random() * 10]\n", + " model = ZIPModel(\n", + " learning_rate=1e-0,\n", + " init_mix_parameter=init_params[0],\n", + " init_poisson_lambda=init_params[1],\n", + " )\n", + "\n", + " # fit network\n", + " trainer = pl.Trainer(\n", + " max_epochs=100,\n", + " )\n", + " trainer.fit(model, dataloader_train)\n", + "\n", + " seed_model_params.append(\n", + " {\n", + " \"err\": model.train_log_error[-1],\n", + " \"mixture\": model.mixture_prob_log[-1],\n", + " \"lambda\": model.poisson_lambda_log[-1],\n", + " \"mixture_init\": init_params[0],\n", + " \"lambda_init\": init_params[1],\n", + " }\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
errmixturelambdamixture_initlambda_init
01.7549713[0.5084596][4.990373]0.2061261.271737
11.7549436[0.5095606][5.0154147]0.9573755.609975
21.75495[0.50945556][5.0249567]0.1133689.019589
31.7549486[0.51113284][5.019183]0.0243602.504726
41.7549453[0.50945634][5.0075474]0.4229566.300903
51.7549646[0.5087664][4.992894]0.3170831.335294
61.7549437[0.50981426][5.016553]0.5690973.339710
71.7549523[0.51095116][5.002224]0.1460558.153884
81.7549479[0.50936157][5.022941]0.5381953.024650
91.7549556[0.50825804][5.0001493]0.0753639.901132
\n", + "
" + ], + "text/plain": [ + " err mixture lambda mixture_init lambda_init\n", + "0 1.7549713 [0.5084596] [4.990373] 0.206126 1.271737\n", + "1 1.7549436 [0.5095606] [5.0154147] 0.957375 5.609975\n", + "2 1.75495 [0.50945556] [5.0249567] 0.113368 9.019589\n", + "3 1.7549486 [0.51113284] [5.019183] 0.024360 2.504726\n", + "4 1.7549453 [0.50945634] [5.0075474] 0.422956 6.300903\n", + "5 1.7549646 [0.5087664] [4.992894] 0.317083 1.335294\n", + "6 1.7549437 [0.50981426] [5.016553] 0.569097 3.339710\n", + "7 1.7549523 [0.51095116] [5.002224] 0.146055 8.153884\n", + "8 1.7549479 [0.50936157] [5.022941] 0.538195 3.024650\n", + "9 1.7549556 [0.50825804] [5.0001493] 0.075363 9.901132" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "pd.DataFrame(seed_model_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compared with the true parameters, it looks like we're converging well in each case.\n", + "\n", + "So seems successful overall! This approach makes it quite easy to train fairly complex distributions without having to understand the particular methods for fitting that distribution type (if they even exist)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a PyTorch distribution object\n", + "\n", + "The solution is a bit hacky - we need to define the mixture of two of the same types, so we used a Poisson distribution with a $\\lambda$ of 0.\n", + "Can we reformulate this is a pytorch distribution class instead?\n", + "\n", + "The poisson class in PyTorch is based around the following (simplified) form:\n", + "```python\n", + "class Poisson(ExponentialFamily):\n", + " def __init__(self, rate, validate_args=None):\n", + " self.rate = rate\n", + " super(Poisson, self).__init__(batch_shape, validate_args=validate_args)\n", + "\n", + " def sample(self, sample_shape=torch.Size()):\n", + " with torch.no_grad():\n", + " return torch.poisson(self.rate.expand(shape))\n", + "\n", + " def log_prob(self, value):\n", + " return value.xlogy(self.rate) - self.rate - (value + 1).lgamma()\n", + "\n", + " ...\n", + "```\n", + "\n", + "Can we build the equivalent zero-inflated Poisson distribution?\n", + "\n", + "We need to specify a `sample` and `log-prob` method. For both of those we need to find an expression for the product mass function (PMF).\n", + "\n", + "### PMF\n", + "Poisson PMF:\n", + "$$\\mathrm{P_{s}}(Y=k | \\lambda) = \\frac{\\lambda^{k} e^{-\\lambda }} {k!}$$\n", + "\n", + "Bernoulli PMF:\n", + "$$\n", + "\\begin{align*}\n", + "\\mathrm{P_{d}}(Y=0) &= \\pi\\\\\n", + "\\mathrm{P_{d}}(Y!=0) &= 1-\\pi\n", + "\\end{align*}\n", + "$$\n", + "\n", + "Mixture distribution:\n", + "$$\\mathrm{P}(Y) = \\mathrm{P_{d}}(Y) \\mathrm{P_{s}}(Y)$$\n", + "\n", + "This is the expression we implement to create the `sample` method.\n", + "\n", + "\n", + "We need to dig further to get the full expression.\n", + "The zero case:\n", + "$$\n", + "\\begin{align*}\n", + "\\mathrm{P}(Y=0) &= \\mathrm{P_{d}}(Y=0) + \\mathrm{P_{d}}(Y=1) \\mathrm{P_{s}}(Y=0 | \\lambda)\\\\\n", + " &= \\pi + (1-\\pi) \\frac{\\lambda^{0} e^{-\\lambda }} {0!}\\\\\n", + " &= \\pi + (1-\\pi) e^{-\\lambda }\n", + "\\end{align*}\n", + "$$\n", + "\n", + "The non-zero case:\n", + "$$\\mathrm{P}(Y=k) = (1-\\pi) \\frac{\\lambda^{k} e^{-\\lambda }} {k!}$$\n", + "\n", + "### Log probs\n", + "Log of the above:\n", + "$$\\log{\\mathrm{P}(Y=0)} = \\log(\\pi + (1-\\pi) e^{-\\lambda })$$\n", + "\n", + "Non zero case:\n", + "$$\n", + "\\begin{align*}\n", + "\\log{\\mathrm{P}(Y=k)} &= \\log((1-\\pi) \\frac{\\lambda^{k} e^{-\\lambda }} {k!})\\\\\n", + " &= \\log(1-\\pi) + \\log(\\frac{\\lambda^{k} e^{-\\lambda }} {k!})\\\\\n", + " &= \\log(1-\\pi) + \\log(\\lambda^{k} e^{-\\lambda }) -\\log(k!)\\\\\n", + " &= \\log(1-\\pi) + k\\log(\\lambda) -\\lambda -\\log(\\Gamma(k+1))\n", + "\\end{align*}\n", + "$$\n", + "\n", + "\n", + "We can create those expression inside a new class:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from numbers import Number\n", + "\n", + "import torch\n", + "from torch.distributions import constraints\n", + "from torch.distributions.exp_family import ExponentialFamily\n", + "from torch.distributions.utils import broadcast_all\n", + "\n", + "\n", + "class ZeroInflatedPoisson(ExponentialFamily):\n", + " arg_constraints = {\"p\": constraints.unit_interval, \"rate\": constraints.nonnegative}\n", + " support = constraints.nonnegative_integer\n", + "\n", + " def __init__(self, p, rate, validate_args=None):\n", + " self.p, self.rate = broadcast_all(p, rate)\n", + " if isinstance(rate, Number):\n", + " batch_shape = torch.Size()\n", + " else:\n", + " batch_shape = self.rate.size()\n", + " super(ZeroInflatedPoisson, self).__init__(\n", + " batch_shape, validate_args=validate_args\n", + " )\n", + "\n", + " def sample(self, sample_shape=torch.Size()):\n", + " shape = self._extended_shape(sample_shape)\n", + " with torch.no_grad():\n", + " zero = torch.bernoulli(self.p.expand(shape))\n", + " poisson = torch.poisson(self.rate.expand(shape))\n", + " return zero * poisson\n", + "\n", + " def log_prob(self, value):\n", + " if self._validate_args:\n", + " self._validate_sample(value)\n", + " rate, p, value = broadcast_all(self.rate, self.p, value)\n", + " poisson_log_prob = value.xlogy(rate) - rate - (value + 1).lgamma()\n", + " poisson_log_prob[torch.where(value == 0.0)] = torch.log(\n", + " p + (1 - p) * torch.exp(-rate)\n", + " )[torch.where(value == 0.0)]\n", + " poisson_log_prob[value > 0.0] = (\n", + " torch.log(1 - p)[value > 0.0] + poisson_log_prob[value > 0.0]\n", + " )\n", + " return poisson_log_prob" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can sample from the distribution object and it looks good:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "dist = ZeroInflatedPoisson(p=0.5, rate=10.0)\n", + "samples = dist.sample(sample_shape=(10000,))\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 6))\n", + "ax.hist(\n", + " samples, bins=int(samples.max() - samples.min()), density=True, label=\"raw data\"\n", + ")\n", + "\n", + "# plot pdf\n", + "x = torch.arange(0, 20)\n", + "y = torch.exp(dist.log_prob(x))\n", + "\n", + "ax.plot(x.detach().numpy(), y.detach().numpy(), label=\"model pdf\")\n", + "fig.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Can we fit the same model with the output as a `ZeroInflatedPoisson`?\n", + "\n", + "We can build the model class with the new distribution object as the output:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class DistZIPModel(ZIPModel):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + "\n", + " def forward(self):\n", + " # ensure correct domain for params\n", + " mixture_prob_norm = torch.sigmoid(self.mixture_prob)\n", + " poisson_lambda_norm = torch.nn.functional.softplus(self.poisson_lambda)\n", + "\n", + " zip_dist = ZeroInflatedPoisson(p=mixture_prob_norm, rate=poisson_lambda_norm)\n", + " return zip_dist" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similarly to before, we can then fit the model multiple times from random starting seeds." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/Users/rich/Developer/miniconda3/envs/pytorch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1789: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.\n", + " rank_zero_warn(\n", + "/Users/rich/Developer/miniconda3/envs/pytorch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:107: PossibleUserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + " rank_zero_warn(\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n", + "/Users/rich/Developer/miniconda3/envs/pytorch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n", + "/Users/rich/Developer/miniconda3/envs/pytorch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", + " rank_zero_warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2f75aeb222b94cbdab78dee8d6789180", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d17ead5b225e4a0c88d6f7b0e845159c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d798af0ee7f547faac29030d3397a256", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6882e4d50091428b9b7e677c5cbf8a33", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ad50ee6e37054b11b0c2a1d56410f6e9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b6910ac0d21f47bbb7cb5f9e47e7b07c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "01e7d1a92d2a4abc9297c24caec6e4d9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f6feedbe320f45e1b21e565dec839ef7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6600e6dc4f5e414686f0c67c78086900", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + " | Name | Type | Params\n", + "------------------------------\n", + "------------------------------\n", + "2 Trainable params\n", + "0 Non-trainable params\n", + "2 Total params\n", + "0.000 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a565ffe920614df0b79e28adc5eb56bf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "seed_model_params = []\n", + "for _idx in range(10):\n", + " print(_idx)\n", + " init_params = [random.random(), random.random() * 10]\n", + " model = DistZIPModel(\n", + " learning_rate=1e-0,\n", + " init_mix_parameter=init_params[0],\n", + " init_poisson_lambda=init_params[1],\n", + " )\n", + "\n", + " # fit network\n", + " trainer = pl.Trainer(\n", + " max_epochs=100,\n", + " )\n", + " trainer.fit(model, dataloader_train)\n", + "\n", + " seed_model_params.append(\n", + " {\n", + " \"err\": model.train_log_error[-1],\n", + " \"mixture\": model.mixture_prob_log[-1][0],\n", + " \"lambda\": model.poisson_lambda_log[-1][0],\n", + " \"mixture_init\": init_params[0],\n", + " \"lambda_init\": init_params[1],\n", + " }\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We get similar results to last time, so the training is working:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
errmixturelambdamixture_initlambda_init
01.75495620.5091874.9969970.9036352.128501
11.754950.5102705.0023540.4655128.292895
21.75494430.5090345.0136550.5271305.674198
31.75494950.5107385.0042870.9605277.272479
41.75494850.5103415.0229510.5437477.638310
51.75496140.5092035.0323930.3782670.567890
61.75494520.5098375.0078500.2416341.943764
71.7549770.5081375.0377240.5975360.653239
81.75494650.5105585.0194300.8072176.436518
91.7549440.5092885.0110040.8718268.402232
\n", + "
" + ], + "text/plain": [ + " err mixture lambda mixture_init lambda_init\n", + "0 1.7549562 0.509187 4.996997 0.903635 2.128501\n", + "1 1.75495 0.510270 5.002354 0.465512 8.292895\n", + "2 1.7549443 0.509034 5.013655 0.527130 5.674198\n", + "3 1.7549495 0.510738 5.004287 0.960527 7.272479\n", + "4 1.7549485 0.510341 5.022951 0.543747 7.638310\n", + "5 1.7549614 0.509203 5.032393 0.378267 0.567890\n", + "6 1.7549452 0.509837 5.007850 0.241634 1.943764\n", + "7 1.754977 0.508137 5.037724 0.597536 0.653239\n", + "8 1.7549465 0.510558 5.019430 0.807217 6.436518\n", + "9 1.754944 0.509288 5.011004 0.871826 8.402232" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "pd.DataFrame(seed_model_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So in summary, we can create a PyTorch distribution class by constructing the `log_prob` of the distribution product mass function. We can then fit that distribution to data via gradient descent without needing to know any particular fitting proceedure for that distribution." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.6 64-bit ('pytorch_env')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "155ead7e3e7cc87f1c8ac5b93741665d5692c3c80ccfbc75b55ba980f115c04c" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}