Skip to content

Commit c486f7f

Browse files
added commands for saving the whole model
1 parent 668d4ab commit c486f7f

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

01_pytorch_workflow.ipynb

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,19 +1379,32 @@
13791379
"cell_type": "markdown",
13801380
"metadata": {},
13811381
"source": [
1382+
"## Only model parameters\n",
1383+
"\n",
13821384
"### Save\n",
13831385
"`torch.save(model.state_dict(), PATH)`\n",
13841386
"\n",
13851387
"### Load\n",
13861388
"\n",
13871389
"`model = TheModelClass(*args, **kwargs)` \n",
13881390
"`model.load_state_dict(torch.load(PATH))` \n",
1389-
"`model.eval()` "
1391+
"`model.eval()` \n",
1392+
"\n",
1393+
"## The entire model\n",
1394+
"\n",
1395+
"### Save\n",
1396+
"\n",
1397+
"`torch.save(model, PATH)`\n",
1398+
"\n",
1399+
"### Load\n",
1400+
"\n",
1401+
"`model = torch.load(PATH)` \n",
1402+
"`model.eval()`"
13901403
]
13911404
},
13921405
{
13931406
"cell_type": "code",
1394-
"execution_count": 313,
1407+
"execution_count": 1,
13951408
"metadata": {
13961409
"colab": {
13971410
"base_uri": "https://localhost:8080/"
@@ -1406,6 +1419,17 @@
14061419
"text": [
14071420
"Saving model to: models\\01_pytorch_workflow_model_0.pth\n"
14081421
]
1422+
},
1423+
{
1424+
"ename": "NameError",
1425+
"evalue": "name 'torch' is not defined",
1426+
"output_type": "error",
1427+
"traceback": [
1428+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
1429+
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
1430+
"Cell \u001b[1;32mIn[1], line 13\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;66;03m# 3. Save the model state dict\u001b[39;00m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSaving model to: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mMODEL_SAVE_PATH\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m---> 13\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39msave(obj\u001b[38;5;241m=\u001b[39mmodel_0\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;66;03m# only saving the state_dict() only saves the models learned parameters\u001b[39;00m\n\u001b[0;32m 14\u001b[0m f\u001b[38;5;241m=\u001b[39mMODEL_SAVE_PATH)\n",
1431+
"\u001b[1;31mNameError\u001b[0m: name 'torch' is not defined"
1432+
]
14091433
}
14101434
],
14111435
"source": [

0 commit comments

Comments
 (0)