diff --git a/profiling/mttkrp_example/implementations.py b/profiling/mttkrp_example/implementations.py new file mode 100644 index 00000000..891d5dfe --- /dev/null +++ b/profiling/mttkrp_example/implementations.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import numpy as np + + +def old_mttkrp(Ul, Y, Ur, szl, szr, szn, R): + """Existing implementation""" + Y = np.reshape(Y, (-1, szr)) + Y = Y @ Ul + Y = np.reshape(Y, (szl, szn, R)) + Ur = Ur.reshape((szl, 1, R)) + V = np.zeros((szn, R)) + for r in range(R): + V[:, [r]] = Y[:, :, r].T @ Ur[:, :, r] + return V + + +def new_mttkrp(Ul, Y, Ur, szl, szr, szn, R): + """Proposed implementation using einsum""" + Y = np.reshape(Y, (-1, szr)) + Y = Y.reshape(-1, szr) @ Ul + Y = np.reshape(Y, (szl, szn, R)) + V = np.einsum("ijk, ik -> jk", Y, Ur) + return V diff --git a/profiling/mttkrp_example/mttkrp_profiling.ipynb b/profiling/mttkrp_example/mttkrp_profiling.ipynb new file mode 100644 index 00000000..46c40c55 --- /dev/null +++ b/profiling/mttkrp_example/mttkrp_profiling.ipynb @@ -0,0 +1,152 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "import time\n", + "\n", + "import numpy as np\n", + "from implementations import new_mttkrp, old_mttkrp\n", + "from memory_profiler import memory_usage\n", + "\n", + "%load_ext memory_profiler" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "def get_tens(n, R, shape, seed=123):\n", + " szl = np.prod(shape[0:n])\n", + " szr = np.prod(shape[n + 1 :])\n", + " szn = shape[n]\n", + "\n", + " np_rng = np.random.default_rng(seed)\n", + " Ul = np_rng.random((szr, R))\n", + " Ur = np_rng.random((szl, R))\n", + " Y = np_rng.random(shape)\n", + " return Ul, Ur, Y, szl, szr, szn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "def time_mttkrp(version, n, r, shape=[20, 30, 40, 3]):\n", + " Ul, Ur, Y, szl, szr, szn = get_tens(n, r, shape)\n", + " times = []\n", + " result = None\n", + " for _ in range(10):\n", + " start = time.time()\n", + " result = version(Ul, Y, Ur, szl, szr, szn, r)\n", + " end = time.time()\n", + " times.append(end - start)\n", + " return times, result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "def mem_mttkrp(version, n, r, shape=[20, 30, 40, 3]):\n", + " Ul, Ur, Y, szl, szr, szn = get_tens(n, r, shape)\n", + " mem_measurements = []\n", + " result = None\n", + " for _ in range(10):\n", + " mem_usage = memory_usage(\n", + " (version, (Ul, Y, Ur, szl, szr, szn, r)), max_usage=True\n", + " )\n", + " result = version(Ul, Y, Ur, szl, szr, szn, r)\n", + " mem_measurements.append(mem_usage)\n", + " return mem_measurements, result" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Memory\n", + "This runs before timing to avoid certain python high water mark effects.\n", + "Technically we should run these as separate processes to have more control and precision.\n", + "There is also a slight memory bias in the order we run things but that's only a concern if we are trying to be extra precise." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"shape\", [20, 30, 40, 50])\n", + "print(\"---------------------\")\n", + "for n in [1, 2]:\n", + " print(f\"mode-{1} mttkrp\")\n", + " for r in [2, 10, 50, 100, 200]:\n", + " print(f\"rank-{r}\")\n", + " old_mem, old_result = mem_mttkrp(old_mttkrp, n, r)\n", + " new_mem, new_result = mem_mttkrp(new_mttkrp, n, r)\n", + " print(\"results equal:\", np.all(old_result == new_result))\n", + " print(\"old\", np.array(old_mem)[1:].mean())\n", + " print(\"new\", np.array(new_mem)[1:].mean())\n", + " print(\"ratio\", np.array(old_mem)[1:].mean() / np.array(new_mem)[1:].mean())\n", + " print(\"-----------------------------------------------------------------------\")" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Timings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"shape\", [20, 30, 40, 50])\n", + "print(\"---------------------\")\n", + "for n in [1, 2]:\n", + " print(f\"mode-{1} mttkrp\")\n", + " for r in [2, 10, 50, 100, 200]:\n", + " print(f\"rank-{r}\")\n", + " old_time, old_result = time_mttkrp(old_mttkrp, n, r)\n", + " new_time, new_result = time_mttkrp(new_mttkrp, n, r)\n", + " print(\"results equal:\", np.all(old_result == new_result))\n", + " print(\"old\", np.array(old_time)[1:].mean())\n", + " print(\"new\", np.array(new_time)[1:].mean())\n", + " print(\"speedup\", np.array(old_time)[1:].mean() / np.array(new_time)[1:].mean())\n", + " print(\"-----------------------------------------------------------------------\")" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 7aa29841..6b82e87d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ doc = [ profiling = [ "gprof2dot", "graphviz", + "memory-profiler" ] [tool.setuptools.packages.find] @@ -98,6 +99,7 @@ convention = "numpy" # Ignore everything for now "docs/**.py" = ["E", "F", "PL", "W", "N", "NPY", "RUF", "B", "D"] "docs/**.ipynb" = ["E", "F", "PL", "W", "N", "NPY", "RUF", "B", "D"] +"profiling/**.py" = ["E", "F", "PL", "W", "N", "NPY", "RUF", "B", "D"] "profiling/**.ipynb" = ["E", "F", "PL", "W", "N", "NPY", "RUF", "B", "D"] [tool.ruff.format]