From c6a2c21dfb0c698b0ac1655f87278ac1a6abc522 Mon Sep 17 00:00:00 2001 From: Kilian Lieret Date: Sat, 14 Oct 2023 15:44:12 +0200 Subject: [PATCH] Testing tiger implementation --- notebooks/010_torch_benchmark.ipynb | 257 +++++++++++++++++++--------- 1 file changed, 177 insertions(+), 80 deletions(-) diff --git a/notebooks/010_torch_benchmark.ipynb b/notebooks/010_torch_benchmark.ipynb index 76876d8..573709a 100644 --- a/notebooks/010_torch_benchmark.ipynb +++ b/notebooks/010_torch_benchmark.ipynb @@ -5,12 +5,13 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2023-10-13T14:07:06.963098Z", - "start_time": "2023-10-13T14:07:04.994889Z" + "end_time": "2023-10-14T12:43:28.741581Z", + "start_time": "2023-10-14T12:43:24.827711Z" } }, "outputs": [], "source": [ + "import pandas as pd\n", "import torch\n", "from gnn_tracking.training.tc import TCModule\n", "from pathlib import Path\n", @@ -34,7 +35,7 @@ ")\n", "assert chkpt_path.is_file()\n", "data_home = Path(\n", - " \"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/\"\n", + " \"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/\"\n", ")\n", "assert data_home.is_dir()\n", "data_path = data_home / \"part_1\" / \"data21000_s0.pt\"\n", @@ -43,8 +44,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-13T14:07:06.963646Z", - "start_time": "2023-10-13T14:07:06.955942Z" + "end_time": "2023-10-14T12:43:28.741878Z", + "start_time": "2023-10-14T12:43:28.741127Z" } } }, @@ -58,8 +59,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-13T14:07:07.010300Z", - "start_time": "2023-10-13T14:07:06.980576Z" + "end_time": "2023-10-14T12:43:28.855227Z", + "start_time": "2023-10-14T12:43:28.741311Z" } } }, @@ -71,13 +72,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[36m[10:07:07] DEBUG: Getting class PreTrainedECGraphTCN from module gnn_tracking.models.track_condensation_networks\u001b[0m\n", + "\u001b[36m[08:43:28] DEBUG: Getting class PreTrainedECGraphTCN from module gnn_tracking.models.track_condensation_networks\u001b[0m\n", "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:196: UserWarning: Attribute 'hc_in' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['hc_in'])`.\n", " rank_zero_warn(\n", - "\u001b[36m[10:07:07] DEBUG: Getting class MLGraphConstruction from module gnn_tracking.models.graph_construction\u001b[0m\n", - "\u001b[36m[10:07:07] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction\u001b[0m\n", - "\u001b[36m[10:07:07] DEBUG: Getting class PotentialLoss from module gnn_tracking.metrics.losses\u001b[0m\n", - "\u001b[36m[10:07:07] DEBUG: Getting class DBSCANHyperParamScanner from module gnn_tracking.postprocessing.dbscanscanner\u001b[0m\n" + "\u001b[36m[08:43:29] DEBUG: Getting class MLGraphConstruction from module gnn_tracking.models.graph_construction\u001b[0m\n", + "\u001b[36m[08:43:29] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction\u001b[0m\n", + "\u001b[36m[08:43:29] DEBUG: Getting class PotentialLoss from module gnn_tracking.metrics.losses\u001b[0m\n", + "\u001b[36m[08:43:29] DEBUG: Getting class DBSCANHyperParamScanner from module gnn_tracking.postprocessing.dbscanscanner\u001b[0m\n" ] } ], @@ -87,14 +88,14 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-13T14:07:07.848961Z", - "start_time": "2023-10-13T14:07:07.051270Z" + "end_time": "2023-10-14T12:43:29.403060Z", + "start_time": "2023-10-14T12:43:28.860180Z" } } }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "outputs": [], "source": [ "data = torch.load(data_path)\n", @@ -105,14 +106,14 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-13T14:07:07.874688Z", - "start_time": "2023-10-13T14:07:07.861655Z" + "end_time": "2023-10-14T12:43:46.164940Z", + "start_time": "2023-10-14T12:43:46.145387Z" } } }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "outputs": [], "source": [ "dp = lmodel.preproc(data)" @@ -120,44 +121,70 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-13T14:07:08.110693Z", - "start_time": "2023-10-13T14:07:07.863722Z" + "end_time": "2023-10-14T12:43:46.629034Z", + "start_time": "2023-10-14T12:43:46.304463Z" } } }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 39, "outputs": [], "source": [ + "from timeit import default_timer as timer\n", + "\n", + "\n", "class MemLogger:\n", " def __init__(self):\n", " self.mem = 0\n", + " self._cols = [\"desc\", \"Δmax\", \"max\", \"Δpersistent\", \"persistent\", \"Δtime\"]\n", + " self.data = []\n", + " self._active = True\n", + " self.time = timer()\n", + "\n", + " def deactivate(self):\n", + " self._active = False\n", + "\n", + " def activate(self):\n", + " self._active = True\n", "\n", " def log(self, desc=\"\"):\n", + " if not self._active:\n", + " return\n", " current = torch.cuda.memory_allocated() / 1e9\n", + " current_max = torch.cuda.max_memory_allocated() / 1e9\n", " added = current - self.mem\n", - " print(f\"{desc:<30} added {added:>8.2f} GB, total {current:>8.2f} GB\")\n", - " self.mem = current" + " added_max = current_max - self.mem\n", + " self.data.append(\n", + " (desc, added_max, current_max, added, current, timer() - self.time)\n", + " )\n", + " self.mem = current\n", + " self.time = timer()\n", + " torch.cuda.empty_cache()\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + " def get_df(self):\n", + " return pd.DataFrame.from_records(self.data, columns=self._cols)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-13T14:07:08.115666Z", - "start_time": "2023-10-13T14:07:08.108381Z" + "end_time": "2023-10-14T13:06:10.808783Z", + "start_time": "2023-10-14T13:06:10.226247Z" } } }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 47, "outputs": [], "source": [ "from torch import Tensor as T\n", "from torch.nn.functional import relu\n", "\n", "\n", - "def condensation_loss(\n", + "# @torch.compile\n", + "def condensation_loss_tiger(\n", " *,\n", " beta: T,\n", " x: T,\n", @@ -165,6 +192,7 @@ " weights: T,\n", " q_min: float,\n", " noise_threshold: int,\n", + " max_n_rep: int,\n", ") -> dict[str, T]:\n", " # To protect against nan in divisions\n", " eps = 1e-9\n", @@ -195,24 +223,29 @@ "\n", " qw_j_k = weights.view(-1, 1) * q.view(-1, 1) * q_k\n", "\n", - " repulsive_mask = (~attractive_mask) & (dist_j_k < 1)\n", - " # We have to include the hits-per-object normalization factor here, because\n", - " # after applying the mask we only have a 1D tensor anymore\n", - " qw_att_j_k = (qw_j_k / (attractive_mask.sum(dim=0) + eps))[attractive_mask]\n", - " qw_rep_j_k = (qw_j_k / ((~attractive_mask).sum(dim=0) + eps))[repulsive_mask]\n", + " att_norm_k = (attractive_mask.sum(dim=0) + eps) * len(unique_oids)\n", + " qw_att = (qw_j_k / att_norm_k)[attractive_mask]\n", "\n", " # Attractive potential/loss\n", - " v_att_j_k = qw_att_j_k * torch.square(dist_j_k)[attractive_mask]\n", - " # It's important to directly do the .mean here so we don't keep these large\n", - " # matrices in memory longer than we need them\n", - " # Attractive potential per object normalized over number of hits in object\n", - " v_att_k = torch.sum(v_att_j_k, dim=0)\n", - " v_att = torch.sum(v_att_k) / len(unique_oids)\n", - "\n", - " # Repulsive potential/loss\n", - " v_rep_j_k = qw_rep_j_k * (1 - dist_j_k[repulsive_mask])\n", - " v_rep_k = torch.sum(v_rep_j_k, dim=0)\n", - " v_rep = torch.sum(v_rep_k) / len(unique_oids)\n", + " v_att = (qw_att * torch.square(dist_j_k[attractive_mask])).sum()\n", + "\n", + " repulsive_mask = (~attractive_mask) & (dist_j_k < 1)\n", + " n_rep_k = (~attractive_mask).sum(dim=0)\n", + " n_rep = repulsive_mask.sum()\n", + " # Don't normalize to repulsive_mask, it includes the dist < 1 count,\n", + " # (less points within the radius 1 ball should translate to lower loss)\n", + " rep_norm = (n_rep_k + eps) * len(unique_oids)\n", + " if n_rep > max_n_rep:\n", + " sampling_freq = max_n_rep / n_rep\n", + " sampling_mask = (\n", + " torch.rand_like(repulsive_mask, dtype=torch.float16) < sampling_freq\n", + " )\n", + " repulsive_mask &= sampling_mask\n", + " sampling_scale = n_rep / max_n_rep\n", + " print(f\"Sampling {sampling_scale} of repulsive points\")\n", + " rep_norm *= sampling_freq\n", + " qw_rep = (qw_j_k / rep_norm)[repulsive_mask]\n", + " v_rep = (qw_rep * (1 - dist_j_k[repulsive_mask])).sum()\n", "\n", " l_coward = torch.mean(1 - beta[alphas])\n", " l_noise = torch.mean(beta[~not_noise])\n", @@ -222,25 +255,44 @@ " \"repulsive\": v_rep,\n", " \"coward\": l_coward,\n", " \"noise\": l_noise,\n", + " \"n_rep\": n_rep,\n", " }" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-13T14:07:08.117687Z", - "start_time": "2023-10-13T14:07:08.112548Z" + "end_time": "2023-10-14T13:06:58.676012Z", + "start_time": "2023-10-14T13:06:58.251376Z" } } }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 25, + "outputs": [], + "source": [ + "# from importlib import reload\n", + "# import object_condensation.pytorch\n", + "# reload(object_condensation.pytorch.losses)\n", + "# from object_condensation.pytorch.losses import condensation_loss" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-14T12:47:34.886120Z", + "start_time": "2023-10-14T12:47:34.873652Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 26, "outputs": [ { "data": { - "text/plain": "{'attractive': tensor(1.7951, dtype=torch.float64),\n 'repulsive': tensor(1.9509, dtype=torch.float64),\n 'coward': tensor(0.2157, dtype=torch.float64),\n 'noise': tensor(0.7748, dtype=torch.float64)}" + "text/plain": "{'attractive': tensor(1.7951, dtype=torch.float64),\n 'repulsive': tensor(1.9509, dtype=torch.float64),\n 'coward': tensor(0.2157, dtype=torch.float64),\n 'noise': tensor(0.7748, dtype=torch.float64),\n 'n_rep': tensor(220768)}" }, - "execution_count": 11, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -255,60 +307,51 @@ "td = generate_test_data()\n", "\n", "td = TorchCondensationMockData.from_numpy(td)\n", - "cl = condensation_loss(\n", + "cl = condensation_loss_tiger(\n", " beta=td.beta.squeeze(),\n", " x=td.x,\n", " object_id=td.object_id.squeeze(),\n", " weights=td.weights.squeeze(),\n", " q_min=td.q_min,\n", " noise_threshold=0,\n", + " ml=MemLogger(),\n", + " max_n_rep=1000_000,\n", ")\n", "cl" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-13T14:07:21.611605Z", - "start_time": "2023-10-13T14:07:21.516071Z" + "end_time": "2023-10-14T12:47:36.228860Z", + "start_time": "2023-10-14T12:47:36.098183Z" } } }, { - "cell_type": "code", - "execution_count": 12, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'attractive': tensor(1.7951, dtype=torch.float64), 'repulsive': tensor(1.9509, dtype=torch.float64), 'coward': tensor(0.2157, dtype=torch.float64), 'noise': tensor(0.7748, dtype=torch.float64)}\n" - ] - } - ], + "cell_type": "markdown", "source": [ - "print(cl, flush=True)" + "Sampling 2.2076799869537354 of repulsive points\n", + "\n", + "{'attractive': tensor(1.7951, dtype=torch.float64),\n", + " 'repulsive': tensor(1.9543, dtype=torch.float64),\n", + " 'coward': tensor(0.2157, dtype=torch.float64),\n", + " 'noise': tensor(0.7748, dtype=torch.float64),\n", + " 'n_rep': tensor(220768)}" ], "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-10-13T14:07:22.476753Z", - "start_time": "2023-10-13T14:07:22.361876Z" - } + "collapsed": false } }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 51, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "empty added 0.17 GB, total 0.17 GB\n", - "model evaluated added 14.75 GB, total 14.92 GB\n", - "loss evaluated added 3.99 GB, total 18.91 GB\n", - "backward done evaluated added -18.74 GB, total 0.17 GB\n", - "step done added 0.02 GB, total 0.18 GB\n" + "Sampling 832.9146118164062 of repulsive points\n", + "tensor(0.8953, device='cuda:0', grad_fn=) tensor(83291466, device='cuda:0')\n" ] } ], @@ -318,33 +361,86 @@ "ml.log(\"empty\")\n", "out = model(dp)\n", "ml.log(\"model evaluated\")\n", - "loss = condensation_loss(\n", + "loss = condensation_loss_tiger(\n", " beta=out[\"B\"],\n", " x=out[\"H\"],\n", " object_id=data.particle_id,\n", " q_min=0.1,\n", " noise_threshold=0,\n", " weights=torch.ones_like(data.particle_id),\n", + " # ml=ml,\n", + " max_n_rep=100_000,\n", ")\n", "total_loss = loss[\"attractive\"] + loss[\"repulsive\"] + loss[\"noise\"] + loss[\"coward\"]\n", "ml.log(\"loss evaluated\")\n", "optimizer.zero_grad()\n", "total_loss.backward()\n", - "ml.log(\"backward done evaluated\")\n", + "ml.log(\"backward evaluated\")\n", "optimizer.step()\n", - "ml.log(\"step done\")" + "ml.log(\"Optimizer stepped\")\n", + "print(total_loss, loss[\"n_rep\"])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-14T13:07:29.391207Z", + "start_time": "2023-10-14T13:07:27.868595Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 52, + "outputs": [ + { + "data": { + "text/plain": " desc Δmax max Δpersistent persistent Δtime\n0 empty 22.180506 22.180506 22.165249 22.165249 0.001223\n1 model evaluated 25.827639 47.992888 23.333476 45.498725 0.236748\n2 loss evaluated 26.004328 71.503054 8.212371 53.711096 0.203532\n3 backward evaluated 19.140668 72.851764 -31.545847 22.165249 0.499658\n4 Optimizer stepped 0.030507 22.195756 0.015254 22.180503 0.390120", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
descΔmaxmaxΔpersistentpersistentΔtime
0empty22.18050622.18050622.16524922.1652490.001223
1model evaluated25.82763947.99288823.33347645.4987250.236748
2loss evaluated26.00432871.5030548.21237153.7110960.203532
3backward evaluated19.14066872.851764-31.54584722.1652490.499658
4Optimizer stepped0.03050722.1957560.01525422.1805030.390120
\n
" + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "ml.get_df()" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-13T14:07:29.241361Z", - "start_time": "2023-10-13T14:07:28.933839Z" + "end_time": "2023-10-14T13:07:29.904544Z", + "start_time": "2023-10-14T13:07:29.354699Z" } } }, { "cell_type": "markdown", "source": [ + "```\n", + "tiger:\n", + "\n", + "Description | Δ | Σ | Δ max | Σ max\n", + "empty | 0.17 | 0.17 | 0.18 | 0.18\n", + "model evaluated | 14.75 | 14.92 | 16.38 | 16.55\n", + "loss evaluated | 4.54 | 19.46 | 12.30 | 27.21\n", + "backward evaluated | -19.29 | 0.17 | 8.84 | 28.30\n", + "step done | 0.02 | 0.18 | 0.03 | 0.20\n", + "\n", + "\n", + "default: \n", + "\n", + "Description | Δ | Σ | Δ max | Σ max\n", + "empty | 0.17 | 0.17 | 0.18 | 0.18\n", + "model evaluated | 14.75 | 14.92 | 16.38 | 16.55\n", + "loss evaluated | 14.18 | 29.10 | 25.19 | 40.11\n", + "backward evaluated | -28.93 | 0.17 | 5.15 | 34.25\n", + "step done | 0.02 | 0.18 | 0.03 | 0.20\n", + "\n", + "\n", + "---\n", + "\n", "empty added 0.14 GB, total 0.14 GB\n", "model evaluated added 14.76 GB, total 14.90 GB\n", "loss evaluated added 3.87 GB, total 18.77 GB\n", @@ -356,7 +452,8 @@ "model evaluated added 14.76 GB, total 40.63 GB\n", "loss evaluated added 14.18 GB, total 54.82 GB\n", "backward done evaluated added -28.93 GB, total 25.88 GB\n", - "step done added 0.02 GB, total 25.90 GB" + "step done added 0.02 GB, total 25.90 GB\n", + "```" ], "metadata": { "collapsed": false