Skip to content

Commit

Permalink
updated multimodal models
Browse files Browse the repository at this point in the history
  • Loading branch information
oguiza committed Feb 16, 2025
1 parent 3f461e5 commit d751602
Show file tree
Hide file tree
Showing 11 changed files with 399 additions and 402 deletions.
128 changes: 106 additions & 22 deletions nbs/022_tslearner.ipynb

Large diffs are not rendered by default.

30 changes: 15 additions & 15 deletions nbs/029_models.layers.ipynb

Large diffs are not rendered by default.

70 changes: 51 additions & 19 deletions nbs/030_models.utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,46 @@
"create_model = build_ts_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tsai.data.core import get_ts_dls, TSClassification\n",
"from tsai.models.TSiTPlus import TSiTPlus\n",
"from fastai.losses import CrossEntropyLossFlat"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"arch: TSiTPlus(c_in=3 c_out=2 seq_len=128 arch_config={} kwargs={'custom_head': functools.partial(<class 'tsai.models.layers.lin_nd_head'>, d=3)})\n",
"torch.Size([13, 3, 2])\n",
"TensorBase(0.8796, grad_fn=<AliasBackward0>)\n"
]
}
],
"source": [
"X = np.random.rand(16, 3, 128)\n",
"y = np.random.randint(0, 2, (16, 3))\n",
"tfms = [None, [TSClassification()]]\n",
"dls = get_ts_dls(X, y, splits=RandomSplitter()(range_of(X)), tfms=tfms)\n",
"model = build_ts_model(TSiTPlus, dls=dls, pretrained=False, verbose=True)\n",
"xb, yb = dls.one_batch()\n",
"output = model(xb)\n",
"print(output.shape)\n",
"loss = CrossEntropyLossFlat()(output, yb)\n",
"print(loss)\n",
"assert output.shape == (dls.bs, dls.d, dls.c)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -495,15 +535,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
]
}
],
"outputs": [],
"source": [
"c_in = 3\n",
"seq_len = 30\n",
Expand Down Expand Up @@ -555,14 +587,14 @@
{
"data": {
"text/plain": [
"(array([ 0.74775537, 1.41245663, 2.12445924, 2.8943163 , 3.56384351,\n",
" 4.23789602, 4.83134182, 5.18560431, 5.30551186, 6.29076506,\n",
" 6.58873471, 7.03661275, 7.0884361 , 7.57927022, 8.21911791,\n",
" 8.59726773, 9.37382718, 10.17298849, 10.40118308, 10.82265631]),\n",
" array([ 6.29076506, 6.58873471, 7.03661275, 7.0884361 , 7.57927022,\n",
" 8.21911791, 8.59726773, 9.37382718, 10.17298849, 10.40118308]),\n",
" array([ 6.58873471, 7.03661275, 7.0884361 , 7.57927022, 8.21911791,\n",
" 8.59726773, 9.37382718, 10.17298849, 10.40118308, 10.82265631]))"
"(array([0.99029138, 1.68463991, 2.21744589, 2.65448222, 2.85159354,\n",
" 3.26171729, 3.67986707, 4.04343956, 4.3077458 , 4.44585435,\n",
" 4.76876866, 4.85844441, 4.93256093, 5.52415845, 6.10704489,\n",
" 6.74848957, 7.31920741, 8.20198208, 8.78954283, 9.0402 ]),\n",
" array([4.44585435, 4.76876866, 4.85844441, 4.93256093, 5.52415845,\n",
" 6.10704489, 6.74848957, 7.31920741, 8.20198208, 8.78954283]),\n",
" array([4.76876866, 4.85844441, 4.93256093, 5.52415845, 6.10704489,\n",
" 6.74848957, 7.31920741, 8.20198208, 8.78954283, 9.0402 ]))"
]
},
"execution_count": null,
Expand Down Expand Up @@ -595,9 +627,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/nacho/notebooks/tsai/nbs/030_models.utils.ipynb saved at 2024-01-31 13:03:06\n",
"/Users/nacho/notebooks/tsai/nbs/030_models.utils.ipynb saved at 2025-01-22 18:23:18\n",
"Correct notebook to script conversion! 😃\n",
"Wednesday 31/01/24 13:03:08 CET\n"
"Wednesday 22/01/25 18:23:21 CET\n"
]
},
{
Expand Down
122 changes: 70 additions & 52 deletions nbs/068_models.TSiTPlus.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit d751602

Please sign in to comment.