From 9cb38d2d3b636ef5f0a99a9ac4171faeea141213 Mon Sep 17 00:00:00 2001 From: Quentin Raquet Date: Thu, 8 Oct 2020 14:41:35 +0200 Subject: [PATCH] feat: update readme and notebooks --- README.md | 50 ++++++++++++++++++++++++++++---- census_example.ipynb | 36 +++++++++++++++-------- forest_example.ipynb | 33 ++++++++++++++------- multi_regression_example.ipynb | 42 ++++++++++++++++++--------- multi_task_example.ipynb | 35 ++++++++++++++++------ pytorch_tabnet/abstract_model.py | 2 +- regression_example.ipynb | 18 ++++++++++-- 7 files changed, 162 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index ef081023..89cf7c15 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,10 @@ TabNet is now scikit-compatible, training a TabNetClassifier or TabNetRegressor from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor clf = TabNetClassifier() #TabNetRegressor() -clf.fit(X_train, Y_train, X_valid, y_valid) +clf.fit( + X_train, Y_train, + eval_set=[(X_valid, y_valid)] +) preds = clf.predict(X_test) ``` @@ -60,10 +63,37 @@ or for TabNetMultiTaskClassifier : ``` from pytorch_tabnet.multitask import TabNetMultiTaskClassifier clf = TabNetMultiTaskClassifier() -clf.fit(X_train, Y_train, X_valid, y_valid) +clf.fit( + X_train, Y_train, + eval_set=[(X_valid, y_valid)] +) preds = clf.predict(X_test) ``` +### Custom early_stopping_metrics + +``` +from pytorch_tabnet.metrics import Metric +from sklearn.metrics import roc_auc_score + +class Gini(Metric): + def __init__(self): + self._name = "gini" + self._maximize = True + + def __call__(self, y_true, y_score): + auc = roc_auc_score(y_true, y_score[:, 1]) + return max(2*auc - 1, 0.) + +clf = TabNetClassifier() +clf.fit( + X_train, Y_train, + eval_set=[(X_valid, y_valid)], + eval_metric=[Gini] +) + +``` + # Useful links - explanatory video : https://youtu.be/ysBaZO8YmX8 @@ -175,13 +205,18 @@ preds = clf.predict(X_test) Training targets -- X_valid : np.array +- eval_set: list of tuple - Validation features for early stopping + List of eval tuple set (X, y). + The last one is used for early stopping -- y_valid : np.array for early stopping +- eval_name: list of str + List of eval set names. + +- eval_metric : list of str + List of evaluation metrics. + The last metric is used for early stopping. - Validation targets - max_epochs : int (default = 200) Maximum number of epochs for trainng. @@ -218,3 +253,6 @@ preds = clf.predict(X_test) - drop_last : bool (default=False) Whether to drop last batch if not complete during training + +- callbacks : list of callback function + List of custom callbacks diff --git a/census_example.ipynb b/census_example.ipynb index 2b59060a..899fa14e 100755 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -194,14 +194,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "clf.fit(\n", " X_train=X_train, y_train=y_train,\n", - " X_valid=X_valid, y_valid=y_valid,\n", + " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", + " eval_name=['train', 'valid'],\n", + " eval_metric=['auc'],\n", " max_epochs=max_epochs , patience=20,\n", " batch_size=1024, virtual_batch_size=128,\n", " num_workers=0,\n", @@ -217,8 +217,7 @@ "outputs": [], "source": [ "# plot losses\n", - "plt.plot(clf.history['train']['loss'])\n", - "plt.plot(clf.history['valid']['loss'])" + "plt.plot(clf.history['loss'])" ] }, { @@ -228,8 +227,8 @@ "outputs": [], "source": [ "# plot auc\n", - "plt.plot([-x for x in clf.history['train']['metric']])\n", - "plt.plot([-x for x in clf.history['valid']['metric']])" + "plt.plot(clf.history['train_auc'])\n", + "plt.plot(clf.history['valid_auc'])" ] }, { @@ -239,7 +238,7 @@ "outputs": [], "source": [ "# plot learning rates\n", - "plt.plot([x for x in clf.history['train']['lr']])" + "plt.plot(clf.history['lr'])" ] }, { @@ -421,9 +420,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": ".shap", "language": "python", - "name": "python3" + "name": ".shap" }, "language_info": { "codemirror_mode": { @@ -435,7 +434,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.6.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4, diff --git a/forest_example.ipynb b/forest_example.ipynb index c3395747..3cc2e82a 100644 --- a/forest_example.ipynb +++ b/forest_example.ipynb @@ -237,7 +237,7 @@ "metadata": {}, "outputs": [], "source": [ - "max_epochs = 1000 if not os.getenv(\"CI\", False) else 2" + "max_epochs = 5 if not os.getenv(\"CI\", False) else 2" ] }, { @@ -250,7 +250,8 @@ "source": [ "clf.fit(\n", " X_train=X_train, y_train=y_train,\n", - " X_valid=X_valid, y_valid=y_valid,\n", + " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", + " eval_name=['train', 'valid'],\n", " max_epochs=max_epochs, patience=100,\n", " batch_size=16384, virtual_batch_size=256\n", ") " @@ -263,8 +264,7 @@ "outputs": [], "source": [ "# plot losses\n", - "plt.plot(clf.history['train']['loss'])\n", - "plt.plot(clf.history['valid']['loss'])" + "plt.plot(clf.history['loss'])" ] }, { @@ -273,9 +273,9 @@ "metadata": {}, "outputs": [], "source": [ - "# plot accuracies\n", - "plt.plot([-x for x in clf.history['train']['metric']])\n", - "plt.plot([-x for x in clf.history['valid']['metric']])" + "# plot accuracy\n", + "plt.plot(clf.history['train_accuracy'])\n", + "plt.plot(clf.history['valid_accuracy'])" ] }, { @@ -495,9 +495,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": ".shap", "language": "python", - "name": "python3" + "name": ".shap" }, "language_info": { "codemirror_mode": { @@ -509,7 +509,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5" + "version": "3.6.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4, diff --git a/multi_regression_example.ipynb b/multi_regression_example.ipynb index fd09b2dc..139beacc 100644 --- a/multi_regression_example.ipynb +++ b/multi_regression_example.ipynb @@ -19,7 +19,12 @@ "\n", "import os\n", "import wget\n", - "from pathlib import Path" + "from pathlib import Path\n", + "\n", + "\n", + "%load_ext autoreload\n", + "\n", + "%autoreload 2" ] }, { @@ -188,20 +193,21 @@ "metadata": {}, "outputs": [], "source": [ - "max_epochs = 1000 if not os.getenv(\"CI\", False) else 2" + "max_epochs = 10 if not os.getenv(\"CI\", False) else 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": true + "scrolled": false }, "outputs": [], "source": [ "clf.fit(\n", " X_train=X_train, y_train=y_train,\n", - " X_valid=X_valid, y_valid=y_valid,\n", + " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", + " eval_name=['train', 'valid'],\n", " max_epochs=max_epochs,\n", " patience=50,\n", " batch_size=1024, virtual_batch_size=128,\n", @@ -216,17 +222,12 @@ "metadata": {}, "outputs": [], "source": [ - "# Deprecated : best model is automatically loaded at end of fit\n", - "# clf.load_best_model()\n", - "\n", "preds = clf.predict(X_test)\n", "\n", - "y_true = y_test\n", - "\n", - "test_score = mean_squared_error(y_pred=preds, y_true=y_true)\n", + "test_mse = mean_squared_error(y_pred=preds, y_true=y_test)\n", "\n", "print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n", - "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_score}\")" + "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_mse}\")" ] }, { @@ -296,9 +297,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": ".shap", "language": "python", - "name": "python3" + "name": ".shap" }, "language_info": { "codemirror_mode": { @@ -310,7 +311,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5" + "version": "3.6.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4, diff --git a/multi_task_example.ipynb b/multi_task_example.ipynb index 435c6448..21acf83b 100644 --- a/multi_task_example.ipynb +++ b/multi_task_example.ipynb @@ -22,7 +22,11 @@ "from pathlib import Path\n", "\n", "from matplotlib import pyplot as plt\n", - "%matplotlib inline" + "%matplotlib inline\n", + "\n", + "%load_ext autoreload\n", + "\n", + "%autoreload 2" ] }, { @@ -212,7 +216,8 @@ "source": [ "clf.fit(\n", " X_train=X_train, y_train=y_train,\n", - " X_valid=X_valid, y_valid=y_valid,\n", + " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", + " eval_name=['train', 'valid'],\n", " max_epochs=max_epochs , patience=20,\n", " batch_size=1024, virtual_batch_size=128,\n", " num_workers=0,\n", @@ -228,8 +233,7 @@ "outputs": [], "source": [ "# plot losses\n", - "plt.plot(clf.history['train']['loss'])\n", - "plt.plot(clf.history['valid']['loss'])" + "plt.plot(clf.history['loss'])" ] }, { @@ -238,9 +242,9 @@ "metadata": {}, "outputs": [], "source": [ - "# plot auc\n", - "plt.plot([-x for x in clf.history['train']['metric']])\n", - "plt.plot([-x for x in clf.history['valid']['metric']])" + "# plot logloss\n", + "plt.plot(clf.history['train_logloss'])\n", + "plt.plot(clf.history['valid_logloss'])" ] }, { @@ -250,7 +254,7 @@ "outputs": [], "source": [ "# plot learning rates\n", - "plt.plot([x for x in clf.history['train']['lr']])" + "plt.plot(clf.history['lr'])" ] }, { @@ -444,7 +448,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.6.9" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4, diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index d3c8ec31..149e46be 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -257,7 +257,7 @@ def explain(self, X): for key, value in masks.items(): res_masks[key] = np.vstack([res_masks[key], value]) - res_explain = np.vstack(res_explain) + res_explain = np.vstack(res_explain) return res_explain, res_masks diff --git a/regression_example.ipynb b/regression_example.ipynb index 1c89f390..d2983e55 100644 --- a/regression_example.ipynb +++ b/regression_example.ipynb @@ -189,7 +189,8 @@ "source": [ "clf.fit(\n", " X_train=X_train, y_train=y_train,\n", - " X_valid=X_valid, y_valid=y_valid,\n", + " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", + " eval_name=['train', 'valid'],\n", " max_epochs=max_epochs,\n", " patience=50,\n", " batch_size=1024, virtual_batch_size=128,\n", @@ -350,7 +351,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5" + "version": "3.6.9" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4,