diff --git a/02_pytorch_classification.ipynb b/02_pytorch_classification.ipynb index 03064f5d..addff66e 100644 --- a/02_pytorch_classification.ipynb +++ b/02_pytorch_classification.ipynb @@ -1366,6 +1366,11 @@ "# Set the number of epochs\n", "epochs = 100\n", "\n", + "def accuracy_fn(y_true, y_pred):\n" + " correct = (y_true == y_pred).sum().item()\n", + " total = len(y_true)\n", + " return (correct / total) * 100\n", + "\n", "# Put data to target device\n", "X_train, y_train = X_train.to(device), y_train.to(device)\n", "X_test, y_test = X_test.to(device), y_test.to(device)\n",