Skip to content

Commit

Permalink
Update compiled() calls in notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Dec 6, 2024
1 parent 19a26af commit 1c4d5c6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
12 changes: 6 additions & 6 deletions examples/sparse_finch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
"\n",
" @sparse.compiled(opt=\"default\")\n",
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
" def mttkrp_finch(B, D, C):\n",
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
"\n",
Expand All @@ -204,7 +204,7 @@
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
"\n",
" @sparse.compiled(opt=\"galley\")\n",
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
" def mttkrp_finch(B, D, C):\n",
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
"\n",
Expand Down Expand Up @@ -309,7 +309,7 @@
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
"\n",
" @sparse.compiled(opt=\"default\")\n",
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
" def sddmm_finch(s, a, b):\n",
" return sparse.sum(\n",
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
Expand All @@ -330,7 +330,7 @@
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
"\n",
" @sparse.compiled(opt=\"galley\")\n",
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
" def sddmm_finch(s, a, b):\n",
" # return s * (a @ b)\n",
" return sparse.sum(\n",
Expand Down Expand Up @@ -442,7 +442,7 @@
"\n",
" a = sparse.asarray(a_sps)\n",
"\n",
" @sparse.compiled(opt=\"default\")\n",
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
" def ct_finch(a):\n",
" return sparse.sum(\n",
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
Expand All @@ -460,7 +460,7 @@
"\n",
" a = sparse.asarray(a_sps)\n",
"\n",
" @sparse.compiled(opt=\"galley\")\n",
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
" def ct_finch(a):\n",
" return sparse.sum(\n",
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
Expand Down
6 changes: 6 additions & 0 deletions sparse/finch_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
raise ImportError("Finch not installed. Run `pip install sparse[finch]` to enable Finch backend") from e

from finch import (
DefaultScheduler,
GalleyScheduler,
SparseArray,
abs,
acos,
Expand Down Expand Up @@ -97,6 +99,7 @@
remainder,
reshape,
round,
set_optimizer,
sign,
sin,
sinh,
Expand All @@ -119,6 +122,8 @@
)

__all__ = [
"DefaultScheduler",
"GalleyScheduler",
"SparseArray",
"abs",
"acos",
Expand Down Expand Up @@ -231,4 +236,5 @@
"empty_like",
"arange",
"linspace",
"set_optimizer",
]

0 comments on commit 1c4d5c6

Please sign in to comment.