From 6e14beed38dd11dd68301c69e98ecd581d1eb4fb Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Fri, 8 May 2020 11:11:42 -0700 Subject: [PATCH] Add Colab test notebooks for CPU, GPU, and TPU (#3000) --- tests/notebooks/colab_cpu.ipynb | 250 ++++++++++++++++++++++++++++++++ tests/notebooks/colab_gpu.ipynb | 243 +++++++++++++++++++++++++++++++ tests/notebooks/colab_tpu.ipynb | 243 +++++++++++++++++++++++++++++++ 3 files changed, 736 insertions(+) create mode 100644 tests/notebooks/colab_cpu.ipynb create mode 100644 tests/notebooks/colab_gpu.ipynb create mode 100644 tests/notebooks/colab_tpu.ipynb diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb new file mode 100644 index 000000000000..e32d7fd381fb --- /dev/null +++ b/tests/notebooks/colab_cpu.ipynb @@ -0,0 +1,250 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "JAX Colab CPU Test", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WkadOyTDCAWD", + "colab_type": "text" + }, + "source": [ + "# JAX Colab CPU Test\n", + "\n", + "This notebook is meant to be run in a [Colab](http://colab.research.google.com) CPU runtime as a basic check for JAX updates." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_tKNrbqqBHwu", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "071fb360-ddf5-41ae-d772-acc08ec71d9b" + }, + "source": [ + "import jax\n", + "import jaxlib\n", + "\n", + "!cat /var/colab/hostname\n", + "print(jax.__version__)\n", + "print(jaxlib.__version__)" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "text": [ + "m-s-1p12yf76kgzz\n", + "0.1.64\n", + "0.1.45\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oqEG21rADO1F", + "colab_type": "text" + }, + "source": [ + "## Confirm Device" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "8BwzMYhKGQj6", + "outputId": "f79a44e3-4303-494c-9288-a4e582bb34cb", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + } + }, + "source": [ + "from jaxlib import xla_extension\n", + "import jax\n", + "key = jax.random.PRNGKey(1701)\n", + "arr = jax.random.normal(key, (1000,))\n", + "device = arr.device_buffer.device()\n", + "print(f\"JAX device type: {device}\")\n", + "assert isinstance(device, xla_extension.CpuDevice), \"unexpected JAX device type\"" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.\n", + " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "JAX device type: cpu:0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z0FUY9yUC4k1", + "colab_type": "text" + }, + "source": [ + "## Matrix Multiplication" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "eXn8GUl6CG5N", + "outputId": "307aa669-76f1-4117-b62a-7acb2aee2c16", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "import jax\n", + "import numpy as np\n", + "\n", + "# matrix multiplication on GPU\n", + "key = jax.random.PRNGKey(0)\n", + "x = jax.random.normal(key, (3000, 3000))\n", + "result = jax.numpy.dot(x, x.T).mean()\n", + "print(result)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "text": [ + "1.0216691\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0zTA2Q19DW4G", + "colab_type": "text" + }, + "source": [ + "## Linear Algebra" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "uW9j84_UDYof", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "outputId": "3dd5d7c0-9d47-4be1-c6f7-068b432b69f7" + }, + "source": [ + "import jax.numpy as jnp\n", + "import jax.random as rand\n", + "\n", + "N = 10\n", + "M = 20\n", + "key = rand.PRNGKey(1701)\n", + "\n", + "X = rand.normal(key, (N, M))\n", + "u, s, vt = jnp.linalg.svd(X)\n", + "assert u.shape == (N, N)\n", + "assert vt.shape == (M, M)\n", + "print(s)" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[6.9178133 5.9580317 5.581113 4.506963 4.111582 3.973543 3.3307292\n", + " 2.8664916 1.8229378 1.5478933]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jCyKUn4-DCXn", + "colab_type": "text" + }, + "source": [ + "## XLA Compilation" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "2GOn_HhDPuEn", + "outputId": "41a40dd9-3680-458d-cedd-81ebcc2ab06f", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "@jax.jit\n", + "def selu(x, alpha=1.67, lmbda=1.05):\n", + " return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n", + "x = jax.random.normal(key, (5000,))\n", + "result = selu(x).block_until_ready()\n", + "print(result)" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[ 0.34676832 -0.7532232 1.7060695 ... 2.1208048 -0.42621925\n", + " 0.13093236]\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/notebooks/colab_gpu.ipynb b/tests/notebooks/colab_gpu.ipynb new file mode 100644 index 000000000000..5cce333da07e --- /dev/null +++ b/tests/notebooks/colab_gpu.ipynb @@ -0,0 +1,243 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "JAX Colab GPU Test", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WkadOyTDCAWD", + "colab_type": "text" + }, + "source": [ + "# JAX Colab GPU Test\n", + "\n", + "This notebook is meant to be run in a [Colab](http://colab.research.google.com) GPU runtime as a basic check for JAX updates." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_tKNrbqqBHwu", + "colab_type": "code", + "outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + } + }, + "source": [ + "import jax\n", + "import jaxlib\n", + "\n", + "!cat /var/colab/hostname\n", + "print(jax.__version__)\n", + "print(jaxlib.__version__)" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "gpu-t4-s-kbefivsjoreh\n", + "0.1.64\n", + "0.1.45\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oqEG21rADO1F", + "colab_type": "text" + }, + "source": [ + "## Confirm Device" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "8BwzMYhKGQj6", + "outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "from jaxlib import xla_extension\n", + "import jax\n", + "key = jax.random.PRNGKey(1701)\n", + "arr = jax.random.normal(key, (1000,))\n", + "device = arr.device_buffer.device()\n", + "print(f\"JAX device type: {device}\")\n", + "assert isinstance(device, xla_extension.GpuDevice), \"unexpected JAX device type\"" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "text": [ + "JAX device type: gpu:0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z0FUY9yUC4k1", + "colab_type": "text" + }, + "source": [ + "## Matrix Multiplication" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "eXn8GUl6CG5N", + "outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "import jax\n", + "import numpy as np\n", + "\n", + "# matrix multiplication on GPU\n", + "key = jax.random.PRNGKey(0)\n", + "x = jax.random.normal(key, (3000, 3000))\n", + "result = jax.numpy.dot(x, x.T).mean()\n", + "print(result)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "text": [ + "1.0216676\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0zTA2Q19DW4G", + "colab_type": "text" + }, + "source": [ + "## Linear Algebra" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "uW9j84_UDYof", + "colab_type": "code", + "outputId": "80069760-12ab-4df2-9f5c-be2536de59b7", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "import jax.numpy as jnp\n", + "import jax.random as rand\n", + "\n", + "N = 10\n", + "M = 20\n", + "key = rand.PRNGKey(1701)\n", + "\n", + "X = rand.normal(key, (N, M))\n", + "u, s, vt = jnp.linalg.svd(X)\n", + "assert u.shape == (N, N)\n", + "assert vt.shape == (M, M)\n", + "print(s)" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n", + " 2.866489 1.8229384 1.5478926]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jCyKUn4-DCXn", + "colab_type": "text" + }, + "source": [ + "## XLA Compilation" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "2GOn_HhDPuEn", + "outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "@jax.jit\n", + "def selu(x, alpha=1.67, lmbda=1.05):\n", + " return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n", + "x = jax.random.normal(key, (5000,))\n", + "result = selu(x).block_until_ready()\n", + "print(result)" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n", + " 0.13093245]\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/notebooks/colab_tpu.ipynb b/tests/notebooks/colab_tpu.ipynb new file mode 100644 index 000000000000..f22b8d6766d0 --- /dev/null +++ b/tests/notebooks/colab_tpu.ipynb @@ -0,0 +1,243 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "JAX Colab TPU Test", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "TPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WkadOyTDCAWD", + "colab_type": "text" + }, + "source": [ + "# JAX Colab TPU Test\n", + "\n", + "This notebook is meant to be run in a [Colab](http://colab.research.google.com) TPU runtime as a basic check for JAX updates." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_tKNrbqqBHwu", + "colab_type": "code", + "outputId": "bf0043b0-6f2b-44e4-9822-4f426b3d158e", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + } + }, + "source": [ + "import jax\n", + "import jaxlib\n", + "\n", + "!cat /var/colab/hostname\n", + "print(jax.__version__)\n", + "print(jaxlib.__version__)" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "tpu-s-2dna7uebo6z96\n", + "0.1.64\n", + "0.1.45\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DzVStuLobcoG", + "colab_type": "text" + }, + "source": [ + "## TPU Setup" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "IXF0_gNCRH08", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "e66e0f60-6e57-44ed-ba6a-2d3fa906b101" + }, + "source": [ + "import requests\n", + "import os\n", + "if 'TPU_DRIVER_MODE' not in globals():\n", + " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20200416'\n", + " resp = requests.post(url)\n", + " assert resp.status_code == 200\n", + " TPU_DRIVER_MODE = 1\n", + "\n", + "# The following is required to use TPU Driver as JAX's backend.\n", + "from jax.config import config\n", + "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", + "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", + "print(config.FLAGS.jax_backend_target)" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "text": [ + "grpc://10.69.129.170:8470\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oqEG21rADO1F", + "colab_type": "text" + }, + "source": [ + "## Confirm Device" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "8BwzMYhKGQj6", + "outputId": "d51b7f21-d300-4420-8c5c-483bace8617d", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "from jaxlib import tpu_client_extension\n", + "import jax\n", + "key = jax.random.PRNGKey(1701)\n", + "arr = jax.random.normal(key, (1000,))\n", + "device = arr.device_buffer.device()\n", + "print(f\"JAX device type: {device}\")\n", + "assert isinstance(device, tpu_client_extension.TpuDevice), \"unexpected JAX device type\"" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "text": [ + "JAX device type: TPU_0(host=0,(0,0,0,0))\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z0FUY9yUC4k1", + "colab_type": "text" + }, + "source": [ + "## Matrix Multiplication" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "eXn8GUl6CG5N", + "outputId": "9954a064-ef8b-4db3-aad7-85d07b50f678", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "import jax\n", + "import numpy as np\n", + "\n", + "# matrix multiplication on GPU\n", + "key = jax.random.PRNGKey(0)\n", + "x = jax.random.normal(key, (3000, 3000))\n", + "result = jax.numpy.dot(x, x.T).mean()\n", + "print(result)" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "text": [ + "1.021576\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jCyKUn4-DCXn", + "colab_type": "text" + }, + "source": [ + "## XLA Compilation" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "2GOn_HhDPuEn", + "outputId": "a4384c55-41fb-44be-845d-17b86b152068", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "@jax.jit\n", + "def selu(x, alpha=1.67, lmbda=1.05):\n", + " return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n", + "x = jax.random.normal(key, (5000,))\n", + "result = selu(x).block_until_ready()\n", + "print(result)" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[ 0.34676817 -0.7532211 1.7060809 ... 2.120809 -0.42622015\n", + " 0.13093244]\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file