From 63600f52b23d2dadf59add3ebbba5c6b61a907d2 Mon Sep 17 00:00:00 2001 From: andyElking Date: Mon, 29 Jul 2024 18:22:26 +0100 Subject: [PATCH] Added an example of how to compute the number of gradient evaluations used by NUTS. --- notebooks/source/bayesian_regression.ipynb | 470 +++++++++++++++------ 1 file changed, 337 insertions(+), 133 deletions(-) diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index 1c7a8ae40b..c0da8b60d9 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -35,22 +35,38 @@ }, { "cell_type": "code", - "execution_count": 1, "metadata": { - "id": "FlhcyvtqN7l1" + "id": "FlhcyvtqN7l1", + "ExecuteTime": { + "end_time": "2024-07-29T17:12:15.637112Z", + "start_time": "2024-07-29T17:12:05.726504Z" + } }, - "outputs": [], "source": [ "!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m24.0\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.2\u001B[0m\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n" + ] + } + ], + "execution_count": 1 }, { "cell_type": "code", - "execution_count": 2, "metadata": { - "id": "B_9Gru7DN7l3" + "id": "B_9Gru7DN7l3", + "ExecuteTime": { + "end_time": "2024-07-29T17:16:21.100577Z", + "start_time": "2024-07-29T17:16:20.871186Z" + } }, - "outputs": [], "source": [ "import os\n", "\n", @@ -74,7 +90,9 @@ " set_matplotlib_formats(\"svg\")\n", "\n", "assert numpyro.__version__.startswith(\"0.15.1\")" - ] + ], + "outputs": [], + "execution_count": 1 }, { "cell_type": "markdown", @@ -89,18 +107,183 @@ }, { "cell_type": "code", - "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "KsCe9ruUN7l4", - "outputId": "f26f9596-9f21-46d6-ec58-33dfc92b58a0" + "outputId": "f26f9596-9f21-46d6-ec58-33dfc92b58a0", + "ExecuteTime": { + "end_time": "2024-07-29T17:16:22.125032Z", + "start_time": "2024-07-29T17:16:22.058625Z" + } }, + "source": [ + "DATASET_URL = \"https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv\"\n", + "dset = pd.read_csv(DATASET_URL, sep=\";\")\n", + "dset" + ], "outputs": [ { "data": { + "text/plain": [ + " Location Loc Population MedianAgeMarriage Marriage \\\n", + "0 Alabama AL 4.78 25.3 20.2 \n", + "1 Alaska AK 0.71 25.2 26.0 \n", + "2 Arizona AZ 6.33 25.8 20.3 \n", + "3 Arkansas AR 2.92 24.3 26.4 \n", + "4 California CA 37.25 26.8 19.1 \n", + "5 Colorado CO 5.03 25.7 23.5 \n", + "6 Connecticut CT 3.57 27.6 17.1 \n", + "7 Delaware DE 0.90 26.6 23.1 \n", + "8 District of Columbia DC 0.60 29.7 17.7 \n", + "9 Florida FL 18.80 26.4 17.0 \n", + "10 Georgia GA 9.69 25.9 22.1 \n", + "11 Hawaii HI 1.36 26.9 24.9 \n", + "12 Idaho ID 1.57 23.2 25.8 \n", + "13 Illinois IL 12.83 27.0 17.9 \n", + "14 Indiana IN 6.48 25.7 19.8 \n", + "15 Iowa IA 3.05 25.4 21.5 \n", + "16 Kansas KS 2.85 25.0 22.1 \n", + "17 Kentucky KY 4.34 24.8 22.2 \n", + "18 Louisiana LA 4.53 25.9 20.6 \n", + "19 Maine ME 1.33 26.4 13.5 \n", + "20 Maryland MD 5.77 27.3 18.3 \n", + "21 Massachusetts MA 6.55 28.5 15.8 \n", + "22 Michigan MI 9.88 26.4 16.5 \n", + "23 Minnesota MN 5.30 26.3 15.3 \n", + "24 Mississippi MS 2.97 25.8 19.3 \n", + "25 Missouri MO 5.99 25.6 18.6 \n", + "26 Montana MT 0.99 25.7 18.5 \n", + "27 Nebraska NE 1.83 25.4 19.6 \n", + "28 New Hampshire NH 1.32 26.8 16.7 \n", + "29 New Jersey NJ 8.79 27.7 14.8 \n", + "30 New Mexico NM 2.06 25.8 20.4 \n", + "31 New York NY 19.38 28.4 16.8 \n", + "32 North Carolina NC 9.54 25.7 20.4 \n", + "33 North Dakota ND 0.67 25.3 26.7 \n", + "34 Ohio OH 11.54 26.3 16.9 \n", + "35 Oklahoma OK 3.75 24.4 23.8 \n", + "36 Oregon OR 3.83 26.0 18.9 \n", + "37 Pennsylvania PA 12.70 27.1 15.5 \n", + "38 Rhode Island RI 1.05 28.2 15.0 \n", + "39 South Carolina SC 4.63 26.4 18.1 \n", + "40 South Dakota SD 0.81 25.6 20.1 \n", + "41 Tennessee TN 6.35 25.2 19.4 \n", + "42 Texas TX 25.15 25.2 21.5 \n", + "43 Utah UT 2.76 23.3 29.6 \n", + "44 Vermont VT 0.63 26.9 16.4 \n", + "45 Virginia VA 8.00 26.4 20.5 \n", + "46 Washington WA 6.72 25.9 21.4 \n", + "47 West Virginia WV 1.85 25.0 22.2 \n", + "48 Wisconsin WI 5.69 26.3 17.2 \n", + "49 Wyoming WY 0.56 24.2 30.7 \n", + "\n", + " Marriage SE Divorce Divorce SE WaffleHouses South Slaves1860 \\\n", + "0 1.27 12.7 0.79 128 1 435080 \n", + "1 2.93 12.5 2.05 0 0 0 \n", + "2 0.98 10.8 0.74 18 0 0 \n", + "3 1.70 13.5 1.22 41 1 111115 \n", + "4 0.39 8.0 0.24 0 0 0 \n", + "5 1.24 11.6 0.94 11 0 0 \n", + "6 1.06 6.7 0.77 0 0 0 \n", + "7 2.89 8.9 1.39 3 0 1798 \n", + "8 2.53 6.3 1.89 0 0 0 \n", + "9 0.58 8.5 0.32 133 1 61745 \n", + "10 0.81 11.5 0.58 381 1 462198 \n", + "11 2.54 8.3 1.27 0 0 0 \n", + "12 1.84 7.7 1.05 0 0 0 \n", + "13 0.58 8.0 0.45 2 0 0 \n", + "14 0.81 11.0 0.63 17 0 0 \n", + "15 1.46 10.2 0.91 0 0 0 \n", + "16 1.48 10.6 1.09 6 0 2 \n", + "17 1.11 12.6 0.75 64 1 225483 \n", + "18 1.19 11.0 0.89 66 1 331726 \n", + "19 1.40 13.0 1.48 0 0 0 \n", + "20 1.02 8.8 0.69 11 0 87189 \n", + "21 0.70 7.8 0.52 0 0 0 \n", + "22 0.69 9.2 0.53 0 0 0 \n", + "23 0.77 7.4 0.60 0 0 0 \n", + "24 1.54 11.1 1.01 72 1 436631 \n", + "25 0.81 9.5 0.67 39 1 114931 \n", + "26 2.31 9.1 1.71 0 0 0 \n", + "27 1.44 8.8 0.94 0 0 15 \n", + "28 1.76 10.1 1.61 0 0 0 \n", + "29 0.59 6.1 0.46 0 0 18 \n", + "30 1.90 10.2 1.11 2 0 0 \n", + "31 0.47 6.6 0.31 0 0 0 \n", + "32 0.98 9.9 0.48 142 1 331059 \n", + "33 2.93 8.0 1.44 0 0 0 \n", + "34 0.61 9.5 0.45 64 0 0 \n", + "35 1.29 12.8 1.01 16 0 0 \n", + "36 1.10 10.4 0.80 0 0 0 \n", + "37 0.48 7.7 0.43 11 0 0 \n", + "38 2.11 9.4 1.79 0 0 0 \n", + "39 1.18 8.1 0.70 144 1 402406 \n", + "40 2.64 10.9 2.50 0 0 0 \n", + "41 0.85 11.4 0.75 103 1 275719 \n", + "42 0.61 10.0 0.35 99 1 182566 \n", + "43 1.77 10.2 0.93 0 0 0 \n", + "44 2.40 9.6 1.87 0 0 0 \n", + "45 0.83 8.9 0.52 40 1 490865 \n", + "46 1.00 10.0 0.65 0 0 0 \n", + "47 1.69 10.9 1.34 4 1 18371 \n", + "48 0.79 8.3 0.57 0 0 0 \n", + "49 3.92 10.3 1.90 0 0 0 \n", + "\n", + " Population1860 PropSlaves1860 \n", + "0 964201 0.450000 \n", + "1 0 0.000000 \n", + "2 0 0.000000 \n", + "3 435450 0.260000 \n", + "4 379994 0.000000 \n", + "5 34277 0.000000 \n", + "6 460147 0.000000 \n", + "7 112216 0.016000 \n", + "8 75080 0.000000 \n", + "9 140424 0.440000 \n", + "10 1057286 0.440000 \n", + "11 0 0.000000 \n", + "12 0 0.000000 \n", + "13 1711951 0.000000 \n", + "14 1350428 0.000000 \n", + "15 674913 0.000000 \n", + "16 107206 0.000019 \n", + "17 1155684 0.000000 \n", + "18 708002 0.470000 \n", + "19 628279 0.000000 \n", + "20 687049 0.130000 \n", + "21 1231066 0.000000 \n", + "22 749113 0.000000 \n", + "23 172023 0.000000 \n", + "24 791305 0.550000 \n", + "25 1182012 0.097000 \n", + "26 0 0.000000 \n", + "27 28841 0.000520 \n", + "28 326073 0.000000 \n", + "29 672035 0.000027 \n", + "30 93516 0.000000 \n", + "31 3880735 0.000000 \n", + "32 992622 0.330000 \n", + "33 0 0.000000 \n", + "34 2339511 0.000000 \n", + "35 0 0.000000 \n", + "36 52465 0.000000 \n", + "37 2906215 0.000000 \n", + "38 174620 0.000000 \n", + "39 703708 0.570000 \n", + "40 4837 0.000000 \n", + "41 1109801 0.200000 \n", + "42 604215 0.300000 \n", + "43 40273 0.000000 \n", + "44 315098 0.000000 \n", + "45 1219630 0.400000 \n", + "46 11594 0.000000 \n", + "47 376688 0.049000 \n", + "48 775881 0.000000 \n", + "49 0 0.000000 " + ], "text/html": [ "
\n", "