Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions profiling/mttkrp_example/implementations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import numpy as np


def old_mttkrp(Ul, Y, Ur, szl, szr, szn, R):
"""Existing implementation"""
Y = np.reshape(Y, (-1, szr))
Y = Y @ Ul
Y = np.reshape(Y, (szl, szn, R))
Ur = Ur.reshape((szl, 1, R))
V = np.zeros((szn, R))
for r in range(R):
V[:, [r]] = Y[:, :, r].T @ Ur[:, :, r]
return V


def new_mttkrp(Ul, Y, Ur, szl, szr, szn, R):
"""Proposed implementation using einsum"""
Y = np.reshape(Y, (-1, szr))
Y = Y.reshape(-1, szr) @ Ul
Y = np.reshape(Y, (szl, szn, R))
V = np.einsum("ijk, ik -> jk", Y, Ur)
return V
152 changes: 152 additions & 0 deletions profiling/mttkrp_example/mttkrp_profiling.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"import time\n",
"\n",
"import numpy as np\n",
"from implementations import new_mttkrp, old_mttkrp\n",
"from memory_profiler import memory_usage\n",
"\n",
"%load_ext memory_profiler"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"def get_tens(n, R, shape, seed=123):\n",
" szl = np.prod(shape[0:n])\n",
" szr = np.prod(shape[n + 1 :])\n",
" szn = shape[n]\n",
"\n",
" np_rng = np.random.default_rng(seed)\n",
" Ul = np_rng.random((szr, R))\n",
" Ur = np_rng.random((szl, R))\n",
" Y = np_rng.random(shape)\n",
" return Ul, Ur, Y, szl, szr, szn"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"def time_mttkrp(version, n, r, shape=[20, 30, 40, 3]):\n",
" Ul, Ur, Y, szl, szr, szn = get_tens(n, r, shape)\n",
" times = []\n",
" result = None\n",
" for _ in range(10):\n",
" start = time.time()\n",
" result = version(Ul, Y, Ur, szl, szr, szn, r)\n",
" end = time.time()\n",
" times.append(end - start)\n",
" return times, result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"def mem_mttkrp(version, n, r, shape=[20, 30, 40, 3]):\n",
" Ul, Ur, Y, szl, szr, szn = get_tens(n, r, shape)\n",
" mem_measurements = []\n",
" result = None\n",
" for _ in range(10):\n",
" mem_usage = memory_usage(\n",
" (version, (Ul, Y, Ur, szl, szr, szn, r)), max_usage=True\n",
" )\n",
" result = version(Ul, Y, Ur, szl, szr, szn, r)\n",
" mem_measurements.append(mem_usage)\n",
" return mem_measurements, result"
]
},
{
"cell_type": "markdown",
"id": "4",
"metadata": {},
"source": [
"## Memory\n",
"This runs before timing to avoid certain python high water mark effects.\n",
"Technically we should run these as separate processes to have more control and precision.\n",
"There is also a slight memory bias in the order we run things but that's only a concern if we are trying to be extra precise."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"print(\"shape\", [20, 30, 40, 50])\n",
"print(\"---------------------\")\n",
"for n in [1, 2]:\n",
" print(f\"mode-{1} mttkrp\")\n",
" for r in [2, 10, 50, 100, 200]:\n",
" print(f\"rank-{r}\")\n",
" old_mem, old_result = mem_mttkrp(old_mttkrp, n, r)\n",
" new_mem, new_result = mem_mttkrp(new_mttkrp, n, r)\n",
" print(\"results equal:\", np.all(old_result == new_result))\n",
" print(\"old\", np.array(old_mem)[1:].mean())\n",
" print(\"new\", np.array(new_mem)[1:].mean())\n",
" print(\"ratio\", np.array(old_mem)[1:].mean() / np.array(new_mem)[1:].mean())\n",
" print(\"-----------------------------------------------------------------------\")"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"## Timings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"print(\"shape\", [20, 30, 40, 50])\n",
"print(\"---------------------\")\n",
"for n in [1, 2]:\n",
" print(f\"mode-{1} mttkrp\")\n",
" for r in [2, 10, 50, 100, 200]:\n",
" print(f\"rank-{r}\")\n",
" old_time, old_result = time_mttkrp(old_mttkrp, n, r)\n",
" new_time, new_result = time_mttkrp(new_mttkrp, n, r)\n",
" print(\"results equal:\", np.all(old_result == new_result))\n",
" print(\"old\", np.array(old_time)[1:].mean())\n",
" print(\"new\", np.array(new_time)[1:].mean())\n",
" print(\"speedup\", np.array(old_time)[1:].mean() / np.array(new_time)[1:].mean())\n",
" print(\"-----------------------------------------------------------------------\")"
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ doc = [
profiling = [
"gprof2dot",
"graphviz",
"memory-profiler"
]

[tool.setuptools.packages.find]
Expand Down Expand Up @@ -98,6 +99,7 @@ convention = "numpy"
# Ignore everything for now
"docs/**.py" = ["E", "F", "PL", "W", "N", "NPY", "RUF", "B", "D"]
"docs/**.ipynb" = ["E", "F", "PL", "W", "N", "NPY", "RUF", "B", "D"]
"profiling/**.py" = ["E", "F", "PL", "W", "N", "NPY", "RUF", "B", "D"]
"profiling/**.ipynb" = ["E", "F", "PL", "W", "N", "NPY", "RUF", "B", "D"]

[tool.ruff.format]
Expand Down