From f01c9a4fb5c1cdb0a9fe6eaf4211cd1f44cfa8a1 Mon Sep 17 00:00:00 2001 From: Martin Jankowiak Date: Sun, 6 Aug 2023 11:54:05 -0700 Subject: [PATCH] fix intro_long.ipynb --- tutorial/source/intro_long.ipynb | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tutorial/source/intro_long.ipynb b/tutorial/source/intro_long.ipynb index dc2a986dd1..4e9095b6ae 100644 --- a/tutorial/source/intro_long.ipynb +++ b/tutorial/source/intro_long.ipynb @@ -990,8 +990,7 @@ " a_loc = pyro.param('a_loc', lambda: torch.tensor(0.))\n", " a_scale = pyro.param('a_scale', lambda: torch.tensor(1.),\n", " constraint=constraints.positive)\n", - " sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(1.),\n", - " constraint=constraints.positive)\n", + " sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(0.))\n", " weights_loc = pyro.param('weights_loc', lambda: torch.randn(3))\n", " weights_scale = pyro.param('weights_scale', lambda: torch.ones(3),\n", " constraint=constraints.positive)\n", @@ -999,7 +998,7 @@ " b_a = pyro.sample(\"bA\", dist.Normal(weights_loc[0], weights_scale[0]))\n", " b_r = pyro.sample(\"bR\", dist.Normal(weights_loc[1], weights_scale[1]))\n", " b_ar = pyro.sample(\"bAR\", dist.Normal(weights_loc[2], weights_scale[2]))\n", - " sigma = pyro.sample(\"sigma\", dist.Normal(sigma_loc, torch.tensor(0.05)))\n", + " sigma = pyro.sample(\"sigma\", dist.LogNormal(sigma_loc, torch.tensor(0.05))) # fixed scale for simplicity\n", " return {\"a\": a, \"b_a\": b_a, \"b_r\": b_r, \"b_ar\": b_ar, \"sigma\": sigma}" ] }, @@ -1963,9 +1962,9 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python [conda env:root] *", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "conda-root-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -1977,7 +1976,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.9.13" } }, "nbformat": 4,