Skip to content

Commit

Permalink
Merge pull request #827 from pydata/update-finch-notebook
Browse files Browse the repository at this point in the history
Update `sparse_finch.ipynb`
  • Loading branch information
mtsokol authored Dec 12, 2024
2 parents 68eb00f + 77c155d commit bf8236e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 47 deletions.
2 changes: 1 addition & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ dependencies:
- pytest-cov
- pytest-xdist
- pip:
- finch-tensor>=0.2.1
- finch-tensor>=0.2.2
- finch-mlir>=0.0.2
- pytest-codspeed
86 changes: 42 additions & 44 deletions examples/sparse_finch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
"X = sparse.asarray(X, format=\"csc\")\n",
"X_lazy = sparse.lazy(X)\n",
"\n",
"X_X = sparse.compute(sparse.permute_dims(X_lazy, (1, 0)) @ X_lazy, verbose=True)\n",
"X_X = sparse.compute(sparse.permute_dims(X_lazy, (1, 0)) @ X_lazy)\n",
"\n",
"X_X = sparse.asarray(X_X, format=\"csc\") # move back from dense to CSC format\n",
"\n",
Expand Down Expand Up @@ -171,11 +171,8 @@
"finch_galley_times = []\n",
"\n",
"for config in configs:\n",
" B_sps = sparse.random(\n",
" (config[\"I_\"], config[\"K_\"], config[\"L_\"]),\n",
" density=config[\"DENSITY\"],\n",
" random_state=rng,\n",
" )\n",
" B_shape = (config[\"I_\"], config[\"K_\"], config[\"L_\"])\n",
" B_sps = sparse.random(B_shape, density=config[\"DENSITY\"], random_state=rng)\n",
" D_sps = rng.random((config[\"L_\"], config[\"J_\"]))\n",
" C_sps = rng.random((config[\"K_\"], config[\"J_\"]))\n",
"\n",
Expand Down Expand Up @@ -204,14 +201,14 @@
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
"\n",
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
" def mttkrp_finch(B, D, C):\n",
" @sparse.compiled(opt=sparse.GalleyScheduler(), tag=sum(B_shape))\n",
" def mttkrp_finch_galley(B, D, C):\n",
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
"\n",
" # Compile\n",
" result_finch_galley = mttkrp_finch(B, D, C)\n",
" result_finch_galley = mttkrp_finch_galley(B, D, C)\n",
" # Benchmark\n",
" time_finch_galley = benchmark(mttkrp_finch, info=\"Finch\", args=[B, D, C])\n",
" time_finch_galley = benchmark(mttkrp_finch_galley, info=\"Finch Galley\", args=[B, D, C])\n",
"\n",
" # ======= Numba =======\n",
" os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
Expand Down Expand Up @@ -276,11 +273,12 @@
"configs = [\n",
" {\"LEN\": 5000, \"DENSITY\": 0.00001},\n",
" {\"LEN\": 10000, \"DENSITY\": 0.00001},\n",
" {\"LEN\": 15000, \"DENSITY\": 0.00001},\n",
" {\"LEN\": 20000, \"DENSITY\": 0.00001},\n",
" {\"LEN\": 25000, \"DENSITY\": 0.00001},\n",
" {\"LEN\": 30000, \"DENSITY\": 0.00001},\n",
"]\n",
"size_n = [5000, 10000, 20000, 25000, 30000]\n",
"size_n = [5000, 10000, 15000, 20000, 25000, 30000]\n",
"\n",
"if CI_MODE:\n",
" configs = configs[:1]\n",
Expand All @@ -306,15 +304,12 @@
" importlib.reload(sparse)\n",
"\n",
" s = sparse.asarray(s_sps)\n",
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
" a = sparse.asarray(a_sps)\n",
" b = sparse.asarray(b_sps)\n",
"\n",
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
" def sddmm_finch(s, a, b):\n",
" return sparse.sum(\n",
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
" axis=-1,\n",
" )\n",
" return s * (a @ b)\n",
"\n",
" # Compile\n",
" result_finch = sddmm_finch(s, a, b)\n",
Expand All @@ -327,21 +322,17 @@
" importlib.reload(sparse)\n",
"\n",
" s = sparse.asarray(s_sps)\n",
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
" a = sparse.asarray(a_sps)\n",
" b = sparse.asarray(b_sps)\n",
"\n",
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
" def sddmm_finch(s, a, b):\n",
" # return s * (a @ b)\n",
" return sparse.sum(\n",
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
" axis=-1,\n",
" )\n",
" @sparse.compiled(opt=sparse.GalleyScheduler(), tag=LEN)\n",
" def sddmm_finch_galley(s, a, b):\n",
" return s * (a @ b)\n",
"\n",
" # Compile\n",
" result_finch_galley = sddmm_finch(s, a, b)\n",
" result_finch_galley = sddmm_finch_galley(s, a, b)\n",
" # Benchmark\n",
" time_finch_galley = benchmark(sddmm_finch, info=\"Finch\", args=[s, a, b])\n",
" time_finch_galley = benchmark(sddmm_finch_galley, info=\"Finch Galley\", args=[s, a, b])\n",
"\n",
" # ======= Numba =======\n",
" print(\"numba\")\n",
Expand Down Expand Up @@ -402,6 +393,13 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Counting Triangles"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -411,13 +409,17 @@
"print(\"Counting Triangles Example:\\n\")\n",
"\n",
"configs = [\n",
" {\"LEN\": 1000, \"DENSITY\": 0.1},\n",
" {\"LEN\": 2000, \"DENSITY\": 0.1},\n",
" {\"LEN\": 3000, \"DENSITY\": 0.1},\n",
" {\"LEN\": 4000, \"DENSITY\": 0.1},\n",
" {\"LEN\": 5000, \"DENSITY\": 0.1},\n",
" {\"LEN\": 10000, \"DENSITY\": 0.001},\n",
" {\"LEN\": 15000, \"DENSITY\": 0.001},\n",
" {\"LEN\": 20000, \"DENSITY\": 0.001},\n",
" {\"LEN\": 25000, \"DENSITY\": 0.001},\n",
" {\"LEN\": 30000, \"DENSITY\": 0.001},\n",
" {\"LEN\": 35000, \"DENSITY\": 0.001},\n",
" {\"LEN\": 40000, \"DENSITY\": 0.001},\n",
" {\"LEN\": 45000, \"DENSITY\": 0.001},\n",
" {\"LEN\": 50000, \"DENSITY\": 0.001},\n",
"]\n",
"size_n = [1000, 2000, 3000, 4000, 5000]\n",
"size_n = [10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000]\n",
"\n",
"if CI_MODE:\n",
" configs = configs[:1]\n",
Expand All @@ -444,9 +446,7 @@
"\n",
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
" def ct_finch(a):\n",
" return sparse.sum(\n",
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
" ) / sparse.asarray(6)\n",
" return sparse.sum(a @ a * a) / sparse.asarray(6)\n",
"\n",
" # Compile\n",
" result_finch = ct_finch(a)\n",
Expand All @@ -460,16 +460,14 @@
"\n",
" a = sparse.asarray(a_sps)\n",
"\n",
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
" def ct_finch(a):\n",
" return sparse.sum(\n",
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
" ) / sparse.asarray(6)\n",
" @sparse.compiled(opt=sparse.GalleyScheduler(), tag=LEN)\n",
" def ct_finch_galley(a):\n",
" return sparse.sum(a @ a * a) / sparse.asarray(6)\n",
"\n",
" # Compile\n",
" result_finch_galley = ct_finch(a)\n",
" result_finch_galley = ct_finch_galley(a)\n",
" # Benchmark\n",
" time_finch_galley = benchmark(ct_finch, info=\"Finch\", args=[a])\n",
" time_finch_galley = benchmark(ct_finch_galley, info=\"Finch Galley\", args=[a])\n",
"\n",
" # ======= SciPy =======\n",
" print(\"scipy\")\n",
Expand Down
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ precompile = "python -c 'import finch'"

[feature.finch.pypi-dependencies]
scipy = ">=1.13"
finch-tensor = ">=0.2.1"
finch-tensor = ">=0.2.2"

[feature.finch.activation.env]
SPARSE_BACKEND = "Finch"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ tests = [
tox = ["sparse[tests]", "tox"]
notebooks = ["sparse[tests]", "nbmake", "matplotlib"]
all = ["sparse[docs,tox,notebooks]", "matrepr"]
finch = ["finch-tensor>=0.2.1"]
finch = ["finch-tensor>=0.2.2"]

[project.urls]
Documentation = "https://sparse.pydata.org/"
Expand Down

0 comments on commit bf8236e

Please sign in to comment.