diff --git a/Makefile b/Makefile index f7c39f98..6a0e5ab6 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ NO_COLOR=\\e[39m OK_COLOR=\\e[32m ERROR_COLOR=\\e[31m WARN_COLOR=\\e[33m -PORT=8889 +PORT=8887 .SILENT: ; default: help; # default target diff --git a/README.md b/README.md index 14ab004c..607622c5 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,12 @@ A complete example can be found within the notebook `pretraining_example.ipynb`. /!\ : current implementation is trying to reconstruct the original inputs, but Batch Normalization applies a random transformation that can't be deduced by a single line, making the reconstruction harder. Lowering the `batch_size` might make the pretraining easier. +# Data augmentation on the fly + +It is now possible to apply custom data augmentation pipeline during training. +Templates for ClassificationSMOTE and RegressionSMOTE have been added in `pytorch-tabnet/augmentations.py` and can be used as is. + + # Easy saving and loading It's really easy to save and re-load a trained model, this makes TabNet production ready. diff --git a/census_example.ipynb b/census_example.ipynb index 7ef1caf8..de257ba8 100755 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -205,6 +205,16 @@ "max_epochs = 100 if not os.getenv(\"CI\", False) else 2" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_tabnet.augmentations import ClassificationSMOTE\n", + "aug = ClassificationSMOTE(p=0.2)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -225,10 +235,11 @@ " batch_size=1024, virtual_batch_size=128,\n", " num_workers=0,\n", " weights=1,\n", - " drop_last=False\n", + " drop_last=False,\n", + " augmentations=aug, #aug, None\n", " )\n", " save_history.append(clf.history[\"valid_auc\"])\n", - " \n", + "\n", "assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))" ] }, diff --git a/forest_example.ipynb b/forest_example.ipynb index 5ba8bd1d..ccb4ac69 100644 --- a/forest_example.ipynb +++ b/forest_example.ipynb @@ -237,7 +237,7 @@ "metadata": {}, "outputs": [], "source": [ - "max_epochs = 5 if not os.getenv(\"CI\", False) else 2" + "max_epochs = 50 if not os.getenv(\"CI\", False) else 2" ] }, { @@ -248,12 +248,16 @@ }, "outputs": [], "source": [ + "from pytorch_tabnet.augmentations import ClassificationSMOTE\n", + "aug = ClassificationSMOTE(p=0.2)\n", + "\n", "clf.fit(\n", " X_train=X_train, y_train=y_train,\n", " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", " eval_name=['train', 'valid'],\n", " max_epochs=max_epochs, patience=100,\n", - " batch_size=16384, virtual_batch_size=256\n", + " batch_size=16384, virtual_batch_size=256,\n", + " augmentations=aug\n", ") " ] }, diff --git a/poetry.lock b/poetry.lock index c0e3ba6b..6a2912e0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -65,17 +65,17 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] name = "attrs" -version = "21.2.0" +version = "21.4.0" description = "Classes Without Boilerplate" category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.extras] -dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] -tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"] -tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] [[package]] name = "babel" @@ -130,7 +130,7 @@ pycparser = "*" [[package]] name = "charset-normalizer" -version = "2.0.9" +version = "2.0.12" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." category = "dev" optional = false @@ -176,7 +176,7 @@ python-versions = ">=3.6, <3.7" [[package]] name = "decorator" -version = "5.1.0" +version = "5.1.1" description = "Decorators for Humans" category = "dev" optional = false @@ -253,6 +253,21 @@ docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] perf = ["ipython"] testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"] +[[package]] +name = "importlib-resources" +version = "5.4.0" +description = "Read resources from Python packages" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-black (>=0.3.7)", "pytest-mypy"] + [[package]] name = "iniconfig" version = "1.1.1" @@ -282,7 +297,7 @@ test = ["pytest (!=5.3.4)", "pytest-cov", "flaky", "nose", "jedi (<=0.17.2)"] [[package]] name = "ipython" -version = "7.16.2" +version = "7.16.3" description = "IPython: Productive Interactive Computing" category = "dev" optional = false @@ -431,7 +446,7 @@ test = ["async-generator", "ipykernel", "ipython", "mock", "pytest-asyncio", "py [[package]] name = "jupyter-console" -version = "6.4.0" +version = "6.4.2" description = "Jupyter terminal console" category = "dev" optional = false @@ -449,7 +464,7 @@ test = ["pexpect"] [[package]] name = "jupyter-core" -version = "4.9.1" +version = "4.9.2" description = "Jupyter core package. A base package on which Jupyter projects rely." category = "dev" optional = false @@ -603,7 +618,7 @@ python-versions = ">=3.5" [[package]] name = "notebook" -version = "6.4.6" +version = "6.4.8" description = "A web-based notebook environment for interactive computing" category = "dev" optional = false @@ -721,11 +736,11 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "prometheus-client" -version = "0.12.0" +version = "0.13.1" description = "Python client for the Prometheus monitoring system." category = "dev" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.6" [package.extras] twisted = ["twisted"] @@ -783,7 +798,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] name = "pygments" -version = "2.10.0" +version = "2.11.2" description = "Pygments is a syntax highlighting package written in Python." category = "dev" optional = false @@ -791,7 +806,7 @@ python-versions = ">=3.5" [[package]] name = "pyparsing" -version = "3.0.6" +version = "3.0.7" description = "Python parsing module" category = "dev" optional = false @@ -851,7 +866,7 @@ python-versions = "*" [[package]] name = "pywin32" -version = "302" +version = "303" description = "Python for Window Extensions" category = "dev" optional = false @@ -859,7 +874,7 @@ python-versions = "*" [[package]] name = "pywinpty" -version = "1.1.6" +version = "2.0.3" description = "Pseudo terminal support for Windows from Python." category = "dev" optional = false @@ -901,11 +916,17 @@ test = ["flaky", "pytest", "pytest-qt"] [[package]] name = "qtpy" -version = "1.11.3" -description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5, PyQt4 and PySide) and additional custom QWidgets." +version = "2.0.1" +description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." category = "dev" optional = false -python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" +python-versions = ">=3.6" + +[package.dependencies] +packaging = "*" + +[package.extras] +test = ["pytest (>=6.0.0)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "recommonmark" @@ -922,7 +943,7 @@ sphinx = ">=1.3.1" [[package]] name = "requests" -version = "2.26.0" +version = "2.27.1" description = "Python HTTP for Humans." category = "dev" optional = false @@ -1115,11 +1136,11 @@ test = ["pytest"] [[package]] name = "terminado" -version = "0.12.1" +version = "0.13.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." category = "dev" optional = false -python-versions = ">=3.6" +python-versions = "*" [package.dependencies] ptyprocess = {version = "*", markers = "os_name != \"nt\""} @@ -1131,18 +1152,18 @@ test = ["pytest"] [[package]] name = "testpath" -version = "0.5.0" +version = "0.6.0" description = "Test utilities for code working with files and commands" category = "dev" optional = false python-versions = ">= 3.5" [package.extras] -test = ["pytest", "pathlib2"] +test = ["pytest"] [[package]] name = "threadpoolctl" -version = "3.0.0" +version = "3.1.0" description = "threadpoolctl" category = "main" optional = false @@ -1174,7 +1195,7 @@ python-versions = ">= 3.5" [[package]] name = "tqdm" -version = "4.62.3" +version = "4.63.0" description = "Fast, Extensible Progress Meter" category = "main" optional = false @@ -1182,6 +1203,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +importlib-resources = {version = "*", markers = "python_version < \"3.7\""} [package.extras] dev = ["py-make (>=0.1.0)", "twine", "wheel"] @@ -1206,7 +1228,7 @@ test = ["pytest", "mock"] [[package]] name = "typing-extensions" -version = "4.0.1" +version = "4.1.1" description = "Backported and Experimental Type Hints for Python 3.6+" category = "dev" optional = false @@ -1214,14 +1236,13 @@ python-versions = ">=3.6" [[package]] name = "urllib3" -version = "1.26.7" +version = "1.22" description = "HTTP library with thread-safe connection pooling, file post, and more." category = "dev" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" +python-versions = "*" [package.extras] -brotli = ["brotlipy (>=0.6.0)"] secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] @@ -1276,7 +1297,7 @@ scipy = "*" name = "zipp" version = "3.6.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -1286,8 +1307,8 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" -python-versions = "^3.6" -content-hash = "924b692f42c5b45e28c771ce997fce9ecb16db9f77c05b22bc0620360130926a" +python-versions = ">=3.6" +content-hash = "d38522a7948f24326479797abe2f75ad8b77200c58152779d423040f13f0b0a3" [metadata.files] alabaster = [ @@ -1334,8 +1355,8 @@ atomicwrites = [ {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, ] attrs = [ - {file = "attrs-21.2.0-py2.py3-none-any.whl", hash = "sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1"}, - {file = "attrs-21.2.0.tar.gz", hash = "sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb"}, + {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, + {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, ] babel = [ {file = "Babel-2.9.1-py2.py3-none-any.whl", hash = "sha256:ab49e12b91d937cd11f0b67cb259a57ab4ad2b59ac7a3b41d6c06c0ac5b0def9"}, @@ -1406,8 +1427,8 @@ cffi = [ {file = "cffi-1.15.0.tar.gz", hash = "sha256:920f0d66a896c2d99f0adbb391f990a84091179542c205fa53ce5787aff87954"}, ] charset-normalizer = [ - {file = "charset-normalizer-2.0.9.tar.gz", hash = "sha256:b0b883e8e874edfdece9c28f314e3dd5badf067342e42fb162203335ae61aa2c"}, - {file = "charset_normalizer-2.0.9-py3-none-any.whl", hash = "sha256:1eecaa09422db5be9e29d7fc65664e6c33bd06f9ced7838578ba40d58bdf3721"}, + {file = "charset-normalizer-2.0.12.tar.gz", hash = "sha256:2857e29ff0d34db842cd7ca3230549d1a697f96ee6d3fb071cfa6c7393832597"}, + {file = "charset_normalizer-2.0.12-py3-none-any.whl", hash = "sha256:6881edbebdb17b39b4eaaa821b438bf6eddffb4468cf344f09f89def34a8b1df"}, ] colorama = [ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, @@ -1426,8 +1447,8 @@ dataclasses = [ {file = "dataclasses-0.8.tar.gz", hash = "sha256:8479067f342acf957dc82ec415d355ab5edb7e7646b90dc6e2fd1d96ad084c97"}, ] decorator = [ - {file = "decorator-5.1.0-py3-none-any.whl", hash = "sha256:7b12e7c3c6ab203a29e157335e9122cb03de9ab7264b137594103fd4a683b374"}, - {file = "decorator-5.1.0.tar.gz", hash = "sha256:e59913af105b9860aa2c8d3272d9de5a56a4e608db9a2f167a8480b323d529a7"}, + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] defusedxml = [ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, @@ -1457,6 +1478,10 @@ importlib-metadata = [ {file = "importlib_metadata-4.8.3-py3-none-any.whl", hash = "sha256:65a9576a5b2d58ca44d133c42a241905cc45e34d2c06fd5ba2bafa221e5d7b5e"}, {file = "importlib_metadata-4.8.3.tar.gz", hash = "sha256:766abffff765960fcc18003801f7044eb6755ffae4521c8e8ce8e83b9c9b0668"}, ] +importlib-resources = [ + {file = "importlib_resources-5.4.0-py3-none-any.whl", hash = "sha256:33a95faed5fc19b4bc16b29a6eeae248a3fe69dd55d4d229d2b480e23eeaad45"}, + {file = "importlib_resources-5.4.0.tar.gz", hash = "sha256:d756e2f85dd4de2ba89be0b21dba2a3bbec2e871a42a3a16719258a11f87506b"}, +] iniconfig = [ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, @@ -1466,8 +1491,8 @@ ipykernel = [ {file = "ipykernel-5.5.6.tar.gz", hash = "sha256:4ea44b90ae1f7c38987ad58ea0809562a17c2695a0499644326f334aecd369ec"}, ] ipython = [ - {file = "ipython-7.16.2-py3-none-any.whl", hash = "sha256:2f644313be4fdc5c8c2a17467f2949c29423c9e283a159d1fc9bf450a1a300af"}, - {file = "ipython-7.16.2.tar.gz", hash = "sha256:613085f8acb0f35f759e32bea35fba62c651a4a2e409a0da11414618f5eec0c4"}, + {file = "ipython-7.16.3-py3-none-any.whl", hash = "sha256:c0427ed8bc33ac481faf9d3acf7e84e0010cdaada945e0badd1e2e74cc075833"}, + {file = "ipython-7.16.3.tar.gz", hash = "sha256:5ac47dc9af66fc2f5530c12069390877ae372ac905edca75a92a6e363b5d7caa"}, ] ipython-genutils = [ {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, @@ -1503,12 +1528,12 @@ jupyter-client = [ {file = "jupyter_client-6.1.13.tar.gz", hash = "sha256:d03558bc9b7955d8b4a6df604a8d9d257e00bcea7fb364fe41cdef81d998a966"}, ] jupyter-console = [ - {file = "jupyter_console-6.4.0-py3-none-any.whl", hash = "sha256:7799c4ea951e0e96ba8260575423cb323ea5a03fcf5503560fa3e15748869e27"}, - {file = "jupyter_console-6.4.0.tar.gz", hash = "sha256:242248e1685039cd8bff2c2ecb7ce6c1546eb50ee3b08519729e6e881aec19c7"}, + {file = "jupyter_console-6.4.2-py3-none-any.whl", hash = "sha256:1d52cf1a80f0c7accaa2b8c68ca3e6fa2311ec33ac9651d6cb6b9168cca1dad9"}, + {file = "jupyter_console-6.4.2.tar.gz", hash = "sha256:fce5bccac926c690924168ad46cae33a7d78d643a7b60af0f260af25d38ecf26"}, ] jupyter-core = [ - {file = "jupyter_core-4.9.1-py3-none-any.whl", hash = "sha256:1c091f3bbefd6f2a8782f2c1db662ca8478ac240e962ae2c66f0b87c818154ea"}, - {file = "jupyter_core-4.9.1.tar.gz", hash = "sha256:dce8a7499da5a53ae3afd5a9f4b02e5df1d57250cf48f3ad79da23b4778cd6fa"}, + {file = "jupyter_core-4.9.2-py3-none-any.whl", hash = "sha256:f875e4d27e202590311d468fa55f90c575f201490bd0c18acabe4e318db4a46d"}, + {file = "jupyter_core-4.9.2.tar.gz", hash = "sha256:d69baeb9ffb128b8cd2657fcf2703f89c769d1673c851812119e3a2a0e93ad9a"}, ] jupyterlab-pygments = [ {file = "jupyterlab_pygments-0.1.2-py2.py3-none-any.whl", hash = "sha256:abfb880fd1561987efaefcb2d2ac75145d2a5d0139b1876d5be806e32f630008"}, @@ -1664,8 +1689,8 @@ nest-asyncio = [ {file = "nest_asyncio-1.5.4.tar.gz", hash = "sha256:f969f6013a16fadb4adcf09d11a68a4f617c6049d7af7ac2c676110169a63abd"}, ] notebook = [ - {file = "notebook-6.4.6-py3-none-any.whl", hash = "sha256:5cad068fa82cd4fb98d341c052100ed50cd69fbfb4118cb9b8ab5a346ef27551"}, - {file = "notebook-6.4.6.tar.gz", hash = "sha256:7bcdf79bd1cda534735bd9830d2cbedab4ee34d8fe1df6e7b946b3aab0902ba3"}, + {file = "notebook-6.4.8-py3-none-any.whl", hash = "sha256:3e702fcc54b8ae597533c3864793b7a1e971dec9e112f67235828d8a798fd654"}, + {file = "notebook-6.4.8.tar.gz", hash = "sha256:1e985c9dc6f678bdfffb9dc657306b5469bfa62d73e03f74e8defbf76d284312"}, ] numpy = [ {file = "numpy-1.19.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cc6bd4fd593cb261332568485e20a0712883cf631f6f5e8e86a52caa8b2b50ff"}, @@ -1749,8 +1774,8 @@ pluggy = [ {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, ] prometheus-client = [ - {file = "prometheus_client-0.12.0-py2.py3-none-any.whl", hash = "sha256:317453ebabff0a1b02df7f708efbab21e3489e7072b61cb6957230dd004a0af0"}, - {file = "prometheus_client-0.12.0.tar.gz", hash = "sha256:1b12ba48cee33b9b0b9de64a1047cbd3c5f2d0ab6ebcead7ddda613a750ec3c5"}, + {file = "prometheus_client-0.13.1-py3-none-any.whl", hash = "sha256:357a447fd2359b0a1d2e9b311a0c5778c330cfbe186d880ad5a6b39884652316"}, + {file = "prometheus_client-0.13.1.tar.gz", hash = "sha256:ada41b891b79fca5638bd5cfe149efa86512eaa55987893becd2c6d8d0a5dfc5"}, ] prompt-toolkit = [ {file = "prompt_toolkit-3.0.3-py3-none-any.whl", hash = "sha256:c93e53af97f630f12f5f62a3274e79527936ed466f038953dfa379d4941f651a"}, @@ -1777,12 +1802,12 @@ pyflakes = [ {file = "pyflakes-2.1.1.tar.gz", hash = "sha256:d976835886f8c5b31d47970ed689944a0262b5f3afa00a5a7b4dc81e5449f8a2"}, ] pygments = [ - {file = "Pygments-2.10.0-py3-none-any.whl", hash = "sha256:b8e67fe6af78f492b3c4b3e2970c0624cbf08beb1e493b2c99b9fa1b67a20380"}, - {file = "Pygments-2.10.0.tar.gz", hash = "sha256:f398865f7eb6874156579fdf36bc840a03cab64d1cde9e93d68f46a425ec52c6"}, + {file = "Pygments-2.11.2-py3-none-any.whl", hash = "sha256:44238f1b60a76d78fc8ca0528ee429702aae011c265fe6a8dd8b63049ae41c65"}, + {file = "Pygments-2.11.2.tar.gz", hash = "sha256:4e426f72023d88d03b2fa258de560726ce890ff3b630f88c21cbb8b2503b8c6a"}, ] pyparsing = [ - {file = "pyparsing-3.0.6-py3-none-any.whl", hash = "sha256:04ff808a5b90911829c55c4e26f75fa5ca8a2f5f36aa3a51f68e27033341d3e4"}, - {file = "pyparsing-3.0.6.tar.gz", hash = "sha256:d9bdec0013ef1eb5a84ab39a3b3868911598afa494f5faa038647101504e2b81"}, + {file = "pyparsing-3.0.7-py3-none-any.whl", hash = "sha256:a6c06a88f252e6c322f65faf8f418b16213b51bdfaece0524c1c1bc30c63c484"}, + {file = "pyparsing-3.0.7.tar.gz", hash = "sha256:18ee9022775d270c55187733956460083db60b37d0d0fb357445f3094eed3eea"}, ] pyrsistent = [ {file = "pyrsistent-0.18.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f4c8cabb46ff8e5d61f56a037974228e978f26bfefce4f61a4b1ac0ba7a2ab72"}, @@ -1820,24 +1845,25 @@ pytz = [ {file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"}, ] pywin32 = [ - {file = "pywin32-302-cp310-cp310-win32.whl", hash = "sha256:251b7a9367355ccd1a4cd69cd8dd24bd57b29ad83edb2957cfa30f7ed9941efa"}, - {file = "pywin32-302-cp310-cp310-win_amd64.whl", hash = "sha256:79cf7e6ddaaf1cd47a9e50cc74b5d770801a9db6594464137b1b86aa91edafcc"}, - {file = "pywin32-302-cp36-cp36m-win32.whl", hash = "sha256:fe21c2fb332d03dac29de070f191bdbf14095167f8f2165fdc57db59b1ecc006"}, - {file = "pywin32-302-cp36-cp36m-win_amd64.whl", hash = "sha256:d3761ab4e8c5c2dbc156e2c9ccf38dd51f936dc77e58deb940ffbc4b82a30528"}, - {file = "pywin32-302-cp37-cp37m-win32.whl", hash = "sha256:48dd4e348f1ee9538dd4440bf201ea8c110ea6d9f3a5010d79452e9fa80480d9"}, - {file = "pywin32-302-cp37-cp37m-win_amd64.whl", hash = "sha256:496df89f10c054c9285cc99f9d509e243f4e14ec8dfc6d78c9f0bf147a893ab1"}, - {file = "pywin32-302-cp38-cp38-win32.whl", hash = "sha256:e372e477d938a49266136bff78279ed14445e00718b6c75543334351bf535259"}, - {file = "pywin32-302-cp38-cp38-win_amd64.whl", hash = "sha256:543552e66936378bd2d673c5a0a3d9903dba0b0a87235ef0c584f058ceef5872"}, - {file = "pywin32-302-cp39-cp39-win32.whl", hash = "sha256:2393c1a40dc4497fd6161b76801b8acd727c5610167762b7c3e9fd058ef4a6ab"}, - {file = "pywin32-302-cp39-cp39-win_amd64.whl", hash = "sha256:af5aea18167a31efcacc9f98a2ca932c6b6a6d91ebe31f007509e293dea12580"}, + {file = "pywin32-303-cp310-cp310-win32.whl", hash = "sha256:6fed4af057039f309263fd3285d7b8042d41507343cd5fa781d98fcc5b90e8bb"}, + {file = "pywin32-303-cp310-cp310-win_amd64.whl", hash = "sha256:51cb52c5ec6709f96c3f26e7795b0bf169ee0d8395b2c1d7eb2c029a5008ed51"}, + {file = "pywin32-303-cp311-cp311-win32.whl", hash = "sha256:d9b5d87ca944eb3aa4cd45516203ead4b37ab06b8b777c54aedc35975dec0dee"}, + {file = "pywin32-303-cp311-cp311-win_amd64.whl", hash = "sha256:fcf44032f5b14fcda86028cdf49b6ebdaea091230eb0a757282aa656e4732439"}, + {file = "pywin32-303-cp36-cp36m-win32.whl", hash = "sha256:aad484d52ec58008ca36bd4ad14a71d7dd0a99db1a4ca71072213f63bf49c7d9"}, + {file = "pywin32-303-cp36-cp36m-win_amd64.whl", hash = "sha256:2a09632916b6bb231ba49983fe989f2f625cea237219530e81a69239cd0c4559"}, + {file = "pywin32-303-cp37-cp37m-win32.whl", hash = "sha256:b1675d82bcf6dbc96363fca747bac8bff6f6e4a447a4287ac652aa4b9adc796e"}, + {file = "pywin32-303-cp37-cp37m-win_amd64.whl", hash = "sha256:c268040769b48a13367221fced6d4232ed52f044ffafeda247bd9d2c6bdc29ca"}, + {file = "pywin32-303-cp38-cp38-win32.whl", hash = "sha256:5f9ec054f5a46a0f4dfd72af2ce1372f3d5a6e4052af20b858aa7df2df7d355b"}, + {file = "pywin32-303-cp38-cp38-win_amd64.whl", hash = "sha256:793bf74fce164bcffd9d57bb13c2c15d56e43c9542a7b9687b4fccf8f8a41aba"}, + {file = "pywin32-303-cp39-cp39-win32.whl", hash = "sha256:7d3271c98434617a11921c5ccf74615794d97b079e22ed7773790822735cc352"}, + {file = "pywin32-303-cp39-cp39-win_amd64.whl", hash = "sha256:79cbb862c11b9af19bcb682891c1b91942ec2ff7de8151e2aea2e175899cda34"}, ] pywinpty = [ - {file = "pywinpty-1.1.6-cp310-none-win_amd64.whl", hash = "sha256:5f526f21b569b5610a61e3b6126259c76da979399598e5154498582df3736ade"}, - {file = "pywinpty-1.1.6-cp36-none-win_amd64.whl", hash = "sha256:7576e14f42b31fa98b62d24ded79754d2ea4625570c016b38eb347ce158a30f2"}, - {file = "pywinpty-1.1.6-cp37-none-win_amd64.whl", hash = "sha256:979ffdb9bdbe23db3f46fc7285fd6dbb86b80c12325a50582b211b3894072354"}, - {file = "pywinpty-1.1.6-cp38-none-win_amd64.whl", hash = "sha256:2308b1fc77545427610a705799d4ead5e7f00874af3fb148a03e202437456a7e"}, - {file = "pywinpty-1.1.6-cp39-none-win_amd64.whl", hash = "sha256:c703bf569a98ab7844b9daf37e88ab86f31862754ef6910a8b3824993a525c72"}, - {file = "pywinpty-1.1.6.tar.gz", hash = "sha256:8808f07350c709119cc4464144d6e749637f98e15acc1e5d3c37db1953d2eebc"}, + {file = "pywinpty-2.0.3-cp310-none-win_amd64.whl", hash = "sha256:7a330ef7a2ce284370b1a1fdd2a80c523585464fa5e5ab934c9f27220fa7feab"}, + {file = "pywinpty-2.0.3-cp37-none-win_amd64.whl", hash = "sha256:6455f1075f978942d318f95616661c605d5e0f991c5b176c0c852d237aafefc0"}, + {file = "pywinpty-2.0.3-cp38-none-win_amd64.whl", hash = "sha256:2e7a288a8121393c526d4e6ec7d65edef75d68c7787ab9560e438df867b75a5d"}, + {file = "pywinpty-2.0.3-cp39-none-win_amd64.whl", hash = "sha256:def51627e6aa659f33ea7a0ea4c6b68365c83af4aad7940600f844746817a0ed"}, + {file = "pywinpty-2.0.3.tar.gz", hash = "sha256:6b29a826e896105370c38d53904c3aaac6c36146a50448fc0ed5082cf9d092bc"}, ] pyzmq = [ {file = "pyzmq-22.3.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:6b217b8f9dfb6628f74b94bdaf9f7408708cb02167d644edca33f38746ca12dd"}, @@ -1893,16 +1919,16 @@ qtconsole = [ {file = "qtconsole-5.2.2.tar.gz", hash = "sha256:8f9db97b27782184efd0a0f2d57ea3bd852d053747a2e442a9011329c082976d"}, ] qtpy = [ - {file = "QtPy-1.11.3-py2.py3-none-any.whl", hash = "sha256:e121fbee8e95645af29c5a4aceba8d657991551fc1aa3b6b6012faf4725a1d20"}, - {file = "QtPy-1.11.3.tar.gz", hash = "sha256:d427addd37386a8d786db81864a5536700861d95bf085cb31d1bea855d699557"}, + {file = "QtPy-2.0.1-py3-none-any.whl", hash = "sha256:d93f2c98e97387fcc9d623d509772af5b6c15ab9d8f9f4c5dfbad9a73ad34812"}, + {file = "QtPy-2.0.1.tar.gz", hash = "sha256:adfd073ffbd2de81dc7aaa0b983499ef5c59c96adcfdcc9dea60d42ca885eb8f"}, ] recommonmark = [ {file = "recommonmark-0.6.0-py2.py3-none-any.whl", hash = "sha256:2ec4207a574289355d5b6ae4ae4abb29043346ca12cdd5f07d374dc5987d2852"}, {file = "recommonmark-0.6.0.tar.gz", hash = "sha256:29cd4faeb6c5268c633634f2d69aef9431e0f4d347f90659fd0aab20e541efeb"}, ] requests = [ - {file = "requests-2.26.0-py2.py3-none-any.whl", hash = "sha256:6c1246513ecd5ecd4528a0906f910e8f0f9c6b8ec72030dc9fd154dc1a6efd24"}, - {file = "requests-2.26.0.tar.gz", hash = "sha256:b8aa58f8cf793ffd8782d3d8cb19e66ef36f7aba4353eec859e74678b01b07a7"}, + {file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"}, + {file = "requests-2.27.1.tar.gz", hash = "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61"}, ] scikit-learn = [ {file = "scikit-learn-0.24.2.tar.gz", hash = "sha256:d14701a12417930392cd3898e9646cf5670c190b933625ebe7511b1f7d7b8736"}, @@ -2011,16 +2037,16 @@ sphinxcontrib-serializinghtml = [ {file = "sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl", hash = "sha256:352a9a00ae864471d3a7ead8d7d79f5fc0b57e8b3f95e9867eb9eb28999b92fd"}, ] terminado = [ - {file = "terminado-0.12.1-py3-none-any.whl", hash = "sha256:09fdde344324a1c9c6e610ee4ca165c4bb7f5bbf982fceeeb38998a988ef8452"}, - {file = "terminado-0.12.1.tar.gz", hash = "sha256:b20fd93cc57c1678c799799d117874367cc07a3d2d55be95205b1a88fa08393f"}, + {file = "terminado-0.13.0-py3-none-any.whl", hash = "sha256:50a18654ad0cff153fdeb58711c9a7b25e0e2b74fb721dbaddd9e80d5612fac6"}, + {file = "terminado-0.13.0.tar.gz", hash = "sha256:713531ccb5db7d4f544651f14050da79809030f00d1afa21462088cf32fb143a"}, ] testpath = [ - {file = "testpath-0.5.0-py3-none-any.whl", hash = "sha256:8044f9a0bab6567fc644a3593164e872543bb44225b0e24846e2c89237937589"}, - {file = "testpath-0.5.0.tar.gz", hash = "sha256:1acf7a0bcd3004ae8357409fc33751e16d37ccc650921da1094a86581ad1e417"}, + {file = "testpath-0.6.0-py3-none-any.whl", hash = "sha256:8ada9f80a2ac6fb0391aa7cdb1a7d11cfa8429f693eda83f74dde570fe6fa639"}, + {file = "testpath-0.6.0.tar.gz", hash = "sha256:2f1b97e6442c02681ebe01bd84f531028a7caea1af3825000f52345c30285e0f"}, ] threadpoolctl = [ - {file = "threadpoolctl-3.0.0-py3-none-any.whl", hash = "sha256:4fade5b3b48ae4b1c30f200b28f39180371104fccc642e039e0f2435ec8cc211"}, - {file = "threadpoolctl-3.0.0.tar.gz", hash = "sha256:d03115321233d0be715f0d3a5ad1d6c065fe425ddc2d671ca8e45e9fd5d7a52a"}, + {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"}, + {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"}, ] toml = [ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, @@ -2083,20 +2109,20 @@ tornado = [ {file = "tornado-6.1.tar.gz", hash = "sha256:33c6e81d7bd55b468d2e793517c909b139960b6c790a60b7991b9b6b76fb9791"}, ] tqdm = [ - {file = "tqdm-4.62.3-py2.py3-none-any.whl", hash = "sha256:8dd278a422499cd6b727e6ae4061c40b48fce8b76d1ccbf5d34fca9b7f925b0c"}, - {file = "tqdm-4.62.3.tar.gz", hash = "sha256:d359de7217506c9851b7869f3708d8ee53ed70a1b8edbba4dbcb47442592920d"}, + {file = "tqdm-4.63.0-py2.py3-none-any.whl", hash = "sha256:e643e071046f17139dea55b880dc9b33822ce21613b4a4f5ea57f202833dbc29"}, + {file = "tqdm-4.63.0.tar.gz", hash = "sha256:1d9835ede8e394bb8c9dcbffbca02d717217113adc679236873eeaac5bc0b3cd"}, ] traitlets = [ {file = "traitlets-4.3.3-py2.py3-none-any.whl", hash = "sha256:70b4c6a1d9019d7b4f6846832288f86998aa3b9207c6821f3578a6a6a467fe44"}, {file = "traitlets-4.3.3.tar.gz", hash = "sha256:d023ee369ddd2763310e4c3eae1ff649689440d4ae59d7485eb4cfbbe3e359f7"}, ] typing-extensions = [ - {file = "typing_extensions-4.0.1-py3-none-any.whl", hash = "sha256:7f001e5ac290a0c0401508864c7ec868be4e701886d5b573a9528ed3973d9d3b"}, - {file = "typing_extensions-4.0.1.tar.gz", hash = "sha256:4ca091dea149f945ec56afb48dae714f21e8692ef22a395223bcd328961b6a0e"}, + {file = "typing_extensions-4.1.1-py3-none-any.whl", hash = "sha256:21c85e0fe4b9a155d0799430b0ad741cdce7e359660ccbd8b530613e8df88ce2"}, + {file = "typing_extensions-4.1.1.tar.gz", hash = "sha256:1a9462dcc3347a79b1f1c0271fbe79e844580bb598bafa1ed208b94da3cdcd42"}, ] urllib3 = [ - {file = "urllib3-1.26.7-py2.py3-none-any.whl", hash = "sha256:c4fdf4019605b6e5423637e01bc9fe4daef873709a7973e195ceba0a62bbc844"}, - {file = "urllib3-1.26.7.tar.gz", hash = "sha256:4987c65554f7a2dbf30c18fd48778ef124af6fab771a377103da0585e2336ece"}, + {file = "urllib3-1.22-py2.py3-none-any.whl", hash = "sha256:06330f386d6e4b195fbfc736b297f58c5a892e4440e54d294d7004e3a9bbea1b"}, + {file = "urllib3-1.22.tar.gz", hash = "sha256:cc44da8e1145637334317feebd728bd869a35285b93cbb4cca2577da7e62db4f"}, ] wcwidth = [ {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, diff --git a/pyproject.toml b/pyproject.toml index 16118657..1f96b264 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ keywords = ["tabnet", "pytorch", "neural-networks" ] exclude = ["tabnet/*.ipynb"] [tool.poetry.dependencies] -python = "^3.6" +python = ">=3.6" numpy="^1.17" torch="^1.2" diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 8f48e29f..915631f5 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -125,7 +125,8 @@ def fit( callbacks=None, pin_memory=True, from_unsupervised=None, - warm_start=False + warm_start=False, + augmentations=None, ): """Train a neural network stored in self.network Using train_dataloader for training data and @@ -183,6 +184,11 @@ def fit( self.input_dim = X_train.shape[1] self._stop_training = False self.pin_memory = pin_memory and (self.device.type != "cpu") + self.augmentations = augmentations + + if self.augmentations is not None: + # This ensure reproducibility + self.augmentations._set_seed() eval_set = eval_set if eval_set else [] @@ -480,6 +486,9 @@ def _train_batch(self, X, y): X = X.to(self.device).float() y = y.to(self.device).float() + if self.augmentations is not None: + X, y = self.augmentations(X, y) + for param in self.network.parameters(): param.grad = None diff --git a/pytorch_tabnet/augmentations.py b/pytorch_tabnet/augmentations.py new file mode 100644 index 00000000..287fa365 --- /dev/null +++ b/pytorch_tabnet/augmentations.py @@ -0,0 +1,85 @@ +import torch +from pytorch_tabnet.utils import define_device +import numpy as np + + +class RegressionSMOTE(): + """ + Apply SMOTE + + This will average a percentage p of the elements in the batch with other elements. + The target will be averaged as well (this might work with binary classification + and certain loss), following a beta distribution. + """ + def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0): + "" + self.seed = seed + self._set_seed() + self.device = define_device(device_name) + self.alpha = alpha + self.beta = beta + self.p = p + if (p < 0.) or (p > 1.0): + raise ValueError("Value of p should be between 0. and 1.") + + def _set_seed(self): + torch.manual_seed(self.seed) + np.random.seed(self.seed) + return + + def __call__(self, X, y): + batch_size = X.shape[0] + random_values = torch.rand(batch_size, device=self.device) + idx_to_change = random_values < self.p + + # ensure that first element to switch has probability > 0.5 + np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5 + random_betas = torch.from_numpy(np_betas).to(self.device).float() + index_permute = torch.randperm(batch_size, device=self.device) + + X[idx_to_change] = random_betas[idx_to_change, None] * X[idx_to_change] + X[idx_to_change] += (1 - random_betas[idx_to_change, None]) * X[index_permute][idx_to_change].view(X[idx_to_change].size()) # noqa + + y[idx_to_change] = random_betas[idx_to_change, None] * y[idx_to_change] + y[idx_to_change] += (1 - random_betas[idx_to_change, None]) * y[index_permute][idx_to_change].view(y[idx_to_change].size()) # noqa + + return X, y + + +class ClassificationSMOTE(): + """ + Apply SMOTE for classification tasks. + + This will average a percentage p of the elements in the batch with other elements. + The target will stay unchanged and keep the value of the most important row in the mix. + """ + def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0): + "" + self.seed = seed + self._set_seed() + self.device = define_device(device_name) + self.alpha = alpha + self.beta = beta + self.p = p + if (p < 0.) or (p > 1.0): + raise ValueError("Value of p should be between 0. and 1.") + + def _set_seed(self): + torch.manual_seed(self.seed) + np.random.seed(self.seed) + return + + def __call__(self, X, y): + batch_size = X.shape[0] + random_values = torch.rand(batch_size, device=self.device) + idx_to_change = random_values < self.p + + # ensure that first element to switch has probability > 0.5 + np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5 + random_betas = torch.from_numpy(np_betas).to(self.device).float() + index_permute = torch.randperm(batch_size, device=self.device) + + X[idx_to_change] = random_betas[idx_to_change, None] * X[idx_to_change] + X[idx_to_change] += (1 - random_betas[idx_to_change, None]) * X[index_permute][idx_to_change].view(X[idx_to_change].size()) # noqa + + return X, y diff --git a/regression_example.ipynb b/regression_example.ipynb index efd0f8c1..c6096f20 100644 --- a/regression_example.ipynb +++ b/regression_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -42,9 +42,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "File already exists.\n" + ] + } + ], "source": [ "out.parent.mkdir(parents=True, exist_ok=True)\n", "if out.exists():\n", @@ -63,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -88,9 +96,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " State-gov 9\n", + " Bachelors 16\n", + " Never-married 7\n", + " Adm-clerical 15\n", + " Not-in-family 6\n", + " White 5\n", + " Male 2\n", + " United-States 42\n", + " <=50K 2\n", + "Set 3\n" + ] + } + ], "source": [ "categorical_columns = []\n", "categorical_dims = {}\n", @@ -115,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -140,9 +165,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/work/pytorch_tabnet/abstract_model.py:75: UserWarning: Device used : cuda\n", + " warnings.warn(f\"Device used : {self.device}\")\n" + ] + } + ], "source": [ "clf = TabNetRegressor(cat_dims=cat_dims, cat_emb_dim=cat_emb_dim, cat_idxs=cat_idxs)" ] @@ -156,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -172,20 +206,158 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "max_epochs = 1000 if not os.getenv(\"CI\", False) else 2" + "max_epochs = 100 if not os.getenv(\"CI\", False) else 2" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_tabnet.augmentations import RegressionSMOTE\n", + "aug = RegressionSMOTE(p=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0 | loss: 0.16632 | train_rmsle: 0.08172 | train_mae: 0.32434 | train_rmse: 0.42037 | train_mse: 0.17671 | valid_rmsle: 0.08229 | valid_mae: 0.32509 | valid_rmse: 0.41638 | valid_mse: 0.17337 | 0:00:01s\n", + "epoch 1 | loss: 0.13043 | train_rmsle: 0.08851 | train_mae: 0.34694 | train_rmse: 0.43319 | train_mse: 0.18765 | valid_rmsle: 0.08877 | valid_mae: 0.34742 | valid_rmse: 0.43215 | valid_mse: 0.18676 | 0:00:02s\n", + "epoch 2 | loss: 0.12421 | train_rmsle: 0.07367 | train_mae: 0.31208 | train_rmse: 0.38054 | train_mse: 0.14481 | valid_rmsle: 0.07219 | valid_mae: 0.30876 | valid_rmse: 0.37698 | valid_mse: 0.14211 | 0:00:04s\n", + "epoch 3 | loss: 0.11857 | train_rmsle: 0.08045 | train_mae: 0.30774 | train_rmse: 0.38188 | train_mse: 0.14583 | valid_rmsle: 0.08033 | valid_mae: 0.30664 | valid_rmse: 0.38204 | valid_mse: 0.14595 | 0:00:05s\n", + "epoch 4 | loss: 0.11478 | train_rmsle: 0.0657 | train_mae: 0.2798 | train_rmse: 0.35083 | train_mse: 0.12308 | valid_rmsle: 0.06471 | valid_mae: 0.2762 | valid_rmse: 0.34809 | valid_mse: 0.12116 | 0:00:07s\n", + "epoch 5 | loss: 0.11121 | train_rmsle: 0.05961 | train_mae: 0.25979 | train_rmse: 0.34053 | train_mse: 0.11596 | valid_rmsle: 0.05839 | valid_mae: 0.2566 | valid_rmse: 0.33742 | valid_mse: 0.11385 | 0:00:08s\n", + "epoch 6 | loss: 0.11029 | train_rmsle: 0.06378 | train_mae: 0.25423 | train_rmse: 0.35531 | train_mse: 0.12624 | valid_rmsle: 0.06213 | valid_mae: 0.25054 | valid_rmse: 0.35461 | valid_mse: 0.12575 | 0:00:09s\n", + "epoch 7 | loss: 0.11003 | train_rmsle: 0.057 | train_mae: 0.25219 | train_rmse: 0.34344 | train_mse: 0.11795 | valid_rmsle: 0.05512 | valid_mae: 0.2475 | valid_rmse: 0.33776 | valid_mse: 0.11408 | 0:00:11s\n", + "epoch 8 | loss: 0.10988 | train_rmsle: 0.05512 | train_mae: 0.24226 | train_rmse: 0.33481 | train_mse: 0.1121 | valid_rmsle: 0.05343 | valid_mae: 0.23781 | valid_rmse: 0.32937 | valid_mse: 0.10848 | 0:00:12s\n", + "epoch 9 | loss: 0.10829 | train_rmsle: 0.05693 | train_mae: 0.24596 | train_rmse: 0.33503 | train_mse: 0.11225 | valid_rmsle: 0.05497 | valid_mae: 0.24018 | valid_rmse: 0.32932 | valid_mse: 0.10845 | 0:00:13s\n", + "epoch 10 | loss: 0.10833 | train_rmsle: 0.05375 | train_mae: 0.23563 | train_rmse: 0.33305 | train_mse: 0.11092 | valid_rmsle: 0.05185 | valid_mae: 0.23062 | valid_rmse: 0.32643 | valid_mse: 0.10656 | 0:00:15s\n", + "epoch 11 | loss: 0.10764 | train_rmsle: 0.0526 | train_mae: 0.22899 | train_rmse: 0.33287 | train_mse: 0.1108 | valid_rmsle: 0.05076 | valid_mae: 0.22411 | valid_rmse: 0.3265 | valid_mse: 0.1066 | 0:00:16s\n", + "epoch 12 | loss: 0.10698 | train_rmsle: 0.05624 | train_mae: 0.23993 | train_rmse: 0.33312 | train_mse: 0.11097 | valid_rmsle: 0.05461 | valid_mae: 0.2353 | valid_rmse: 0.32751 | valid_mse: 0.10726 | 0:00:18s\n", + "epoch 13 | loss: 0.10679 | train_rmsle: 0.05276 | train_mae: 0.22877 | train_rmse: 0.33555 | train_mse: 0.1126 | valid_rmsle: 0.05128 | valid_mae: 0.22473 | valid_rmse: 0.33015 | valid_mse: 0.109 | 0:00:19s\n", + "epoch 14 | loss: 0.10577 | train_rmsle: 0.05284 | train_mae: 0.22906 | train_rmse: 0.32791 | train_mse: 0.10752 | valid_rmsle: 0.05154 | valid_mae: 0.22525 | valid_rmse: 0.32335 | valid_mse: 0.10455 | 0:00:20s\n", + "epoch 15 | loss: 0.10497 | train_rmsle: 0.05111 | train_mae: 0.22535 | train_rmse: 0.32761 | train_mse: 0.10733 | valid_rmsle: 0.05009 | valid_mae: 0.22221 | valid_rmse: 0.32379 | valid_mse: 0.10484 | 0:00:22s\n", + "epoch 16 | loss: 0.10512 | train_rmsle: 0.05156 | train_mae: 0.22343 | train_rmse: 0.32994 | train_mse: 0.10886 | valid_rmsle: 0.05017 | valid_mae: 0.21935 | valid_rmse: 0.32502 | valid_mse: 0.10564 | 0:00:23s\n", + "epoch 17 | loss: 0.10614 | train_rmsle: 0.05414 | train_mae: 0.22927 | train_rmse: 0.32792 | train_mse: 0.10753 | valid_rmsle: 0.05246 | valid_mae: 0.22524 | valid_rmse: 0.32246 | valid_mse: 0.10398 | 0:00:25s\n", + "epoch 18 | loss: 0.1028 | train_rmsle: 0.05195 | train_mae: 0.22859 | train_rmse: 0.32487 | train_mse: 0.10554 | valid_rmsle: 0.05053 | valid_mae: 0.22437 | valid_rmse: 0.32036 | valid_mse: 0.10263 | 0:00:26s\n", + "epoch 19 | loss: 0.10342 | train_rmsle: 0.05142 | train_mae: 0.22324 | train_rmse: 0.32451 | train_mse: 0.10531 | valid_rmsle: 0.04992 | valid_mae: 0.21936 | valid_rmse: 0.31937 | valid_mse: 0.10199 | 0:00:27s\n", + "epoch 20 | loss: 0.10317 | train_rmsle: 0.05125 | train_mae: 0.22449 | train_rmse: 0.32393 | train_mse: 0.10493 | valid_rmsle: 0.04942 | valid_mae: 0.22011 | valid_rmse: 0.31785 | valid_mse: 0.10103 | 0:00:29s\n", + "epoch 21 | loss: 0.10194 | train_rmsle: 0.0503 | train_mae: 0.21491 | train_rmse: 0.32128 | train_mse: 0.10322 | valid_rmsle: 0.0485 | valid_mae: 0.21037 | valid_rmse: 0.31517 | valid_mse: 0.09933 | 0:00:30s\n", + "epoch 22 | loss: 0.10324 | train_rmsle: 0.05067 | train_mae: 0.22173 | train_rmse: 0.3247 | train_mse: 0.10543 | valid_rmsle: 0.04861 | valid_mae: 0.21679 | valid_rmse: 0.31732 | valid_mse: 0.10069 | 0:00:31s\n", + "epoch 23 | loss: 0.10248 | train_rmsle: 0.05005 | train_mae: 0.21331 | train_rmse: 0.32139 | train_mse: 0.10329 | valid_rmsle: 0.04804 | valid_mae: 0.20833 | valid_rmse: 0.31477 | valid_mse: 0.09908 | 0:00:33s\n", + "epoch 24 | loss: 0.10151 | train_rmsle: 0.04969 | train_mae: 0.215 | train_rmse: 0.32014 | train_mse: 0.10249 | valid_rmsle: 0.04756 | valid_mae: 0.20974 | valid_rmse: 0.31347 | valid_mse: 0.09826 | 0:00:34s\n", + "epoch 25 | loss: 0.10211 | train_rmsle: 0.04935 | train_mae: 0.21158 | train_rmse: 0.3224 | train_mse: 0.10394 | valid_rmsle: 0.04687 | valid_mae: 0.20514 | valid_rmse: 0.31442 | valid_mse: 0.09886 | 0:00:36s\n", + "epoch 26 | loss: 0.10208 | train_rmsle: 0.04962 | train_mae: 0.21607 | train_rmse: 0.31855 | train_mse: 0.10148 | valid_rmsle: 0.04761 | valid_mae: 0.21083 | valid_rmse: 0.31159 | valid_mse: 0.09709 | 0:00:37s\n", + "epoch 27 | loss: 0.10094 | train_rmsle: 0.05275 | train_mae: 0.22095 | train_rmse: 0.32081 | train_mse: 0.10292 | valid_rmsle: 0.05063 | valid_mae: 0.2158 | valid_rmse: 0.3131 | valid_mse: 0.09803 | 0:00:39s\n", + "epoch 28 | loss: 0.10124 | train_rmsle: 0.04891 | train_mae: 0.2123 | train_rmse: 0.32237 | train_mse: 0.10392 | valid_rmsle: 0.04737 | valid_mae: 0.20892 | valid_rmse: 0.31726 | valid_mse: 0.10065 | 0:00:40s\n", + "epoch 29 | loss: 0.10122 | train_rmsle: 0.04956 | train_mae: 0.2099 | train_rmse: 0.31865 | train_mse: 0.10154 | valid_rmsle: 0.04759 | valid_mae: 0.20479 | valid_rmse: 0.31198 | valid_mse: 0.09733 | 0:00:41s\n", + "epoch 30 | loss: 0.10027 | train_rmsle: 0.0496 | train_mae: 0.21145 | train_rmse: 0.31576 | train_mse: 0.0997 | valid_rmsle: 0.04841 | valid_mae: 0.20842 | valid_rmse: 0.31175 | valid_mse: 0.09719 | 0:00:43s\n", + "epoch 31 | loss: 0.10034 | train_rmsle: 0.05005 | train_mae: 0.21385 | train_rmse: 0.3167 | train_mse: 0.1003 | valid_rmsle: 0.049 | valid_mae: 0.21084 | valid_rmse: 0.31334 | valid_mse: 0.09818 | 0:00:44s\n", + "epoch 32 | loss: 0.09935 | train_rmsle: 0.04969 | train_mae: 0.21522 | train_rmse: 0.31826 | train_mse: 0.10129 | valid_rmsle: 0.04807 | valid_mae: 0.2116 | valid_rmse: 0.31319 | valid_mse: 0.09809 | 0:00:46s\n", + "epoch 33 | loss: 0.09984 | train_rmsle: 0.04837 | train_mae: 0.21348 | train_rmse: 0.31861 | train_mse: 0.10152 | valid_rmsle: 0.04715 | valid_mae: 0.21017 | valid_rmse: 0.31453 | valid_mse: 0.09893 | 0:00:47s\n", + "epoch 34 | loss: 0.09977 | train_rmsle: 0.04906 | train_mae: 0.2057 | train_rmse: 0.31576 | train_mse: 0.09971 | valid_rmsle: 0.04742 | valid_mae: 0.20152 | valid_rmse: 0.31029 | valid_mse: 0.09628 | 0:00:48s\n", + "epoch 35 | loss: 0.09988 | train_rmsle: 0.04863 | train_mae: 0.20762 | train_rmse: 0.31561 | train_mse: 0.09961 | valid_rmsle: 0.04672 | valid_mae: 0.20284 | valid_rmse: 0.30913 | valid_mse: 0.09556 | 0:00:50s\n", + "epoch 36 | loss: 0.10015 | train_rmsle: 0.04951 | train_mae: 0.21187 | train_rmse: 0.31592 | train_mse: 0.09981 | valid_rmsle: 0.04835 | valid_mae: 0.20902 | valid_rmse: 0.31217 | valid_mse: 0.09745 | 0:00:51s\n", + "epoch 37 | loss: 0.10147 | train_rmsle: 0.05169 | train_mae: 0.2214 | train_rmse: 0.32128 | train_mse: 0.10322 | valid_rmsle: 0.04984 | valid_mae: 0.21678 | valid_rmse: 0.31476 | valid_mse: 0.09908 | 0:00:53s\n", + "epoch 38 | loss: 0.10198 | train_rmsle: 0.04936 | train_mae: 0.21867 | train_rmse: 0.32333 | train_mse: 0.10454 | valid_rmsle: 0.04824 | valid_mae: 0.21547 | valid_rmse: 0.31959 | valid_mse: 0.10214 | 0:00:54s\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 39 | loss: 0.10109 | train_rmsle: 0.05026 | train_mae: 0.21729 | train_rmse: 0.31905 | train_mse: 0.1018 | valid_rmsle: 0.04844 | valid_mae: 0.21238 | valid_rmse: 0.31259 | valid_mse: 0.09771 | 0:00:55s\n", + "epoch 40 | loss: 0.1015 | train_rmsle: 0.05009 | train_mae: 0.21863 | train_rmse: 0.31728 | train_mse: 0.10067 | valid_rmsle: 0.04864 | valid_mae: 0.21431 | valid_rmse: 0.31218 | valid_mse: 0.09746 | 0:00:57s\n", + "epoch 41 | loss: 0.10149 | train_rmsle: 0.04867 | train_mae: 0.21139 | train_rmse: 0.32118 | train_mse: 0.10316 | valid_rmsle: 0.04726 | valid_mae: 0.20743 | valid_rmse: 0.31627 | valid_mse: 0.10002 | 0:00:58s\n", + "epoch 42 | loss: 0.10082 | train_rmsle: 0.0507 | train_mae: 0.21215 | train_rmse: 0.31643 | train_mse: 0.10013 | valid_rmsle: 0.04898 | valid_mae: 0.20688 | valid_rmse: 0.31061 | valid_mse: 0.09648 | 0:01:00s\n", + "epoch 43 | loss: 0.1012 | train_rmsle: 0.05138 | train_mae: 0.21497 | train_rmse: 0.31801 | train_mse: 0.10113 | valid_rmsle: 0.04964 | valid_mae: 0.21004 | valid_rmse: 0.31243 | valid_mse: 0.09761 | 0:01:01s\n", + "epoch 44 | loss: 0.09969 | train_rmsle: 0.05366 | train_mae: 0.21694 | train_rmse: 0.32127 | train_mse: 0.10321 | valid_rmsle: 0.05198 | valid_mae: 0.21147 | valid_rmse: 0.31592 | valid_mse: 0.09981 | 0:01:02s\n", + "epoch 45 | loss: 0.09983 | train_rmsle: 0.04806 | train_mae: 0.20948 | train_rmse: 0.3201 | train_mse: 0.10246 | valid_rmsle: 0.04689 | valid_mae: 0.20623 | valid_rmse: 0.31651 | valid_mse: 0.10018 | 0:01:04s\n", + "epoch 46 | loss: 0.09854 | train_rmsle: 0.04937 | train_mae: 0.20906 | train_rmse: 0.31537 | train_mse: 0.09946 | valid_rmsle: 0.04784 | valid_mae: 0.20508 | valid_rmse: 0.31005 | valid_mse: 0.09613 | 0:01:05s\n", + "epoch 47 | loss: 0.09865 | train_rmsle: 0.04765 | train_mae: 0.20356 | train_rmse: 0.31644 | train_mse: 0.10013 | valid_rmsle: 0.0461 | valid_mae: 0.19918 | valid_rmse: 0.31127 | valid_mse: 0.09689 | 0:01:07s\n", + "epoch 48 | loss: 0.09813 | train_rmsle: 0.04759 | train_mae: 0.2065 | train_rmse: 0.31669 | train_mse: 0.10029 | valid_rmsle: 0.04641 | valid_mae: 0.20337 | valid_rmse: 0.31259 | valid_mse: 0.09771 | 0:01:08s\n", + "epoch 49 | loss: 0.09894 | train_rmsle: 0.04811 | train_mae: 0.21058 | train_rmse: 0.31469 | train_mse: 0.09903 | valid_rmsle: 0.04708 | valid_mae: 0.20755 | valid_rmse: 0.31112 | valid_mse: 0.0968 | 0:01:09s\n", + "epoch 50 | loss: 0.09902 | train_rmsle: 0.04838 | train_mae: 0.20795 | train_rmse: 0.31515 | train_mse: 0.09932 | valid_rmsle: 0.04738 | valid_mae: 0.20433 | valid_rmse: 0.31136 | valid_mse: 0.09695 | 0:01:11s\n", + "epoch 51 | loss: 0.09911 | train_rmsle: 0.0484 | train_mae: 0.20765 | train_rmse: 0.31505 | train_mse: 0.09925 | valid_rmsle: 0.04698 | valid_mae: 0.20378 | valid_rmse: 0.31041 | valid_mse: 0.09635 | 0:01:12s\n", + "epoch 52 | loss: 0.09917 | train_rmsle: 0.04775 | train_mae: 0.20601 | train_rmse: 0.3163 | train_mse: 0.10005 | valid_rmsle: 0.04622 | valid_mae: 0.20242 | valid_rmse: 0.31109 | valid_mse: 0.09678 | 0:01:14s\n", + "epoch 53 | loss: 0.0982 | train_rmsle: 0.04794 | train_mae: 0.20522 | train_rmse: 0.31544 | train_mse: 0.0995 | valid_rmsle: 0.04642 | valid_mae: 0.20135 | valid_rmse: 0.31003 | valid_mse: 0.09612 | 0:01:15s\n", + "epoch 54 | loss: 0.09837 | train_rmsle: 0.04802 | train_mae: 0.20703 | train_rmse: 0.31501 | train_mse: 0.09923 | valid_rmsle: 0.04645 | valid_mae: 0.20307 | valid_rmse: 0.30927 | valid_mse: 0.09565 | 0:01:16s\n", + "epoch 55 | loss: 0.09796 | train_rmsle: 0.04762 | train_mae: 0.2029 | train_rmse: 0.31699 | train_mse: 0.10048 | valid_rmsle: 0.04655 | valid_mae: 0.20033 | valid_rmse: 0.31329 | valid_mse: 0.09815 | 0:01:18s\n", + "epoch 56 | loss: 0.09951 | train_rmsle: 0.04796 | train_mae: 0.20457 | train_rmse: 0.31803 | train_mse: 0.10114 | valid_rmsle: 0.0467 | valid_mae: 0.20155 | valid_rmse: 0.31322 | valid_mse: 0.09811 | 0:01:19s\n", + "epoch 57 | loss: 0.09864 | train_rmsle: 0.04795 | train_mae: 0.20608 | train_rmse: 0.31984 | train_mse: 0.1023 | valid_rmsle: 0.04688 | valid_mae: 0.20307 | valid_rmse: 0.31594 | valid_mse: 0.09982 | 0:01:21s\n", + "epoch 58 | loss: 0.09906 | train_rmsle: 0.04846 | train_mae: 0.20418 | train_rmse: 0.31676 | train_mse: 0.10034 | valid_rmsle: 0.04721 | valid_mae: 0.20046 | valid_rmse: 0.31195 | valid_mse: 0.09732 | 0:01:22s\n", + "epoch 59 | loss: 0.09942 | train_rmsle: 0.04874 | train_mae: 0.20833 | train_rmse: 0.31395 | train_mse: 0.09856 | valid_rmsle: 0.04808 | valid_mae: 0.20561 | valid_rmse: 0.31126 | valid_mse: 0.09689 | 0:01:23s\n", + "epoch 60 | loss: 0.09816 | train_rmsle: 0.04848 | train_mae: 0.20814 | train_rmse: 0.3157 | train_mse: 0.09967 | valid_rmsle: 0.04738 | valid_mae: 0.20452 | valid_rmse: 0.31145 | valid_mse: 0.097 | 0:01:25s\n", + "epoch 61 | loss: 0.09783 | train_rmsle: 0.04989 | train_mae: 0.20517 | train_rmse: 0.31546 | train_mse: 0.09951 | valid_rmsle: 0.04939 | valid_mae: 0.20392 | valid_rmse: 0.31329 | valid_mse: 0.09815 | 0:01:26s\n", + "epoch 62 | loss: 0.09757 | train_rmsle: 0.0489 | train_mae: 0.20754 | train_rmse: 0.31342 | train_mse: 0.09823 | valid_rmsle: 0.04812 | valid_mae: 0.20535 | valid_rmse: 0.31025 | valid_mse: 0.09626 | 0:01:28s\n", + "epoch 63 | loss: 0.09733 | train_rmsle: 0.05012 | train_mae: 0.20823 | train_rmse: 0.31422 | train_mse: 0.09873 | valid_rmsle: 0.04929 | valid_mae: 0.20605 | valid_rmse: 0.31077 | valid_mse: 0.09658 | 0:01:29s\n", + "epoch 64 | loss: 0.09746 | train_rmsle: 0.04716 | train_mae: 0.20427 | train_rmse: 0.31561 | train_mse: 0.09961 | valid_rmsle: 0.04616 | valid_mae: 0.20203 | valid_rmse: 0.31183 | valid_mse: 0.09724 | 0:01:30s\n", + "epoch 65 | loss: 0.09895 | train_rmsle: 0.04738 | train_mae: 0.2032 | train_rmse: 0.31327 | train_mse: 0.09814 | valid_rmsle: 0.0465 | valid_mae: 0.20116 | valid_rmse: 0.31021 | valid_mse: 0.09623 | 0:01:32s\n", + "epoch 66 | loss: 0.09839 | train_rmsle: 0.04817 | train_mae: 0.21045 | train_rmse: 0.31406 | train_mse: 0.09864 | valid_rmsle: 0.04732 | valid_mae: 0.20821 | valid_rmse: 0.31109 | valid_mse: 0.09677 | 0:01:33s\n", + "epoch 67 | loss: 0.09867 | train_rmsle: 0.04858 | train_mae: 0.20448 | train_rmse: 0.31458 | train_mse: 0.09896 | valid_rmsle: 0.04756 | valid_mae: 0.20188 | valid_rmse: 0.31064 | valid_mse: 0.0965 | 0:01:34s\n", + "epoch 68 | loss: 0.0975 | train_rmsle: 0.04824 | train_mae: 0.20539 | train_rmse: 0.31288 | train_mse: 0.09789 | valid_rmsle: 0.04716 | valid_mae: 0.203 | valid_rmse: 0.30895 | valid_mse: 0.09545 | 0:01:36s\n", + "epoch 69 | loss: 0.09763 | train_rmsle: 0.04792 | train_mae: 0.20796 | train_rmse: 0.31485 | train_mse: 0.09913 | valid_rmsle: 0.04667 | valid_mae: 0.20488 | valid_rmse: 0.31019 | valid_mse: 0.09622 | 0:01:37s\n", + "epoch 70 | loss: 0.09723 | train_rmsle: 0.04897 | train_mae: 0.2058 | train_rmse: 0.31261 | train_mse: 0.09773 | valid_rmsle: 0.04791 | valid_mae: 0.2035 | valid_rmse: 0.30854 | valid_mse: 0.0952 | 0:01:39s\n", + "epoch 71 | loss: 0.09762 | train_rmsle: 0.04741 | train_mae: 0.20414 | train_rmse: 0.31256 | train_mse: 0.0977 | valid_rmsle: 0.04639 | valid_mae: 0.20201 | valid_rmse: 0.30874 | valid_mse: 0.09532 | 0:01:40s\n", + "epoch 72 | loss: 0.0971 | train_rmsle: 0.04862 | train_mae: 0.20357 | train_rmse: 0.31456 | train_mse: 0.09895 | valid_rmsle: 0.04736 | valid_mae: 0.20057 | valid_rmse: 0.30995 | valid_mse: 0.09607 | 0:01:41s\n", + "epoch 73 | loss: 0.09689 | train_rmsle: 0.04786 | train_mae: 0.21148 | train_rmse: 0.31723 | train_mse: 0.10063 | valid_rmsle: 0.04678 | valid_mae: 0.20891 | valid_rmse: 0.31349 | valid_mse: 0.09828 | 0:01:43s\n", + "epoch 74 | loss: 0.09743 | train_rmsle: 0.04749 | train_mae: 0.20384 | train_rmse: 0.31325 | train_mse: 0.09812 | valid_rmsle: 0.04617 | valid_mae: 0.20132 | valid_rmse: 0.30846 | valid_mse: 0.09515 | 0:01:44s\n", + "epoch 75 | loss: 0.0972 | train_rmsle: 0.04803 | train_mae: 0.20561 | train_rmse: 0.31244 | train_mse: 0.09762 | valid_rmsle: 0.04672 | valid_mae: 0.20277 | valid_rmse: 0.30754 | valid_mse: 0.09458 | 0:01:45s\n", + "epoch 76 | loss: 0.09595 | train_rmsle: 0.0478 | train_mae: 0.20355 | train_rmse: 0.31466 | train_mse: 0.09901 | valid_rmsle: 0.04658 | valid_mae: 0.20084 | valid_rmse: 0.31016 | valid_mse: 0.0962 | 0:01:47s\n", + "epoch 77 | loss: 0.09632 | train_rmsle: 0.0472 | train_mae: 0.20232 | train_rmse: 0.31319 | train_mse: 0.09809 | valid_rmsle: 0.04605 | valid_mae: 0.20031 | valid_rmse: 0.30921 | valid_mse: 0.09561 | 0:01:48s\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 78 | loss: 0.09692 | train_rmsle: 0.04901 | train_mae: 0.20755 | train_rmse: 0.31433 | train_mse: 0.0988 | valid_rmsle: 0.04771 | valid_mae: 0.20417 | valid_rmse: 0.30959 | valid_mse: 0.09585 | 0:01:50s\n", + "epoch 79 | loss: 0.09665 | train_rmsle: 0.04741 | train_mae: 0.2011 | train_rmse: 0.31131 | train_mse: 0.09691 | valid_rmsle: 0.04623 | valid_mae: 0.19879 | valid_rmse: 0.30713 | valid_mse: 0.09433 | 0:01:51s\n", + "epoch 80 | loss: 0.09707 | train_rmsle: 0.04718 | train_mae: 0.20445 | train_rmse: 0.31355 | train_mse: 0.09831 | valid_rmsle: 0.04589 | valid_mae: 0.20147 | valid_rmse: 0.30907 | valid_mse: 0.09552 | 0:01:52s\n", + "epoch 81 | loss: 0.09612 | train_rmsle: 0.04716 | train_mae: 0.20586 | train_rmse: 0.31175 | train_mse: 0.09719 | valid_rmsle: 0.04576 | valid_mae: 0.20278 | valid_rmse: 0.30691 | valid_mse: 0.09419 | 0:01:54s\n", + "epoch 82 | loss: 0.09604 | train_rmsle: 0.04763 | train_mae: 0.20646 | train_rmse: 0.31163 | train_mse: 0.09711 | valid_rmsle: 0.04651 | valid_mae: 0.20362 | valid_rmse: 0.30755 | valid_mse: 0.09459 | 0:01:55s\n", + "epoch 83 | loss: 0.09563 | train_rmsle: 0.04713 | train_mae: 0.20618 | train_rmse: 0.31666 | train_mse: 0.10027 | valid_rmsle: 0.04591 | valid_mae: 0.20295 | valid_rmse: 0.3122 | valid_mse: 0.09747 | 0:01:56s\n", + "epoch 84 | loss: 0.09528 | train_rmsle: 0.0471 | train_mae: 0.20568 | train_rmse: 0.31303 | train_mse: 0.09799 | valid_rmsle: 0.04643 | valid_mae: 0.20403 | valid_rmse: 0.31028 | valid_mse: 0.09627 | 0:01:58s\n", + "epoch 85 | loss: 0.09574 | train_rmsle: 0.04854 | train_mae: 0.2081 | train_rmse: 0.31425 | train_mse: 0.09875 | valid_rmsle: 0.04779 | valid_mae: 0.20611 | valid_rmse: 0.31135 | valid_mse: 0.09694 | 0:01:59s\n", + "epoch 86 | loss: 0.09567 | train_rmsle: 0.04867 | train_mae: 0.20816 | train_rmse: 0.31457 | train_mse: 0.09895 | valid_rmsle: 0.0474 | valid_mae: 0.20504 | valid_rmse: 0.30997 | valid_mse: 0.09608 | 0:02:01s\n", + "epoch 87 | loss: 0.09636 | train_rmsle: 0.04716 | train_mae: 0.20594 | train_rmse: 0.31315 | train_mse: 0.09806 | valid_rmsle: 0.04654 | valid_mae: 0.20452 | valid_rmse: 0.31063 | valid_mse: 0.09649 | 0:02:02s\n", + "epoch 88 | loss: 0.09591 | train_rmsle: 0.04893 | train_mae: 0.20342 | train_rmse: 0.31486 | train_mse: 0.09914 | valid_rmsle: 0.04787 | valid_mae: 0.20101 | valid_rmse: 0.31063 | valid_mse: 0.09649 | 0:02:03s\n", + "epoch 89 | loss: 0.09616 | train_rmsle: 0.04726 | train_mae: 0.2021 | train_rmse: 0.3109 | train_mse: 0.09666 | valid_rmsle: 0.04637 | valid_mae: 0.20035 | valid_rmse: 0.30758 | valid_mse: 0.09461 | 0:02:05s\n", + "epoch 90 | loss: 0.09552 | train_rmsle: 0.04688 | train_mae: 0.20108 | train_rmse: 0.31004 | train_mse: 0.09612 | valid_rmsle: 0.04626 | valid_mae: 0.19981 | valid_rmse: 0.30742 | valid_mse: 0.09451 | 0:02:06s\n", + "epoch 91 | loss: 0.09573 | train_rmsle: 0.04755 | train_mae: 0.20916 | train_rmse: 0.31493 | train_mse: 0.09918 | valid_rmsle: 0.04691 | valid_mae: 0.20723 | valid_rmse: 0.31232 | valid_mse: 0.09754 | 0:02:08s\n", + "epoch 92 | loss: 0.09485 | train_rmsle: 0.04686 | train_mae: 0.20371 | train_rmse: 0.31634 | train_mse: 0.10007 | valid_rmsle: 0.04612 | valid_mae: 0.20201 | valid_rmse: 0.31364 | valid_mse: 0.09837 | 0:02:09s\n", + "epoch 93 | loss: 0.09578 | train_rmsle: 0.04734 | train_mae: 0.20391 | train_rmse: 0.31296 | train_mse: 0.09794 | valid_rmsle: 0.0466 | valid_mae: 0.20216 | valid_rmse: 0.31012 | valid_mse: 0.09617 | 0:02:10s\n", + "epoch 94 | loss: 0.09567 | train_rmsle: 0.04701 | train_mae: 0.20346 | train_rmse: 0.31216 | train_mse: 0.09744 | valid_rmsle: 0.0463 | valid_mae: 0.20177 | valid_rmse: 0.3092 | valid_mse: 0.0956 | 0:02:12s\n", + "epoch 95 | loss: 0.09615 | train_rmsle: 0.04657 | train_mae: 0.19859 | train_rmse: 0.31408 | train_mse: 0.09864 | valid_rmsle: 0.04581 | valid_mae: 0.19735 | valid_rmse: 0.31106 | valid_mse: 0.09676 | 0:02:13s\n", + "epoch 96 | loss: 0.09594 | train_rmsle: 0.04694 | train_mae: 0.20517 | train_rmse: 0.31133 | train_mse: 0.09692 | valid_rmsle: 0.04651 | valid_mae: 0.20409 | valid_rmse: 0.30935 | valid_mse: 0.0957 | 0:02:14s\n", + "epoch 97 | loss: 0.09541 | train_rmsle: 0.04739 | train_mae: 0.20526 | train_rmse: 0.31208 | train_mse: 0.09739 | valid_rmsle: 0.04679 | valid_mae: 0.2037 | valid_rmse: 0.30957 | valid_mse: 0.09583 | 0:02:16s\n", + "epoch 98 | loss: 0.09606 | train_rmsle: 0.04706 | train_mae: 0.20423 | train_rmse: 0.31524 | train_mse: 0.09938 | valid_rmsle: 0.04626 | valid_mae: 0.20216 | valid_rmse: 0.3122 | valid_mse: 0.09747 | 0:02:17s\n", + "epoch 99 | loss: 0.09641 | train_rmsle: 0.04973 | train_mae: 0.20321 | train_rmse: 0.31253 | train_mse: 0.09767 | valid_rmsle: 0.04895 | valid_mae: 0.2007 | valid_rmse: 0.30926 | valid_mse: 0.09564 | 0:02:18s\n", + "Stop training because you reached max_epochs = 100 with best_epoch = 81 and best_valid_mse = 0.09419\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/work/pytorch_tabnet/callbacks.py:172: UserWarning: Best weights from best epoch are automatically used!\n", + " warnings.warn(wrn_msg)\n" + ] + } + ], "source": [ "clf.fit(\n", " X_train=X_train, y_train=y_train,\n", @@ -196,7 +368,8 @@ " patience=50,\n", " batch_size=1024, virtual_batch_size=128,\n", " num_workers=0,\n", - " drop_last=False\n", + " drop_last=False,\n", + " augmentations=aug, #aug\n", ") " ] },