Skip to content

Commit

Permalink
Add Colab test notebooks for CPU, GPU, and TPU (jax-ml#3000)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored and Jamie Townsend committed May 14, 2020
1 parent 2edc5fb commit 6e14bee
Show file tree
Hide file tree
Showing 3 changed files with 736 additions and 0 deletions.
250 changes: 250 additions & 0 deletions tests/notebooks/colab_cpu.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX Colab CPU Test",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/google/jax/blob/master/tests/notebooks/colab_cpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WkadOyTDCAWD",
"colab_type": "text"
},
"source": [
"# JAX Colab CPU Test\n",
"\n",
"This notebook is meant to be run in a [Colab](http://colab.research.google.com) CPU runtime as a basic check for JAX updates."
]
},
{
"cell_type": "code",
"metadata": {
"id": "_tKNrbqqBHwu",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"outputId": "071fb360-ddf5-41ae-d772-acc08ec71d9b"
},
"source": [
"import jax\n",
"import jaxlib\n",
"\n",
"!cat /var/colab/hostname\n",
"print(jax.__version__)\n",
"print(jaxlib.__version__)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"m-s-1p12yf76kgzz\n",
"0.1.64\n",
"0.1.45\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oqEG21rADO1F",
"colab_type": "text"
},
"source": [
"## Confirm Device"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "8BwzMYhKGQj6",
"outputId": "f79a44e3-4303-494c-9288-a4e582bb34cb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"from jaxlib import xla_extension\n",
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = arr.device_buffer.device()\n",
"print(f\"JAX device type: {device}\")\n",
"assert isinstance(device, xla_extension.CpuDevice), \"unexpected JAX device type\""
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"JAX device type: cpu:0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0FUY9yUC4k1",
"colab_type": "text"
},
"source": [
"## Matrix Multiplication"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "eXn8GUl6CG5N",
"outputId": "307aa669-76f1-4117-b62a-7acb2aee2c16",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"import jax\n",
"import numpy as np\n",
"\n",
"# matrix multiplication on GPU\n",
"key = jax.random.PRNGKey(0)\n",
"x = jax.random.normal(key, (3000, 3000))\n",
"result = jax.numpy.dot(x, x.T).mean()\n",
"print(result)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"1.0216691\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0zTA2Q19DW4G",
"colab_type": "text"
},
"source": [
"## Linear Algebra"
]
},
{
"cell_type": "code",
"metadata": {
"id": "uW9j84_UDYof",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "3dd5d7c0-9d47-4be1-c6f7-068b432b69f7"
},
"source": [
"import jax.numpy as jnp\n",
"import jax.random as rand\n",
"\n",
"N = 10\n",
"M = 20\n",
"key = rand.PRNGKey(1701)\n",
"\n",
"X = rand.normal(key, (N, M))\n",
"u, s, vt = jnp.linalg.svd(X)\n",
"assert u.shape == (N, N)\n",
"assert vt.shape == (M, M)\n",
"print(s)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"[6.9178133 5.9580317 5.581113 4.506963 4.111582 3.973543 3.3307292\n",
" 2.8664916 1.8229378 1.5478933]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCyKUn4-DCXn",
"colab_type": "text"
},
"source": [
"## XLA Compilation"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "2GOn_HhDPuEn",
"outputId": "41a40dd9-3680-458d-cedd-81ebcc2ab06f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"@jax.jit\n",
"def selu(x, alpha=1.67, lmbda=1.05):\n",
" return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n",
"x = jax.random.normal(key, (5000,))\n",
"result = selu(x).block_until_ready()\n",
"print(result)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"[ 0.34676832 -0.7532232 1.7060695 ... 2.1208048 -0.42621925\n",
" 0.13093236]\n"
],
"name": "stdout"
}
]
}
]
}
Loading

0 comments on commit 6e14bee

Please sign in to comment.