diff --git a/examples/sparse_finch.ipynb b/examples/sparse_finch.ipynb index 4d100215..d1790288 100644 --- a/examples/sparse_finch.ipynb +++ b/examples/sparse_finch.ipynb @@ -39,6 +39,7 @@ "import sparse\n", "\n", "import matplotlib.pyplot as plt\n", + "import networkx as nx\n", "\n", "import numpy as np\n", "import scipy.sparse as sps\n", @@ -105,7 +106,7 @@ "metadata": {}, "outputs": [], "source": [ - "ITERS = 3\n", + "ITERS = 1\n", "rng = np.random.default_rng(0)" ] }, @@ -134,6 +135,13 @@ " return elapsed / ITERS" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MTTKRP" + ] + }, { "cell_type": "code", "execution_count": null, @@ -146,14 +154,13 @@ "importlib.reload(sparse)\n", "\n", "configs = [\n", - " {\"I_\": 100, \"J_\": 25, \"K_\": 10, \"L_\": 10, \"DENSITY\": 0.001},\n", " {\"I_\": 100, \"J_\": 25, \"K_\": 100, \"L_\": 10, \"DENSITY\": 0.001},\n", " {\"I_\": 100, \"J_\": 25, \"K_\": 100, \"L_\": 100, \"DENSITY\": 0.001},\n", " {\"I_\": 1000, \"J_\": 25, \"K_\": 100, \"L_\": 100, \"DENSITY\": 0.001},\n", " {\"I_\": 1000, \"J_\": 25, \"K_\": 1000, \"L_\": 100, \"DENSITY\": 0.001},\n", " {\"I_\": 1000, \"J_\": 25, \"K_\": 1000, \"L_\": 1000, \"DENSITY\": 0.001},\n", "]\n", - "nonzeros = [10000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]\n", + "nonzeros = [100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]\n", "\n", "if CI_MODE:\n", " configs = configs[:1]\n", @@ -161,11 +168,16 @@ "\n", "finch_times = []\n", "numba_times = []\n", + "finch_galley_times = []\n", "\n", "for config in configs:\n", - " B_sps = sparse.random((config[\"I_\"], config[\"K_\"], config[\"L_\"]), density=config[\"DENSITY\"], random_state=rng) * 10\n", - " D_sps = rng.random((config[\"L_\"], config[\"J_\"])) * 10\n", - " C_sps = rng.random((config[\"K_\"], config[\"J_\"])) * 10\n", + " B_sps = sparse.random(\n", + " (config[\"I_\"], config[\"K_\"], config[\"L_\"]),\n", + " density=config[\"DENSITY\"],\n", + " random_state=rng,\n", + " )\n", + " D_sps = rng.random((config[\"L_\"], config[\"J_\"]))\n", + " C_sps = rng.random((config[\"K_\"], config[\"J_\"]))\n", "\n", " # ======= Finch =======\n", " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n", @@ -175,7 +187,7 @@ " D = sparse.asarray(np.array(D_sps, order=\"F\"))\n", " C = sparse.asarray(np.array(C_sps, order=\"F\"))\n", "\n", - " @sparse.compiled\n", + " @sparse.compiled(opt=\"default\")\n", " def mttkrp_finch(B, D, C):\n", " return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n", "\n", @@ -184,6 +196,23 @@ " # Benchmark\n", " time_finch = benchmark(mttkrp_finch, info=\"Finch\", args=[B, D, C])\n", "\n", + " # ======= Finch Galley =======\n", + " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n", + " importlib.reload(sparse)\n", + "\n", + " B = sparse.asarray(B_sps.todense(), format=\"csf\")\n", + " D = sparse.asarray(np.array(D_sps, order=\"F\"))\n", + " C = sparse.asarray(np.array(C_sps, order=\"F\"))\n", + "\n", + " @sparse.compiled(opt=\"galley\")\n", + " def mttkrp_finch(B, D, C):\n", + " return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n", + "\n", + " # Compile\n", + " result_finch_galley = mttkrp_finch(B, D, C)\n", + " # Benchmark\n", + " time_finch_galley = benchmark(mttkrp_finch, info=\"Finch\", args=[B, D, C])\n", + "\n", " # ======= Numba =======\n", " os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n", " importlib.reload(sparse)\n", @@ -201,8 +230,10 @@ " time_numba = benchmark(mttkrp_numba, info=\"Numba\", args=[B, D, C])\n", "\n", " np.testing.assert_allclose(result_finch.todense(), result_numba.todense())\n", + "\n", " finch_times.append(time_finch)\n", - " numba_times.append(time_numba)" + " numba_times.append(time_numba)\n", + " finch_galley_times.append(time_finch_galley)" ] }, { @@ -215,6 +246,7 @@ "\n", "ax.plot(nonzeros, finch_times, \"o-\", label=\"Finch\")\n", "ax.plot(nonzeros, numba_times, \"o-\", label=\"Numba\")\n", + "ax.plot(nonzeros, finch_galley_times, \"o-\", label=\"Finch - Galley\")\n", "ax.grid(True)\n", "ax.set_xlabel(\"no. of elements\")\n", "ax.set_ylabel(\"time (sec)\")\n", @@ -226,6 +258,13 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SDDMM" + ] + }, { "cell_type": "code", "execution_count": null, @@ -235,15 +274,13 @@ "print(\"SDDMM Example:\\n\")\n", "\n", "configs = [\n", - " {\"LEN\": 10, \"DENSITY\": 0.1},\n", - " {\"LEN\": 50, \"DENSITY\": 0.05},\n", - " {\"LEN\": 100, \"DENSITY\": 0.01},\n", - " {\"LEN\": 500, \"DENSITY\": 0.005},\n", - " {\"LEN\": 1000, \"DENSITY\": 0.001},\n", - " {\"LEN\": 5000, \"DENSITY\": 0.00005},\n", + " {\"LEN\": 5000, \"DENSITY\": 0.00001},\n", " {\"LEN\": 10000, \"DENSITY\": 0.00001},\n", + " {\"LEN\": 20000, \"DENSITY\": 0.00001},\n", + " {\"LEN\": 25000, \"DENSITY\": 0.00001},\n", + " {\"LEN\": 30000, \"DENSITY\": 0.00001},\n", "]\n", - "size_n = [10, 50, 100, 500, 1000, 5000, 10000]\n", + "size_n = [5000, 10000, 20000, 25000, 30000]\n", "\n", "if CI_MODE:\n", " configs = configs[:1]\n", @@ -252,17 +289,19 @@ "finch_times = []\n", "numba_times = []\n", "scipy_times = []\n", + "finch_galley_times = []\n", "\n", "for config in configs:\n", " LEN = config[\"LEN\"]\n", " DENSITY = config[\"DENSITY\"]\n", "\n", - " a_sps = rng.random((LEN, LEN)) * 10\n", - " b_sps = rng.random((LEN, LEN)) * 10\n", - " s_sps = sps.random(LEN, LEN, format=\"coo\", density=DENSITY, random_state=rng) * 10\n", + " a_sps = rng.random((LEN, LEN))\n", + " b_sps = rng.random((LEN, LEN))\n", + " s_sps = sps.random(LEN, LEN, format=\"coo\", density=DENSITY, random_state=rng)\n", " s_sps.sum_duplicates()\n", "\n", " # ======= Finch =======\n", + " print(\"finch\")\n", " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n", " importlib.reload(sparse)\n", "\n", @@ -270,7 +309,7 @@ " a = sparse.asarray(np.array(a_sps, order=\"F\"))\n", " b = sparse.asarray(np.array(b_sps, order=\"C\"))\n", "\n", - " @sparse.compiled\n", + " @sparse.compiled(opt=\"default\")\n", " def sddmm_finch(s, a, b):\n", " return sparse.sum(\n", " s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n", @@ -282,7 +321,30 @@ " # Benchmark\n", " time_finch = benchmark(sddmm_finch, info=\"Finch\", args=[s, a, b])\n", "\n", + " # ======= Finch Galley =======\n", + " print(\"finch galley\")\n", + " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n", + " importlib.reload(sparse)\n", + "\n", + " s = sparse.asarray(s_sps)\n", + " a = sparse.asarray(np.array(a_sps, order=\"F\"))\n", + " b = sparse.asarray(np.array(b_sps, order=\"C\"))\n", + "\n", + " @sparse.compiled(opt=\"galley\")\n", + " def sddmm_finch(s, a, b):\n", + " # return s * (a @ b)\n", + " return sparse.sum(\n", + " s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n", + " axis=-1,\n", + " )\n", + "\n", + " # Compile\n", + " result_finch_galley = sddmm_finch(s, a, b)\n", + " # Benchmark\n", + " time_finch_galley = benchmark(sddmm_finch, info=\"Finch\", args=[s, a, b])\n", + "\n", " # ======= Numba =======\n", + " print(\"numba\")\n", " os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n", " importlib.reload(sparse)\n", "\n", @@ -299,6 +361,8 @@ " time_numba = benchmark(sddmm_numba, info=\"Numba\", args=[s, a, b])\n", "\n", " # ======= SciPy =======\n", + " print(\"scipy\")\n", + "\n", " def sddmm_scipy(s, a, b):\n", " return s.multiply(a @ b)\n", "\n", @@ -312,7 +376,8 @@ "\n", " finch_times.append(time_finch)\n", " numba_times.append(time_numba)\n", - " scipy_times.append(time_scipy)" + " scipy_times.append(time_scipy)\n", + " finch_galley_times.append(time_finch_galley)" ] }, { @@ -326,13 +391,134 @@ "ax.plot(size_n, finch_times, \"o-\", label=\"Finch\")\n", "ax.plot(size_n, numba_times, \"o-\", label=\"Numba\")\n", "ax.plot(size_n, scipy_times, \"o-\", label=\"SciPy\")\n", + "ax.plot(size_n, finch_galley_times, \"o-\", label=\"Finch Galley\")\n", "\n", "ax.grid(True)\n", "ax.set_xlabel(\"size N\")\n", "ax.set_ylabel(\"time (sec)\")\n", "ax.set_title(\"SDDMM\")\n", - "ax.set_xscale(\"log\")\n", - "# ax.set_yscale('log')\n", + "# ax.set_xscale(\"log\")\n", + "# ax.set_yscale(\"log\")\n", + "ax.legend(loc=\"best\", numpoints=1)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Counting Triangles Example:\\n\")\n", + "\n", + "configs = [\n", + " {\"LEN\": 1000, \"DENSITY\": 0.1},\n", + " {\"LEN\": 2000, \"DENSITY\": 0.1},\n", + " {\"LEN\": 3000, \"DENSITY\": 0.1},\n", + " {\"LEN\": 4000, \"DENSITY\": 0.1},\n", + " {\"LEN\": 5000, \"DENSITY\": 0.1},\n", + "]\n", + "size_n = [1000, 2000, 3000, 4000, 5000]\n", + "\n", + "if CI_MODE:\n", + " configs = configs[:1]\n", + " size_n = size_n[:1]\n", + "\n", + "finch_times = []\n", + "finch_galley_times = []\n", + "networkx_times = []\n", + "scipy_times = []\n", + "\n", + "for config in configs:\n", + " LEN = config[\"LEN\"]\n", + " DENSITY = config[\"DENSITY\"]\n", + "\n", + " G = nx.gnp_random_graph(n=LEN, p=DENSITY)\n", + " a_sps = nx.to_scipy_sparse_array(G)\n", + "\n", + " # ======= Finch =======\n", + " print(\"finch\")\n", + " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n", + " importlib.reload(sparse)\n", + "\n", + " a = sparse.asarray(a_sps)\n", + "\n", + " @sparse.compiled(opt=\"default\")\n", + " def ct_finch(a):\n", + " return sparse.sum(\n", + " a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n", + " ) / sparse.asarray(6)\n", + "\n", + " # Compile\n", + " result_finch = ct_finch(a)\n", + " # Benchmark\n", + " time_finch = benchmark(ct_finch, info=\"Finch\", args=[a])\n", + "\n", + " # ======= Finch Galley =======\n", + " print(\"finch galley\")\n", + " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n", + " importlib.reload(sparse)\n", + "\n", + " a = sparse.asarray(a_sps)\n", + "\n", + " @sparse.compiled(opt=\"galley\")\n", + " def ct_finch(a):\n", + " return sparse.sum(\n", + " a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n", + " ) / sparse.asarray(6)\n", + "\n", + " # Compile\n", + " result_finch_galley = ct_finch(a)\n", + " # Benchmark\n", + " time_finch_galley = benchmark(ct_finch, info=\"Finch\", args=[a])\n", + "\n", + " # ======= SciPy =======\n", + " print(\"scipy\")\n", + "\n", + " def ct_scipy(a):\n", + " return (a @ a * a).sum() / 6\n", + "\n", + " a = a_sps\n", + "\n", + " # Benchmark\n", + " time_scipy = benchmark(ct_scipy, info=\"SciPy\", args=[a])\n", + "\n", + " # ======= NetworkX =======\n", + " print(\"networkx\")\n", + "\n", + " def ct_networkx(a):\n", + " return sum(nx.triangles(a).values()) / 3\n", + "\n", + " a = G\n", + "\n", + " time_networkx = benchmark(ct_networkx, info=\"SciPy\", args=[a])\n", + "\n", + " finch_times.append(time_finch)\n", + " finch_galley_times.append(time_finch_galley)\n", + " networkx_times.append(time_networkx)\n", + " scipy_times.append(time_scipy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(nrows=1, ncols=1)\n", + "\n", + "ax.plot(size_n, finch_times, \"o-\", label=\"Finch\")\n", + "ax.plot(size_n, networkx_times, \"o-\", label=\"NetworkX\")\n", + "ax.plot(size_n, scipy_times, \"o-\", label=\"SciPy\")\n", + "ax.plot(size_n, finch_galley_times, \"o-\", label=\"Finch Galley\")\n", + "\n", + "ax.grid(True)\n", + "ax.set_xlabel(\"size N\")\n", + "ax.set_ylabel(\"time (sec)\")\n", + "ax.set_title(\"Counting Triangles\")\n", + "# ax.set_xscale(\"log\")\n", + "# ax.set_yscale(\"log\")\n", "ax.legend(loc=\"best\", numpoints=1)\n", "\n", "plt.show()" @@ -355,7 +541,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.10.14" } }, "nbformat": 4,