Skip to content

Commit

Permalink
Add chained matmul example benchmark plot.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Dec 11, 2024
1 parent 68eb00f commit ece625f
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ results/

# Notebooks converted to scripts.
docs/examples_ipynb/
*/.ipynb_checkpoints/

# Pixi envs
.pixi/
Expand Down
161 changes: 158 additions & 3 deletions examples/sparse_finch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@
" LEN = config[\"LEN\"]\n",
" DENSITY = config[\"DENSITY\"]\n",
"\n",
" G = nx.gnp_random_graph(n=LEN, p=DENSITY)\n",
" G = nx.gnp_random_graph(n=LEN, p=DENSITY, seed=rng.integers(2**64))\n",
" a_sps = nx.to_scipy_sparse_array(G)\n",
"\n",
" # ======= Finch =======\n",
Expand Down Expand Up @@ -498,6 +498,161 @@
" scipy_times.append(time_scipy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(nrows=1, ncols=1)\n",
"\n",
"ax.plot(size_n, finch_times, \"o-\", label=\"Finch\")\n",
"ax.plot(size_n, networkx_times, \"o-\", label=\"NetworkX\")\n",
"ax.plot(size_n, scipy_times, \"o-\", label=\"SciPy\")\n",
"ax.plot(size_n, finch_galley_times, \"o-\", label=\"Finch Galley\")\n",
"\n",
"ax.grid(True)\n",
"ax.set_xlabel(\"size N\")\n",
"ax.set_ylabel(\"time (sec)\")\n",
"ax.set_title(\"Counting Triangles\")\n",
"ax.legend(loc=\"best\", numpoints=1)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"Chained Matmul Example:\\n\")\n",
"\n",
"os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
"importlib.reload(sparse)\n",
"\n",
"CHAINED_MATMUL_LEN = 1000\n",
"\n",
"configs = [\n",
" {\n",
" \"I_\": CHAINED_MATMUL_LEN,\n",
" \"J_\": CHAINED_MATMUL_LEN,\n",
" \"K_\": CHAINED_MATMUL_LEN,\n",
" \"L_\": CHAINED_MATMUL_LEN,\n",
" \"DENSITY\": 1.0,\n",
" },\n",
" {\n",
" \"I_\": CHAINED_MATMUL_LEN,\n",
" \"J_\": CHAINED_MATMUL_LEN,\n",
" \"K_\": CHAINED_MATMUL_LEN,\n",
" \"L_\": CHAINED_MATMUL_LEN,\n",
" \"DENSITY\": 0.1,\n",
" },\n",
" {\n",
" \"I_\": CHAINED_MATMUL_LEN,\n",
" \"J_\": CHAINED_MATMUL_LEN,\n",
" \"K_\": CHAINED_MATMUL_LEN,\n",
" \"L_\": CHAINED_MATMUL_LEN,\n",
" \"DENSITY\": 0.01,\n",
" },\n",
" {\n",
" \"I_\": CHAINED_MATMUL_LEN,\n",
" \"J_\": CHAINED_MATMUL_LEN,\n",
" \"K_\": CHAINED_MATMUL_LEN,\n",
" \"L_\": CHAINED_MATMUL_LEN,\n",
" \"DENSITY\": 0.001,\n",
" },\n",
" {\n",
" \"I_\": CHAINED_MATMUL_LEN,\n",
" \"J_\": CHAINED_MATMUL_LEN,\n",
" \"K_\": CHAINED_MATMUL_LEN,\n",
" \"L_\": CHAINED_MATMUL_LEN,\n",
" \"DENSITY\": 0.0001,\n",
" },\n",
" {\n",
" \"I_\": CHAINED_MATMUL_LEN,\n",
" \"J_\": CHAINED_MATMUL_LEN,\n",
" \"K_\": CHAINED_MATMUL_LEN,\n",
" \"L_\": CHAINED_MATMUL_LEN,\n",
" \"DENSITY\": 0.00001,\n",
" },\n",
"]\n",
"densities = [x[\"DENSITY\"] for x in configs]\n",
"\n",
"if CI_MODE:\n",
" configs = configs[:1]\n",
" nonzeros = nonzeros[:1]\n",
"\n",
"finch_times = []\n",
"numba_times = []\n",
"finch_galley_times = []\n",
"\n",
"for config in configs:\n",
" A_np = sparse.random((config[\"I_\"], config[\"J_\"]), density=0.5).todense()\n",
" B_np = sparse.random((config[\"J_\"], config[\"K_\"]), density=0.5).todense()\n",
" C_sps = sparse.random(\n",
" (config[\"K_\"], config[\"L_\"]),\n",
" density=config[\"DENSITY\"],\n",
" random_state=rng,\n",
" )\n",
"\n",
" # ======= Finch =======\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
" importlib.reload(sparse)\n",
"\n",
" A = sparse.asarray(A_np, format=\"dense\")\n",
" B = sparse.asarray(B_np, format=\"dense\")\n",
" C = sparse.asarray(C_sps.todense(), format=\"csf\")\n",
"\n",
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
" def chained_matmul_finch(A, B, C):\n",
" return A @ B @ C\n",
"\n",
" # Compile\n",
" result_finch = mttkrp_finch(A, B, C)\n",
" # Benchmark\n",
" time_finch = benchmark(chained_matmul_finch, info=\"Finch\", args=[A, B, C])\n",
"\n",
" # ======= Finch Galley =======\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
" importlib.reload(sparse)\n",
"\n",
" A = sparse.asarray(A_np, format=\"dense\")\n",
" B = sparse.asarray(B_np, format=\"dense\")\n",
" C = sparse.asarray(C_sps.todense(), format=\"csf\")\n",
"\n",
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
" def chained_matmul_finch(A, B, C):\n",
" return A @ B @ C\n",
"\n",
" # Compile\n",
" result_finch_galley = chained_matmul_finch(A, B, C)\n",
" # Benchmark\n",
" time_finch_galley = benchmark(chained_matmul_finch, info=\"Finch\", args=[A, B, C])\n",
"\n",
" # ======= Numba =======\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
" importlib.reload(sparse)\n",
"\n",
" A = A_np\n",
" B = B_np\n",
" C = C_sps.asformat(\"csr\")\n",
"\n",
" def chained_matmul_numba(A, B, C):\n",
" return A @ B @ C\n",
"\n",
" # Compile\n",
" result_numba = chained_matmul_numba(A, B, C)\n",
" # Benchmark\n",
" time_numba = benchmark(chained_matmul_numba, info=\"Numba\", args=[A, B, C])\n",
"\n",
" np.testing.assert_allclose(result_finch.todense(), result_numba.todense())\n",
"\n",
" finch_times.append(time_finch)\n",
" numba_times.append(time_numba)\n",
" finch_galley_times.append(time_finch_galley)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -523,7 +678,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "sparse-dev",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -537,7 +692,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.12.8"
}
},
"nbformat": 4,
Expand Down
4 changes: 4 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ pre-commit = "*"
pytest-codspeed = "*"

[feature.notebooks.dependencies]
ipykernel = "*"
nbmake = "*"
matplotlib = "*"
networkx = "*"
jupyterlab = "*"

[feature.matrepr.dependencies]
matrepr = "*"
Expand Down Expand Up @@ -71,3 +74,4 @@ tests = ["tests", "extras"]
docs = ["docs", "extras"]
mlir-dev = {features = ["tests", "mlir"], no-default-feature = true}
finch-dev = {features = ["tests", "finch"], no-default-feature = true}
notebooks = ["extras", "mlir", "finch", "notebooks"]

0 comments on commit ece625f

Please sign in to comment.