From 92c0b813d5c9ba1bb21166634c20ff9f1a3689dd Mon Sep 17 00:00:00 2001 From: Optimox Date: Tue, 8 Mar 2022 10:59:59 +0100 Subject: [PATCH] chore: poetry update --- README.md | 6 + poetry.lock | 196 +++++++++++++++++-------------- pyproject.toml | 2 +- pytorch_tabnet/abstract_model.py | 4 +- pytorch_tabnet/augmentations.py | 31 ++--- 5 files changed, 133 insertions(+), 106 deletions(-) diff --git a/README.md b/README.md index 14ab004c..607622c5 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,12 @@ A complete example can be found within the notebook `pretraining_example.ipynb`. /!\ : current implementation is trying to reconstruct the original inputs, but Batch Normalization applies a random transformation that can't be deduced by a single line, making the reconstruction harder. Lowering the `batch_size` might make the pretraining easier. +# Data augmentation on the fly + +It is now possible to apply custom data augmentation pipeline during training. +Templates for ClassificationSMOTE and RegressionSMOTE have been added in `pytorch-tabnet/augmentations.py` and can be used as is. + + # Easy saving and loading It's really easy to save and re-load a trained model, this makes TabNet production ready. diff --git a/poetry.lock b/poetry.lock index c0e3ba6b..6a2912e0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -65,17 +65,17 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] name = "attrs" -version = "21.2.0" +version = "21.4.0" description = "Classes Without Boilerplate" category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.extras] -dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] -tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"] -tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] [[package]] name = "babel" @@ -130,7 +130,7 @@ pycparser = "*" [[package]] name = "charset-normalizer" -version = "2.0.9" +version = "2.0.12" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." category = "dev" optional = false @@ -176,7 +176,7 @@ python-versions = ">=3.6, <3.7" [[package]] name = "decorator" -version = "5.1.0" +version = "5.1.1" description = "Decorators for Humans" category = "dev" optional = false @@ -253,6 +253,21 @@ docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] perf = ["ipython"] testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"] +[[package]] +name = "importlib-resources" +version = "5.4.0" +description = "Read resources from Python packages" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-black (>=0.3.7)", "pytest-mypy"] + [[package]] name = "iniconfig" version = "1.1.1" @@ -282,7 +297,7 @@ test = ["pytest (!=5.3.4)", "pytest-cov", "flaky", "nose", "jedi (<=0.17.2)"] [[package]] name = "ipython" -version = "7.16.2" +version = "7.16.3" description = "IPython: Productive Interactive Computing" category = "dev" optional = false @@ -431,7 +446,7 @@ test = ["async-generator", "ipykernel", "ipython", "mock", "pytest-asyncio", "py [[package]] name = "jupyter-console" -version = "6.4.0" +version = "6.4.2" description = "Jupyter terminal console" category = "dev" optional = false @@ -449,7 +464,7 @@ test = ["pexpect"] [[package]] name = "jupyter-core" -version = "4.9.1" +version = "4.9.2" description = "Jupyter core package. A base package on which Jupyter projects rely." category = "dev" optional = false @@ -603,7 +618,7 @@ python-versions = ">=3.5" [[package]] name = "notebook" -version = "6.4.6" +version = "6.4.8" description = "A web-based notebook environment for interactive computing" category = "dev" optional = false @@ -721,11 +736,11 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "prometheus-client" -version = "0.12.0" +version = "0.13.1" description = "Python client for the Prometheus monitoring system." category = "dev" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.6" [package.extras] twisted = ["twisted"] @@ -783,7 +798,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] name = "pygments" -version = "2.10.0" +version = "2.11.2" description = "Pygments is a syntax highlighting package written in Python." category = "dev" optional = false @@ -791,7 +806,7 @@ python-versions = ">=3.5" [[package]] name = "pyparsing" -version = "3.0.6" +version = "3.0.7" description = "Python parsing module" category = "dev" optional = false @@ -851,7 +866,7 @@ python-versions = "*" [[package]] name = "pywin32" -version = "302" +version = "303" description = "Python for Window Extensions" category = "dev" optional = false @@ -859,7 +874,7 @@ python-versions = "*" [[package]] name = "pywinpty" -version = "1.1.6" +version = "2.0.3" description = "Pseudo terminal support for Windows from Python." category = "dev" optional = false @@ -901,11 +916,17 @@ test = ["flaky", "pytest", "pytest-qt"] [[package]] name = "qtpy" -version = "1.11.3" -description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5, PyQt4 and PySide) and additional custom QWidgets." +version = "2.0.1" +description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." category = "dev" optional = false -python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" +python-versions = ">=3.6" + +[package.dependencies] +packaging = "*" + +[package.extras] +test = ["pytest (>=6.0.0)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "recommonmark" @@ -922,7 +943,7 @@ sphinx = ">=1.3.1" [[package]] name = "requests" -version = "2.26.0" +version = "2.27.1" description = "Python HTTP for Humans." category = "dev" optional = false @@ -1115,11 +1136,11 @@ test = ["pytest"] [[package]] name = "terminado" -version = "0.12.1" +version = "0.13.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." category = "dev" optional = false -python-versions = ">=3.6" +python-versions = "*" [package.dependencies] ptyprocess = {version = "*", markers = "os_name != \"nt\""} @@ -1131,18 +1152,18 @@ test = ["pytest"] [[package]] name = "testpath" -version = "0.5.0" +version = "0.6.0" description = "Test utilities for code working with files and commands" category = "dev" optional = false python-versions = ">= 3.5" [package.extras] -test = ["pytest", "pathlib2"] +test = ["pytest"] [[package]] name = "threadpoolctl" -version = "3.0.0" +version = "3.1.0" description = "threadpoolctl" category = "main" optional = false @@ -1174,7 +1195,7 @@ python-versions = ">= 3.5" [[package]] name = "tqdm" -version = "4.62.3" +version = "4.63.0" description = "Fast, Extensible Progress Meter" category = "main" optional = false @@ -1182,6 +1203,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +importlib-resources = {version = "*", markers = "python_version < \"3.7\""} [package.extras] dev = ["py-make (>=0.1.0)", "twine", "wheel"] @@ -1206,7 +1228,7 @@ test = ["pytest", "mock"] [[package]] name = "typing-extensions" -version = "4.0.1" +version = "4.1.1" description = "Backported and Experimental Type Hints for Python 3.6+" category = "dev" optional = false @@ -1214,14 +1236,13 @@ python-versions = ">=3.6" [[package]] name = "urllib3" -version = "1.26.7" +version = "1.22" description = "HTTP library with thread-safe connection pooling, file post, and more." category = "dev" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" +python-versions = "*" [package.extras] -brotli = ["brotlipy (>=0.6.0)"] secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] @@ -1276,7 +1297,7 @@ scipy = "*" name = "zipp" version = "3.6.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -1286,8 +1307,8 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" -python-versions = "^3.6" -content-hash = "924b692f42c5b45e28c771ce997fce9ecb16db9f77c05b22bc0620360130926a" +python-versions = ">=3.6" +content-hash = "d38522a7948f24326479797abe2f75ad8b77200c58152779d423040f13f0b0a3" [metadata.files] alabaster = [ @@ -1334,8 +1355,8 @@ atomicwrites = [ {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, ] attrs = [ - {file = "attrs-21.2.0-py2.py3-none-any.whl", hash = "sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1"}, - {file = "attrs-21.2.0.tar.gz", hash = "sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb"}, + {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, + {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, ] babel = [ {file = "Babel-2.9.1-py2.py3-none-any.whl", hash = "sha256:ab49e12b91d937cd11f0b67cb259a57ab4ad2b59ac7a3b41d6c06c0ac5b0def9"}, @@ -1406,8 +1427,8 @@ cffi = [ {file = "cffi-1.15.0.tar.gz", hash = "sha256:920f0d66a896c2d99f0adbb391f990a84091179542c205fa53ce5787aff87954"}, ] charset-normalizer = [ - {file = "charset-normalizer-2.0.9.tar.gz", hash = "sha256:b0b883e8e874edfdece9c28f314e3dd5badf067342e42fb162203335ae61aa2c"}, - {file = "charset_normalizer-2.0.9-py3-none-any.whl", hash = "sha256:1eecaa09422db5be9e29d7fc65664e6c33bd06f9ced7838578ba40d58bdf3721"}, + {file = "charset-normalizer-2.0.12.tar.gz", hash = "sha256:2857e29ff0d34db842cd7ca3230549d1a697f96ee6d3fb071cfa6c7393832597"}, + {file = "charset_normalizer-2.0.12-py3-none-any.whl", hash = "sha256:6881edbebdb17b39b4eaaa821b438bf6eddffb4468cf344f09f89def34a8b1df"}, ] colorama = [ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, @@ -1426,8 +1447,8 @@ dataclasses = [ {file = "dataclasses-0.8.tar.gz", hash = "sha256:8479067f342acf957dc82ec415d355ab5edb7e7646b90dc6e2fd1d96ad084c97"}, ] decorator = [ - {file = "decorator-5.1.0-py3-none-any.whl", hash = "sha256:7b12e7c3c6ab203a29e157335e9122cb03de9ab7264b137594103fd4a683b374"}, - {file = "decorator-5.1.0.tar.gz", hash = "sha256:e59913af105b9860aa2c8d3272d9de5a56a4e608db9a2f167a8480b323d529a7"}, + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] defusedxml = [ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, @@ -1457,6 +1478,10 @@ importlib-metadata = [ {file = "importlib_metadata-4.8.3-py3-none-any.whl", hash = "sha256:65a9576a5b2d58ca44d133c42a241905cc45e34d2c06fd5ba2bafa221e5d7b5e"}, {file = "importlib_metadata-4.8.3.tar.gz", hash = "sha256:766abffff765960fcc18003801f7044eb6755ffae4521c8e8ce8e83b9c9b0668"}, ] +importlib-resources = [ + {file = "importlib_resources-5.4.0-py3-none-any.whl", hash = "sha256:33a95faed5fc19b4bc16b29a6eeae248a3fe69dd55d4d229d2b480e23eeaad45"}, + {file = "importlib_resources-5.4.0.tar.gz", hash = "sha256:d756e2f85dd4de2ba89be0b21dba2a3bbec2e871a42a3a16719258a11f87506b"}, +] iniconfig = [ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, @@ -1466,8 +1491,8 @@ ipykernel = [ {file = "ipykernel-5.5.6.tar.gz", hash = "sha256:4ea44b90ae1f7c38987ad58ea0809562a17c2695a0499644326f334aecd369ec"}, ] ipython = [ - {file = "ipython-7.16.2-py3-none-any.whl", hash = "sha256:2f644313be4fdc5c8c2a17467f2949c29423c9e283a159d1fc9bf450a1a300af"}, - {file = "ipython-7.16.2.tar.gz", hash = "sha256:613085f8acb0f35f759e32bea35fba62c651a4a2e409a0da11414618f5eec0c4"}, + {file = "ipython-7.16.3-py3-none-any.whl", hash = "sha256:c0427ed8bc33ac481faf9d3acf7e84e0010cdaada945e0badd1e2e74cc075833"}, + {file = "ipython-7.16.3.tar.gz", hash = "sha256:5ac47dc9af66fc2f5530c12069390877ae372ac905edca75a92a6e363b5d7caa"}, ] ipython-genutils = [ {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, @@ -1503,12 +1528,12 @@ jupyter-client = [ {file = "jupyter_client-6.1.13.tar.gz", hash = "sha256:d03558bc9b7955d8b4a6df604a8d9d257e00bcea7fb364fe41cdef81d998a966"}, ] jupyter-console = [ - {file = "jupyter_console-6.4.0-py3-none-any.whl", hash = "sha256:7799c4ea951e0e96ba8260575423cb323ea5a03fcf5503560fa3e15748869e27"}, - {file = "jupyter_console-6.4.0.tar.gz", hash = "sha256:242248e1685039cd8bff2c2ecb7ce6c1546eb50ee3b08519729e6e881aec19c7"}, + {file = "jupyter_console-6.4.2-py3-none-any.whl", hash = "sha256:1d52cf1a80f0c7accaa2b8c68ca3e6fa2311ec33ac9651d6cb6b9168cca1dad9"}, + {file = "jupyter_console-6.4.2.tar.gz", hash = "sha256:fce5bccac926c690924168ad46cae33a7d78d643a7b60af0f260af25d38ecf26"}, ] jupyter-core = [ - {file = "jupyter_core-4.9.1-py3-none-any.whl", hash = "sha256:1c091f3bbefd6f2a8782f2c1db662ca8478ac240e962ae2c66f0b87c818154ea"}, - {file = "jupyter_core-4.9.1.tar.gz", hash = "sha256:dce8a7499da5a53ae3afd5a9f4b02e5df1d57250cf48f3ad79da23b4778cd6fa"}, + {file = "jupyter_core-4.9.2-py3-none-any.whl", hash = "sha256:f875e4d27e202590311d468fa55f90c575f201490bd0c18acabe4e318db4a46d"}, + {file = "jupyter_core-4.9.2.tar.gz", hash = "sha256:d69baeb9ffb128b8cd2657fcf2703f89c769d1673c851812119e3a2a0e93ad9a"}, ] jupyterlab-pygments = [ {file = "jupyterlab_pygments-0.1.2-py2.py3-none-any.whl", hash = "sha256:abfb880fd1561987efaefcb2d2ac75145d2a5d0139b1876d5be806e32f630008"}, @@ -1664,8 +1689,8 @@ nest-asyncio = [ {file = "nest_asyncio-1.5.4.tar.gz", hash = "sha256:f969f6013a16fadb4adcf09d11a68a4f617c6049d7af7ac2c676110169a63abd"}, ] notebook = [ - {file = "notebook-6.4.6-py3-none-any.whl", hash = "sha256:5cad068fa82cd4fb98d341c052100ed50cd69fbfb4118cb9b8ab5a346ef27551"}, - {file = "notebook-6.4.6.tar.gz", hash = "sha256:7bcdf79bd1cda534735bd9830d2cbedab4ee34d8fe1df6e7b946b3aab0902ba3"}, + {file = "notebook-6.4.8-py3-none-any.whl", hash = "sha256:3e702fcc54b8ae597533c3864793b7a1e971dec9e112f67235828d8a798fd654"}, + {file = "notebook-6.4.8.tar.gz", hash = "sha256:1e985c9dc6f678bdfffb9dc657306b5469bfa62d73e03f74e8defbf76d284312"}, ] numpy = [ {file = "numpy-1.19.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cc6bd4fd593cb261332568485e20a0712883cf631f6f5e8e86a52caa8b2b50ff"}, @@ -1749,8 +1774,8 @@ pluggy = [ {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, ] prometheus-client = [ - {file = "prometheus_client-0.12.0-py2.py3-none-any.whl", hash = "sha256:317453ebabff0a1b02df7f708efbab21e3489e7072b61cb6957230dd004a0af0"}, - {file = "prometheus_client-0.12.0.tar.gz", hash = "sha256:1b12ba48cee33b9b0b9de64a1047cbd3c5f2d0ab6ebcead7ddda613a750ec3c5"}, + {file = "prometheus_client-0.13.1-py3-none-any.whl", hash = "sha256:357a447fd2359b0a1d2e9b311a0c5778c330cfbe186d880ad5a6b39884652316"}, + {file = "prometheus_client-0.13.1.tar.gz", hash = "sha256:ada41b891b79fca5638bd5cfe149efa86512eaa55987893becd2c6d8d0a5dfc5"}, ] prompt-toolkit = [ {file = "prompt_toolkit-3.0.3-py3-none-any.whl", hash = "sha256:c93e53af97f630f12f5f62a3274e79527936ed466f038953dfa379d4941f651a"}, @@ -1777,12 +1802,12 @@ pyflakes = [ {file = "pyflakes-2.1.1.tar.gz", hash = "sha256:d976835886f8c5b31d47970ed689944a0262b5f3afa00a5a7b4dc81e5449f8a2"}, ] pygments = [ - {file = "Pygments-2.10.0-py3-none-any.whl", hash = "sha256:b8e67fe6af78f492b3c4b3e2970c0624cbf08beb1e493b2c99b9fa1b67a20380"}, - {file = "Pygments-2.10.0.tar.gz", hash = "sha256:f398865f7eb6874156579fdf36bc840a03cab64d1cde9e93d68f46a425ec52c6"}, + {file = "Pygments-2.11.2-py3-none-any.whl", hash = "sha256:44238f1b60a76d78fc8ca0528ee429702aae011c265fe6a8dd8b63049ae41c65"}, + {file = "Pygments-2.11.2.tar.gz", hash = "sha256:4e426f72023d88d03b2fa258de560726ce890ff3b630f88c21cbb8b2503b8c6a"}, ] pyparsing = [ - {file = "pyparsing-3.0.6-py3-none-any.whl", hash = "sha256:04ff808a5b90911829c55c4e26f75fa5ca8a2f5f36aa3a51f68e27033341d3e4"}, - {file = "pyparsing-3.0.6.tar.gz", hash = "sha256:d9bdec0013ef1eb5a84ab39a3b3868911598afa494f5faa038647101504e2b81"}, + {file = "pyparsing-3.0.7-py3-none-any.whl", hash = "sha256:a6c06a88f252e6c322f65faf8f418b16213b51bdfaece0524c1c1bc30c63c484"}, + {file = "pyparsing-3.0.7.tar.gz", hash = "sha256:18ee9022775d270c55187733956460083db60b37d0d0fb357445f3094eed3eea"}, ] pyrsistent = [ {file = "pyrsistent-0.18.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f4c8cabb46ff8e5d61f56a037974228e978f26bfefce4f61a4b1ac0ba7a2ab72"}, @@ -1820,24 +1845,25 @@ pytz = [ {file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"}, ] pywin32 = [ - {file = "pywin32-302-cp310-cp310-win32.whl", hash = "sha256:251b7a9367355ccd1a4cd69cd8dd24bd57b29ad83edb2957cfa30f7ed9941efa"}, - {file = "pywin32-302-cp310-cp310-win_amd64.whl", hash = "sha256:79cf7e6ddaaf1cd47a9e50cc74b5d770801a9db6594464137b1b86aa91edafcc"}, - {file = "pywin32-302-cp36-cp36m-win32.whl", hash = "sha256:fe21c2fb332d03dac29de070f191bdbf14095167f8f2165fdc57db59b1ecc006"}, - {file = "pywin32-302-cp36-cp36m-win_amd64.whl", hash = "sha256:d3761ab4e8c5c2dbc156e2c9ccf38dd51f936dc77e58deb940ffbc4b82a30528"}, - {file = "pywin32-302-cp37-cp37m-win32.whl", hash = "sha256:48dd4e348f1ee9538dd4440bf201ea8c110ea6d9f3a5010d79452e9fa80480d9"}, - {file = "pywin32-302-cp37-cp37m-win_amd64.whl", hash = "sha256:496df89f10c054c9285cc99f9d509e243f4e14ec8dfc6d78c9f0bf147a893ab1"}, - {file = "pywin32-302-cp38-cp38-win32.whl", hash = "sha256:e372e477d938a49266136bff78279ed14445e00718b6c75543334351bf535259"}, - {file = "pywin32-302-cp38-cp38-win_amd64.whl", hash = "sha256:543552e66936378bd2d673c5a0a3d9903dba0b0a87235ef0c584f058ceef5872"}, - {file = "pywin32-302-cp39-cp39-win32.whl", hash = "sha256:2393c1a40dc4497fd6161b76801b8acd727c5610167762b7c3e9fd058ef4a6ab"}, - {file = "pywin32-302-cp39-cp39-win_amd64.whl", hash = "sha256:af5aea18167a31efcacc9f98a2ca932c6b6a6d91ebe31f007509e293dea12580"}, + {file = "pywin32-303-cp310-cp310-win32.whl", hash = "sha256:6fed4af057039f309263fd3285d7b8042d41507343cd5fa781d98fcc5b90e8bb"}, + {file = "pywin32-303-cp310-cp310-win_amd64.whl", hash = "sha256:51cb52c5ec6709f96c3f26e7795b0bf169ee0d8395b2c1d7eb2c029a5008ed51"}, + {file = "pywin32-303-cp311-cp311-win32.whl", hash = "sha256:d9b5d87ca944eb3aa4cd45516203ead4b37ab06b8b777c54aedc35975dec0dee"}, + {file = "pywin32-303-cp311-cp311-win_amd64.whl", hash = "sha256:fcf44032f5b14fcda86028cdf49b6ebdaea091230eb0a757282aa656e4732439"}, + {file = "pywin32-303-cp36-cp36m-win32.whl", hash = "sha256:aad484d52ec58008ca36bd4ad14a71d7dd0a99db1a4ca71072213f63bf49c7d9"}, + {file = "pywin32-303-cp36-cp36m-win_amd64.whl", hash = "sha256:2a09632916b6bb231ba49983fe989f2f625cea237219530e81a69239cd0c4559"}, + {file = "pywin32-303-cp37-cp37m-win32.whl", hash = "sha256:b1675d82bcf6dbc96363fca747bac8bff6f6e4a447a4287ac652aa4b9adc796e"}, + {file = "pywin32-303-cp37-cp37m-win_amd64.whl", hash = "sha256:c268040769b48a13367221fced6d4232ed52f044ffafeda247bd9d2c6bdc29ca"}, + {file = "pywin32-303-cp38-cp38-win32.whl", hash = "sha256:5f9ec054f5a46a0f4dfd72af2ce1372f3d5a6e4052af20b858aa7df2df7d355b"}, + {file = "pywin32-303-cp38-cp38-win_amd64.whl", hash = "sha256:793bf74fce164bcffd9d57bb13c2c15d56e43c9542a7b9687b4fccf8f8a41aba"}, + {file = "pywin32-303-cp39-cp39-win32.whl", hash = "sha256:7d3271c98434617a11921c5ccf74615794d97b079e22ed7773790822735cc352"}, + {file = "pywin32-303-cp39-cp39-win_amd64.whl", hash = "sha256:79cbb862c11b9af19bcb682891c1b91942ec2ff7de8151e2aea2e175899cda34"}, ] pywinpty = [ - {file = "pywinpty-1.1.6-cp310-none-win_amd64.whl", hash = "sha256:5f526f21b569b5610a61e3b6126259c76da979399598e5154498582df3736ade"}, - {file = "pywinpty-1.1.6-cp36-none-win_amd64.whl", hash = "sha256:7576e14f42b31fa98b62d24ded79754d2ea4625570c016b38eb347ce158a30f2"}, - {file = "pywinpty-1.1.6-cp37-none-win_amd64.whl", hash = "sha256:979ffdb9bdbe23db3f46fc7285fd6dbb86b80c12325a50582b211b3894072354"}, - {file = "pywinpty-1.1.6-cp38-none-win_amd64.whl", hash = "sha256:2308b1fc77545427610a705799d4ead5e7f00874af3fb148a03e202437456a7e"}, - {file = "pywinpty-1.1.6-cp39-none-win_amd64.whl", hash = "sha256:c703bf569a98ab7844b9daf37e88ab86f31862754ef6910a8b3824993a525c72"}, - {file = "pywinpty-1.1.6.tar.gz", hash = "sha256:8808f07350c709119cc4464144d6e749637f98e15acc1e5d3c37db1953d2eebc"}, + {file = "pywinpty-2.0.3-cp310-none-win_amd64.whl", hash = "sha256:7a330ef7a2ce284370b1a1fdd2a80c523585464fa5e5ab934c9f27220fa7feab"}, + {file = "pywinpty-2.0.3-cp37-none-win_amd64.whl", hash = "sha256:6455f1075f978942d318f95616661c605d5e0f991c5b176c0c852d237aafefc0"}, + {file = "pywinpty-2.0.3-cp38-none-win_amd64.whl", hash = "sha256:2e7a288a8121393c526d4e6ec7d65edef75d68c7787ab9560e438df867b75a5d"}, + {file = "pywinpty-2.0.3-cp39-none-win_amd64.whl", hash = "sha256:def51627e6aa659f33ea7a0ea4c6b68365c83af4aad7940600f844746817a0ed"}, + {file = "pywinpty-2.0.3.tar.gz", hash = "sha256:6b29a826e896105370c38d53904c3aaac6c36146a50448fc0ed5082cf9d092bc"}, ] pyzmq = [ {file = "pyzmq-22.3.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:6b217b8f9dfb6628f74b94bdaf9f7408708cb02167d644edca33f38746ca12dd"}, @@ -1893,16 +1919,16 @@ qtconsole = [ {file = "qtconsole-5.2.2.tar.gz", hash = "sha256:8f9db97b27782184efd0a0f2d57ea3bd852d053747a2e442a9011329c082976d"}, ] qtpy = [ - {file = "QtPy-1.11.3-py2.py3-none-any.whl", hash = "sha256:e121fbee8e95645af29c5a4aceba8d657991551fc1aa3b6b6012faf4725a1d20"}, - {file = "QtPy-1.11.3.tar.gz", hash = "sha256:d427addd37386a8d786db81864a5536700861d95bf085cb31d1bea855d699557"}, + {file = "QtPy-2.0.1-py3-none-any.whl", hash = "sha256:d93f2c98e97387fcc9d623d509772af5b6c15ab9d8f9f4c5dfbad9a73ad34812"}, + {file = "QtPy-2.0.1.tar.gz", hash = "sha256:adfd073ffbd2de81dc7aaa0b983499ef5c59c96adcfdcc9dea60d42ca885eb8f"}, ] recommonmark = [ {file = "recommonmark-0.6.0-py2.py3-none-any.whl", hash = "sha256:2ec4207a574289355d5b6ae4ae4abb29043346ca12cdd5f07d374dc5987d2852"}, {file = "recommonmark-0.6.0.tar.gz", hash = "sha256:29cd4faeb6c5268c633634f2d69aef9431e0f4d347f90659fd0aab20e541efeb"}, ] requests = [ - {file = "requests-2.26.0-py2.py3-none-any.whl", hash = "sha256:6c1246513ecd5ecd4528a0906f910e8f0f9c6b8ec72030dc9fd154dc1a6efd24"}, - {file = "requests-2.26.0.tar.gz", hash = "sha256:b8aa58f8cf793ffd8782d3d8cb19e66ef36f7aba4353eec859e74678b01b07a7"}, + {file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"}, + {file = "requests-2.27.1.tar.gz", hash = "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61"}, ] scikit-learn = [ {file = "scikit-learn-0.24.2.tar.gz", hash = "sha256:d14701a12417930392cd3898e9646cf5670c190b933625ebe7511b1f7d7b8736"}, @@ -2011,16 +2037,16 @@ sphinxcontrib-serializinghtml = [ {file = "sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl", hash = "sha256:352a9a00ae864471d3a7ead8d7d79f5fc0b57e8b3f95e9867eb9eb28999b92fd"}, ] terminado = [ - {file = "terminado-0.12.1-py3-none-any.whl", hash = "sha256:09fdde344324a1c9c6e610ee4ca165c4bb7f5bbf982fceeeb38998a988ef8452"}, - {file = "terminado-0.12.1.tar.gz", hash = "sha256:b20fd93cc57c1678c799799d117874367cc07a3d2d55be95205b1a88fa08393f"}, + {file = "terminado-0.13.0-py3-none-any.whl", hash = "sha256:50a18654ad0cff153fdeb58711c9a7b25e0e2b74fb721dbaddd9e80d5612fac6"}, + {file = "terminado-0.13.0.tar.gz", hash = "sha256:713531ccb5db7d4f544651f14050da79809030f00d1afa21462088cf32fb143a"}, ] testpath = [ - {file = "testpath-0.5.0-py3-none-any.whl", hash = "sha256:8044f9a0bab6567fc644a3593164e872543bb44225b0e24846e2c89237937589"}, - {file = "testpath-0.5.0.tar.gz", hash = "sha256:1acf7a0bcd3004ae8357409fc33751e16d37ccc650921da1094a86581ad1e417"}, + {file = "testpath-0.6.0-py3-none-any.whl", hash = "sha256:8ada9f80a2ac6fb0391aa7cdb1a7d11cfa8429f693eda83f74dde570fe6fa639"}, + {file = "testpath-0.6.0.tar.gz", hash = "sha256:2f1b97e6442c02681ebe01bd84f531028a7caea1af3825000f52345c30285e0f"}, ] threadpoolctl = [ - {file = "threadpoolctl-3.0.0-py3-none-any.whl", hash = "sha256:4fade5b3b48ae4b1c30f200b28f39180371104fccc642e039e0f2435ec8cc211"}, - {file = "threadpoolctl-3.0.0.tar.gz", hash = "sha256:d03115321233d0be715f0d3a5ad1d6c065fe425ddc2d671ca8e45e9fd5d7a52a"}, + {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"}, + {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"}, ] toml = [ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, @@ -2083,20 +2109,20 @@ tornado = [ {file = "tornado-6.1.tar.gz", hash = "sha256:33c6e81d7bd55b468d2e793517c909b139960b6c790a60b7991b9b6b76fb9791"}, ] tqdm = [ - {file = "tqdm-4.62.3-py2.py3-none-any.whl", hash = "sha256:8dd278a422499cd6b727e6ae4061c40b48fce8b76d1ccbf5d34fca9b7f925b0c"}, - {file = "tqdm-4.62.3.tar.gz", hash = "sha256:d359de7217506c9851b7869f3708d8ee53ed70a1b8edbba4dbcb47442592920d"}, + {file = "tqdm-4.63.0-py2.py3-none-any.whl", hash = "sha256:e643e071046f17139dea55b880dc9b33822ce21613b4a4f5ea57f202833dbc29"}, + {file = "tqdm-4.63.0.tar.gz", hash = "sha256:1d9835ede8e394bb8c9dcbffbca02d717217113adc679236873eeaac5bc0b3cd"}, ] traitlets = [ {file = "traitlets-4.3.3-py2.py3-none-any.whl", hash = "sha256:70b4c6a1d9019d7b4f6846832288f86998aa3b9207c6821f3578a6a6a467fe44"}, {file = "traitlets-4.3.3.tar.gz", hash = "sha256:d023ee369ddd2763310e4c3eae1ff649689440d4ae59d7485eb4cfbbe3e359f7"}, ] typing-extensions = [ - {file = "typing_extensions-4.0.1-py3-none-any.whl", hash = "sha256:7f001e5ac290a0c0401508864c7ec868be4e701886d5b573a9528ed3973d9d3b"}, - {file = "typing_extensions-4.0.1.tar.gz", hash = "sha256:4ca091dea149f945ec56afb48dae714f21e8692ef22a395223bcd328961b6a0e"}, + {file = "typing_extensions-4.1.1-py3-none-any.whl", hash = "sha256:21c85e0fe4b9a155d0799430b0ad741cdce7e359660ccbd8b530613e8df88ce2"}, + {file = "typing_extensions-4.1.1.tar.gz", hash = "sha256:1a9462dcc3347a79b1f1c0271fbe79e844580bb598bafa1ed208b94da3cdcd42"}, ] urllib3 = [ - {file = "urllib3-1.26.7-py2.py3-none-any.whl", hash = "sha256:c4fdf4019605b6e5423637e01bc9fe4daef873709a7973e195ceba0a62bbc844"}, - {file = "urllib3-1.26.7.tar.gz", hash = "sha256:4987c65554f7a2dbf30c18fd48778ef124af6fab771a377103da0585e2336ece"}, + {file = "urllib3-1.22-py2.py3-none-any.whl", hash = "sha256:06330f386d6e4b195fbfc736b297f58c5a892e4440e54d294d7004e3a9bbea1b"}, + {file = "urllib3-1.22.tar.gz", hash = "sha256:cc44da8e1145637334317feebd728bd869a35285b93cbb4cca2577da7e62db4f"}, ] wcwidth = [ {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, diff --git a/pyproject.toml b/pyproject.toml index 16118657..1f96b264 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ keywords = ["tabnet", "pytorch", "neural-networks" ] exclude = ["tabnet/*.ipynb"] [tool.poetry.dependencies] -python = "^3.6" +python = ">=3.6" numpy="^1.17" torch="^1.2" diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 2a1b2f61..c5a19bdf 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -478,11 +478,11 @@ def _train_batch(self, X, y): """ batch_logs = {"batch_size": X.shape[0]} - X = X.to(self.device).float() # Is this .float() needed ? + X = X.to(self.device).float() y = y.to(self.device).float() if self.augmentations is not None: - X, y = self.augmentations(X, y) + X, y = self.augmentations(X, y) for param in self.network.parameters(): param.grad = None diff --git a/pytorch_tabnet/augmentations.py b/pytorch_tabnet/augmentations.py index 8c8d3ce0..287fa365 100644 --- a/pytorch_tabnet/augmentations.py +++ b/pytorch_tabnet/augmentations.py @@ -2,16 +2,14 @@ from pytorch_tabnet.utils import define_device import numpy as np -# TODO : change this so that p would be the proportion of rows that are changed -# add a beta argument (beta distribution) + class RegressionSMOTE(): """ Apply SMOTE This will average a percentage p of the elements in the batch with other elements. - The target will be averaged as well (this might work with binary classification and certain loss), - following a beta distribution. - + The target will be averaged as well (this might work with binary classification + and certain loss), following a beta distribution. """ def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0): "" @@ -21,9 +19,9 @@ def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0): self.alpha = alpha self.beta = beta self.p = p - if (p < 0.) or (p > 1.0): + if (p < 0.) or (p > 1.0): raise ValueError("Value of p should be between 0. and 1.") - + def _set_seed(self): torch.manual_seed(self.seed) np.random.seed(self.seed) @@ -39,14 +37,15 @@ def __call__(self, X, y): random_betas = torch.from_numpy(np_betas).to(self.device).float() index_permute = torch.randperm(batch_size, device=self.device) - X[idx_to_change] = random_betas[idx_to_change, None]*X[idx_to_change] + \ - (1 - random_betas[idx_to_change, None])*X[index_permute][idx_to_change].view(X[idx_to_change].size()) + X[idx_to_change] = random_betas[idx_to_change, None] * X[idx_to_change] + X[idx_to_change] += (1 - random_betas[idx_to_change, None]) * X[index_permute][idx_to_change].view(X[idx_to_change].size()) # noqa - y[idx_to_change] = random_betas[idx_to_change, None]*y[idx_to_change] + \ - (1 - random_betas[idx_to_change, None])*y[index_permute][idx_to_change].view(y[idx_to_change].size()) + y[idx_to_change] = random_betas[idx_to_change, None] * y[idx_to_change] + y[idx_to_change] += (1 - random_betas[idx_to_change, None]) * y[index_permute][idx_to_change].view(y[idx_to_change].size()) # noqa return X, y + class ClassificationSMOTE(): """ Apply SMOTE for classification tasks. @@ -62,7 +61,7 @@ def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0): self.alpha = alpha self.beta = beta self.p = p - if (p < 0.) or (p > 1.0): + if (p < 0.) or (p > 1.0): raise ValueError("Value of p should be between 0. and 1.") def _set_seed(self): @@ -80,11 +79,7 @@ def __call__(self, X, y): random_betas = torch.from_numpy(np_betas).to(self.device).float() index_permute = torch.randperm(batch_size, device=self.device) - X[idx_to_change] = random_betas[idx_to_change, None]*X[idx_to_change] + \ - (1 - random_betas[idx_to_change, None])*X[index_permute][idx_to_change].view(X[idx_to_change].size()) - + X[idx_to_change] = random_betas[idx_to_change, None] * X[idx_to_change] + X[idx_to_change] += (1 - random_betas[idx_to_change, None]) * X[index_permute][idx_to_change].view(X[idx_to_change].size()) # noqa return X, y - - - \ No newline at end of file