diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000..b7bbb5c
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,21 @@
+name: Test suite
+
+on: [push, pull_request]
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python 3.10
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.10'
+ - name: Install pipenv
+ run: pip install pipenv
+ - name: Run tests
+ run: |
+ PIP_FIND_LINKS=https://download.pytorch.org/whl/torch pipenv install torch==1.13.1+cpu
+ pipenv sync --dev
+ pipenv run pytest -vv
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..a989108
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,31 @@
+# IDE
+.vscode/
+.idea/
+
+# Models
+lightning_logs/
+mlruns/
+
+# Data
+datasets/
+
+# Testing
+test/resources/out
+.pytest_cache/
+
+# LaTeX
+*.aux
+*.log
+*.out
+*.fls
+*.fdb_latexmk
+*.gz
+*.xdv
+*.bbl
+*.blg
+*.bcf
+*.run.xml
+
+# Metadata
+*.DS_Store
+*.ipynb_checkpoints
\ No newline at end of file
diff --git a/.readme/header.png b/.readme/header.png
new file mode 100644
index 0000000..b9d025b
Binary files /dev/null and b/.readme/header.png differ
diff --git a/.readme/onex.svg b/.readme/onex.svg
new file mode 100644
index 0000000..9bac4b3
--- /dev/null
+++ b/.readme/onex.svg
@@ -0,0 +1,945 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/.readme/recall.svg b/.readme/recall.svg
new file mode 100644
index 0000000..a300a2a
--- /dev/null
+++ b/.readme/recall.svg
@@ -0,0 +1,973 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/.style.yapf b/.style.yapf
new file mode 100644
index 0000000..d762a26
--- /dev/null
+++ b/.style.yapf
@@ -0,0 +1,3 @@
+[style]
+based_on_style = pep8
+column_limit = 128
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..5b22c13
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Otto (GmbH & Co KG), https://www.otto.de/jobs/technology/ueberblick/
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/Pipfile b/Pipfile
new file mode 100644
index 0000000..6601e61
--- /dev/null
+++ b/Pipfile
@@ -0,0 +1,23 @@
+[[source]]
+url = "https://pypi.org/simple"
+verify_ssl = true
+name = "pypi"
+
+[packages]
+torch = "==1.13.1"
+pytorch-lightning = "*"
+tensorboardx = "*"
+onnx = "*"
+pandas = "*"
+kaggle = "*"
+gdown = "*"
+mlflow = "*"
+
+[dev-packages]
+pytest = "*"
+ipykernel = "*"
+autopep8 = "*"
+yapf = "*"
+
+[requires]
+python_version = "3.10"
diff --git a/Pipfile.lock b/Pipfile.lock
new file mode 100644
index 0000000..0bb60e7
--- /dev/null
+++ b/Pipfile.lock
@@ -0,0 +1,2086 @@
+{
+ "_meta": {
+ "hash": {
+ "sha256": "c1714f1ab70c9a1d7f2e580a7ef5118fd54bfe716a2a0fc8d4802b0f00360c9d"
+ },
+ "pipfile-spec": 6,
+ "requires": {
+ "python_version": "3.10"
+ },
+ "sources": [
+ {
+ "name": "pypi",
+ "url": "https://pypi.org/simple",
+ "verify_ssl": true
+ }
+ ]
+ },
+ "default": {
+ "aiohttp": {
+ "hashes": [
+ "sha256:03543dcf98a6619254b409be2d22b51f21ec66272be4ebda7b04e6412e4b2e14",
+ "sha256:03baa76b730e4e15a45f81dfe29a8d910314143414e528737f8589ec60cf7391",
+ "sha256:0a63f03189a6fa7c900226e3ef5ba4d3bd047e18f445e69adbd65af433add5a2",
+ "sha256:10c8cefcff98fd9168cdd86c4da8b84baaa90bf2da2269c6161984e6737bf23e",
+ "sha256:147ae376f14b55f4f3c2b118b95be50a369b89b38a971e80a17c3fd623f280c9",
+ "sha256:176a64b24c0935869d5bbc4c96e82f89f643bcdf08ec947701b9dbb3c956b7dd",
+ "sha256:17b79c2963db82086229012cff93ea55196ed31f6493bb1ccd2c62f1724324e4",
+ "sha256:1a45865451439eb320784918617ba54b7a377e3501fb70402ab84d38c2cd891b",
+ "sha256:1b3ea7edd2d24538959c1c1abf97c744d879d4e541d38305f9bd7d9b10c9ec41",
+ "sha256:22f6eab15b6db242499a16de87939a342f5a950ad0abaf1532038e2ce7d31567",
+ "sha256:3032dcb1c35bc330134a5b8a5d4f68c1a87252dfc6e1262c65a7e30e62298275",
+ "sha256:33587f26dcee66efb2fff3c177547bd0449ab7edf1b73a7f5dea1e38609a0c54",
+ "sha256:34ce9f93a4a68d1272d26030655dd1b58ff727b3ed2a33d80ec433561b03d67a",
+ "sha256:3a80464982d41b1fbfe3154e440ba4904b71c1a53e9cd584098cd41efdb188ef",
+ "sha256:3b90467ebc3d9fa5b0f9b6489dfb2c304a1db7b9946fa92aa76a831b9d587e99",
+ "sha256:3d89efa095ca7d442a6d0cbc755f9e08190ba40069b235c9886a8763b03785da",
+ "sha256:3d8ef1a630519a26d6760bc695842579cb09e373c5f227a21b67dc3eb16cfea4",
+ "sha256:3f43255086fe25e36fd5ed8f2ee47477408a73ef00e804cb2b5cba4bf2ac7f5e",
+ "sha256:40653609b3bf50611356e6b6554e3a331f6879fa7116f3959b20e3528783e699",
+ "sha256:41a86a69bb63bb2fc3dc9ad5ea9f10f1c9c8e282b471931be0268ddd09430b04",
+ "sha256:493f5bc2f8307286b7799c6d899d388bbaa7dfa6c4caf4f97ef7521b9cb13719",
+ "sha256:4a6cadebe132e90cefa77e45f2d2f1a4b2ce5c6b1bfc1656c1ddafcfe4ba8131",
+ "sha256:4c745b109057e7e5f1848c689ee4fb3a016c8d4d92da52b312f8a509f83aa05e",
+ "sha256:4d347a172f866cd1d93126d9b239fcbe682acb39b48ee0873c73c933dd23bd0f",
+ "sha256:4dac314662f4e2aa5009977b652d9b8db7121b46c38f2073bfeed9f4049732cd",
+ "sha256:4ddaae3f3d32fc2cb4c53fab020b69a05c8ab1f02e0e59665c6f7a0d3a5be54f",
+ "sha256:5393fb786a9e23e4799fec788e7e735de18052f83682ce2dfcabaf1c00c2c08e",
+ "sha256:59f029a5f6e2d679296db7bee982bb3d20c088e52a2977e3175faf31d6fb75d1",
+ "sha256:5a7bdf9e57126dc345b683c3632e8ba317c31d2a41acd5800c10640387d193ed",
+ "sha256:5b3f2e06a512e94722886c0827bee9807c86a9f698fac6b3aee841fab49bbfb4",
+ "sha256:5ce45967538fb747370308d3145aa68a074bdecb4f3a300869590f725ced69c1",
+ "sha256:5e14f25765a578a0a634d5f0cd1e2c3f53964553a00347998dfdf96b8137f777",
+ "sha256:618c901dd3aad4ace71dfa0f5e82e88b46ef57e3239fc7027773cb6d4ed53531",
+ "sha256:652b1bff4f15f6287550b4670546a2947f2a4575b6c6dff7760eafb22eacbf0b",
+ "sha256:6c08e8ed6fa3d477e501ec9db169bfac8140e830aa372d77e4a43084d8dd91ab",
+ "sha256:6ddb2a2026c3f6a68c3998a6c47ab6795e4127315d2e35a09997da21865757f8",
+ "sha256:6e601588f2b502c93c30cd5a45bfc665faaf37bbe835b7cfd461753068232074",
+ "sha256:6e74dd54f7239fcffe07913ff8b964e28b712f09846e20de78676ce2a3dc0bfc",
+ "sha256:7235604476a76ef249bd64cb8274ed24ccf6995c4a8b51a237005ee7a57e8643",
+ "sha256:7ab43061a0c81198d88f39aaf90dae9a7744620978f7ef3e3708339b8ed2ef01",
+ "sha256:7c7837fe8037e96b6dd5cfcf47263c1620a9d332a87ec06a6ca4564e56bd0f36",
+ "sha256:80575ba9377c5171407a06d0196b2310b679dc752d02a1fcaa2bc20b235dbf24",
+ "sha256:80a37fe8f7c1e6ce8f2d9c411676e4bc633a8462844e38f46156d07a7d401654",
+ "sha256:8189c56eb0ddbb95bfadb8f60ea1b22fcfa659396ea36f6adcc521213cd7b44d",
+ "sha256:854f422ac44af92bfe172d8e73229c270dc09b96535e8a548f99c84f82dde241",
+ "sha256:880e15bb6dad90549b43f796b391cfffd7af373f4646784795e20d92606b7a51",
+ "sha256:8b631e26df63e52f7cce0cce6507b7a7f1bc9b0c501fcde69742130b32e8782f",
+ "sha256:8c29c77cc57e40f84acef9bfb904373a4e89a4e8b74e71aa8075c021ec9078c2",
+ "sha256:91f6d540163f90bbaef9387e65f18f73ffd7c79f5225ac3d3f61df7b0d01ad15",
+ "sha256:92c0cea74a2a81c4c76b62ea1cac163ecb20fb3ba3a75c909b9fa71b4ad493cf",
+ "sha256:9bcb89336efa095ea21b30f9e686763f2be4478f1b0a616969551982c4ee4c3b",
+ "sha256:a1f4689c9a1462f3df0a1f7e797791cd6b124ddbee2b570d34e7f38ade0e2c71",
+ "sha256:a3fec6a4cb5551721cdd70473eb009d90935b4063acc5f40905d40ecfea23e05",
+ "sha256:a5d794d1ae64e7753e405ba58e08fcfa73e3fad93ef9b7e31112ef3c9a0efb52",
+ "sha256:a86d42d7cba1cec432d47ab13b6637bee393a10f664c425ea7b305d1301ca1a3",
+ "sha256:adfbc22e87365a6e564c804c58fc44ff7727deea782d175c33602737b7feadb6",
+ "sha256:aeb29c84bb53a84b1a81c6c09d24cf33bb8432cc5c39979021cc0f98c1292a1a",
+ "sha256:aede4df4eeb926c8fa70de46c340a1bc2c6079e1c40ccf7b0eae1313ffd33519",
+ "sha256:b744c33b6f14ca26b7544e8d8aadff6b765a80ad6164fb1a430bbadd593dfb1a",
+ "sha256:b7a00a9ed8d6e725b55ef98b1b35c88013245f35f68b1b12c5cd4100dddac333",
+ "sha256:bb96fa6b56bb536c42d6a4a87dfca570ff8e52de2d63cabebfd6fb67049c34b6",
+ "sha256:bbcf1a76cf6f6dacf2c7f4d2ebd411438c275faa1dc0c68e46eb84eebd05dd7d",
+ "sha256:bca5f24726e2919de94f047739d0a4fc01372801a3672708260546aa2601bf57",
+ "sha256:bf2e1a9162c1e441bf805a1fd166e249d574ca04e03b34f97e2928769e91ab5c",
+ "sha256:c4eb3b82ca349cf6fadcdc7abcc8b3a50ab74a62e9113ab7a8ebc268aad35bb9",
+ "sha256:c6cc15d58053c76eacac5fa9152d7d84b8d67b3fde92709195cb984cfb3475ea",
+ "sha256:c6cd05ea06daca6ad6a4ca3ba7fe7dc5b5de063ff4daec6170ec0f9979f6c332",
+ "sha256:c844fd628851c0bc309f3c801b3a3d58ce430b2ce5b359cd918a5a76d0b20cb5",
+ "sha256:c9cb1565a7ad52e096a6988e2ee0397f72fe056dadf75d17fa6b5aebaea05622",
+ "sha256:cab9401de3ea52b4b4c6971db5fb5c999bd4260898af972bf23de1c6b5dd9d71",
+ "sha256:cd468460eefef601ece4428d3cf4562459157c0f6523db89365202c31b6daebb",
+ "sha256:d1e6a862b76f34395a985b3cd39a0d949ca80a70b6ebdea37d3ab39ceea6698a",
+ "sha256:d1f9282c5f2b5e241034a009779e7b2a1aa045f667ff521e7948ea9b56e0c5ff",
+ "sha256:d265f09a75a79a788237d7f9054f929ced2e69eb0bb79de3798c468d8a90f945",
+ "sha256:db3fc6120bce9f446d13b1b834ea5b15341ca9ff3f335e4a951a6ead31105480",
+ "sha256:dbf3a08a06b3f433013c143ebd72c15cac33d2914b8ea4bea7ac2c23578815d6",
+ "sha256:de04b491d0e5007ee1b63a309956eaed959a49f5bb4e84b26c8f5d49de140fa9",
+ "sha256:e4b09863aae0dc965c3ef36500d891a3ff495a2ea9ae9171e4519963c12ceefd",
+ "sha256:e595432ac259af2d4630008bf638873d69346372d38255774c0e286951e8b79f",
+ "sha256:e75b89ac3bd27d2d043b234aa7b734c38ba1b0e43f07787130a0ecac1e12228a",
+ "sha256:ea9eb976ffdd79d0e893869cfe179a8f60f152d42cb64622fca418cd9b18dc2a",
+ "sha256:eafb3e874816ebe2a92f5e155f17260034c8c341dad1df25672fb710627c6949",
+ "sha256:ee3c36df21b5714d49fc4580247947aa64bcbe2939d1b77b4c8dcb8f6c9faecc",
+ "sha256:f352b62b45dff37b55ddd7b9c0c8672c4dd2eb9c0f9c11d395075a84e2c40f75",
+ "sha256:fabb87dd8850ef0f7fe2b366d44b77d7e6fa2ea87861ab3844da99291e81e60f",
+ "sha256:fe11310ae1e4cd560035598c3f29d86cef39a83d244c7466f95c27ae04850f10",
+ "sha256:fe7ba4a51f33ab275515f66b0a236bcde4fb5561498fe8f898d4e549b2e4509f"
+ ],
+ "version": "==3.8.4"
+ },
+ "aiosignal": {
+ "hashes": [
+ "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc",
+ "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.3.1"
+ },
+ "alembic": {
+ "hashes": [
+ "sha256:6a810a6b012c88b33458fceb869aef09ac75d6ace5291915ba7fae44de372c01",
+ "sha256:dc871798a601fab38332e38d6ddb38d5e734f60034baeb8e2db5b642fccd8ab8"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.11.1"
+ },
+ "async-timeout": {
+ "hashes": [
+ "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15",
+ "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==4.0.2"
+ },
+ "attrs": {
+ "hashes": [
+ "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04",
+ "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==23.1.0"
+ },
+ "beautifulsoup4": {
+ "hashes": [
+ "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da",
+ "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"
+ ],
+ "markers": "python_full_version >= '3.6.0'",
+ "version": "==4.12.2"
+ },
+ "blinker": {
+ "hashes": [
+ "sha256:4afd3de66ef3a9f8067559fb7a1cbe555c17dcbe15971b05d1b625c3e7abe213",
+ "sha256:c3d739772abb7bc2860abf5f2ec284223d9ad5c76da018234f6f50d6f31ab1f0"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.6.2"
+ },
+ "certifi": {
+ "hashes": [
+ "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7",
+ "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==2023.5.7"
+ },
+ "charset-normalizer": {
+ "hashes": [
+ "sha256:04afa6387e2b282cf78ff3dbce20f0cc071c12dc8f685bd40960cc68644cfea6",
+ "sha256:04eefcee095f58eaabe6dc3cc2262f3bcd776d2c67005880894f447b3f2cb9c1",
+ "sha256:0be65ccf618c1e7ac9b849c315cc2e8a8751d9cfdaa43027d4f6624bd587ab7e",
+ "sha256:0c95f12b74681e9ae127728f7e5409cbbef9cd914d5896ef238cc779b8152373",
+ "sha256:0ca564606d2caafb0abe6d1b5311c2649e8071eb241b2d64e75a0d0065107e62",
+ "sha256:10c93628d7497c81686e8e5e557aafa78f230cd9e77dd0c40032ef90c18f2230",
+ "sha256:11d117e6c63e8f495412d37e7dc2e2fff09c34b2d09dbe2bee3c6229577818be",
+ "sha256:11d3bcb7be35e7b1bba2c23beedac81ee893ac9871d0ba79effc7fc01167db6c",
+ "sha256:12a2b561af122e3d94cdb97fe6fb2bb2b82cef0cdca131646fdb940a1eda04f0",
+ "sha256:12d1a39aa6b8c6f6248bb54550efcc1c38ce0d8096a146638fd4738e42284448",
+ "sha256:1435ae15108b1cb6fffbcea2af3d468683b7afed0169ad718451f8db5d1aff6f",
+ "sha256:1c60b9c202d00052183c9be85e5eaf18a4ada0a47d188a83c8f5c5b23252f649",
+ "sha256:1e8fcdd8f672a1c4fc8d0bd3a2b576b152d2a349782d1eb0f6b8e52e9954731d",
+ "sha256:20064ead0717cf9a73a6d1e779b23d149b53daf971169289ed2ed43a71e8d3b0",
+ "sha256:21fa558996782fc226b529fdd2ed7866c2c6ec91cee82735c98a197fae39f706",
+ "sha256:22908891a380d50738e1f978667536f6c6b526a2064156203d418f4856d6e86a",
+ "sha256:3160a0fd9754aab7d47f95a6b63ab355388d890163eb03b2d2b87ab0a30cfa59",
+ "sha256:322102cdf1ab682ecc7d9b1c5eed4ec59657a65e1c146a0da342b78f4112db23",
+ "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5",
+ "sha256:3573d376454d956553c356df45bb824262c397c6e26ce43e8203c4c540ee0acb",
+ "sha256:3747443b6a904001473370d7810aa19c3a180ccd52a7157aacc264a5ac79265e",
+ "sha256:38e812a197bf8e71a59fe55b757a84c1f946d0ac114acafaafaf21667a7e169e",
+ "sha256:3a06f32c9634a8705f4ca9946d667609f52cf130d5548881401f1eb2c39b1e2c",
+ "sha256:3a5fc78f9e3f501a1614a98f7c54d3969f3ad9bba8ba3d9b438c3bc5d047dd28",
+ "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d",
+ "sha256:3dc5b6a8ecfdc5748a7e429782598e4f17ef378e3e272eeb1340ea57c9109f41",
+ "sha256:4155b51ae05ed47199dc5b2a4e62abccb274cee6b01da5b895099b61b1982974",
+ "sha256:49919f8400b5e49e961f320c735388ee686a62327e773fa5b3ce6721f7e785ce",
+ "sha256:53d0a3fa5f8af98a1e261de6a3943ca631c526635eb5817a87a59d9a57ebf48f",
+ "sha256:5f008525e02908b20e04707a4f704cd286d94718f48bb33edddc7d7b584dddc1",
+ "sha256:628c985afb2c7d27a4800bfb609e03985aaecb42f955049957814e0491d4006d",
+ "sha256:65ed923f84a6844de5fd29726b888e58c62820e0769b76565480e1fdc3d062f8",
+ "sha256:6734e606355834f13445b6adc38b53c0fd45f1a56a9ba06c2058f86893ae8017",
+ "sha256:6baf0baf0d5d265fa7944feb9f7451cc316bfe30e8df1a61b1bb08577c554f31",
+ "sha256:6f4f4668e1831850ebcc2fd0b1cd11721947b6dc7c00bf1c6bd3c929ae14f2c7",
+ "sha256:6f5c2e7bc8a4bf7c426599765b1bd33217ec84023033672c1e9a8b35eaeaaaf8",
+ "sha256:6f6c7a8a57e9405cad7485f4c9d3172ae486cfef1344b5ddd8e5239582d7355e",
+ "sha256:7381c66e0561c5757ffe616af869b916c8b4e42b367ab29fedc98481d1e74e14",
+ "sha256:73dc03a6a7e30b7edc5b01b601e53e7fc924b04e1835e8e407c12c037e81adbd",
+ "sha256:74db0052d985cf37fa111828d0dd230776ac99c740e1a758ad99094be4f1803d",
+ "sha256:75f2568b4189dda1c567339b48cba4ac7384accb9c2a7ed655cd86b04055c795",
+ "sha256:78cacd03e79d009d95635e7d6ff12c21eb89b894c354bd2b2ed0b4763373693b",
+ "sha256:80d1543d58bd3d6c271b66abf454d437a438dff01c3e62fdbcd68f2a11310d4b",
+ "sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b",
+ "sha256:891cf9b48776b5c61c700b55a598621fdb7b1e301a550365571e9624f270c203",
+ "sha256:8f25e17ab3039b05f762b0a55ae0b3632b2e073d9c8fc88e89aca31a6198e88f",
+ "sha256:9a3267620866c9d17b959a84dd0bd2d45719b817245e49371ead79ed4f710d19",
+ "sha256:a04f86f41a8916fe45ac5024ec477f41f886b3c435da2d4e3d2709b22ab02af1",
+ "sha256:aaf53a6cebad0eae578f062c7d462155eada9c172bd8c4d250b8c1d8eb7f916a",
+ "sha256:abc1185d79f47c0a7aaf7e2412a0eb2c03b724581139193d2d82b3ad8cbb00ac",
+ "sha256:ac0aa6cd53ab9a31d397f8303f92c42f534693528fafbdb997c82bae6e477ad9",
+ "sha256:ac3775e3311661d4adace3697a52ac0bab17edd166087d493b52d4f4f553f9f0",
+ "sha256:b06f0d3bf045158d2fb8837c5785fe9ff9b8c93358be64461a1089f5da983137",
+ "sha256:b116502087ce8a6b7a5f1814568ccbd0e9f6cfd99948aa59b0e241dc57cf739f",
+ "sha256:b82fab78e0b1329e183a65260581de4375f619167478dddab510c6c6fb04d9b6",
+ "sha256:bd7163182133c0c7701b25e604cf1611c0d87712e56e88e7ee5d72deab3e76b5",
+ "sha256:c36bcbc0d5174a80d6cccf43a0ecaca44e81d25be4b7f90f0ed7bcfbb5a00909",
+ "sha256:c3af8e0f07399d3176b179f2e2634c3ce9c1301379a6b8c9c9aeecd481da494f",
+ "sha256:c84132a54c750fda57729d1e2599bb598f5fa0344085dbde5003ba429a4798c0",
+ "sha256:cb7b2ab0188829593b9de646545175547a70d9a6e2b63bf2cd87a0a391599324",
+ "sha256:cca4def576f47a09a943666b8f829606bcb17e2bc2d5911a46c8f8da45f56755",
+ "sha256:cf6511efa4801b9b38dc5546d7547d5b5c6ef4b081c60b23e4d941d0eba9cbeb",
+ "sha256:d16fd5252f883eb074ca55cb622bc0bee49b979ae4e8639fff6ca3ff44f9f854",
+ "sha256:d2686f91611f9e17f4548dbf050e75b079bbc2a82be565832bc8ea9047b61c8c",
+ "sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60",
+ "sha256:dd5653e67b149503c68c4018bf07e42eeed6b4e956b24c00ccdf93ac79cdff84",
+ "sha256:de5695a6f1d8340b12a5d6d4484290ee74d61e467c39ff03b39e30df62cf83a0",
+ "sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b",
+ "sha256:e1b25e3ad6c909f398df8921780d6a3d120d8c09466720226fc621605b6f92b1",
+ "sha256:e633940f28c1e913615fd624fcdd72fdba807bf53ea6925d6a588e84e1151531",
+ "sha256:e89df2958e5159b811af9ff0f92614dabf4ff617c03a4c1c6ff53bf1c399e0e1",
+ "sha256:ea9f9c6034ea2d93d9147818f17c2a0860d41b71c38b9ce4d55f21b6f9165a11",
+ "sha256:f645caaf0008bacf349875a974220f1f1da349c5dbe7c4ec93048cdc785a3326",
+ "sha256:f8303414c7b03f794347ad062c0516cee0e15f7a612abd0ce1e25caf6ceb47df",
+ "sha256:fca62a8301b605b954ad2e9c3666f9d97f63872aa4efcae5492baca2056b74ab"
+ ],
+ "markers": "python_full_version >= '3.7.0'",
+ "version": "==3.1.0"
+ },
+ "click": {
+ "hashes": [
+ "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e",
+ "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==8.1.3"
+ },
+ "cloudpickle": {
+ "hashes": [
+ "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f",
+ "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==2.2.1"
+ },
+ "contourpy": {
+ "hashes": [
+ "sha256:052cc634bf903c604ef1a00a5aa093c54f81a2612faedaa43295809ffdde885e",
+ "sha256:084eaa568400cfaf7179b847ac871582199b1b44d5699198e9602ecbbb5f6104",
+ "sha256:0b6616375d7de55797d7a66ee7d087efe27f03d336c27cf1f32c02b8c1a5ac70",
+ "sha256:0b7b04ed0961647691cfe5d82115dd072af7ce8846d31a5fac6c142dcce8b882",
+ "sha256:143dde50520a9f90e4a2703f367cf8ec96a73042b72e68fcd184e1279962eb6f",
+ "sha256:17cfaf5ec9862bc93af1ec1f302457371c34e688fbd381f4035a06cd47324f48",
+ "sha256:181cbace49874f4358e2929aaf7ba84006acb76694102e88dd15af861996c16e",
+ "sha256:189ceb1525eb0655ab8487a9a9c41f42a73ba52d6789754788d1883fb06b2d8a",
+ "sha256:18a64814ae7bce73925131381603fff0116e2df25230dfc80d6d690aa6e20b37",
+ "sha256:1f0cbd657e9bde94cd0e33aa7df94fb73c1ab7799378d3b3f902eb8eb2e04a3a",
+ "sha256:1f795597073b09d631782e7245016a4323cf1cf0b4e06eef7ea6627e06a37ff2",
+ "sha256:25ae46595e22f93592d39a7eac3d638cda552c3e1160255258b695f7b58e5655",
+ "sha256:27bc79200c742f9746d7dd51a734ee326a292d77e7d94c8af6e08d1e6c15d545",
+ "sha256:2b836d22bd2c7bb2700348e4521b25e077255ebb6ab68e351ab5aa91ca27e027",
+ "sha256:30f511c05fab7f12e0b1b7730ebdc2ec8deedcfb505bc27eb570ff47c51a8f15",
+ "sha256:317267d915490d1e84577924bd61ba71bf8681a30e0d6c545f577363157e5e94",
+ "sha256:397b0ac8a12880412da3551a8cb5a187d3298a72802b45a3bd1805e204ad8439",
+ "sha256:438ba416d02f82b692e371858143970ed2eb6337d9cdbbede0d8ad9f3d7dd17d",
+ "sha256:53cc3a40635abedbec7f1bde60f8c189c49e84ac180c665f2cd7c162cc454baa",
+ "sha256:5d123a5bc63cd34c27ff9c7ac1cd978909e9c71da12e05be0231c608048bb2ae",
+ "sha256:62013a2cf68abc80dadfd2307299bfa8f5aa0dcaec5b2954caeb5fa094171103",
+ "sha256:89f06eff3ce2f4b3eb24c1055a26981bffe4e7264acd86f15b97e40530b794bc",
+ "sha256:90c81f22b4f572f8a2110b0b741bb64e5a6427e0a198b2cdc1fbaf85f352a3aa",
+ "sha256:911ff4fd53e26b019f898f32db0d4956c9d227d51338fb3b03ec72ff0084ee5f",
+ "sha256:9382a1c0bc46230fb881c36229bfa23d8c303b889b788b939365578d762b5c18",
+ "sha256:9f2931ed4741f98f74b410b16e5213f71dcccee67518970c42f64153ea9313b9",
+ "sha256:a67259c2b493b00e5a4d0f7bfae51fb4b3371395e47d079a4446e9b0f4d70e76",
+ "sha256:a698c6a7a432789e587168573a864a7ea374c6be8d4f31f9d87c001d5a843493",
+ "sha256:bc00bb4225d57bff7ebb634646c0ee2a1298402ec10a5fe7af79df9a51c1bfd9",
+ "sha256:bcb41692aa09aeb19c7c213411854402f29f6613845ad2453d30bf421fe68fed",
+ "sha256:d4f26b25b4f86087e7d75e63212756c38546e70f2a92d2be44f80114826e1cd4",
+ "sha256:d551f3a442655f3dcc1285723f9acd646ca5858834efeab4598d706206b09c9f",
+ "sha256:dffcc2ddec1782dd2f2ce1ef16f070861af4fb78c69862ce0aab801495dda6a3",
+ "sha256:e53046c3863828d21d531cc3b53786e6580eb1ba02477e8681009b6aa0870b21",
+ "sha256:e5cec36c5090e75a9ac9dbd0ff4a8cf7cecd60f1b6dc23a374c7d980a1cd710e",
+ "sha256:e7a117ce7df5a938fe035cad481b0189049e8d92433b4b33aa7fc609344aafa1",
+ "sha256:e94bef2580e25b5fdb183bf98a2faa2adc5b638736b2c0a4da98691da641316a",
+ "sha256:ed614aea8462735e7d70141374bd7650afd1c3f3cb0c2dbbcbe44e14331bf002",
+ "sha256:fb3b7d9e6243bfa1efb93ccfe64ec610d85cfe5aec2c25f97fbbd2e58b531256"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==1.1.0"
+ },
+ "cycler": {
+ "hashes": [
+ "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3",
+ "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==0.11.0"
+ },
+ "databricks-cli": {
+ "hashes": [
+ "sha256:5a545063449f3b9ad904644c0f251058485e29e564dedf8d4e4a7b45caf9549b",
+ "sha256:5b025943c70bbd374415264d38bfaddfb34ce070fadb083d851aec311e0f8901"
+ ],
+ "version": "==0.17.7"
+ },
+ "docker": {
+ "hashes": [
+ "sha256:aa6d17830045ba5ef0168d5eaa34d37beeb113948c413affe1d5991fc11f9a20",
+ "sha256:aecd2277b8bf8e506e484f6ab7aec39abe0038e29fa4a6d3ba86c3fe01844ed9"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==6.1.3"
+ },
+ "entrypoints": {
+ "hashes": [
+ "sha256:b706eddaa9218a19ebcd67b56818f05bb27589b1ca9e8d797b74affad4ccacd4",
+ "sha256:f174b5ff827504fd3cd97cc3f8649f3693f51538c7e4bdf3ef002c8429d42f9f"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==0.4"
+ },
+ "filelock": {
+ "hashes": [
+ "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81",
+ "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==3.12.2"
+ },
+ "flask": {
+ "hashes": [
+ "sha256:77fd4e1249d8c9923de34907236b747ced06e5467ecac1a7bb7115ae0e9670b0",
+ "sha256:8c2f9abd47a9e8df7f0c3f091ce9497d011dc3b31effcf4c85a6e2b50f4114ef"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==2.3.2"
+ },
+ "fonttools": {
+ "hashes": [
+ "sha256:00ab569b2a3e591e00425023ade87e8fef90380c1dde61be7691cb524ca5f743",
+ "sha256:022c4a16b412293e7f1ce21b8bab7a6f9d12c4ffdf171fdc67122baddb973069",
+ "sha256:05171f3c546f64d78569f10adc0de72561882352cac39ec7439af12304d8d8c0",
+ "sha256:14037c31138fbd21847ad5e5441dfdde003e0a8f3feb5812a1a21fd1c255ffbd",
+ "sha256:15abb3d055c1b2dff9ce376b6c3db10777cb74b37b52b78f61657634fd348a0d",
+ "sha256:18ea64ac43e94c9e0c23d7a9475f1026be0e25b10dda8f236fc956188761df97",
+ "sha256:1a003608400dd1cca3e089e8c94973c6b51a4fb1ef00ff6d7641617b9242e637",
+ "sha256:1bc4c5b147be8dbc5df9cc8ac5e93ee914ad030fe2a201cc8f02f499db71011d",
+ "sha256:200729d12461e2038700d31f0d49ad5a7b55855dec7525074979a06b46f88505",
+ "sha256:337b6e83d7ee73c40ea62407f2ce03b07c3459e213b6f332b94a69923b9e1cb9",
+ "sha256:37467cee0f32cada2ec08bc16c9c31f9b53ea54b2f5604bf25a1246b5f50593a",
+ "sha256:425b74a608427499b0e45e433c34ddc350820b6f25b7c8761963a08145157a66",
+ "sha256:530c5d35109f3e0cea2535742d6a3bc99c0786cf0cbd7bb2dc9212387f0d908c",
+ "sha256:56d4d85f5374b45b08d2f928517d1e313ea71b4847240398decd0ab3ebbca885",
+ "sha256:5e00334c66f4e83535384cb5339526d01d02d77f142c23b2f97bd6a4f585497a",
+ "sha256:5fdf60f8a5c6bcce7d024a33f7e4bc7921f5b74e8ea13bccd204f2c8b86f3470",
+ "sha256:6a8d71b9a5c884c72741868e845c0e563c5d83dcaf10bb0ceeec3b4b2eb14c67",
+ "sha256:6d5adf4ba114f028fc3f5317a221fd8b0f4ef7a2e5524a2b1e0fd891b093791a",
+ "sha256:7449e5e306f3a930a8944c85d0cbc8429cba13503372a1a40f23124d6fb09b58",
+ "sha256:7961575221e3da0841c75da53833272c520000d76f7f71274dbf43370f8a1065",
+ "sha256:7f6e3fa3da923063c286320e728ba2270e49c73386e3a711aa680f4b0747d692",
+ "sha256:882983279bf39afe4e945109772c2ffad2be2c90983d6559af8b75c19845a80a",
+ "sha256:8a917828dbfdb1cbe50cf40eeae6fbf9c41aef9e535649ed8f4982b2ef65c091",
+ "sha256:8c4305b171b61040b1ee75d18f9baafe58bd3b798d1670078efe2c92436bfb63",
+ "sha256:91784e21a1a085fac07c6a407564f4a77feb471b5954c9ee55a4f9165151f6c1",
+ "sha256:94c915f6716589f78bc00fbc14c5b8de65cfd11ee335d32504f1ef234524cb24",
+ "sha256:97d95b8301b62bdece1af943b88bcb3680fd385f88346a4a899ee145913b414a",
+ "sha256:a954b90d1473c85a22ecf305761d9fd89da93bbd31dae86e7dea436ad2cb5dc9",
+ "sha256:aa83b3f151bc63970f39b2b42a06097c5a22fd7ed9f7ba008e618de4503d3895",
+ "sha256:b802dcbf9bcff74672f292b2466f6589ab8736ce4dcf36f48eb994c2847c4b30",
+ "sha256:bae8c13abbc2511e9a855d2142c0ab01178dd66b1a665798f357da0d06253e0d",
+ "sha256:c55f1b4109dbc3aeb496677b3e636d55ef46dc078c2a5e3f3db4e90f1c6d2907",
+ "sha256:eb52c10fda31159c22c7ed85074e05f8b97da8773ea461706c273e31bcbea836",
+ "sha256:ec468c022d09f1817c691cf884feb1030ef6f1e93e3ea6831b0d8144c06480d1"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==4.40.0"
+ },
+ "frozenlist": {
+ "hashes": [
+ "sha256:008a054b75d77c995ea26629ab3a0c0d7281341f2fa7e1e85fa6153ae29ae99c",
+ "sha256:02c9ac843e3390826a265e331105efeab489ffaf4dd86384595ee8ce6d35ae7f",
+ "sha256:034a5c08d36649591be1cbb10e09da9f531034acfe29275fc5454a3b101ce41a",
+ "sha256:05cdb16d09a0832eedf770cb7bd1fe57d8cf4eaf5aced29c4e41e3f20b30a784",
+ "sha256:0693c609e9742c66ba4870bcee1ad5ff35462d5ffec18710b4ac89337ff16e27",
+ "sha256:0771aed7f596c7d73444c847a1c16288937ef988dc04fb9f7be4b2aa91db609d",
+ "sha256:0af2e7c87d35b38732e810befb9d797a99279cbb85374d42ea61c1e9d23094b3",
+ "sha256:14143ae966a6229350021384870458e4777d1eae4c28d1a7aa47f24d030e6678",
+ "sha256:180c00c66bde6146a860cbb81b54ee0df350d2daf13ca85b275123bbf85de18a",
+ "sha256:1841e200fdafc3d51f974d9d377c079a0694a8f06de2e67b48150328d66d5483",
+ "sha256:23d16d9f477bb55b6154654e0e74557040575d9d19fe78a161bd33d7d76808e8",
+ "sha256:2b07ae0c1edaa0a36339ec6cce700f51b14a3fc6545fdd32930d2c83917332cf",
+ "sha256:2c926450857408e42f0bbc295e84395722ce74bae69a3b2aa2a65fe22cb14b99",
+ "sha256:2e24900aa13212e75e5b366cb9065e78bbf3893d4baab6052d1aca10d46d944c",
+ "sha256:303e04d422e9b911a09ad499b0368dc551e8c3cd15293c99160c7f1f07b59a48",
+ "sha256:352bd4c8c72d508778cf05ab491f6ef36149f4d0cb3c56b1b4302852255d05d5",
+ "sha256:3843f84a6c465a36559161e6c59dce2f2ac10943040c2fd021cfb70d58c4ad56",
+ "sha256:394c9c242113bfb4b9aa36e2b80a05ffa163a30691c7b5a29eba82e937895d5e",
+ "sha256:3bbdf44855ed8f0fbcd102ef05ec3012d6a4fd7c7562403f76ce6a52aeffb2b1",
+ "sha256:40de71985e9042ca00b7953c4f41eabc3dc514a2d1ff534027f091bc74416401",
+ "sha256:41fe21dc74ad3a779c3d73a2786bdf622ea81234bdd4faf90b8b03cad0c2c0b4",
+ "sha256:47df36a9fe24054b950bbc2db630d508cca3aa27ed0566c0baf661225e52c18e",
+ "sha256:4ea42116ceb6bb16dbb7d526e242cb6747b08b7710d9782aa3d6732bd8d27649",
+ "sha256:58bcc55721e8a90b88332d6cd441261ebb22342e238296bb330968952fbb3a6a",
+ "sha256:5c11e43016b9024240212d2a65043b70ed8dfd3b52678a1271972702d990ac6d",
+ "sha256:5cf820485f1b4c91e0417ea0afd41ce5cf5965011b3c22c400f6d144296ccbc0",
+ "sha256:5d8860749e813a6f65bad8285a0520607c9500caa23fea6ee407e63debcdbef6",
+ "sha256:6327eb8e419f7d9c38f333cde41b9ae348bec26d840927332f17e887a8dcb70d",
+ "sha256:65a5e4d3aa679610ac6e3569e865425b23b372277f89b5ef06cf2cdaf1ebf22b",
+ "sha256:66080ec69883597e4d026f2f71a231a1ee9887835902dbe6b6467d5a89216cf6",
+ "sha256:783263a4eaad7c49983fe4b2e7b53fa9770c136c270d2d4bbb6d2192bf4d9caf",
+ "sha256:7f44e24fa70f6fbc74aeec3e971f60a14dde85da364aa87f15d1be94ae75aeef",
+ "sha256:7fdfc24dcfce5b48109867c13b4cb15e4660e7bd7661741a391f821f23dfdca7",
+ "sha256:810860bb4bdce7557bc0febb84bbd88198b9dbc2022d8eebe5b3590b2ad6c842",
+ "sha256:841ea19b43d438a80b4de62ac6ab21cfe6827bb8a9dc62b896acc88eaf9cecba",
+ "sha256:84610c1502b2461255b4c9b7d5e9c48052601a8957cd0aea6ec7a7a1e1fb9420",
+ "sha256:899c5e1928eec13fd6f6d8dc51be23f0d09c5281e40d9cf4273d188d9feeaf9b",
+ "sha256:8bae29d60768bfa8fb92244b74502b18fae55a80eac13c88eb0b496d4268fd2d",
+ "sha256:8df3de3a9ab8325f94f646609a66cbeeede263910c5c0de0101079ad541af332",
+ "sha256:8fa3c6e3305aa1146b59a09b32b2e04074945ffcfb2f0931836d103a2c38f936",
+ "sha256:924620eef691990dfb56dc4709f280f40baee568c794b5c1885800c3ecc69816",
+ "sha256:9309869032abb23d196cb4e4db574232abe8b8be1339026f489eeb34a4acfd91",
+ "sha256:9545a33965d0d377b0bc823dcabf26980e77f1b6a7caa368a365a9497fb09420",
+ "sha256:9ac5995f2b408017b0be26d4a1d7c61bce106ff3d9e3324374d66b5964325448",
+ "sha256:9bbbcedd75acdfecf2159663b87f1bb5cfc80e7cd99f7ddd9d66eb98b14a8411",
+ "sha256:a4ae8135b11652b08a8baf07631d3ebfe65a4c87909dbef5fa0cdde440444ee4",
+ "sha256:a6394d7dadd3cfe3f4b3b186e54d5d8504d44f2d58dcc89d693698e8b7132b32",
+ "sha256:a97b4fe50b5890d36300820abd305694cb865ddb7885049587a5678215782a6b",
+ "sha256:ae4dc05c465a08a866b7a1baf360747078b362e6a6dbeb0c57f234db0ef88ae0",
+ "sha256:b1c63e8d377d039ac769cd0926558bb7068a1f7abb0f003e3717ee003ad85530",
+ "sha256:b1e2c1185858d7e10ff045c496bbf90ae752c28b365fef2c09cf0fa309291669",
+ "sha256:b4395e2f8d83fbe0c627b2b696acce67868793d7d9750e90e39592b3626691b7",
+ "sha256:b756072364347cb6aa5b60f9bc18e94b2f79632de3b0190253ad770c5df17db1",
+ "sha256:ba64dc2b3b7b158c6660d49cdb1d872d1d0bf4e42043ad8d5006099479a194e5",
+ "sha256:bed331fe18f58d844d39ceb398b77d6ac0b010d571cba8267c2e7165806b00ce",
+ "sha256:c188512b43542b1e91cadc3c6c915a82a5eb95929134faf7fd109f14f9892ce4",
+ "sha256:c21b9aa40e08e4f63a2f92ff3748e6b6c84d717d033c7b3438dd3123ee18f70e",
+ "sha256:ca713d4af15bae6e5d79b15c10c8522859a9a89d3b361a50b817c98c2fb402a2",
+ "sha256:cd4210baef299717db0a600d7a3cac81d46ef0e007f88c9335db79f8979c0d3d",
+ "sha256:cfe33efc9cb900a4c46f91a5ceba26d6df370ffddd9ca386eb1d4f0ad97b9ea9",
+ "sha256:d5cd3ab21acbdb414bb6c31958d7b06b85eeb40f66463c264a9b343a4e238642",
+ "sha256:dfbac4c2dfcc082fcf8d942d1e49b6aa0766c19d3358bd86e2000bf0fa4a9cf0",
+ "sha256:e235688f42b36be2b6b06fc37ac2126a73b75fb8d6bc66dd632aa35286238703",
+ "sha256:eb82dbba47a8318e75f679690190c10a5e1f447fbf9df41cbc4c3afd726d88cb",
+ "sha256:ebb86518203e12e96af765ee89034a1dbb0c3c65052d1b0c19bbbd6af8a145e1",
+ "sha256:ee78feb9d293c323b59a6f2dd441b63339a30edf35abcb51187d2fc26e696d13",
+ "sha256:eedab4c310c0299961ac285591acd53dc6723a1ebd90a57207c71f6e0c2153ab",
+ "sha256:efa568b885bca461f7c7b9e032655c0c143d305bf01c30caf6db2854a4532b38",
+ "sha256:efce6ae830831ab6a22b9b4091d411698145cb9b8fc869e1397ccf4b4b6455cb",
+ "sha256:f163d2fd041c630fed01bc48d28c3ed4a3b003c00acd396900e11ee5316b56bb",
+ "sha256:f20380df709d91525e4bee04746ba612a4df0972c1b8f8e1e8af997e678c7b81",
+ "sha256:f30f1928162e189091cf4d9da2eac617bfe78ef907a761614ff577ef4edfb3c8",
+ "sha256:f470c92737afa7d4c3aacc001e335062d582053d4dbe73cda126f2d7031068dd",
+ "sha256:ff8bf625fe85e119553b5383ba0fb6aa3d0ec2ae980295aaefa552374926b3f4"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.3.3"
+ },
+ "fsspec": {
+ "extras": [
+ "http"
+ ],
+ "hashes": [
+ "sha256:1cbad1faef3e391fba6dc005ae9b5bdcbf43005c9167ce78c915549c352c869a",
+ "sha256:d0b2f935446169753e7a5c5c55681c54ea91996cc67be93c39a154fb3a2742af"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==2023.6.0"
+ },
+ "gdown": {
+ "hashes": [
+ "sha256:347f23769679aaf7efa73e5655270fcda8ca56be65eb84a4a21d143989541045",
+ "sha256:65d495699e7c2c61af0d0e9c32748fb4f79abaf80d747a87456c7be14aac2560"
+ ],
+ "index": "pypi",
+ "version": "==4.7.1"
+ },
+ "gitdb": {
+ "hashes": [
+ "sha256:6eb990b69df4e15bad899ea868dc46572c3f75339735663b81de79b06f17eb9a",
+ "sha256:c286cf298426064079ed96a9e4a9d39e7f3e9bf15ba60701e95f5492f28415c7"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==4.0.10"
+ },
+ "gitpython": {
+ "hashes": [
+ "sha256:8ce3bcf69adfdf7c7d503e78fd3b1c492af782d58893b650adb2ac8912ddd573",
+ "sha256:f04893614f6aa713a60cbbe1e6a97403ef633103cdd0ef5eb6efe0deb98dbe8d"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==3.1.31"
+ },
+ "greenlet": {
+ "hashes": [
+ "sha256:03a8f4f3430c3b3ff8d10a2a86028c660355ab637cee9333d63d66b56f09d52a",
+ "sha256:0bf60faf0bc2468089bdc5edd10555bab6e85152191df713e2ab1fcc86382b5a",
+ "sha256:18a7f18b82b52ee85322d7a7874e676f34ab319b9f8cce5de06067384aa8ff43",
+ "sha256:18e98fb3de7dba1c0a852731c3070cf022d14f0d68b4c87a19cc1016f3bb8b33",
+ "sha256:1a819eef4b0e0b96bb0d98d797bef17dc1b4a10e8d7446be32d1da33e095dbb8",
+ "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088",
+ "sha256:2780572ec463d44c1d3ae850239508dbeb9fed38e294c68d19a24d925d9223ca",
+ "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343",
+ "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645",
+ "sha256:2dd11f291565a81d71dab10b7033395b7a3a5456e637cf997a6f33ebdf06f8db",
+ "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df",
+ "sha256:32e5b64b148966d9cccc2c8d35a671409e45f195864560829f395a54226408d3",
+ "sha256:36abbf031e1c0f79dd5d596bfaf8e921c41df2bdf54ee1eed921ce1f52999a86",
+ "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2",
+ "sha256:3a51c9751078733d88e013587b108f1b7a1fb106d402fb390740f002b6f6551a",
+ "sha256:3c9b12575734155d0c09d6c3e10dbd81665d5c18e1a7c6597df72fd05990c8cf",
+ "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7",
+ "sha256:4b58adb399c4d61d912c4c331984d60eb66565175cdf4a34792cd9600f21b394",
+ "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40",
+ "sha256:5454276c07d27a740c5892f4907c86327b632127dd9abec42ee62e12427ff7e3",
+ "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6",
+ "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74",
+ "sha256:703f18f3fda276b9a916f0934d2fb6d989bf0b4fb5a64825260eb9bfd52d78f0",
+ "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3",
+ "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91",
+ "sha256:7cafd1208fdbe93b67c7086876f061f660cfddc44f404279c1585bbf3cdc64c5",
+ "sha256:7efde645ca1cc441d6dc4b48c0f7101e8d86b54c8530141b09fd31cef5149ec9",
+ "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8",
+ "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b",
+ "sha256:910841381caba4f744a44bf81bfd573c94e10b3045ee00de0cbf436fe50673a6",
+ "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb",
+ "sha256:937e9020b514ceedb9c830c55d5c9872abc90f4b5862f89c0887033ae33c6f73",
+ "sha256:94c817e84245513926588caf1152e3b559ff794d505555211ca041f032abbb6b",
+ "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df",
+ "sha256:9d14b83fab60d5e8abe587d51c75b252bcc21683f24699ada8fb275d7712f5a9",
+ "sha256:9f35ec95538f50292f6d8f2c9c9f8a3c6540bbfec21c9e5b4b751e0a7c20864f",
+ "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0",
+ "sha256:acd2162a36d3de67ee896c43effcd5ee3de247eb00354db411feb025aa319857",
+ "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a",
+ "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249",
+ "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30",
+ "sha256:b9ec052b06a0524f0e35bd8790686a1da006bd911dd1ef7d50b77bfbad74e292",
+ "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b",
+ "sha256:bdfea8c661e80d3c1c99ad7c3ff74e6e87184895bbaca6ee8cc61209f8b9b85d",
+ "sha256:be4ed120b52ae4d974aa40215fcdfde9194d63541c7ded40ee12eb4dda57b76b",
+ "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c",
+ "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca",
+ "sha256:c9c59a2120b55788e800d82dfa99b9e156ff8f2227f07c5e3012a45a399620b7",
+ "sha256:cd021c754b162c0fb55ad5d6b9d960db667faad0fa2ff25bb6e1301b0b6e6a75",
+ "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae",
+ "sha256:d5508f0b173e6aa47273bdc0a0b5ba055b59662ba7c7ee5119528f466585526b",
+ "sha256:d75209eed723105f9596807495d58d10b3470fa6732dd6756595e89925ce2470",
+ "sha256:db1a39669102a1d8d12b57de2bb7e2ec9066a6f2b3da35ae511ff93b01b5d564",
+ "sha256:dbfcfc0218093a19c252ca8eb9aee3d29cfdcb586df21049b9d777fd32c14fd9",
+ "sha256:e0f72c9ddb8cd28532185f54cc1453f2c16fb417a08b53a855c4e6a418edd099",
+ "sha256:e7c8dc13af7db097bed64a051d2dd49e9f0af495c26995c00a9ee842690d34c0",
+ "sha256:ea9872c80c132f4663822dd2a08d404073a5a9b5ba6155bea72fb2a79d1093b5",
+ "sha256:eff4eb9b7eb3e4d0cae3d28c283dc16d9bed6b193c2e1ace3ed86ce48ea8df19",
+ "sha256:f82d4d717d8ef19188687aa32b8363e96062911e63ba22a0cff7802a8e58e5f1",
+ "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"
+ ],
+ "markers": "platform_machine == 'aarch64' or (platform_machine == 'ppc64le' or (platform_machine == 'x86_64' or (platform_machine == 'amd64' or (platform_machine == 'AMD64' or (platform_machine == 'win32' or platform_machine == 'WIN32')))))",
+ "version": "==2.0.2"
+ },
+ "gunicorn": {
+ "hashes": [
+ "sha256:9dcc4547dbb1cb284accfb15ab5667a0e5d1881cc443e0677b4882a4067a807e",
+ "sha256:e0a968b5ba15f8a328fdfd7ab1fcb5af4470c28aaf7e55df02a99bc13138e6e8"
+ ],
+ "markers": "platform_system != 'Windows'",
+ "version": "==20.1.0"
+ },
+ "idna": {
+ "hashes": [
+ "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4",
+ "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"
+ ],
+ "markers": "python_version >= '3.5'",
+ "version": "==3.4"
+ },
+ "importlib-metadata": {
+ "hashes": [
+ "sha256:1aaf550d4f73e5d6783e7acb77aec43d49da8017410afae93822cc9cca98c4d4",
+ "sha256:cb52082e659e97afc5dac71e79de97d8681de3aa07ff18578330904a9d18e5b5"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==6.7.0"
+ },
+ "itsdangerous": {
+ "hashes": [
+ "sha256:2c2349112351b88699d8d4b6b075022c0808887cb7ad10069318a8b0bc88db44",
+ "sha256:5dbbc68b317e5e42f327f9021763545dc3fc3bfe22e6deb96aaf1fc38874156a"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==2.1.2"
+ },
+ "jinja2": {
+ "hashes": [
+ "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852",
+ "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"
+ ],
+ "markers": "platform_system != 'Windows'",
+ "version": "==3.1.2"
+ },
+ "joblib": {
+ "hashes": [
+ "sha256:0b12a65dc76c530dbd790dd92881f75c40932b4254a7c8e608a868df408ca0a3",
+ "sha256:172d56d4c43dd6bcd953bea213018c4084cf754963bbf54b8dae40faea716b98"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.3.0"
+ },
+ "kaggle": {
+ "hashes": [
+ "sha256:8364c56c36125cb8195876d911d0bc9efbc173167e963b9e9ef793789ba9a7fe"
+ ],
+ "index": "pypi",
+ "version": "==1.5.13"
+ },
+ "kiwisolver": {
+ "hashes": [
+ "sha256:02f79693ec433cb4b5f51694e8477ae83b3205768a6fb48ffba60549080e295b",
+ "sha256:03baab2d6b4a54ddbb43bba1a3a2d1627e82d205c5cf8f4c924dc49284b87166",
+ "sha256:1041feb4cda8708ce73bb4dcb9ce1ccf49d553bf87c3954bdfa46f0c3f77252c",
+ "sha256:10ee06759482c78bdb864f4109886dff7b8a56529bc1609d4f1112b93fe6423c",
+ "sha256:1d1573129aa0fd901076e2bfb4275a35f5b7aa60fbfb984499d661ec950320b0",
+ "sha256:283dffbf061a4ec60391d51e6155e372a1f7a4f5b15d59c8505339454f8989e4",
+ "sha256:28bc5b299f48150b5f822ce68624e445040595a4ac3d59251703779836eceff9",
+ "sha256:2a66fdfb34e05b705620dd567f5a03f239a088d5a3f321e7b6ac3239d22aa286",
+ "sha256:2e307eb9bd99801f82789b44bb45e9f541961831c7311521b13a6c85afc09767",
+ "sha256:2e407cb4bd5a13984a6c2c0fe1845e4e41e96f183e5e5cd4d77a857d9693494c",
+ "sha256:2f5e60fabb7343a836360c4f0919b8cd0d6dbf08ad2ca6b9cf90bf0c76a3c4f6",
+ "sha256:36dafec3d6d6088d34e2de6b85f9d8e2324eb734162fba59d2ba9ed7a2043d5b",
+ "sha256:3fe20f63c9ecee44560d0e7f116b3a747a5d7203376abeea292ab3152334d004",
+ "sha256:41dae968a94b1ef1897cb322b39360a0812661dba7c682aa45098eb8e193dbdf",
+ "sha256:4bd472dbe5e136f96a4b18f295d159d7f26fd399136f5b17b08c4e5f498cd494",
+ "sha256:4ea39b0ccc4f5d803e3337dd46bcce60b702be4d86fd0b3d7531ef10fd99a1ac",
+ "sha256:5853eb494c71e267912275e5586fe281444eb5e722de4e131cddf9d442615626",
+ "sha256:5bce61af018b0cb2055e0e72e7d65290d822d3feee430b7b8203d8a855e78766",
+ "sha256:6295ecd49304dcf3bfbfa45d9a081c96509e95f4b9d0eb7ee4ec0530c4a96514",
+ "sha256:62ac9cc684da4cf1778d07a89bf5f81b35834cb96ca523d3a7fb32509380cbf6",
+ "sha256:70e7c2e7b750585569564e2e5ca9845acfaa5da56ac46df68414f29fea97be9f",
+ "sha256:7577c1987baa3adc4b3c62c33bd1118c3ef5c8ddef36f0f2c950ae0b199e100d",
+ "sha256:75facbe9606748f43428fc91a43edb46c7ff68889b91fa31f53b58894503a191",
+ "sha256:787518a6789009c159453da4d6b683f468ef7a65bbde796bcea803ccf191058d",
+ "sha256:78d6601aed50c74e0ef02f4204da1816147a6d3fbdc8b3872d263338a9052c51",
+ "sha256:7c43e1e1206cd421cd92e6b3280d4385d41d7166b3ed577ac20444b6995a445f",
+ "sha256:81e38381b782cc7e1e46c4e14cd997ee6040768101aefc8fa3c24a4cc58e98f8",
+ "sha256:841293b17ad704d70c578f1f0013c890e219952169ce8a24ebc063eecf775454",
+ "sha256:872b8ca05c40d309ed13eb2e582cab0c5a05e81e987ab9c521bf05ad1d5cf5cb",
+ "sha256:877272cf6b4b7e94c9614f9b10140e198d2186363728ed0f701c6eee1baec1da",
+ "sha256:8c808594c88a025d4e322d5bb549282c93c8e1ba71b790f539567932722d7bd8",
+ "sha256:8ed58b8acf29798b036d347791141767ccf65eee7f26bde03a71c944449e53de",
+ "sha256:91672bacaa030f92fc2f43b620d7b337fd9a5af28b0d6ed3f77afc43c4a64b5a",
+ "sha256:968f44fdbf6dd757d12920d63b566eeb4d5b395fd2d00d29d7ef00a00582aac9",
+ "sha256:9f85003f5dfa867e86d53fac6f7e6f30c045673fa27b603c397753bebadc3008",
+ "sha256:a553dadda40fef6bfa1456dc4be49b113aa92c2a9a9e8711e955618cd69622e3",
+ "sha256:a68b62a02953b9841730db7797422f983935aeefceb1679f0fc85cbfbd311c32",
+ "sha256:abbe9fa13da955feb8202e215c4018f4bb57469b1b78c7a4c5c7b93001699938",
+ "sha256:ad881edc7ccb9d65b0224f4e4d05a1e85cf62d73aab798943df6d48ab0cd79a1",
+ "sha256:b1792d939ec70abe76f5054d3f36ed5656021dcad1322d1cc996d4e54165cef9",
+ "sha256:b428ef021242344340460fa4c9185d0b1f66fbdbfecc6c63eff4b7c29fad429d",
+ "sha256:b533558eae785e33e8c148a8d9921692a9fe5aa516efbdff8606e7d87b9d5824",
+ "sha256:ba59c92039ec0a66103b1d5fe588fa546373587a7d68f5c96f743c3396afc04b",
+ "sha256:bc8d3bd6c72b2dd9decf16ce70e20abcb3274ba01b4e1c96031e0c4067d1e7cd",
+ "sha256:bc9db8a3efb3e403e4ecc6cd9489ea2bac94244f80c78e27c31dcc00d2790ac2",
+ "sha256:bf7d9fce9bcc4752ca4a1b80aabd38f6d19009ea5cbda0e0856983cf6d0023f5",
+ "sha256:c2dbb44c3f7e6c4d3487b31037b1bdbf424d97687c1747ce4ff2895795c9bf69",
+ "sha256:c79ebe8f3676a4c6630fd3f777f3cfecf9289666c84e775a67d1d358578dc2e3",
+ "sha256:c97528e64cb9ebeff9701e7938653a9951922f2a38bd847787d4a8e498cc83ae",
+ "sha256:d0611a0a2a518464c05ddd5a3a1a0e856ccc10e67079bb17f265ad19ab3c7597",
+ "sha256:d06adcfa62a4431d404c31216f0f8ac97397d799cd53800e9d3efc2fbb3cf14e",
+ "sha256:d41997519fcba4a1e46eb4a2fe31bc12f0ff957b2b81bac28db24744f333e955",
+ "sha256:d5b61785a9ce44e5a4b880272baa7cf6c8f48a5180c3e81c59553ba0cb0821ca",
+ "sha256:da152d8cdcab0e56e4f45eb08b9aea6455845ec83172092f09b0e077ece2cf7a",
+ "sha256:da7e547706e69e45d95e116e6939488d62174e033b763ab1496b4c29b76fabea",
+ "sha256:db5283d90da4174865d520e7366801a93777201e91e79bacbac6e6927cbceede",
+ "sha256:db608a6757adabb32f1cfe6066e39b3706d8c3aa69bbc353a5b61edad36a5cb4",
+ "sha256:e0ea21f66820452a3f5d1655f8704a60d66ba1191359b96541eaf457710a5fc6",
+ "sha256:e7da3fec7408813a7cebc9e4ec55afed2d0fd65c4754bc376bf03498d4e92686",
+ "sha256:e92a513161077b53447160b9bd8f522edfbed4bd9759e4c18ab05d7ef7e49408",
+ "sha256:ecb1fa0db7bf4cff9dac752abb19505a233c7f16684c5826d1f11ebd9472b871",
+ "sha256:efda5fc8cc1c61e4f639b8067d118e742b812c930f708e6667a5ce0d13499e29",
+ "sha256:f0a1dbdb5ecbef0d34eb77e56fcb3e95bbd7e50835d9782a45df81cc46949750",
+ "sha256:f0a71d85ecdd570ded8ac3d1c0f480842f49a40beb423bb8014539a9f32a5897",
+ "sha256:f4f270de01dd3e129a72efad823da90cc4d6aafb64c410c9033aba70db9f1ff0",
+ "sha256:f6cb459eea32a4e2cf18ba5fcece2dbdf496384413bc1bae15583f19e567f3b2",
+ "sha256:f8ad8285b01b0d4695102546b342b493b3ccc6781fc28c8c6a1bb63e95d22f09",
+ "sha256:f9f39e2f049db33a908319cf46624a569b36983c7c78318e9726a4cb8923b26c"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.4.4"
+ },
+ "lightning-utilities": {
+ "hashes": [
+ "sha256:22aa107b51c8f50ccef54d08885eb370903eb04148cddb2891b9c65c59de2a6e",
+ "sha256:8e5d95c7c57f026cdfed7c154303e88c93a7a5e868c9944cb02cf71f1db29720"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==0.8.0"
+ },
+ "mako": {
+ "hashes": [
+ "sha256:c97c79c018b9165ac9922ae4f32da095ffd3c4e6872b45eded42926deea46818",
+ "sha256:d60a3903dc3bb01a18ad6a89cdbe2e4eadc69c0bc8ef1e3773ba53d44c3f7a34"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.2.4"
+ },
+ "markdown": {
+ "hashes": [
+ "sha256:065fd4df22da73a625f14890dd77eb8040edcbd68794bcd35943be14490608b2",
+ "sha256:8bf101198e004dc93e84a12a7395e31aac6a9c9942848ae1d99b9d72cf9b3520"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==3.4.3"
+ },
+ "markupsafe": {
+ "hashes": [
+ "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e",
+ "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e",
+ "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431",
+ "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686",
+ "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559",
+ "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc",
+ "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c",
+ "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0",
+ "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4",
+ "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9",
+ "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575",
+ "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba",
+ "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d",
+ "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3",
+ "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00",
+ "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155",
+ "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac",
+ "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52",
+ "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f",
+ "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8",
+ "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b",
+ "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24",
+ "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea",
+ "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198",
+ "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0",
+ "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee",
+ "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be",
+ "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2",
+ "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707",
+ "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6",
+ "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58",
+ "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779",
+ "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636",
+ "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c",
+ "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad",
+ "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee",
+ "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc",
+ "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2",
+ "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48",
+ "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7",
+ "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e",
+ "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b",
+ "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa",
+ "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5",
+ "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e",
+ "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb",
+ "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9",
+ "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57",
+ "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc",
+ "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==2.1.3"
+ },
+ "matplotlib": {
+ "hashes": [
+ "sha256:08308bae9e91aca1ec6fd6dda66237eef9f6294ddb17f0d0b3c863169bf82353",
+ "sha256:14645aad967684e92fc349493fa10c08a6da514b3d03a5931a1bac26e6792bd1",
+ "sha256:21e9cff1a58d42e74d01153360de92b326708fb205250150018a52c70f43c290",
+ "sha256:28506a03bd7f3fe59cd3cd4ceb2a8d8a2b1db41afede01f66c42561b9be7b4b7",
+ "sha256:2bf092f9210e105f414a043b92af583c98f50050559616930d884387d0772aba",
+ "sha256:3032884084f541163f295db8a6536e0abb0db464008fadca6c98aaf84ccf4717",
+ "sha256:3a2cb34336110e0ed8bb4f650e817eed61fa064acbefeb3591f1b33e3a84fd96",
+ "sha256:3ba2af245e36990facf67fde840a760128ddd71210b2ab6406e640188d69d136",
+ "sha256:3d7bc90727351fb841e4d8ae620d2d86d8ed92b50473cd2b42ce9186104ecbba",
+ "sha256:438196cdf5dc8d39b50a45cb6e3f6274edbcf2254f85fa9b895bf85851c3a613",
+ "sha256:46a561d23b91f30bccfd25429c3c706afe7d73a5cc64ef2dfaf2b2ac47c1a5dc",
+ "sha256:4cf327e98ecf08fcbb82685acaf1939d3338548620ab8dfa02828706402c34de",
+ "sha256:4f99e1b234c30c1e9714610eb0c6d2f11809c9c78c984a613ae539ea2ad2eb4b",
+ "sha256:544764ba51900da4639c0f983b323d288f94f65f4024dc40ecb1542d74dc0500",
+ "sha256:56d94989191de3fcc4e002f93f7f1be5da476385dde410ddafbb70686acf00ea",
+ "sha256:57bfb8c8ea253be947ccb2bc2d1bb3862c2bccc662ad1b4626e1f5e004557042",
+ "sha256:617f14ae9d53292ece33f45cba8503494ee199a75b44de7717964f70637a36aa",
+ "sha256:6eb88d87cb2c49af00d3bbc33a003f89fd9f78d318848da029383bfc08ecfbfb",
+ "sha256:75d4725d70b7c03e082bbb8a34639ede17f333d7247f56caceb3801cb6ff703d",
+ "sha256:770a205966d641627fd5cf9d3cb4b6280a716522cd36b8b284a8eb1581310f61",
+ "sha256:7b73305f25eab4541bd7ee0b96d87e53ae9c9f1823be5659b806cd85786fe882",
+ "sha256:7c9a4b2da6fac77bcc41b1ea95fadb314e92508bf5493ceff058e727e7ecf5b0",
+ "sha256:81a6b377ea444336538638d31fdb39af6be1a043ca5e343fe18d0f17e098770b",
+ "sha256:83111e6388dec67822e2534e13b243cc644c7494a4bb60584edbff91585a83c6",
+ "sha256:8704726d33e9aa8a6d5215044b8d00804561971163563e6e6591f9dcf64340cc",
+ "sha256:89768d84187f31717349c6bfadc0e0d8c321e8eb34522acec8a67b1236a66332",
+ "sha256:8bf26ade3ff0f27668989d98c8435ce9327d24cffb7f07d24ef609e33d582439",
+ "sha256:8c587963b85ce41e0a8af53b9b2de8dddbf5ece4c34553f7bd9d066148dc719c",
+ "sha256:95cbc13c1fc6844ab8812a525bbc237fa1470863ff3dace7352e910519e194b1",
+ "sha256:97cc368a7268141afb5690760921765ed34867ffb9655dd325ed207af85c7529",
+ "sha256:a867bf73a7eb808ef2afbca03bcdb785dae09595fbe550e1bab0cd023eba3de0",
+ "sha256:b867e2f952ed592237a1828f027d332d8ee219ad722345b79a001f49df0936eb",
+ "sha256:c0bd19c72ae53e6ab979f0ac6a3fafceb02d2ecafa023c5cca47acd934d10be7",
+ "sha256:ce463ce590f3825b52e9fe5c19a3c6a69fd7675a39d589e8b5fbe772272b3a24",
+ "sha256:cf0e4f727534b7b1457898c4f4ae838af1ef87c359b76dcd5330fa31893a3ac7",
+ "sha256:def58098f96a05f90af7e92fd127d21a287068202aa43b2a93476170ebd99e87",
+ "sha256:e99bc9e65901bb9a7ce5e7bb24af03675cbd7c70b30ac670aa263240635999a4",
+ "sha256:eb7d248c34a341cd4c31a06fd34d64306624c8cd8d0def7abb08792a5abfd556",
+ "sha256:f67bfdb83a8232cb7a92b869f9355d677bce24485c460b19d01970b64b2ed476",
+ "sha256:f883a22a56a84dba3b588696a2b8a1ab0d2c3d41be53264115c71b0a942d8fdb",
+ "sha256:fbdeeb58c0cf0595efe89c05c224e0a502d1aa6a8696e68a73c3efc6bc354304"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==3.7.1"
+ },
+ "mlflow": {
+ "hashes": [
+ "sha256:355bf5c0214f9f137f0b3d78f2f0a7f284eebbb23a027ad3e686eb1b84c1bbe9",
+ "sha256:6598f78f7ece59a9480573af57b09708a3283ece88f196e5d172c2040cec323f"
+ ],
+ "index": "pypi",
+ "version": "==2.4.1"
+ },
+ "multidict": {
+ "hashes": [
+ "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9",
+ "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8",
+ "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03",
+ "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710",
+ "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161",
+ "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664",
+ "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569",
+ "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067",
+ "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313",
+ "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706",
+ "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2",
+ "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636",
+ "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49",
+ "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93",
+ "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603",
+ "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0",
+ "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60",
+ "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4",
+ "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e",
+ "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1",
+ "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60",
+ "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951",
+ "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc",
+ "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe",
+ "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95",
+ "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d",
+ "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8",
+ "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed",
+ "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2",
+ "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775",
+ "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87",
+ "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c",
+ "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2",
+ "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98",
+ "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3",
+ "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe",
+ "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78",
+ "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660",
+ "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176",
+ "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e",
+ "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988",
+ "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c",
+ "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c",
+ "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0",
+ "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449",
+ "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f",
+ "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde",
+ "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5",
+ "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d",
+ "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac",
+ "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a",
+ "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9",
+ "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca",
+ "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11",
+ "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35",
+ "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063",
+ "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b",
+ "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982",
+ "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258",
+ "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1",
+ "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52",
+ "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480",
+ "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7",
+ "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461",
+ "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d",
+ "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc",
+ "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779",
+ "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a",
+ "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547",
+ "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0",
+ "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171",
+ "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf",
+ "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d",
+ "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==6.0.4"
+ },
+ "numpy": {
+ "hashes": [
+ "sha256:0ac6edfb35d2a99aaf102b509c8e9319c499ebd4978df4971b94419a116d0790",
+ "sha256:26815c6c8498dc49d81faa76d61078c4f9f0859ce7817919021b9eba72b425e3",
+ "sha256:4aedd08f15d3045a4e9c648f1e04daca2ab1044256959f1f95aafeeb3d794c16",
+ "sha256:4c69fe5f05eea336b7a740e114dec995e2f927003c30702d896892403df6dbf0",
+ "sha256:5177310ac2e63d6603f659fadc1e7bab33dd5a8db4e0596df34214eeab0fee3b",
+ "sha256:5aa48bebfb41f93043a796128854b84407d4df730d3fb6e5dc36402f5cd594c0",
+ "sha256:5b1b90860bf7d8a8c313b372d4f27343a54f415b20fb69dd601b7efe1029c91e",
+ "sha256:6c284907e37f5e04d2412950960894b143a648dea3f79290757eb878b91acbd1",
+ "sha256:6d183b5c58513f74225c376643234c369468e02947b47942eacbb23c1671f25d",
+ "sha256:7412125b4f18aeddca2ecd7219ea2d2708f697943e6f624be41aa5f8a9852cc4",
+ "sha256:7cd981ccc0afe49b9883f14761bb57c964df71124dcd155b0cba2b591f0d64b9",
+ "sha256:85cdae87d8c136fd4da4dad1e48064d700f63e923d5af6c8c782ac0df8044542",
+ "sha256:8aa130c3042052d656751df5e81f6d61edff3e289b5994edcf77f54118a8d9f4",
+ "sha256:95367ccd88c07af21b379be1725b5322362bb83679d36691f124a16357390153",
+ "sha256:9c7211d7920b97aeca7b3773a6783492b5b93baba39e7c36054f6e749fc7490c",
+ "sha256:9e3f2b96e3b63c978bc29daaa3700c028fe3f049ea3031b58aa33fe2a5809d24",
+ "sha256:b76aa836a952059d70a2788a2d98cb2a533ccd46222558b6970348939e55fc24",
+ "sha256:b792164e539d99d93e4e5e09ae10f8cbe5466de7d759fc155e075237e0c274e4",
+ "sha256:c0dc071017bc00abb7d7201bac06fa80333c6314477b3d10b52b58fa6a6e38f6",
+ "sha256:cc3fda2b36482891db1060f00f881c77f9423eead4c3579629940a3e12095fe8",
+ "sha256:d6b267f349a99d3908b56645eebf340cb58f01bd1e773b4eea1a905b3f0e4208",
+ "sha256:d76a84998c51b8b68b40448ddd02bd1081bb33abcdc28beee6cd284fe11036c6",
+ "sha256:e559c6afbca484072a98a51b6fa466aae785cfe89b69e8b856c3191bc8872a82",
+ "sha256:ecc68f11404930e9c7ecfc937aa423e1e50158317bf67ca91736a9864eae0232",
+ "sha256:f1accae9a28dc3cda46a91de86acf69de0d1b5f4edd44a9b0c3ceb8036dfff19"
+ ],
+ "markers": "python_version >= '3.9'",
+ "version": "==1.25.0"
+ },
+ "nvidia-cublas-cu11": {
+ "hashes": [
+ "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e",
+ "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"
+ ],
+ "markers": "platform_system == 'Linux'",
+ "version": "==11.10.3.66"
+ },
+ "nvidia-cuda-nvrtc-cu11": {
+ "hashes": [
+ "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03",
+ "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3",
+ "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"
+ ],
+ "markers": "platform_system == 'Linux'",
+ "version": "==11.7.99"
+ },
+ "nvidia-cuda-runtime-cu11": {
+ "hashes": [
+ "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7",
+ "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"
+ ],
+ "markers": "platform_system == 'Linux'",
+ "version": "==11.7.99"
+ },
+ "nvidia-cudnn-cu11": {
+ "hashes": [
+ "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7",
+ "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"
+ ],
+ "markers": "platform_system == 'Linux'",
+ "version": "==8.5.0.96"
+ },
+ "oauthlib": {
+ "hashes": [
+ "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca",
+ "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==3.2.2"
+ },
+ "onnx": {
+ "hashes": [
+ "sha256:00b0d2620c10dcb9ec33441e807dc5851d2843d445e0faab5e22c8ad6874a67a",
+ "sha256:01893a4a2d70b68e8ee20269ccde4069a6fd243dc9e296643e2afeb0050527bc",
+ "sha256:0639427ac61e5a0181f4f7c89f9fc82b3c9715c95071f9c3de79bbe303a4ae65",
+ "sha256:0753b0f118be71ff109dd994a3d6769e5871e9feaddfada77931c63f9de534b3",
+ "sha256:18cd98f7e234e268cb60c47a1f8ea5f6ffba50fe11de924b17498b1571d0cd2c",
+ "sha256:1fe8ba794d261d722018bd1385f02f966aace0fcb5448881ab5dd55ab0ebb81b",
+ "sha256:296e689aa54a9ae4e560b2bb149a64e96775699a0624af5f631665b9cda90482",
+ "sha256:2fab7e6e1c2d9d6479edad8e9088cdfd87ea293cb08f31565adabfb33c6e5789",
+ "sha256:3315c304d23a06ebd07fffe2456ab7f1e0a8dba317393d5c17a671ae2da6645e",
+ "sha256:369c3ecace7e8c7df6efbcbc712b262626796ae4a83decd29111afafa025a30c",
+ "sha256:43b85087c6b919de66872a043c7f4899fe6f840e11ffca7e662b2ce9e4cc2927",
+ "sha256:45d3effe59e20d0a9fdc51f5bb8f38299086c79576b894ed945e6a058c4b210a",
+ "sha256:54614942574415ef3f0bce0800c6f41ecea8201f8042754e204ee8c0a8e473e1",
+ "sha256:5e780fd1ed25493596a141e93303d0b2897acb9ebfdee7047a916d8f8e525ab3",
+ "sha256:6e966f5ef38a0521595cad6a1d14d9ae205c593d2824d8c1fa044fa5ba15370d",
+ "sha256:6fbcdc1a0c1057785bc5f7254aca0cf0b49d19c74696f1ade107638054157315",
+ "sha256:7800b6ec74b1fe3fbb3bf4a2380e2f4007c1a7f2d6927599ad40eead6eae5e19",
+ "sha256:9d28d64cbac3ebdc0c9761a300340c60ec60316099906e354e5059e90335fb3b",
+ "sha256:a593b46015326feb949781d030cb1d0d5d388cca52bff2e2995badf55d56b38d",
+ "sha256:a8f7454acded506b6359ee0837c8527c64964973d7d25ed6b16b7d4314599502",
+ "sha256:a9702e7dd120bca421a820020151cbb1003077e17ded29cc8d44ff32a9a57ad8",
+ "sha256:ac1545159f2e7fbc5b4a3ae032cd4d9ddeafc62c4f27fe22cbc3ecff49338992",
+ "sha256:ba92fed1aa27cba385bc3890fbbe6484603e837e67c957b22899f93c70990cc4",
+ "sha256:bbdca51da9fa9ec43eebd8c640bf71c05daa2afbeaa2c6478466470e28e41111",
+ "sha256:c16dacf577700ff9cb076c61c880d1a4bc612eed96280396a54ee1e1bd7e2d68",
+ "sha256:cd683d4aa6d55365582055a6c1e10a55d6c08a59e9216cbb67e37ad3a5b2b980",
+ "sha256:d8c3a2354d9d997c7a4a5e467b5373c98dc549d4a33c77d5723e1eda7e87559c",
+ "sha256:dcfaeb2d15e93c456003fac13ffa35144ba9d2666a83e2cef650dd5c90a2b768",
+ "sha256:e1607f97007515df303c1f40b77363545af99a1f32d2f73240c8aa526cdbd109",
+ "sha256:ed099fbdada4accead109a4479d5f73fb974566cce8d3c6fca94774f9645934c",
+ "sha256:fb35c2c347486416f87f41557242c05d7ee804d3676c6c8c98eef6f5b1889e7b"
+ ],
+ "index": "pypi",
+ "version": "==1.14.0"
+ },
+ "packaging": {
+ "hashes": [
+ "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61",
+ "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==23.1"
+ },
+ "pandas": {
+ "hashes": [
+ "sha256:02755de164da6827764ceb3bbc5f64b35cb12394b1024fdf88704d0fa06e0e2f",
+ "sha256:0a1e0576611641acde15c2322228d138258f236d14b749ad9af498ab69089e2d",
+ "sha256:1eb09a242184092f424b2edd06eb2b99d06dc07eeddff9929e8667d4ed44e181",
+ "sha256:30a89d0fec4263ccbf96f68592fd668939481854d2ff9da709d32a047689393b",
+ "sha256:50e451932b3011b61d2961b4185382c92cc8c6ee4658dcd4f320687bb2d000ee",
+ "sha256:51a93d422fbb1bd04b67639ba4b5368dffc26923f3ea32a275d2cc450f1d1c86",
+ "sha256:598e9020d85a8cdbaa1815eb325a91cfff2bb2b23c1442549b8a3668e36f0f77",
+ "sha256:66d00300f188fa5de73f92d5725ced162488f6dc6ad4cecfe4144ca29debe3b8",
+ "sha256:69167693cb8f9b3fc060956a5d0a0a8dbfed5f980d9fd2c306fb5b9c855c814c",
+ "sha256:6d6d10c2142d11d40d6e6c0a190b1f89f525bcf85564707e31b0a39e3b398e08",
+ "sha256:713f2f70abcdade1ddd68fc91577cb090b3544b07ceba78a12f799355a13ee44",
+ "sha256:7376e13d28eb16752c398ca1d36ccfe52bf7e887067af9a0474de6331dd948d2",
+ "sha256:77550c8909ebc23e56a89f91b40ad01b50c42cfbfab49b3393694a50549295ea",
+ "sha256:7b21cb72958fc49ad757685db1919021d99650d7aaba676576c9e88d3889d456",
+ "sha256:9ebb9f1c22ddb828e7fd017ea265a59d80461d5a79154b49a4207bd17514d122",
+ "sha256:a18e5c72b989ff0f7197707ceddc99828320d0ca22ab50dd1b9e37db45b010c0",
+ "sha256:a6b5f14cd24a2ed06e14255ff40fe2ea0cfaef79a8dd68069b7ace74bd6acbba",
+ "sha256:b42b120458636a981077cfcfa8568c031b3e8709701315e2bfa866324a83efa8",
+ "sha256:c4af689352c4fe3d75b2834933ee9d0ccdbf5d7a8a7264f0ce9524e877820c08",
+ "sha256:c7319b6e68de14e6209460f72a8d1ef13c09fb3d3ef6c37c1e65b35d50b5c145",
+ "sha256:cf3f0c361a4270185baa89ec7ab92ecaa355fe783791457077473f974f654df5",
+ "sha256:dd46bde7309088481b1cf9c58e3f0e204b9ff9e3244f441accd220dd3365ce7c",
+ "sha256:dd5476b6c3fe410ee95926873f377b856dbc4e81a9c605a0dc05aaccc6a7c6c6",
+ "sha256:e69140bc2d29a8556f55445c15f5794490852af3de0f609a24003ef174528b79",
+ "sha256:f908a77cbeef9bbd646bd4b81214cbef9ac3dda4181d5092a4aa9797d1bc7774"
+ ],
+ "index": "pypi",
+ "version": "==2.0.2"
+ },
+ "pillow": {
+ "hashes": [
+ "sha256:07999f5834bdc404c442146942a2ecadd1cb6292f5229f4ed3b31e0a108746b1",
+ "sha256:0852ddb76d85f127c135b6dd1f0bb88dbb9ee990d2cd9aa9e28526c93e794fba",
+ "sha256:1781a624c229cb35a2ac31cc4a77e28cafc8900733a864870c49bfeedacd106a",
+ "sha256:1e7723bd90ef94eda669a3c2c19d549874dd5badaeefabefd26053304abe5799",
+ "sha256:229e2c79c00e85989a34b5981a2b67aa079fd08c903f0aaead522a1d68d79e51",
+ "sha256:22baf0c3cf0c7f26e82d6e1adf118027afb325e703922c8dfc1d5d0156bb2eeb",
+ "sha256:252a03f1bdddce077eff2354c3861bf437c892fb1832f75ce813ee94347aa9b5",
+ "sha256:2dfaaf10b6172697b9bceb9a3bd7b951819d1ca339a5ef294d1f1ac6d7f63270",
+ "sha256:322724c0032af6692456cd6ed554bb85f8149214d97398bb80613b04e33769f6",
+ "sha256:35f6e77122a0c0762268216315bf239cf52b88865bba522999dc38f1c52b9b47",
+ "sha256:375f6e5ee9620a271acb6820b3d1e94ffa8e741c0601db4c0c4d3cb0a9c224bf",
+ "sha256:3ded42b9ad70e5f1754fb7c2e2d6465a9c842e41d178f262e08b8c85ed8a1d8e",
+ "sha256:432b975c009cf649420615388561c0ce7cc31ce9b2e374db659ee4f7d57a1f8b",
+ "sha256:482877592e927fd263028c105b36272398e3e1be3269efda09f6ba21fd83ec66",
+ "sha256:489f8389261e5ed43ac8ff7b453162af39c3e8abd730af8363587ba64bb2e865",
+ "sha256:54f7102ad31a3de5666827526e248c3530b3a33539dbda27c6843d19d72644ec",
+ "sha256:560737e70cb9c6255d6dcba3de6578a9e2ec4b573659943a5e7e4af13f298f5c",
+ "sha256:5671583eab84af046a397d6d0ba25343c00cd50bce03787948e0fff01d4fd9b1",
+ "sha256:5ba1b81ee69573fe7124881762bb4cd2e4b6ed9dd28c9c60a632902fe8db8b38",
+ "sha256:5d4ebf8e1db4441a55c509c4baa7a0587a0210f7cd25fcfe74dbbce7a4bd1906",
+ "sha256:60037a8db8750e474af7ffc9faa9b5859e6c6d0a50e55c45576bf28be7419705",
+ "sha256:608488bdcbdb4ba7837461442b90ea6f3079397ddc968c31265c1e056964f1ef",
+ "sha256:6608ff3bf781eee0cd14d0901a2b9cc3d3834516532e3bd673a0a204dc8615fc",
+ "sha256:662da1f3f89a302cc22faa9f14a262c2e3951f9dbc9617609a47521c69dd9f8f",
+ "sha256:7002d0797a3e4193c7cdee3198d7c14f92c0836d6b4a3f3046a64bd1ce8df2bf",
+ "sha256:763782b2e03e45e2c77d7779875f4432e25121ef002a41829d8868700d119392",
+ "sha256:77165c4a5e7d5a284f10a6efaa39a0ae8ba839da344f20b111d62cc932fa4e5d",
+ "sha256:7c9af5a3b406a50e313467e3565fc99929717f780164fe6fbb7704edba0cebbe",
+ "sha256:7ec6f6ce99dab90b52da21cf0dc519e21095e332ff3b399a357c187b1a5eee32",
+ "sha256:833b86a98e0ede388fa29363159c9b1a294b0905b5128baf01db683672f230f5",
+ "sha256:84a6f19ce086c1bf894644b43cd129702f781ba5751ca8572f08aa40ef0ab7b7",
+ "sha256:8507eda3cd0608a1f94f58c64817e83ec12fa93a9436938b191b80d9e4c0fc44",
+ "sha256:85ec677246533e27770b0de5cf0f9d6e4ec0c212a1f89dfc941b64b21226009d",
+ "sha256:8aca1152d93dcc27dc55395604dcfc55bed5f25ef4c98716a928bacba90d33a3",
+ "sha256:8d935f924bbab8f0a9a28404422da8af4904e36d5c33fc6f677e4c4485515625",
+ "sha256:8f36397bf3f7d7c6a3abdea815ecf6fd14e7fcd4418ab24bae01008d8d8ca15e",
+ "sha256:91ec6fe47b5eb5a9968c79ad9ed78c342b1f97a091677ba0e012701add857829",
+ "sha256:965e4a05ef364e7b973dd17fc765f42233415974d773e82144c9bbaaaea5d089",
+ "sha256:96e88745a55b88a7c64fa49bceff363a1a27d9a64e04019c2281049444a571e3",
+ "sha256:99eb6cafb6ba90e436684e08dad8be1637efb71c4f2180ee6b8f940739406e78",
+ "sha256:9adf58f5d64e474bed00d69bcd86ec4bcaa4123bfa70a65ce72e424bfb88ed96",
+ "sha256:9b1af95c3a967bf1da94f253e56b6286b50af23392a886720f563c547e48e964",
+ "sha256:a0aa9417994d91301056f3d0038af1199eb7adc86e646a36b9e050b06f526597",
+ "sha256:a0f9bb6c80e6efcde93ffc51256d5cfb2155ff8f78292f074f60f9e70b942d99",
+ "sha256:a127ae76092974abfbfa38ca2d12cbeddcdeac0fb71f9627cc1135bedaf9d51a",
+ "sha256:aaf305d6d40bd9632198c766fb64f0c1a83ca5b667f16c1e79e1661ab5060140",
+ "sha256:aca1c196f407ec7cf04dcbb15d19a43c507a81f7ffc45b690899d6a76ac9fda7",
+ "sha256:ace6ca218308447b9077c14ea4ef381ba0b67ee78d64046b3f19cf4e1139ad16",
+ "sha256:b416f03d37d27290cb93597335a2f85ed446731200705b22bb927405320de903",
+ "sha256:bf548479d336726d7a0eceb6e767e179fbde37833ae42794602631a070d630f1",
+ "sha256:c1170d6b195555644f0616fd6ed929dfcf6333b8675fcca044ae5ab110ded296",
+ "sha256:c380b27d041209b849ed246b111b7c166ba36d7933ec6e41175fd15ab9eb1572",
+ "sha256:c446d2245ba29820d405315083d55299a796695d747efceb5717a8b450324115",
+ "sha256:c830a02caeb789633863b466b9de10c015bded434deb3ec87c768e53752ad22a",
+ "sha256:cb841572862f629b99725ebaec3287fc6d275be9b14443ea746c1dd325053cbd",
+ "sha256:cfa4561277f677ecf651e2b22dc43e8f5368b74a25a8f7d1d4a3a243e573f2d4",
+ "sha256:cfcc2c53c06f2ccb8976fb5c71d448bdd0a07d26d8e07e321c103416444c7ad1",
+ "sha256:d3c6b54e304c60c4181da1c9dadf83e4a54fd266a99c70ba646a9baa626819eb",
+ "sha256:d3d403753c9d5adc04d4694d35cf0391f0f3d57c8e0030aac09d7678fa8030aa",
+ "sha256:d9c206c29b46cfd343ea7cdfe1232443072bbb270d6a46f59c259460db76779a",
+ "sha256:e49eb4e95ff6fd7c0c402508894b1ef0e01b99a44320ba7d8ecbabefddcc5569",
+ "sha256:f8286396b351785801a976b1e85ea88e937712ee2c3ac653710a4a57a8da5d9c",
+ "sha256:f8fc330c3370a81bbf3f88557097d1ea26cd8b019d6433aa59f71195f5ddebbf",
+ "sha256:fbd359831c1657d69bb81f0db962905ee05e5e9451913b18b831febfe0519082",
+ "sha256:fe7e1c262d3392afcf5071df9afa574544f28eac825284596ac6db56e6d11062",
+ "sha256:fed1e1cf6a42577953abbe8e6cf2fe2f566daebde7c34724ec8803c4c0cda579"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==9.5.0"
+ },
+ "protobuf": {
+ "hashes": [
+ "sha256:0149053336a466e3e0b040e54d0b615fc71de86da66791c592cc3c8d18150bf8",
+ "sha256:08fe19d267608d438aa37019236db02b306e33f6b9902c3163838b8e75970223",
+ "sha256:29660574cd769f2324a57fb78127cda59327eb6664381ecfe1c69731b83e8288",
+ "sha256:2991f5e7690dab569f8f81702e6700e7364cc3b5e572725098215d3da5ccc6ac",
+ "sha256:3b01a5274ac920feb75d0b372d901524f7e3ad39c63b1a2d55043f3887afe0c1",
+ "sha256:3bcbeb2bf4bb61fe960dd6e005801a23a43578200ea8ceb726d1f6bd0e562ba1",
+ "sha256:447b9786ac8e50ae72cae7a2eec5c5df6a9dbf9aa6f908f1b8bda6032644ea62",
+ "sha256:514b6bbd54a41ca50c86dd5ad6488afe9505901b3557c5e0f7823a0cf67106fb",
+ "sha256:5cb9e41188737f321f4fce9a4337bf40a5414b8d03227e1d9fbc59bc3a216e35",
+ "sha256:7a92beb30600332a52cdadbedb40d33fd7c8a0d7f549c440347bc606fb3fe34b",
+ "sha256:84ea0bd90c2fdd70ddd9f3d3fc0197cc24ecec1345856c2b5ba70e4d99815359",
+ "sha256:aca6e86a08c5c5962f55eac9b5bd6fce6ed98645d77e8bfc2b952ecd4a8e4f6a",
+ "sha256:cc14358a8742c4e06b1bfe4be1afbdf5c9f6bd094dff3e14edb78a1513893ff5"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==4.23.3"
+ },
+ "pyarrow": {
+ "hashes": [
+ "sha256:051f9f5ccf585f12d7de836e50965b3c235542cc896959320d9776ab93f3b33d",
+ "sha256:1887bdae17ec3b4c046fcf19951e71b6a619f39fa674f9881216173566c8f718",
+ "sha256:2d3c4cbbf81e6dd23fe921bc91dc4619ea3b79bc58ef10bce0f49bdafb103daf",
+ "sha256:345e1828efdbd9aa4d4de7d5676778aba384a2c3add896d995b23d368e60e5af",
+ "sha256:3de26da901216149ce086920547dfff5cd22818c9eab67ebc41e863a5883bac7",
+ "sha256:43364daec02f69fec89d2315f7fbfbeec956e0d991cbbef471681bd77875c40f",
+ "sha256:459a1c0ed2d68671188b2118c63bac91eaef6fc150c77ddd8a583e3c795737bf",
+ "sha256:6251e38470da97a5b2e00de5c6a049149f7b2bd62f12fa5dbb9ac674119ba71a",
+ "sha256:6895b5fb74289d055c43db3af0de6e16b07586c45763cb5e558d38b86a91e3a7",
+ "sha256:6d288029a94a9bb5407ceebdd7110ba398a00412c5b0155ee9813a40d246c5df",
+ "sha256:749be7fd2ff260683f9cc739cb862fb11be376de965a2a8ccbf2693b098db6c7",
+ "sha256:85e705e33eaf666bbe508a16fd5ba27ca061e177916b7a317ba5a51bee43384c",
+ "sha256:8d6009fdf8986332b2169314da482baed47ac053311c8934ac6651e614deacd6",
+ "sha256:9120c3eb2b1f6f516a3b7a9714ed860882d9ef98c4b17edcdc91d95b7528db60",
+ "sha256:a3c63124fc26bf5f95f508f5d04e1ece8cc23a8b0af2a1e6ab2b1ec3fdc91b24",
+ "sha256:b13329f79fa4472324f8d32dc1b1216616d09bd1e77cfb13104dec5463632c36",
+ "sha256:bb656150d3d12ec1396f6dde542db1675a95c0cc8366d507347b0beed96e87ca",
+ "sha256:be2757e9275875d2a9c6e6052ac7957fbbfc7bc7370e4a036a9b893e96fedaba",
+ "sha256:c780f4dc40460015d80fcd6a6140de80b615349ed68ef9adb653fe351778c9b3",
+ "sha256:cce317fc96e5b71107bf1f9f184d5e54e2bd14bbf3f9a3d62819961f0af86fec",
+ "sha256:cdacf515ec276709ac8042c7d9bd5be83b4f5f39c6c037a17a60d7ebfd92c890",
+ "sha256:ce4aebdf412bd0eeb800d8e47db854f9f9f7e2f5a0220440acf219ddfddd4f63",
+ "sha256:cf812306d66f40f69e684300f7af5111c11f6e0d89d6b733e05a3de44961529d",
+ "sha256:e0d8730c7f6e893f6db5d5b86eda42c0a130842d101992b581e2138e4d5663d3",
+ "sha256:e2c9cb8eeabbadf5fcfc3d1ddea616c7ce893db2ce4dcef0ac13b099ad7ca082"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==12.0.1"
+ },
+ "pyjwt": {
+ "hashes": [
+ "sha256:ba2b425b15ad5ef12f200dc67dd56af4e26de2331f965c5439994dad075876e1",
+ "sha256:bd6ca4a3c4285c1a2d4349e5a035fdf8fb94e04ccd0fcbe6ba289dae9cc3e074"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==2.7.0"
+ },
+ "pyparsing": {
+ "hashes": [
+ "sha256:d554a96d1a7d3ddaf7183104485bc19fd80543ad6ac5bdb6426719d766fb06c1",
+ "sha256:edb662d6fe322d6e990b1594b5feaeadf806803359e3d4d42f11e295e588f0ea"
+ ],
+ "markers": "python_full_version >= '3.6.8'",
+ "version": "==3.1.0"
+ },
+ "pysocks": {
+ "hashes": [
+ "sha256:08e69f092cc6dbe92a0fdd16eeb9b9ffbc13cadfe5ca4c7bd92ffb078b293299",
+ "sha256:2725bd0a9925919b9b51739eea5f9e2bae91e83288108a9ad338b2e3a4435ee5",
+ "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0"
+ ],
+ "version": "==1.7.1"
+ },
+ "python-dateutil": {
+ "hashes": [
+ "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86",
+ "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"
+ ],
+ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
+ "version": "==2.8.2"
+ },
+ "python-slugify": {
+ "hashes": [
+ "sha256:70ca6ea68fe63ecc8fa4fcf00ae651fc8a5d02d93dcd12ae6d4fc7ca46c4d395",
+ "sha256:ce0d46ddb668b3be82f4ed5e503dbc33dd815d83e2eb6824211310d3fb172a27"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==8.0.1"
+ },
+ "pytorch-lightning": {
+ "hashes": [
+ "sha256:6b1b04c818250ebfe5cde576b3e559f4d00c84c2f80598ef40224b2a749a5398",
+ "sha256:e766975a6beaf40808a577a997663ca349adbf6a544c3de18dbcbdc2c5f7d47d"
+ ],
+ "index": "pypi",
+ "version": "==2.0.4"
+ },
+ "pytz": {
+ "hashes": [
+ "sha256:1d8ce29db189191fb55338ee6d0387d82ab59f3d00eac103412d64e0ebd0c588",
+ "sha256:a151b3abb88eda1d4e34a9814df37de2a80e301e68ba0fd856fb9b46bfbbbffb"
+ ],
+ "version": "==2023.3"
+ },
+ "pyyaml": {
+ "hashes": [
+ "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf",
+ "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293",
+ "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b",
+ "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57",
+ "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b",
+ "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4",
+ "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07",
+ "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba",
+ "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9",
+ "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287",
+ "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513",
+ "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0",
+ "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782",
+ "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0",
+ "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92",
+ "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f",
+ "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2",
+ "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc",
+ "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1",
+ "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c",
+ "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86",
+ "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4",
+ "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c",
+ "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34",
+ "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b",
+ "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d",
+ "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c",
+ "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb",
+ "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7",
+ "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737",
+ "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3",
+ "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d",
+ "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358",
+ "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53",
+ "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78",
+ "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803",
+ "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a",
+ "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f",
+ "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174",
+ "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==6.0"
+ },
+ "querystring-parser": {
+ "hashes": [
+ "sha256:644fce1cffe0530453b43a83a38094dbe422ccba8c9b2f2a1c00280e14ca8a62",
+ "sha256:d2fa90765eaf0de96c8b087872991a10238e89ba015ae59fedfed6bd61c242a0"
+ ],
+ "version": "==1.2.4"
+ },
+ "requests": {
+ "extras": [
+ "socks"
+ ],
+ "hashes": [
+ "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f",
+ "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==2.31.0"
+ },
+ "scikit-learn": {
+ "hashes": [
+ "sha256:065e9673e24e0dc5113e2dd2b4ca30c9d8aa2fa90f4c0597241c93b63130d233",
+ "sha256:2dd3ffd3950e3d6c0c0ef9033a9b9b32d910c61bd06cb8206303fb4514b88a49",
+ "sha256:2e2642baa0ad1e8f8188917423dd73994bf25429f8893ddbe115be3ca3183584",
+ "sha256:44b47a305190c28dd8dd73fc9445f802b6ea716669cfc22ab1eb97b335d238b1",
+ "sha256:6477eed40dbce190f9f9e9d0d37e020815825b300121307942ec2110302b66a3",
+ "sha256:6fe83b676f407f00afa388dd1fdd49e5c6612e551ed84f3b1b182858f09e987d",
+ "sha256:7d5312d9674bed14f73773d2acf15a3272639b981e60b72c9b190a0cffed5bad",
+ "sha256:7f69313884e8eb311460cc2f28676d5e400bd929841a2c8eb8742ae78ebf7c20",
+ "sha256:8156db41e1c39c69aa2d8599ab7577af53e9e5e7a57b0504e116cc73c39138dd",
+ "sha256:8429aea30ec24e7a8c7ed8a3fa6213adf3814a6efbea09e16e0a0c71e1a1a3d7",
+ "sha256:8b0670d4224a3c2d596fd572fb4fa673b2a0ccfb07152688ebd2ea0b8c61025c",
+ "sha256:953236889928d104c2ef14027539f5f2609a47ebf716b8cbe4437e85dce42744",
+ "sha256:99cc01184e347de485bf253d19fcb3b1a3fb0ee4cea5ee3c43ec0cc429b6d29f",
+ "sha256:9c710ff9f9936ba8a3b74a455ccf0dcf59b230caa1e9ba0223773c490cab1e51",
+ "sha256:ad66c3848c0a1ec13464b2a95d0a484fd5b02ce74268eaa7e0c697b904f31d6c",
+ "sha256:bf036ea7ef66115e0d49655f16febfa547886deba20149555a41d28f56fd6d3c",
+ "sha256:dfeaf8be72117eb61a164ea6fc8afb6dfe08c6f90365bde2dc16456e4bc8e45f",
+ "sha256:e6e574db9914afcb4e11ade84fab084536a895ca60aadea3041e85b8ac963edb",
+ "sha256:ea061bf0283bf9a9f36ea3c5d3231ba2176221bbd430abd2603b1c3b2ed85c89",
+ "sha256:fe0aa1a7029ed3e1dcbf4a5bc675aa3b1bc468d9012ecf6c6f081251ca47f590",
+ "sha256:fe175ee1dab589d2e1033657c5b6bec92a8a3b69103e3dd361b58014729975c3"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==1.2.2"
+ },
+ "scipy": {
+ "hashes": [
+ "sha256:2c29bae479b17d85208dfdfc67e50d5944ee23211f236728aadde9b0b7c1c33e",
+ "sha256:2e4f14c11fbf825319dbd7f467639a241e7c956c34edb1e036ec7bb6271e4f7b",
+ "sha256:586608ea35206257d4e0ce6f154a6cfef71723b2c1f6d40de5e0b0e8a81cd2ff",
+ "sha256:6302c7cba5bf99c901653ff158746625526cc438f058bce41514d7469b79b2c3",
+ "sha256:6666a1e31b2123a077f0dc7ab1053e36479cfd457fb9f5c367e7198505c6607a",
+ "sha256:684d44607eacd5dd367c7a9e76e922523fa9c0a7f2379a4d0fc4d70d751464cc",
+ "sha256:7a92bd3cd4acad2e0e0b360176d5ec68b100983c8145add8a8233acddf4e5fcc",
+ "sha256:80015b8928f91bd40377b2b1010ba2e09b03680cbfc291208740494aeb8debf2",
+ "sha256:83867a63515c4e3fce3272d81200dda614d70f4c3a22f047d84021bfe83d7929",
+ "sha256:894ced9a2cdb050ff5e392f274617af46dca896d5c9112fa4a2019929554d321",
+ "sha256:a53f9cebcfda6158c241c35a559407a4ef6b8cb0863eb4144958fe0a0b7c3dae",
+ "sha256:b269ed44e2e2e43611f2ae95ba551fd98abbdc1a7ea8268f72f75876982368c4",
+ "sha256:c61ea63124da6a3cff38126426912cc86420898b4902a9bc5e5b6524547a6dcb",
+ "sha256:ccc70892ea674f93183c5c4139557b611e42f644dd755da4b19ca974ab770672",
+ "sha256:d8e631c3c49c24f30828580b8126fe3be5cca5409dad5b797418a5b8965eeafa",
+ "sha256:ebf4b2ea26d50312731ddba2406389c5ddcbff9d777cf3277ea11decc81e5dfb",
+ "sha256:f0c9c160d117fe71cd2a12ef21cce8e0475ade2fd97c761ef327b9839089bd16",
+ "sha256:f9b0248cb9d08eead44cde47cbf6339f1e9aa0dfde28f5fb27950743e317bd5d",
+ "sha256:fad4006248513528e0c496de295a9f4d2b65086cc0e388f748e7dbf49fa12760"
+ ],
+ "markers": "python_version < '3.13' and python_version >= '3.9'",
+ "version": "==1.11.0"
+ },
+ "setuptools": {
+ "hashes": [
+ "sha256:11e52c67415a381d10d6b462ced9cfb97066179f0e871399e006c4ab101fc85f",
+ "sha256:baf1fdb41c6da4cd2eae722e135500da913332ab3f2f5c7d33af9b492acb5235"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==68.0.0"
+ },
+ "six": {
+ "hashes": [
+ "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926",
+ "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"
+ ],
+ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
+ "version": "==1.16.0"
+ },
+ "smmap": {
+ "hashes": [
+ "sha256:2aba19d6a040e78d8b09de5c57e96207b09ed71d8e55ce0959eeee6c8e190d94",
+ "sha256:c840e62059cd3be204b0c9c9f74be2c09d5648eddd4580d9314c3ecde0b30936"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==5.0.0"
+ },
+ "soupsieve": {
+ "hashes": [
+ "sha256:1c1bfee6819544a3447586c889157365a27e10d88cde3ad3da0cf0ddf646feb8",
+ "sha256:89d12b2d5dfcd2c9e8c22326da9d9aa9cb3dfab0a83a024f05704076ee8d35ea"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==2.4.1"
+ },
+ "sqlalchemy": {
+ "hashes": [
+ "sha256:04383f1e3452f6739084184e427e9d5cb4e68ddc765d52157bf5ef30d5eca14f",
+ "sha256:125f9f7e62ddf8b590c069729080ffe18b68a20d9882eb0947f72e06274601d7",
+ "sha256:1822620c89779b85f7c23d535c8e04b79c517739ae07aaed48c81e591ed5498e",
+ "sha256:21583808d37f126a647652c90332ac1d3a102edf3c94bcc3319edcc0ea2300cc",
+ "sha256:218fb20c01e95004f50a3062bf4c447dcb360cab8274232f31947e254f118298",
+ "sha256:2269b1f9b8be47e52b70936069a25a3771eff53367aa5cc59bb94f28a6412e13",
+ "sha256:234678ed6576531b8e4be255b980f20368bf07241a2e67b84e6b0fe679edb9c4",
+ "sha256:28da17059ecde53e2d10ba813d38db942b9f6344360b2958b25872d5cb729d35",
+ "sha256:2c6ff5767d954f6091113fedcaaf49cdec2197ae4c5301fe83d5ae4393c82f33",
+ "sha256:36a87e26fe8fa8c466fae461a8fcb780d0a1cbf8206900759fc6fe874475a3ce",
+ "sha256:394ac3adf3676fad76d4b8fcecddf747627f17f0738dc94bac15f303d05b03d4",
+ "sha256:40a3dc52b2b16f08b5c16b9ee7646329e4b3411e9280e5e8d57b19eaa51cbef4",
+ "sha256:48111d56afea5699bab72c38ec95561796b81befff9e13d1dd5ce251ab25f51d",
+ "sha256:48b40dc2895841ea89d89df9eb3ac69e2950a659db20a369acf4259f68e6dc1f",
+ "sha256:513411d73503a6fc5804f01fae3b3d44f267c1b3a06cfeac02e9286a7330e857",
+ "sha256:51736cfb607cf4e8fafb693906f9bc4e5ee55be0b096d44bd7f20cd8489b8571",
+ "sha256:5f40e3a7d0a464f1c8593f2991e5520b2f5b26da24e88000bbd4423f86103d4f",
+ "sha256:6150560fcffc6aee5ec9a97419ac768c7a9f56baf7a7eb59cb4b1b6a4d463ad9",
+ "sha256:724355973297bbe547f3eb98b46ade65a67a3d5a6303f17ab59a2dc6fb938943",
+ "sha256:74ddcafb6488f382854a7da851c404c394be3729bb3d91b02ad86c5458140eff",
+ "sha256:7830e01b02d440c27f2a5be68296e74ccb55e6a5b5962ffafd360b98930b2e5e",
+ "sha256:7f31d4e7ca1dd8ca5a27fd5eaa0f9e2732fe769ff7dd35bf7bba179597e4df07",
+ "sha256:8741d3d401383e54b2aada37cbd10f55c5d444b360eae3a82f74a2be568a7710",
+ "sha256:910d45bf3673f0e4ef13858674bd23cfdafdc8368b45b948bf511797dbbb401d",
+ "sha256:aa995b21f853864996e4056d9fde479bcecf8b7bff4beb3555eebbbba815f35d",
+ "sha256:af7e2ba75bf84b64adb331918188dda634689a2abb151bc1a583e488363fd2f8",
+ "sha256:b0eaf82cc844f6b46defe15ad243ea00d1e39ed3859df61130c263dc7204da6e",
+ "sha256:b114a16bc03dfe20b625062e456affd7b9938286e05a3f904a025b9aacc29dd4",
+ "sha256:b47be4c6281a86670ea5cfbbbe6c3a65366a8742f5bc8b986f790533c60b5ddb",
+ "sha256:ba03518e64d86f000dc24ab3d3a1aa876bcbaa8aa15662ac2df5e81537fa3394",
+ "sha256:cc9c2630c423ac4973492821b2969f5fe99d9736f3025da670095668fbfcd4d5",
+ "sha256:cf07ff9920cb3ca9d73525dfd4f36ddf9e1a83734ea8b4f724edfd9a2c6e82d9",
+ "sha256:cf175d26f6787cce30fe6c04303ca0aeeb0ad40eeb22e3391f24b32ec432a1e1",
+ "sha256:d0aeb3afaa19f187a70fa592fbe3c20a056b57662691fd3abf60f016aa5c1848",
+ "sha256:e186e9e95fb5d993b075c33fe4f38a22105f7ce11cecb5c17b5618181e356702",
+ "sha256:e2d5c3596254cf1a96474b98e7ce20041c74c008b0f101c1cb4f8261cb77c6d3",
+ "sha256:e3189432db2f5753b4fde1aa90a61c69976f4e7e31d1cf4611bfe3514ed07478",
+ "sha256:e3a6b2788f193756076061626679c5c5a6d600ddf8324f986bc72004c3e9d92e",
+ "sha256:ead58cae2a089eee1b0569060999cb5f2b2462109498a0937cc230a7556945a1",
+ "sha256:f2f389f77c68dc22cb51f026619291c4a38aeb4b7ecb5f998fd145b2d81ca513",
+ "sha256:f593170fc09c5abb1205a738290b39532f7380094dc151805009a07ae0e85330"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==2.0.17"
+ },
+ "sqlparse": {
+ "hashes": [
+ "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3",
+ "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c"
+ ],
+ "markers": "python_version >= '3.5'",
+ "version": "==0.4.4"
+ },
+ "tabulate": {
+ "hashes": [
+ "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c",
+ "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==0.9.0"
+ },
+ "tensorboardx": {
+ "hashes": [
+ "sha256:02e2b84d7dc102edb7a052c77041db30fd6ba9b990635178919b8e9cfa157e96",
+ "sha256:4960feb79b1b84fd2b020885b09fd70962caec277d4bc194f338a6c203cd78ca"
+ ],
+ "index": "pypi",
+ "version": "==2.6.1"
+ },
+ "text-unidecode": {
+ "hashes": [
+ "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8",
+ "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93"
+ ],
+ "version": "==1.3"
+ },
+ "threadpoolctl": {
+ "hashes": [
+ "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b",
+ "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==3.1.0"
+ },
+ "torch": {
+ "hashes": [
+ "sha256:0122806b111b949d21fa1a5f9764d1fd2fcc4a47cb7f8ff914204fd4fc752ed5",
+ "sha256:0aa46f0ac95050c604bcf9ef71da9f1172e5037fdf2ebe051962d47b123848e7",
+ "sha256:0d9b8061048cfb78e675b9d2ea8503bfe30db43d583599ae8626b1263a0c1380",
+ "sha256:22128502fd8f5b25ac1cd849ecb64a418382ae81dd4ce2b5cebaa09ab15b0d9b",
+ "sha256:2c3581a3fd81eb1f0f22997cddffea569fea53bafa372b2c0471db373b26aafc",
+ "sha256:2ee7b81e9c457252bddd7d3da66fb1f619a5d12c24d7074de91c4ddafb832c93",
+ "sha256:33e67eea526e0bbb9151263e65417a9ef2d8fa53cbe628e87310060c9dcfa312",
+ "sha256:393a6273c832e047581063fb74335ff50b4c566217019cc6ace318cd79eb0566",
+ "sha256:50ff5e76d70074f6653d191fe4f6a42fdbe0cf942fbe2a3af0b75eaa414ac038",
+ "sha256:5e1e722a41f52a3f26f0c4fcec227e02c6c42f7c094f32e49d4beef7d1e213ea",
+ "sha256:6930791efa8757cb6974af73d4996b6b50c592882a324b8fb0589c6a9ba2ddaf",
+ "sha256:727dbf00e2cf858052364c0e2a496684b9cb5aa01dc8a8bc8bbb7c54502bdcdd",
+ "sha256:76024be052b659ac1304ab8475ab03ea0a12124c3e7626282c9c86798ac7bc11",
+ "sha256:98124598cdff4c287dbf50f53fb455f0c1e3a88022b39648102957f3445e9b76",
+ "sha256:d9fe785d375f2e26a5d5eba5de91f89e6a3be5d11efb497e76705fdf93fa3c2e",
+ "sha256:df8434b0695e9ceb8cc70650afc1310d8ba949e6db2a0525ddd9c3b2b181e5fe",
+ "sha256:e0df902a7c7dd6c795698532ee5970ce898672625635d885eade9976e5a04949",
+ "sha256:ea8dda84d796094eb8709df0fcd6b56dc20b58fdd6bc4e8d7109930dafc8e419",
+ "sha256:eeeb204d30fd40af6a2d80879b46a7efbe3cf43cdbeb8838dd4f3d126cc90b2b",
+ "sha256:f402ca80b66e9fbd661ed4287d7553f7f3899d9ab54bf5c67faada1555abde28",
+ "sha256:fd12043868a34a8da7d490bf6db66991108b00ffbeecb034228bfcbbd4197143"
+ ],
+ "index": "pypi",
+ "version": "==1.13.1"
+ },
+ "torchmetrics": {
+ "hashes": [
+ "sha256:1fe45a14b44dd65d90199017dd5a4b5a128d56a8a311da7916c402c18c671494",
+ "sha256:45f892f3534e91f3ad9e2488d1b05a93b7cb76b7d037969435a41a1f24750d9a"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==0.11.4"
+ },
+ "tqdm": {
+ "hashes": [
+ "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5",
+ "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==4.65.0"
+ },
+ "typing-extensions": {
+ "hashes": [
+ "sha256:88a4153d8505aabbb4e13aacb7c486c2b4a33ca3b3f807914a9b4c844c471c26",
+ "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==4.6.3"
+ },
+ "tzdata": {
+ "hashes": [
+ "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a",
+ "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"
+ ],
+ "markers": "python_version >= '2'",
+ "version": "==2023.3"
+ },
+ "urllib3": {
+ "hashes": [
+ "sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f",
+ "sha256:8f135f6502756bde6b2a9b28989df5fbe87c9970cecaa69041edcce7f0589b14"
+ ],
+ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'",
+ "version": "==1.26.16"
+ },
+ "websocket-client": {
+ "hashes": [
+ "sha256:c951af98631d24f8df89ab1019fc365f2227c0892f12fd150e935607c79dd0dd",
+ "sha256:f1f9f2ad5291f0225a49efad77abf9e700b6fef553900623060dad6e26503b9d"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.6.1"
+ },
+ "werkzeug": {
+ "hashes": [
+ "sha256:935539fa1413afbb9195b24880778422ed620c0fc09670945185cce4d91a8890",
+ "sha256:98c774df2f91b05550078891dee5f0eb0cb797a522c757a2452b9cee5b202330"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==2.3.6"
+ },
+ "wheel": {
+ "hashes": [
+ "sha256:cd1196f3faee2b31968d626e1731c94f99cbdb67cf5a46e4f5656cbee7738873",
+ "sha256:d236b20e7cb522daf2390fa84c55eea81c5c30190f90f29ae2ca1ad8355bf247"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==0.40.0"
+ },
+ "yarl": {
+ "hashes": [
+ "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571",
+ "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3",
+ "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3",
+ "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c",
+ "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7",
+ "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04",
+ "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191",
+ "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea",
+ "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4",
+ "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4",
+ "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095",
+ "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e",
+ "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74",
+ "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef",
+ "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33",
+ "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde",
+ "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45",
+ "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf",
+ "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b",
+ "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac",
+ "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0",
+ "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528",
+ "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716",
+ "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb",
+ "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18",
+ "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72",
+ "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6",
+ "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582",
+ "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5",
+ "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368",
+ "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc",
+ "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9",
+ "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be",
+ "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a",
+ "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80",
+ "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8",
+ "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6",
+ "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417",
+ "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574",
+ "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59",
+ "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608",
+ "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82",
+ "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1",
+ "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3",
+ "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d",
+ "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8",
+ "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc",
+ "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac",
+ "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8",
+ "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955",
+ "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0",
+ "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367",
+ "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb",
+ "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a",
+ "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623",
+ "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2",
+ "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6",
+ "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7",
+ "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4",
+ "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051",
+ "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938",
+ "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8",
+ "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9",
+ "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3",
+ "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5",
+ "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9",
+ "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333",
+ "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185",
+ "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3",
+ "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560",
+ "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b",
+ "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7",
+ "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78",
+ "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.9.2"
+ },
+ "zipp": {
+ "hashes": [
+ "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b",
+ "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==3.15.0"
+ }
+ },
+ "develop": {
+ "asttokens": {
+ "hashes": [
+ "sha256:4622110b2a6f30b77e1473affaa97e711bc2f07d3f10848420ff1898edbe94f3",
+ "sha256:6b0ac9e93fb0335014d382b8fa9b3afa7df546984258005da0b9e7095b3deb1c"
+ ],
+ "version": "==2.2.1"
+ },
+ "autopep8": {
+ "hashes": [
+ "sha256:86e9303b5e5c8160872b2f5ef611161b2893e9bfe8ccc7e2f76385947d57a2f1",
+ "sha256:f9849cdd62108cb739dbcdbfb7fdcc9a30d1b63c4cc3e1c1f893b5360941b61c"
+ ],
+ "index": "pypi",
+ "version": "==2.0.2"
+ },
+ "backcall": {
+ "hashes": [
+ "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e",
+ "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"
+ ],
+ "version": "==0.2.0"
+ },
+ "comm": {
+ "hashes": [
+ "sha256:16613c6211e20223f215fc6d3b266a247b6e2641bf4e0a3ad34cb1aff2aa3f37",
+ "sha256:a61efa9daffcfbe66fd643ba966f846a624e4e6d6767eda9cf6e993aadaab93e"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==0.1.3"
+ },
+ "debugpy": {
+ "hashes": [
+ "sha256:0679b7e1e3523bd7d7869447ec67b59728675aadfc038550a63a362b63029d2c",
+ "sha256:279d64c408c60431c8ee832dfd9ace7c396984fd7341fa3116aee414e7dcd88d",
+ "sha256:33edb4afa85c098c24cc361d72ba7c21bb92f501104514d4ffec1fb36e09c01a",
+ "sha256:38ed626353e7c63f4b11efad659be04c23de2b0d15efff77b60e4740ea685d07",
+ "sha256:5224eabbbeddcf1943d4e2821876f3e5d7d383f27390b82da5d9558fd4eb30a9",
+ "sha256:53f7a456bc50706a0eaabecf2d3ce44c4d5010e46dfc65b6b81a518b42866267",
+ "sha256:9cd10cf338e0907fdcf9eac9087faa30f150ef5445af5a545d307055141dd7a4",
+ "sha256:aaf6da50377ff4056c8ed470da24632b42e4087bc826845daad7af211e00faad",
+ "sha256:b3e7ac809b991006ad7f857f016fa92014445085711ef111fdc3f74f66144096",
+ "sha256:bae1123dff5bfe548ba1683eb972329ba6d646c3a80e6b4c06cd1b1dd0205e9b",
+ "sha256:c0ff93ae90a03b06d85b2c529eca51ab15457868a377c4cc40a23ab0e4e552a3",
+ "sha256:c4c2f0810fa25323abfdfa36cbbbb24e5c3b1a42cb762782de64439c575d67f2",
+ "sha256:d71b31117779d9a90b745720c0eab54ae1da76d5b38c8026c654f4a066b0130a",
+ "sha256:dbe04e7568aa69361a5b4c47b4493d5680bfa3a911d1e105fbea1b1f23f3eb45",
+ "sha256:de86029696e1b3b4d0d49076b9eba606c226e33ae312a57a46dca14ff370894d",
+ "sha256:e3876611d114a18aafef6383695dfc3f1217c98a9168c1aaf1a02b01ec7d8d1e",
+ "sha256:ed6d5413474e209ba50b1a75b2d9eecf64d41e6e4501977991cdc755dc83ab0f",
+ "sha256:f90a2d4ad9a035cee7331c06a4cf2245e38bd7c89554fe3b616d90ab8aab89cc"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.6.7"
+ },
+ "decorator": {
+ "hashes": [
+ "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330",
+ "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"
+ ],
+ "markers": "python_version >= '3.5'",
+ "version": "==5.1.1"
+ },
+ "exceptiongroup": {
+ "hashes": [
+ "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e",
+ "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"
+ ],
+ "markers": "python_version < '3.11'",
+ "version": "==1.1.1"
+ },
+ "executing": {
+ "hashes": [
+ "sha256:0314a69e37426e3608aada02473b4161d4caf5a4b244d1d0c48072b8fee7bacc",
+ "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107"
+ ],
+ "version": "==1.2.0"
+ },
+ "importlib-metadata": {
+ "hashes": [
+ "sha256:1aaf550d4f73e5d6783e7acb77aec43d49da8017410afae93822cc9cca98c4d4",
+ "sha256:cb52082e659e97afc5dac71e79de97d8681de3aa07ff18578330904a9d18e5b5"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==6.7.0"
+ },
+ "iniconfig": {
+ "hashes": [
+ "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3",
+ "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==2.0.0"
+ },
+ "ipykernel": {
+ "hashes": [
+ "sha256:bc00662dc44d4975b668cdb5fefb725e38e9d8d6e28441a519d043f38994922d",
+ "sha256:dd4e18116357f36a1e459b3768412371bee764c51844cbf25c4ed1eb9cae4a54"
+ ],
+ "index": "pypi",
+ "version": "==6.23.3"
+ },
+ "ipython": {
+ "hashes": [
+ "sha256:1d197b907b6ba441b692c48cf2a3a2de280dc0ac91a3405b39349a50272ca0a1",
+ "sha256:248aca623f5c99a6635bc3857677b7320b9b8039f99f070ee0d20a5ca5a8e6bf"
+ ],
+ "markers": "python_version >= '3.9'",
+ "version": "==8.14.0"
+ },
+ "jedi": {
+ "hashes": [
+ "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e",
+ "sha256:bae794c30d07f6d910d32a7048af09b5a39ed740918da923c6b780790ebac612"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==0.18.2"
+ },
+ "jupyter-client": {
+ "hashes": [
+ "sha256:3af69921fe99617be1670399a0b857ad67275eefcfa291e2c81a160b7b650f5f",
+ "sha256:7441af0c0672edc5d28035e92ba5e32fadcfa8a4e608a434c228836a89df6158"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==8.3.0"
+ },
+ "jupyter-core": {
+ "hashes": [
+ "sha256:5ba5c7938a7f97a6b0481463f7ff0dbac7c15ba48cf46fa4035ca6e838aa1aba",
+ "sha256:ae9036db959a71ec1cac33081eeb040a79e681f08ab68b0883e9a676c7a90dce"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==5.3.1"
+ },
+ "matplotlib-inline": {
+ "hashes": [
+ "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311",
+ "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"
+ ],
+ "markers": "python_version >= '3.5'",
+ "version": "==0.1.6"
+ },
+ "nest-asyncio": {
+ "hashes": [
+ "sha256:b9a953fb40dceaa587d109609098db21900182b16440652454a146cffb06e8b8",
+ "sha256:d267cc1ff794403f7df692964d1d2a3fa9418ffea2a3f6859a439ff482fef290"
+ ],
+ "markers": "python_version >= '3.5'",
+ "version": "==1.5.6"
+ },
+ "packaging": {
+ "hashes": [
+ "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61",
+ "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==23.1"
+ },
+ "parso": {
+ "hashes": [
+ "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0",
+ "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==0.8.3"
+ },
+ "pexpect": {
+ "hashes": [
+ "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937",
+ "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"
+ ],
+ "markers": "sys_platform != 'win32'",
+ "version": "==4.8.0"
+ },
+ "pickleshare": {
+ "hashes": [
+ "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca",
+ "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"
+ ],
+ "version": "==0.7.5"
+ },
+ "platformdirs": {
+ "hashes": [
+ "sha256:b0cabcb11063d21a0b261d557acb0a9d2126350e63b70cdf7db6347baea456dc",
+ "sha256:ca9ed98ce73076ba72e092b23d3c93ea6c4e186b3f1c3dad6edd98ff6ffcca2e"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==3.8.0"
+ },
+ "pluggy": {
+ "hashes": [
+ "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849",
+ "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==1.2.0"
+ },
+ "prompt-toolkit": {
+ "hashes": [
+ "sha256:23ac5d50538a9a38c8bde05fecb47d0b403ecd0662857a86f886f798563d5b9b",
+ "sha256:45ea77a2f7c60418850331366c81cf6b5b9cf4c7fd34616f733c5427e6abbb1f"
+ ],
+ "markers": "python_full_version >= '3.7.0'",
+ "version": "==3.0.38"
+ },
+ "psutil": {
+ "hashes": [
+ "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d",
+ "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217",
+ "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4",
+ "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c",
+ "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f",
+ "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da",
+ "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4",
+ "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42",
+ "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5",
+ "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4",
+ "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9",
+ "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f",
+ "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30",
+ "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"
+ ],
+ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
+ "version": "==5.9.5"
+ },
+ "ptyprocess": {
+ "hashes": [
+ "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35",
+ "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"
+ ],
+ "version": "==0.7.0"
+ },
+ "pure-eval": {
+ "hashes": [
+ "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350",
+ "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"
+ ],
+ "version": "==0.2.2"
+ },
+ "pycodestyle": {
+ "hashes": [
+ "sha256:347187bdb476329d98f695c213d7295a846d1152ff4fe9bacb8a9590b8ee7053",
+ "sha256:8a4eaf0d0495c7395bdab3589ac2db602797d76207242c17d470186815706610"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==2.10.0"
+ },
+ "pygments": {
+ "hashes": [
+ "sha256:8ace4d3c1dd481894b2005f560ead0f9f19ee64fe983366be1a21e171d12775c",
+ "sha256:db2db3deb4b4179f399a09054b023b6a586b76499d36965813c71aa8ed7b5fd1"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==2.15.1"
+ },
+ "pytest": {
+ "hashes": [
+ "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32",
+ "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"
+ ],
+ "index": "pypi",
+ "version": "==7.4.0"
+ },
+ "python-dateutil": {
+ "hashes": [
+ "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86",
+ "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"
+ ],
+ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
+ "version": "==2.8.2"
+ },
+ "pyzmq": {
+ "hashes": [
+ "sha256:01f06f33e12497dca86353c354461f75275a5ad9eaea181ac0dc1662da8074fa",
+ "sha256:0b6b42f7055bbc562f63f3df3b63e3dd1ebe9727ff0f124c3aa7bcea7b3a00f9",
+ "sha256:0c4fc2741e0513b5d5a12fe200d6785bbcc621f6f2278893a9ca7bed7f2efb7d",
+ "sha256:108c96ebbd573d929740d66e4c3d1bdf31d5cde003b8dc7811a3c8c5b0fc173b",
+ "sha256:13bbe36da3f8aaf2b7ec12696253c0bf6ffe05f4507985a8844a1081db6ec22d",
+ "sha256:154bddda2a351161474b36dba03bf1463377ec226a13458725183e508840df89",
+ "sha256:19d0383b1f18411d137d891cab567de9afa609b214de68b86e20173dc624c101",
+ "sha256:1a6169e69034eaa06823da6a93a7739ff38716142b3596c180363dee729d713d",
+ "sha256:1fc56a0221bdf67cfa94ef2d6ce5513a3d209c3dfd21fed4d4e87eca1822e3a3",
+ "sha256:2a21fec5c3cea45421a19ccbe6250c82f97af4175bc09de4d6dd78fb0cb4c200",
+ "sha256:2b15247c49d8cbea695b321ae5478d47cffd496a2ec5ef47131a9e79ddd7e46c",
+ "sha256:2f5efcc29056dfe95e9c9db0dfbb12b62db9c4ad302f812931b6d21dd04a9119",
+ "sha256:2f666ae327a6899ff560d741681fdcdf4506f990595201ed39b44278c471ad98",
+ "sha256:332616f95eb400492103ab9d542b69d5f0ff628b23129a4bc0a2fd48da6e4e0b",
+ "sha256:33d5c8391a34d56224bccf74f458d82fc6e24b3213fc68165c98b708c7a69325",
+ "sha256:3575699d7fd7c9b2108bc1c6128641a9a825a58577775ada26c02eb29e09c517",
+ "sha256:3830be8826639d801de9053cf86350ed6742c4321ba4236e4b5568528d7bfed7",
+ "sha256:3a522510e3434e12aff80187144c6df556bb06fe6b9d01b2ecfbd2b5bfa5c60c",
+ "sha256:3bed53f7218490c68f0e82a29c92335daa9606216e51c64f37b48eb78f1281f4",
+ "sha256:414b8beec76521358b49170db7b9967d6974bdfc3297f47f7d23edec37329b00",
+ "sha256:442d3efc77ca4d35bee3547a8e08e8d4bb88dadb54a8377014938ba98d2e074a",
+ "sha256:47b915ba666c51391836d7ed9a745926b22c434efa76c119f77bcffa64d2c50c",
+ "sha256:48e5e59e77c1a83162ab3c163fc01cd2eebc5b34560341a67421b09be0891287",
+ "sha256:4a82faae00d1eed4809c2f18b37f15ce39a10a1c58fe48b60ad02875d6e13d80",
+ "sha256:4a983c8694667fd76d793ada77fd36c8317e76aa66eec75be2653cef2ea72883",
+ "sha256:4c2fc7aad520a97d64ffc98190fce6b64152bde57a10c704b337082679e74f67",
+ "sha256:4cb27ef9d3bdc0c195b2dc54fcb8720e18b741624686a81942e14c8b67cc61a6",
+ "sha256:4d67609b37204acad3d566bb7391e0ecc25ef8bae22ff72ebe2ad7ffb7847158",
+ "sha256:5482f08d2c3c42b920e8771ae8932fbaa0a67dff925fc476996ddd8155a170f3",
+ "sha256:5489738a692bc7ee9a0a7765979c8a572520d616d12d949eaffc6e061b82b4d1",
+ "sha256:5693dcc4f163481cf79e98cf2d7995c60e43809e325b77a7748d8024b1b7bcba",
+ "sha256:58416db767787aedbfd57116714aad6c9ce57215ffa1c3758a52403f7c68cff5",
+ "sha256:5873d6a60b778848ce23b6c0ac26c39e48969823882f607516b91fb323ce80e5",
+ "sha256:5af31493663cf76dd36b00dafbc839e83bbca8a0662931e11816d75f36155897",
+ "sha256:5e7fbcafa3ea16d1de1f213c226005fea21ee16ed56134b75b2dede5a2129e62",
+ "sha256:65346f507a815a731092421d0d7d60ed551a80d9b75e8b684307d435a5597425",
+ "sha256:6581e886aec3135964a302a0f5eb68f964869b9efd1dbafdebceaaf2934f8a68",
+ "sha256:69511d604368f3dc58d4be1b0bad99b61ee92b44afe1cd9b7bd8c5e34ea8248a",
+ "sha256:7018289b402ebf2b2c06992813523de61d4ce17bd514c4339d8f27a6f6809492",
+ "sha256:71c7b5896e40720d30cd77a81e62b433b981005bbff0cb2f739e0f8d059b5d99",
+ "sha256:75217e83faea9edbc29516fc90c817bc40c6b21a5771ecb53e868e45594826b0",
+ "sha256:7e23a8c3b6c06de40bdb9e06288180d630b562db8ac199e8cc535af81f90e64b",
+ "sha256:80c41023465d36280e801564a69cbfce8ae85ff79b080e1913f6e90481fb8957",
+ "sha256:831ba20b660b39e39e5ac8603e8193f8fce1ee03a42c84ade89c36a251449d80",
+ "sha256:851fb2fe14036cfc1960d806628b80276af5424db09fe5c91c726890c8e6d943",
+ "sha256:8751f9c1442624da391bbd92bd4b072def6d7702a9390e4479f45c182392ff78",
+ "sha256:8b45d722046fea5a5694cba5d86f21f78f0052b40a4bbbbf60128ac55bfcc7b6",
+ "sha256:8b697774ea8273e3c0460cf0bba16cd85ca6c46dfe8b303211816d68c492e132",
+ "sha256:90146ab578931e0e2826ee39d0c948d0ea72734378f1898939d18bc9c823fcf9",
+ "sha256:9301cf1d7fc1ddf668d0abbe3e227fc9ab15bc036a31c247276012abb921b5ff",
+ "sha256:95bd3a998d8c68b76679f6b18f520904af5204f089beebb7b0301d97704634dd",
+ "sha256:968b0c737797c1809ec602e082cb63e9824ff2329275336bb88bd71591e94a90",
+ "sha256:97d984b1b2f574bc1bb58296d3c0b64b10e95e7026f8716ed6c0b86d4679843f",
+ "sha256:9e68ae9864d260b18f311b68d29134d8776d82e7f5d75ce898b40a88df9db30f",
+ "sha256:adecf6d02b1beab8d7c04bc36f22bb0e4c65a35eb0b4750b91693631d4081c70",
+ "sha256:af56229ea6527a849ac9fb154a059d7e32e77a8cba27e3e62a1e38d8808cb1a5",
+ "sha256:b324fa769577fc2c8f5efcd429cef5acbc17d63fe15ed16d6dcbac2c5eb00849",
+ "sha256:b5a07c4f29bf7cb0164664ef87e4aa25435dcc1f818d29842118b0ac1eb8e2b5",
+ "sha256:bad172aba822444b32eae54c2d5ab18cd7dee9814fd5c7ed026603b8cae2d05f",
+ "sha256:bdca18b94c404af6ae5533cd1bc310c4931f7ac97c148bbfd2cd4bdd62b96253",
+ "sha256:be24a5867b8e3b9dd5c241de359a9a5217698ff616ac2daa47713ba2ebe30ad1",
+ "sha256:be86a26415a8b6af02cd8d782e3a9ae3872140a057f1cadf0133de685185c02b",
+ "sha256:c66b7ff2527e18554030319b1376d81560ca0742c6e0b17ff1ee96624a5f1afd",
+ "sha256:c8398a1b1951aaa330269c35335ae69744be166e67e0ebd9869bdc09426f3871",
+ "sha256:cad9545f5801a125f162d09ec9b724b7ad9b6440151b89645241d0120e119dcc",
+ "sha256:cb6d161ae94fb35bb518b74bb06b7293299c15ba3bc099dccd6a5b7ae589aee3",
+ "sha256:d40682ac60b2a613d36d8d3a0cd14fbdf8e7e0618fbb40aa9fa7b796c9081584",
+ "sha256:d6128d431b8dfa888bf51c22a04d48bcb3d64431caf02b3cb943269f17fd2994",
+ "sha256:dbc466744a2db4b7ca05589f21ae1a35066afada2f803f92369f5877c100ef62",
+ "sha256:ddbef8b53cd16467fdbfa92a712eae46dd066aa19780681a2ce266e88fbc7165",
+ "sha256:e21cc00e4debe8f54c3ed7b9fcca540f46eee12762a9fa56feb8512fd9057161",
+ "sha256:eb52e826d16c09ef87132c6e360e1879c984f19a4f62d8a935345deac43f3c12",
+ "sha256:f0d9e7ba6a815a12c8575ba7887da4b72483e4cfc57179af10c9b937f3f9308f",
+ "sha256:f1e931d9a92f628858a50f5bdffdfcf839aebe388b82f9d2ccd5d22a38a789dc",
+ "sha256:f45808eda8b1d71308c5416ef3abe958f033fdbb356984fabbfc7887bed76b3f",
+ "sha256:f6d39e42a0aa888122d1beb8ec0d4ddfb6c6b45aecb5ba4013c27e2f28657765",
+ "sha256:fc34fdd458ff77a2a00e3c86f899911f6f269d393ca5675842a6e92eea565bae"
+ ],
+ "markers": "python_version >= '3.6'",
+ "version": "==25.1.0"
+ },
+ "six": {
+ "hashes": [
+ "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926",
+ "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"
+ ],
+ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
+ "version": "==1.16.0"
+ },
+ "stack-data": {
+ "hashes": [
+ "sha256:32d2dd0376772d01b6cb9fc996f3c8b57a357089dec328ed4b6553d037eaf815",
+ "sha256:cbb2a53eb64e5785878201a97ed7c7b94883f48b87bfb0bbe8b623c74679e4a8"
+ ],
+ "version": "==0.6.2"
+ },
+ "tomli": {
+ "hashes": [
+ "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc",
+ "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"
+ ],
+ "markers": "python_version < '3.11'",
+ "version": "==2.0.1"
+ },
+ "tornado": {
+ "hashes": [
+ "sha256:05615096845cf50a895026f749195bf0b10b8909f9be672f50b0fe69cba368e4",
+ "sha256:0c325e66c8123c606eea33084976c832aa4e766b7dff8aedd7587ea44a604cdf",
+ "sha256:29e71c847a35f6e10ca3b5c2990a52ce38b233019d8e858b755ea6ce4dcdd19d",
+ "sha256:4b927c4f19b71e627b13f3db2324e4ae660527143f9e1f2e2fb404f3a187e2ba",
+ "sha256:5b17b1cf5f8354efa3d37c6e28fdfd9c1c1e5122f2cb56dac121ac61baa47cbe",
+ "sha256:6a0848f1aea0d196a7c4f6772197cbe2abc4266f836b0aac76947872cd29b411",
+ "sha256:7efcbcc30b7c654eb6a8c9c9da787a851c18f8ccd4a5a3a95b05c7accfa068d2",
+ "sha256:834ae7540ad3a83199a8da8f9f2d383e3c3d5130a328889e4cc991acc81e87a0",
+ "sha256:b46a6ab20f5c7c1cb949c72c1994a4585d2eaa0be4853f50a03b5031e964fc7c",
+ "sha256:c2de14066c4a38b4ecbbcd55c5cc4b5340eb04f1c5e81da7451ef555859c833f",
+ "sha256:c367ab6c0393d71171123ca5515c61ff62fe09024fa6bf299cd1339dc9456829"
+ ],
+ "markers": "python_version >= '3.8'",
+ "version": "==6.3.2"
+ },
+ "traitlets": {
+ "hashes": [
+ "sha256:9e6ec080259b9a5940c797d58b613b5e31441c2257b87c2e795c5228ae80d2d8",
+ "sha256:f6cde21a9c68cf756af02035f72d5a723bf607e862e7be33ece505abf4a3bad9"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==5.9.0"
+ },
+ "wcwidth": {
+ "hashes": [
+ "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e",
+ "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"
+ ],
+ "version": "==0.2.6"
+ },
+ "yapf": {
+ "hashes": [
+ "sha256:958587eb5c8ec6c860119a9c25d02addf30a44f75aa152a4220d30e56a98037c",
+ "sha256:b8bfc1f280949153e795181768ca14ef43d7312629a06c43e7abd279323af313"
+ ],
+ "index": "pypi",
+ "version": "==0.40.1"
+ },
+ "zipp": {
+ "hashes": [
+ "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b",
+ "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"
+ ],
+ "markers": "python_version >= '3.7'",
+ "version": "==3.15.0"
+ }
+ }
+}
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..67a70a8
--- /dev/null
+++ b/README.md
@@ -0,0 +1,186 @@
+
+
+# TRON: Scalable Session-Based Transformer Recommender using Optimized Negative Sampling
+
+[![GitHub stars](https://img.shields.io/github/stars/otto-de/TRON.svg?style=for-the-badge&color=yellow)](https://github.com/otto-de/TRON)
+[![Test suite](https://img.shields.io/github/actions/workflow/status/otto-de/TRON/test.yml?branch=main&style=for-the-badge)](https://github.com/otto-de/TRON/actions/workflows/test.yml)
+[![Conference](https://img.shields.io/badge/Conference-RecSys%202023-4b44ce?style=for-the-badge)](https://recsys.acm.org/recsys23/)
+[![OTTO jobs](https://img.shields.io/badge/otto-jobs-F00020?style=for-the-badge&logo=otto)](https://www.otto.de/jobs/technology/ueberblick/)
+
+**TRON is a scalable session-based Transformer Recommender using Optimized Negative-sampling. This repository contains the [PyTorch Lightning](https://github.com/Lightning-AI/lightning) implementation for our upcoming paper: _Scaling Session-Based Transformer Recommendations using Optimized Negative Sampling and Loss Functions\*_, authored by [Timo Wilm](https://www.linkedin.com/in/timo-wilm/), [Philipp Normann](https://www.linkedin.com/in/pnormann), [Sophie Baumeister](https://www.linkedin.com/in/sophie-baumeister/), and [Paul-Vincent Kobow](https://www.linkedin.com/in/paul-vincent-kobow/).**
+
+
+
+
+
+## 🎯 Abstract
+
+This work introduces **TRON**, a scalable session-based **T**ransformer **R**ecommender using **O**ptimized **N**egative-sampling. Motivated by the scalability and performance limitations of prevailing models such as SASRec and GRU4Rec+, TRON integrates top-k negative sampling and listwise loss functions to enhance its recommendation accuracy. Evaluations on relevant large-scale e-commerce datasets show that TRON improves upon the recommendation quality of current methods while maintaining training speeds similar to SASRec. A live A/B test yielded an 18.14% increase in click-through rate over SASRec, highlighting the potential of TRON in practical settings. For further research, we provide access to our [source code](https://github.com/otto-de/TRON) and an [anonymized dataset](https://github.com/otto-de/recsys-dataset).
+
+
+
+
+
+
+
+
+ Offline evaluation results on our private OTTO dataset used for the online A/B test of our three groups.
+ Online results of our A/B test relative to the SASRec baseline. The error bars indicate the 95% confidence interval.
+
+
+
+## 🚀 Quick Start
+
+1. Clone the repository:
+
+```bash
+git clone https://github.com/otto-de/TRON.git
+```
+
+2. Install the dependencies:
+
+```bash
+pip install pipenv
+pipenv install --dev
+```
+
+3. Execute the test scripts:
+
+```bash
+pipenv run pytest
+```
+
+4. Install 7zip, gzip, and zip utilities on your system:
+
+```bash
+sudo apt-get install 7zip gzip unzip
+```
+
+5. Prepare a dataset (e.g., yoochoose):
+
+```bash
+./prepare.sh yoochoose
+```
+
+6. Run the main script with a configuration file:
+
+```bash
+pipenv run python -m src --config-filename tron/yoochoose.json
+```
+
+## 🗂️ Repository Structure
+
+```yaml
+.
+├── Pipfile
+├── Pipfile.lock
+├── README.md
+├── configs # Contains experiment configuration files
+├── doc # Contains the paper and related files
+├── prepare.sh # Script to prepare datasets
+├── src # Source code for the models
+└── test # Test scripts
+```
+
+## ⚙️ Config File Documentation
+
+The [config folder](configs/) contains JSON configuration files for all experiments performed in our research. These configurations detail the model's parameters and options.
+
+Here's an explanation of each parameter in the config file:
+
+- `model`: The base model to be used (e.g., "sasrec", "gru4rec").
+- `dataset`: The dataset to be used for training (e.g., "yoochoose", "otto", "diginetica").
+- `hidden_size`: The size of the hidden layers and item embeddings.
+- `num_layers`: The number of layers in the model.
+- `dropout`: The dropout rate applied to the model's layers.
+- `num_batch_negatives`: The number of negative samples from the batch. Limited by `batch_size` - 1.
+- `num_uniform_negatives`: The number of uniformly sampled negatives.
+- `reject_uniform_session_items`: If true, items from the same session won't be used as uniform negatives. Becomes slow if `num_uniform_negatives` is large.
+- `reject_in_batch_items`: If true, items from the same session won't be used as batch negatives.
+- `sampling_style`: The style of negative sampling to use (e.g., "eventwise", "sessionwise", "batchwise"). Has significant impact on training speed.
+- `loss`: The loss function to use (e.g., "bce", "bpr-max", "ssm").
+- `bpr_penalty`: The penalty factor for BPR-Max loss. Ignored if not using BPR-Max loss.
+- `max_epochs`: The maximum number of training epochs.
+- `batch_size`: The batch size used for training and validation.
+- `max_session_length`: The maximum length of a session. Longer sessions will be truncated.
+- `lr`: The learning rate for the optimizer.
+- `limit_val_batches`: The fraction of validation data to use for the validation step.
+- `accelerator`: The device type to be used for training (e.g., "gpu", "cpu").
+- `overfit_batches`: The fraction or number of batches of training data to use for overfitting. Set to 0 for no overfitting. See [PyTorch Lightning docs](https://lightning.ai/docs/pytorch/stable/common/trainer.html#overfit-batches) for more details.
+- `share_embeddings`: If true, the embedding weights are shared between the input and output layers.
+- `output_bias`: If true, includes bias in the output layer.
+- `shuffling_style`: The style of shuffling to use for the training dataset (e.g., "no_shuffling", "shuffling_with_replacement", "shuffling_without_replacement").
+- `optimizer`: The optimizer to use for training (e.g., "adam", "adagrad")
+- `topk_sampling`: If true, top-k negative sampling is enabled.
+- `topk_sampling_k`: If `topk_sampling` is true, this parameter specifies the number of top k negative samples to be used for training.
+
+### Example Config File for TRON on the OTTO Dataset
+
+```json
+{
+ "model": "sasrec",
+ "dataset": "otto",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 16384,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
+```
+
+For all config files used in our experiments, refer to the [configs directory](configs/).
+
+## 🙌 Contribution
+
+Contributions to TRON are welcome and appreciated. For issues or suggestions for improvements, please open an issue or create a pull request. We believe that open source knowledge sharing is the best way to advance the field of recommender systems.
+
+## 📖 Citing
+
+If TRON aids your research, please consider citing our work:
+
+```bibtex
+@inproceedings{wilm2023tron,
+ title={Scaling Session-Based Transformer Recommendations using Optimized Negative Sampling and Loss Functions},
+ author={Wilm, Timo and Normann, Philipp and Baumeister, Sophie and Kobow, Paul-Vincent},
+ booktitle={Proceedings of the 17th ACM Conference on Recommender Systems},
+ pages={To be updated},
+ year={2023}
+}
+```
+
+## 📜 License
+
+This project is [MIT licensed](./LICENSE).
+
+## 📞 Contact
+
+For any queries or questions, please reach out to us via our LinkedIn profiles:
+
+- [Timo Wilm](https://www.linkedin.com/in/timo-wilm)
+- [Philipp Normann](https://www.linkedin.com/in/pnormann/)
+- [Sophie Baumeister](https://www.linkedin.com/in/sophie-baumeister-9a5a59200/)
+
+For specific issues related to the codebase or for feature requests, please create a new issue on our [GitHub page](https://github.com/otto-de/TRON/issues).
+
+If this project aids your research or you find it interesting, we would appreciate it if you could star ⭐ the repository. Thanks for your support!
+
+
+\* To be published in the proceedings of the 17th ACM Conference on Recommender Systems (RecSys 2023).
diff --git a/configs/experiment1/sasrec_diginetica_uniform512_inbatch16.json b/configs/experiment1/sasrec_diginetica_uniform512_inbatch16.json
new file mode 100644
index 0000000..2bcf47c
--- /dev/null
+++ b/configs/experiment1/sasrec_diginetica_uniform512_inbatch16.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "diginetica",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 16,
+ "num_uniform_negatives": 512,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 100,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment1/sasrec_diginetica_uniform8192_inbatch127.json b/configs/experiment1/sasrec_diginetica_uniform8192_inbatch127.json
new file mode 100644
index 0000000..efb1ec4
--- /dev/null
+++ b/configs/experiment1/sasrec_diginetica_uniform8192_inbatch127.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "diginetica",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 100,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment1/sasrec_otto_uniform512_inbatch16.json b/configs/experiment1/sasrec_otto_uniform512_inbatch16.json
new file mode 100644
index 0000000..84bab11
--- /dev/null
+++ b/configs/experiment1/sasrec_otto_uniform512_inbatch16.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "otto",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 16,
+ "num_uniform_negatives": 512,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment1/sasrec_otto_uniform8192_inbatch127.json b/configs/experiment1/sasrec_otto_uniform8192_inbatch127.json
new file mode 100644
index 0000000..43dc632
--- /dev/null
+++ b/configs/experiment1/sasrec_otto_uniform8192_inbatch127.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "otto",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
diff --git a/configs/experiment1/sasrec_yoochoose_uniform512_inbatch16.json b/configs/experiment1/sasrec_yoochoose_uniform512_inbatch16.json
new file mode 100644
index 0000000..aada8ec
--- /dev/null
+++ b/configs/experiment1/sasrec_yoochoose_uniform512_inbatch16.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "yoochoose",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 16,
+ "num_uniform_negatives": 512,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment1/sasrec_yoochoose_uniform8192_inbatch127.json b/configs/experiment1/sasrec_yoochoose_uniform8192_inbatch127.json
new file mode 100644
index 0000000..81dd593
--- /dev/null
+++ b/configs/experiment1/sasrec_yoochoose_uniform8192_inbatch127.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "yoochoose",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
diff --git a/configs/experiment2/sasrec_diginetica_bpr-max.json b/configs/experiment2/sasrec_diginetica_bpr-max.json
new file mode 100644
index 0000000..0d54ed5
--- /dev/null
+++ b/configs/experiment2/sasrec_diginetica_bpr-max.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "diginetica",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.1,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "bpr-max",
+ "bpr_penalty": 0.1,
+ "max_epochs": 100,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment2/sasrec_diginetica_ssm.json b/configs/experiment2/sasrec_diginetica_ssm.json
new file mode 100644
index 0000000..9e322f6
--- /dev/null
+++ b/configs/experiment2/sasrec_diginetica_ssm.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "diginetica",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.1,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 100,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment2/sasrec_otto_bpr-max.json b/configs/experiment2/sasrec_otto_bpr-max.json
new file mode 100644
index 0000000..c7509be
--- /dev/null
+++ b/configs/experiment2/sasrec_otto_bpr-max.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "otto",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "bpr-max",
+ "bpr_penalty": 0.03,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment2/sasrec_otto_ssm.json b/configs/experiment2/sasrec_otto_ssm.json
new file mode 100644
index 0000000..8b8aebd
--- /dev/null
+++ b/configs/experiment2/sasrec_otto_ssm.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "otto",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment2/sasrec_yoochoose_bpr-max.json b/configs/experiment2/sasrec_yoochoose_bpr-max.json
new file mode 100644
index 0000000..0f53038
--- /dev/null
+++ b/configs/experiment2/sasrec_yoochoose_bpr-max.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "yoochoose",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "bpr-max",
+ "bpr_penalty": 0.125,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment2/sasrec_yoochoose_ssm.json b/configs/experiment2/sasrec_yoochoose_ssm.json
new file mode 100644
index 0000000..66e243d
--- /dev/null
+++ b/configs/experiment2/sasrec_yoochoose_ssm.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "yoochoose",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment3/sasrec_diginetica_topk100_uniform16384_inbatch127.json b/configs/experiment3/sasrec_diginetica_topk100_uniform16384_inbatch127.json
new file mode 100644
index 0000000..9a87516
--- /dev/null
+++ b/configs/experiment3/sasrec_diginetica_topk100_uniform16384_inbatch127.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "diginetica",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.2,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 16384,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 100,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment3/sasrec_diginetica_topk100_uniform8192_inbatch127.json b/configs/experiment3/sasrec_diginetica_topk100_uniform8192_inbatch127.json
new file mode 100644
index 0000000..9342057
--- /dev/null
+++ b/configs/experiment3/sasrec_diginetica_topk100_uniform8192_inbatch127.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "diginetica",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.2,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 100,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/experiment3/sasrec_otto_topk100_uniform16384_inbatch127.json b/configs/experiment3/sasrec_otto_topk100_uniform16384_inbatch127.json
new file mode 100644
index 0000000..6d8be48
--- /dev/null
+++ b/configs/experiment3/sasrec_otto_topk100_uniform16384_inbatch127.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "otto",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 16384,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
diff --git a/configs/experiment3/sasrec_otto_topk100_uniform8192_inbatch127.json b/configs/experiment3/sasrec_otto_topk100_uniform8192_inbatch127.json
new file mode 100644
index 0000000..78ac04c
--- /dev/null
+++ b/configs/experiment3/sasrec_otto_topk100_uniform8192_inbatch127.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "otto",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
diff --git a/configs/experiment3/sasrec_yoochoose_topk100_uniform16384_inbatch127.json b/configs/experiment3/sasrec_yoochoose_topk100_uniform16384_inbatch127.json
new file mode 100644
index 0000000..59496ba
--- /dev/null
+++ b/configs/experiment3/sasrec_yoochoose_topk100_uniform16384_inbatch127.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "yoochoose",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 16384,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
diff --git a/configs/experiment3/sasrec_yoochoose_topk100_uniform8192_inbatch127.json b/configs/experiment3/sasrec_yoochoose_topk100_uniform8192_inbatch127.json
new file mode 100644
index 0000000..5a5f9e0
--- /dev/null
+++ b/configs/experiment3/sasrec_yoochoose_topk100_uniform8192_inbatch127.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "yoochoose",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
diff --git a/configs/onex/sasrec_otto_mcauley.json b/configs/onex/sasrec_otto_mcauley.json
new file mode 100644
index 0000000..de6746c
--- /dev/null
+++ b/configs/onex/sasrec_otto_mcauley.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "onex",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 0,
+ "num_uniform_negatives": 1,
+ "reject_uniform_session_items": true,
+ "reject_in_batch_items": true,
+ "sampling_style": "eventwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.001,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/onex/sasrec_otto_ours.json b/configs/onex/sasrec_otto_ours.json
new file mode 100644
index 0000000..209162c
--- /dev/null
+++ b/configs/onex/sasrec_otto_ours.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "onex",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 16384,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
diff --git a/configs/onex/sasrec_otto_status_quo.json b/configs/onex/sasrec_otto_status_quo.json
new file mode 100644
index 0000000..1021f07
--- /dev/null
+++ b/configs/onex/sasrec_otto_status_quo.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "onex",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 8192,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "sessionwise",
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/standard/gru4rec_diginetica.json b/configs/standard/gru4rec_diginetica.json
new file mode 100644
index 0000000..5cbc7ec
--- /dev/null
+++ b/configs/standard/gru4rec_diginetica.json
@@ -0,0 +1,27 @@
+{
+ "model": "gru4rec",
+ "dataset": "diginetica",
+ "hidden_size": 100,
+ "num_layers": 1,
+ "dropout": 0.0,
+ "num_batch_negatives": null,
+ "num_uniform_negatives": 2048,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "loss": "bpr-max",
+ "bpr_penalty": 0.5,
+ "max_epochs": 10,
+ "batch_size": 32,
+ "max_session_length": 200,
+ "lr": 0.2,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "output_bias": true,
+ "share_embeddings": true,
+ "original_gru": false,
+ "shuffling_style": "no_shuffling",
+ "final_activation": true,
+ "optimizer": "adagrad"
+}
\ No newline at end of file
diff --git a/configs/standard/gru4rec_otto.json b/configs/standard/gru4rec_otto.json
new file mode 100644
index 0000000..b4690eb
--- /dev/null
+++ b/configs/standard/gru4rec_otto.json
@@ -0,0 +1,27 @@
+{
+ "model": "gru4rec",
+ "dataset": "otto",
+ "hidden_size": 100,
+ "num_layers": 1,
+ "dropout": 0.0,
+ "num_batch_negatives": null,
+ "num_uniform_negatives": 2048,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "loss": "bpr-max",
+ "bpr_penalty": 0.5,
+ "max_epochs": 1,
+ "batch_size": 32,
+ "max_session_length": 200,
+ "lr": 0.2,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "output_bias": true,
+ "share_embeddings": true,
+ "original_gru": false,
+ "shuffling_style": "no_shuffling",
+ "final_activation": true,
+ "optimizer": "adagrad"
+}
\ No newline at end of file
diff --git a/configs/standard/gru4rec_yoochoose.json b/configs/standard/gru4rec_yoochoose.json
new file mode 100644
index 0000000..beaa78d
--- /dev/null
+++ b/configs/standard/gru4rec_yoochoose.json
@@ -0,0 +1,27 @@
+{
+ "model": "gru4rec",
+ "dataset": "yoochoose",
+ "hidden_size": 100,
+ "num_layers": 1,
+ "dropout": 0.0,
+ "num_batch_negatives": null,
+ "num_uniform_negatives": 2048,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "loss": "bpr-max",
+ "bpr_penalty": 0.5,
+ "max_epochs": 3,
+ "batch_size": 32,
+ "max_session_length": 200,
+ "lr": 0.2,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "output_bias": true,
+ "share_embeddings": true,
+ "original_gru": false,
+ "shuffling_style": "no_shuffling",
+ "final_activation": true,
+ "optimizer": "adagrad"
+}
\ No newline at end of file
diff --git a/configs/standard/sasrec_diginetica.json b/configs/standard/sasrec_diginetica.json
new file mode 100644
index 0000000..1d88307
--- /dev/null
+++ b/configs/standard/sasrec_diginetica.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "diginetica",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 0,
+ "num_uniform_negatives": 1,
+ "reject_uniform_session_items": true,
+ "reject_in_batch_items": true,
+ "sampling_style": "eventwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 100,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.001,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/standard/sasrec_otto.json b/configs/standard/sasrec_otto.json
new file mode 100644
index 0000000..da3041f
--- /dev/null
+++ b/configs/standard/sasrec_otto.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "otto",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 0,
+ "num_uniform_negatives": 1,
+ "reject_uniform_session_items": true,
+ "reject_in_batch_items": true,
+ "sampling_style": "eventwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.001,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/standard/sasrec_yoochoose.json b/configs/standard/sasrec_yoochoose.json
new file mode 100644
index 0000000..8edbb48
--- /dev/null
+++ b/configs/standard/sasrec_yoochoose.json
@@ -0,0 +1,25 @@
+{
+ "model": "sasrec",
+ "dataset": "yoochoose",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 0,
+ "num_uniform_negatives": 1,
+ "reject_uniform_session_items": true,
+ "reject_in_batch_items": true,
+ "sampling_style": "eventwise",
+ "loss": "bce",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.001,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/tron/diginetica.json b/configs/tron/diginetica.json
new file mode 100644
index 0000000..9a87516
--- /dev/null
+++ b/configs/tron/diginetica.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "diginetica",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.2,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 16384,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 100,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
\ No newline at end of file
diff --git a/configs/tron/otto.json b/configs/tron/otto.json
new file mode 100644
index 0000000..6d8be48
--- /dev/null
+++ b/configs/tron/otto.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "otto",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 16384,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
diff --git a/configs/tron/yoochoose.json b/configs/tron/yoochoose.json
new file mode 100644
index 0000000..59496ba
--- /dev/null
+++ b/configs/tron/yoochoose.json
@@ -0,0 +1,27 @@
+{
+ "model": "sasrec",
+ "dataset": "yoochoose",
+ "hidden_size": 200,
+ "num_layers": 2,
+ "dropout": 0.05,
+ "num_batch_negatives": 127,
+ "num_uniform_negatives": 16384,
+ "reject_uniform_session_items": false,
+ "reject_in_batch_items": true,
+ "sampling_style": "batchwise",
+ "topk_sampling": true,
+ "topk_sampling_k": 100,
+ "loss": "ssm",
+ "bpr_penalty": 1.0,
+ "max_epochs": 10,
+ "batch_size": 128,
+ "max_session_length": 50,
+ "lr": 0.0005,
+ "limit_val_batches": 1.0,
+ "accelerator": "gpu",
+ "overfit_batches": 0,
+ "share_embeddings": true,
+ "output_bias": false,
+ "shuffling_style": "no_shuffling",
+ "optimizer": "adam"
+}
diff --git a/prepare.sh b/prepare.sh
new file mode 100755
index 0000000..5f5893e
--- /dev/null
+++ b/prepare.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+set -e
+
+DATASET=$1
+
+function prepare_yoochoose {
+ echo "Downloading yoochoose"
+ wget -nc https://s3-eu-west-1.amazonaws.com/yc-rdata/yoochoose-data.7z -P datasets/yoochoose/
+ 7zz x -aos datasets/yoochoose/yoochoose-data.7z -odatasets/yoochoose/
+
+ echo "Preprocessing yoochoose"
+ pipenv run python -m src.preprocessing --dataset yoochoose
+}
+
+function download_digitinica {
+ if [ ! -f datasets/diginetica/dataset-train-diginetica.zip ]; then
+ mkdir -p datasets/diginetica
+ echo "Please download the dataset and save it to datasets/diginetica/dataset-train-diginetica.zip"
+ if [ "$(uname)" == "Darwin" ]; then
+ open https://drive.google.com/uc?id=0B7XZSACQf0KdenRmMk8yVUU5LWc
+ else
+ xdg-open https://drive.google.com/uc?id=0B7XZSACQf0KdenRmMk8yVUU5LWc
+ fi
+ echo "Press enter to continue"
+ read
+ fi
+}
+
+function prepare_diginetica {
+ echo "Downloading diginetica"
+ download_digitinica
+ unzip -n datasets/diginetica/dataset-train-diginetica.zip -d datasets/diginetica/
+
+ echo "Preprocessing diginetica"
+ pipenv run python -m src.preprocessing --dataset diginetica
+}
+
+function prepare_otto {
+ echo "Downloading otto"
+ pipenv run kaggle datasets download -d otto/recsys-dataset -p datasets/otto/
+ unzip -n datasets/otto/recsys-dataset.zip -d datasets/otto/
+
+ echo "Preprocessing otto"
+ pipenv run python -m src.preprocessing --dataset otto
+}
+
+if [ "$DATASET" = "yoochoose" ]; then
+ prepare_yoochoose
+elif [ "$DATASET" = "diginetica" ]; then
+ prepare_diginetica
+elif [ "$DATASET" = "otto" ]; then
+ prepare_otto
+else
+ echo "Unknown dataset"
+ exit 1
+fi
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/__main__.py b/src/__main__.py
new file mode 100644
index 0000000..0600732
--- /dev/null
+++ b/src/__main__.py
@@ -0,0 +1,47 @@
+import json
+from argparse import ArgumentParser
+
+import mlflow
+
+from src.gru4rec.train import train_gru
+from src.sasrec.train import train_sasrec
+
+
+def read_stats(data_dir, dataset):
+ with open(f"{data_dir}/{dataset}/{dataset}_stats.json", "r") as f:
+ stats = json.load(f)
+ train_stats = stats["train"]
+ test_stats = stats["test"]
+ return train_stats, test_stats, stats["num_items"]
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument("--config-filename", type=str)
+ parser.add_argument("--config-dir", type=str, default="configs")
+ parser.add_argument("--data-dir", type=str, default="datasets")
+ args = parser.parse_args()
+
+ with open(f"{args.config_dir}/{args.config_filename}.json", "r") as f:
+ config = json.load(f)
+
+ train_stats, test_stats, num_items = read_stats(args.data_dir, config["dataset"])
+
+ if config["model"] == "sasrec":
+ trainer, model, train_loader, test_loader = train_sasrec(config, args.data_dir, train_stats, test_stats, num_items)
+ elif config["model"] == "gru4rec":
+ trainer, model, train_loader, test_loader = train_gru(config, args.data_dir, train_stats, test_stats, num_items)
+ else:
+ raise ValueError('sasrec or gru4rec must be provided as model')
+
+ if config["overfit_batches"] > 0:
+ test_loader = train_loader
+
+ mlflow.pytorch.autolog(log_every_n_epoch=1, log_every_n_step=100)
+
+ with mlflow.start_run(run_name=args.config_filename) as run:
+ mlflow.log_params(config)
+ trainer.fit(model, train_loader, test_loader)
+
+ if config["model"] == "sasrec":
+ model.export(trainer.logger.log_dir)
\ No newline at end of file
diff --git a/src/gru4rec/__init__.py b/src/gru4rec/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/gru4rec/dataset.py b/src/gru4rec/dataset.py
new file mode 100644
index 0000000..6cfb08a
--- /dev/null
+++ b/src/gru4rec/dataset.py
@@ -0,0 +1,130 @@
+import itertools
+import json
+import random
+import warnings
+from copy import copy
+
+import numpy as np
+from torch import long, tensor
+from torch.utils.data.dataset import IterableDataset
+
+from src.shared.sample import (sample_in_batch_negatives, sample_uniform,
+ sample_uniform_negatives_with_shape)
+from src.shared.utils import get_offsets
+
+
+def label_session(session):
+ without_label = session[:-1]
+ labels = session[1:]
+ for idx in range(len(without_label)):
+ without_label[idx]['label'] = labels[idx]['aid']
+ return without_label
+
+
+def get_inactive_buffer_sessions(labeled_session_buffer):
+ inactive_buffer_session_indices = []
+ for session_idx, session in enumerate(labeled_session_buffer):
+ if len(session) == 0:
+ inactive_buffer_session_indices.append(session_idx)
+ return inactive_buffer_session_indices
+
+
+class Gru4RecDataset(IterableDataset):
+
+ def __init__(self,
+ sessions_path,
+ total_sessions,
+ num_items,
+ max_seqlen,
+ shuffling_style="no_shuffling",
+ num_uniform_negatives=1,
+ num_in_batch_negatives=None,
+ reject_uniform_session_items=False,
+ reject_in_batch_items=True,
+ sampling_style="sessionwise",
+ batch_size=128):
+ self.session_path = sessions_path
+ self.total_sessions = total_sessions
+ self.num_items = num_items
+ self.max_seqlen = max_seqlen
+ self.num_uniform_negatives = num_uniform_negatives
+ self.num_in_batch_negatives = num_in_batch_negatives
+ if self.num_in_batch_negatives is None:
+ self.num_in_batch_negatives = batch_size - 1
+ self.reject_uniform_session_items = reject_uniform_session_items
+ self.reject_in_batch_items = reject_in_batch_items
+ self.sampling_style = sampling_style
+ self.shuffling_style = shuffling_style
+ self.batch_size = batch_size
+ self.line_offsets = get_offsets(sessions_path)
+ self.__reset_dataset__()
+ if self.sampling_style == "eventwise":
+ self.sampling_style = "sessionwise"
+ warnings.warn("Warning eventwise is not supported and is set to sessionwise ...")
+
+ def __reset_dataset__(self):
+ self.offset_queue = copy(self.line_offsets)
+ if self.shuffling_style=="shuffle_without_replacement":
+ random.shuffle(self.offset_queue)
+ assert len(self.line_offsets) == self.total_sessions, f"{len(self.line_offsets)} != {self.total_sessions}"
+ self.offset_queue = iter(self.offset_queue)
+ self.labeled_session_buffer = [[]] * self.batch_size
+ self.clicks = [[]] * self.batch_size
+
+ def process_data(self, line_offsets):
+ while True:
+ keep_state = [1.] * self.batch_size
+ with open(self.session_path, "rt") as f:
+ inactive = get_inactive_buffer_sessions(self.labeled_session_buffer)
+ for inactive_index in inactive:
+ try:
+ next_session_index = next(self.offset_queue)
+ except:
+ self.__reset_dataset__()
+ return
+ if self.shuffling_style=="shuffle_with_replacement":
+ next_session_index = line_offsets[np.random.randint(0, self.total_sessions)]
+ f.seek(next_session_index)
+ session = json.loads(f.readline())
+ self.labeled_session_buffer[inactive_index] = label_session(
+ session["events"][-(self.max_seqlen + 1):])
+ keep_state[inactive_index] = 0.
+ self.clicks[inactive_index] = [event['aid'] for event in
+ self.labeled_session_buffer[inactive_index]]
+ batch = [session.pop(0) for session in self.labeled_session_buffer]
+ clicks = [int(event["aid"]) for event in batch]
+ labels = [int(event["label"]) for event in batch]
+ if self.sampling_style == "batchwise":
+ uniform_negatives = sample_uniform(self.num_items, [1, self.num_uniform_negatives],
+ set(itertools.chain.from_iterable(self.clicks)),
+ self.reject_uniform_session_items)
+ else:
+ uniform_negatives = np.array([sample_uniform_negatives_with_shape(session_clicks, self.num_items, 1,
+ self.num_uniform_negatives,
+ self.sampling_style,
+ self.reject_uniform_session_items) for
+ session_clicks in
+ self.clicks])
+ in_batch_negatives = sample_in_batch_negatives(clicks, self.num_in_batch_negatives, [1] * self.batch_size,
+ self.reject_in_batch_items)
+ yield {
+ 'clicks': tensor(clicks, dtype=long),
+ 'labels': tensor(labels, dtype=long).unsqueeze(1),
+ 'keep_state': tensor(keep_state).unsqueeze(1),
+ 'uniform_negatives': tensor(uniform_negatives, dtype=long),
+ 'in_batch_negatives': tensor(in_batch_negatives, dtype=long)
+ }
+
+ def __iter__(self):
+ return self.process_data(self.line_offsets)
+
+ def dynamic_collate(self, batch):
+ batch = batch[0]
+ return {
+ 'clicks': batch['clicks'],
+ 'labels': batch['labels'],
+ 'keep_state': batch['keep_state'],
+ 'uniform_negatives': batch['uniform_negatives'],
+ 'in_batch_negatives': batch['in_batch_negatives'],
+ 'mask': tensor([[1.] * self.batch_size])
+ }
diff --git a/src/gru4rec/model.py b/src/gru4rec/model.py
new file mode 100644
index 0000000..636d75a
--- /dev/null
+++ b/src/gru4rec/model.py
@@ -0,0 +1,156 @@
+import math
+import warnings
+from functools import partial
+
+import pytorch_lightning as pl
+import torch
+from torch import concat, nn, tensor
+
+from src.shared.evaluate import validate_batch_per_timestamp
+from src.shared.loss import (bce_loss, bpr_max_loss, calc_loss,
+ sampled_softmax_loss)
+
+
+def sparse_output(item_lookup, bias_lookup, output, items_to_predict):
+ embeddings = item_lookup(items_to_predict)
+ logits = torch.matmul(embeddings, output.t())
+ bias = bias_lookup(items_to_predict).squeeze(1)
+ return bias + logits.t()
+
+
+def dense_output(linear_layer, output, items_to_predict):
+ return linear_layer(output)[:, items_to_predict.view(-1)]
+
+
+def clean_state(curr_state, keep_state):
+ return curr_state * keep_state
+
+class GRU4REC(pl.LightningModule):
+
+ def __init__(self,
+ hidden_size,
+ dropout_rate,
+ num_items,
+ batch_size,
+ sampling_style="batchwise",
+ topk_sampling=False,
+ topk_sampling_k=1000,
+ learning_rate=0.001,
+ num_layers=1,
+ loss='bce',
+ bpr_penalty=None,
+ optimizer='adagrad',
+ output_bias=False,
+ share_embeddings=True,
+ original_gru=False,
+ final_activation=True):
+ super(GRU4REC, self).__init__()
+ self.num_items = num_items
+ self.learning_rate = learning_rate
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.dropout_hidden = dropout_rate
+ self.batch_size = batch_size
+ self.sampling_style = sampling_style
+ if sampling_style == "eventwise":
+ warnings.warn("Warning eventwise is not supported and is set to sessionwise ...")
+ self.sampling_style = sampling_style
+ self.output_bias = output_bias
+ self.share_embeddings = share_embeddings
+ self.original_gru = original_gru
+
+ if original_gru:
+ warnings.warn("Warning gru original cannot share input and output embeddings, share embedding is set to False")
+ self.share_embeddings = False
+
+ if output_bias and share_embeddings:
+ self.item_embedding = nn.Embedding(num_items + 1, hidden_size + 1, padding_idx=0)
+ elif self.original_gru:
+ self.item_embedding = nn.Embedding(num_items + 1, 3 * hidden_size, padding_idx=0)
+ else:
+ self.item_embedding = nn.Embedding(num_items + 1, hidden_size, padding_idx=0)
+
+ if share_embeddings:
+ self.output_embedding = self.item_embedding
+ elif (not share_embeddings) and output_bias:
+ self.output_embedding = nn.Embedding(num_items + 1, hidden_size + 1, padding_idx=0)
+ else:
+ self.output_embedding = nn.Embedding(num_items + 1, hidden_size, padding_idx=0)
+
+ torch.nn.init.xavier_uniform_(self.item_embedding.weight.data, gain=1 / math.sqrt(6))
+ torch.nn.init.xavier_uniform_(self.output_embedding.weight.data, gain=1 / math.sqrt(6))
+
+ self.gru = nn.GRU(int(3 * self.hidden_size) if self.original_gru else self.hidden_size,
+ self.hidden_size,
+ self.num_layers,
+ dropout=self.dropout_hidden,
+ batch_first=True)
+ if final_activation:
+ self.final_activation = nn.ELU(0.5)
+ else:
+ self.final_activation = nn.Identity()
+
+ if self.original_gru:
+ self.gru.weight_ih_l0 = nn.Parameter(data=torch.eye(3 * self.hidden_size), requires_grad=False)
+ self.register_buffer('current_state', torch.zeros([num_layers, batch_size, hidden_size], device=self.device))
+ self.register_buffer('loss_mask', torch.ones(1, self.batch_size, device=self.device))
+ self.register_buffer('bias_ones', torch.ones([self.batch_size, 1, 1]))
+ self.loss_fn = loss
+ if self.loss_fn == 'bce':
+ self.loss = bce_loss
+ elif self.loss_fn == 'ssm':
+ self.loss = sampled_softmax_loss
+ elif self.loss_fn == 'bpr-max':
+ if bpr_penalty is not None:
+ self.loss = partial(bpr_max_loss, bpr_penalty)
+ else:
+ raise ValueError('bpr_penalty must be provided for bpr_max loss')
+ else:
+ raise ValueError('Loss function not supported')
+
+ self.topk_sampling = topk_sampling
+ self.topk_sampling_k = topk_sampling_k
+ self.optimizer = optimizer
+ self.save_hyperparameters()
+
+ def forward(self, item_indices, in_state, keep_state):
+ embedded = self.item_embedding(item_indices.unsqueeze(1))
+ embedded = embedded[:, :, :-1] if self.output_bias and self.share_embeddings else embedded
+ in_state = clean_state(in_state, keep_state)
+ gru_output, out_state = self.gru(embedded, in_state)
+ scores = concat([gru_output, self.bias_ones], dim=-1) if self.output_bias else gru_output
+ return scores, out_state
+
+ def training_step(self, batch, _):
+ x_hat, c_state = self.forward(batch["clicks"], self.current_state, batch["keep_state"])
+
+ self.current_state = c_state.detach()
+ train_loss = calc_loss(self.loss, x_hat, batch["labels"], batch["uniform_negatives"], batch["in_batch_negatives"],
+ batch["mask"], self.output_embedding, self.sampling_style, self.final_activation,
+ self.topk_sampling, self.topk_sampling_k, self.device)
+
+ self.log("train_loss", train_loss)
+
+ return train_loss
+
+ def validation_step(self, batch, _batch_idx):
+ x_hat, self.current_state = self.forward(batch["clicks"], self.current_state, batch["keep_state"])
+ cut_offs = tensor([5, 10, 20], device=self.device)
+ recall, mrr = validate_batch_per_timestamp(batch, x_hat, self.output_embedding, cut_offs)
+ test_loss = calc_loss(self.loss, x_hat, batch["labels"], batch["uniform_negatives"], batch["in_batch_negatives"],
+ batch["mask"], self.output_embedding, self.sampling_style, self.final_activation,
+ self.topk_sampling, self.topk_sampling_k, self.device)
+ for i, k in enumerate(cut_offs.tolist()):
+ self.log(f'recall_cutoff_{k}', recall[i])
+ self.log(f'mrr_cutoff_{k}', mrr[i])
+ self.log('test_seq_len', x_hat.shape[1])
+ self.log('test_loss', test_loss)
+
+ def configure_optimizers(self):
+ if self.optimizer == 'adam':
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
+ elif self.optimizer == 'adagrad':
+ optimizer = torch.optim.Adagrad(self.parameters(), lr=self.learning_rate)
+ else:
+ raise ValueError('Optimizer not supported, please use adam or adagrad')
+ return optimizer
diff --git a/src/gru4rec/train.py b/src/gru4rec/train.py
new file mode 100644
index 0000000..9c2c16a
--- /dev/null
+++ b/src/gru4rec/train.py
@@ -0,0 +1,83 @@
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.trainer.trainer import Trainer
+from torch.utils.data import DataLoader
+
+from src.gru4rec.dataset import Gru4RecDataset
+from src.gru4rec.model import GRU4REC
+
+
+def train_gru(config, data_dir, train_stats, test_stats, num_items):
+ checkpoint_callback = ModelCheckpoint(save_top_k=1,
+ monitor='recall_cutoff_20',
+ mode='max',
+ filename=f'gru4rec-{config["dataset"]}-' + '{epoch}-{recall_cutoff_20:.3f}')
+
+ trainer = Trainer(max_epochs=config["max_epochs"],
+ precision=16,
+ limit_val_batches=config["limit_val_batches"],
+ log_every_n_steps=1,
+ accelerator=config["accelerator"],
+ devices=1,
+ overfit_batches=config["overfit_batches"],
+ callbacks=[checkpoint_callback])
+
+ train_set = Gru4RecDataset(f'{data_dir}/{config["dataset"]}/{config["dataset"]}_train.jsonl',
+ train_stats["num_sessions"],
+ num_items=num_items,
+ max_seqlen=config["max_session_length"],
+ shuffling_style=config["shuffling_style"],
+ num_in_batch_negatives=config["num_batch_negatives"],
+ num_uniform_negatives=config["num_uniform_negatives"],
+ reject_uniform_session_items=config["reject_uniform_session_items"],
+ reject_in_batch_items=config["reject_in_batch_items"],
+ sampling_style=config["sampling_style"],
+ batch_size=config["batch_size"])
+
+ test_set = Gru4RecDataset(f'{data_dir}/{config["dataset"]}/{config["dataset"]}_test.jsonl',
+ test_stats["num_sessions"],
+ num_items=num_items,
+ max_seqlen=config["max_session_length"],
+ shuffling_style="no_shuffling",
+ num_in_batch_negatives=config["num_batch_negatives"],
+ num_uniform_negatives=config["num_uniform_negatives"],
+ reject_uniform_session_items=config["reject_uniform_session_items"],
+ reject_in_batch_items=config["reject_in_batch_items"],
+ sampling_style=config["sampling_style"],
+ batch_size=config["batch_size"])
+
+ train_loader = DataLoader(
+ train_set,
+ drop_last=True,
+ batch_size=1,
+ pin_memory=True,
+ num_workers=1,
+ collate_fn=train_set.dynamic_collate,
+ prefetch_factor=100)
+
+ test_loader = DataLoader(
+ test_set,
+ drop_last=True,
+ batch_size=1,
+ pin_memory=True,
+ num_workers=1,
+ collate_fn=test_set.dynamic_collate,
+ prefetch_factor=10)
+
+ model = GRU4REC(hidden_size=config["hidden_size"],
+ dropout_rate=config["dropout"],
+ num_items=num_items,
+ learning_rate=config["lr"],
+ batch_size=config["batch_size"],
+ sampling_style=config["sampling_style"],
+ topk_sampling=config.get("topk_sampling", False),
+ topk_sampling_k=config.get("topk_sampling_k", 1000),
+ num_layers=config["num_layers"],
+ loss=config["loss"],
+ bpr_penalty=config["bpr_penalty"],
+ optimizer=config["optimizer"],
+ output_bias=config["output_bias"],
+ share_embeddings=config["share_embeddings"],
+ original_gru=config["original_gru"],
+ final_activation=config["final_activation"])
+
+ return trainer, model, train_loader, test_loader
diff --git a/src/preprocessing.py b/src/preprocessing.py
new file mode 100644
index 0000000..c07fa45
--- /dev/null
+++ b/src/preprocessing.py
@@ -0,0 +1,284 @@
+# Yoochoose Data: https://s3-eu-west-1.amazonaws.com/yc-rdata/yoochoose-data.7z
+# Diginetica Data: https://drive.google.com/file/d/0B7XZSACQf0KdenRmMk8yVUU5LWc/
+# Beauty Data: http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/
+
+import argparse
+import json
+import time
+import logging as log
+from datetime import datetime, timedelta
+from enum import Enum
+from pathlib import Path
+from tqdm.auto import tqdm
+
+
+def read_file(filename, header=False):
+ with open(filename, "r") as f:
+ file_content = f.readlines()
+ return file_content if not header else file_content[1:]
+
+
+def sort_events(events):
+ return sorted(events, key=lambda event: event["ts"])
+
+
+def create_sessions(events, dataset_name):
+ sessions = dict()
+ for event in tqdm(events):
+ if dataset_name == "diginetica":
+ sid, _uid, aid, timeframe, eventdate = event.strip().split(";")
+ ts = (datetime.strptime(eventdate, '%Y-%m-%d') + timedelta(milliseconds=int(timeframe))).timestamp()
+ elif dataset_name == "yoochoose":
+ sid, ts, aid, _cat = event.strip().split(",")
+ ts = datetime.strptime(ts, "%Y-%m-%dT%H:%M:%S.%fZ").timestamp()
+ if not sid in sessions:
+ sessions[sid] = list()
+ sessions[sid].append({"aid": aid, "ts": ts, "type": "clicks"})
+ sessions = [{"session": sid, "events": sort_events(events)} for sid, events in sessions.items()]
+ return sessions
+
+
+def sort_sessions(sessions):
+ return sorted(sessions, key=lambda x: x["events"][0]["ts"])
+
+
+def filter_short_sessions(sessions, min_session_len=2):
+ return [session for session in tqdm(sessions) if len(session["events"]) >= min_session_len]
+
+
+def get_aid_support(sessions):
+ aid_support = {}
+ for session in sessions:
+ for event in session["events"]:
+ aid = event["aid"]
+ if aid in aid_support:
+ aid_support[aid] += 1
+ else:
+ aid_support[aid] = 1
+ return aid_support
+
+
+def filter_low_aid_support(sessions, min_aid_support=5):
+ aid_support = get_aid_support(sessions)
+ for session in tqdm(sessions):
+ session["events"] = list(filter(lambda event: aid_support[event["aid"]] >= min_aid_support, session["events"]))
+ return sessions
+
+
+def get_session_lengths(sessions):
+ return {session["session"]: len(session["events"]) for session in sessions}
+
+
+def filter_low_aid_and_sessions(sessions, min_aid_support, min_session_len):
+ session_lengths = get_session_lengths(sessions)
+ aid_support = get_aid_support(sessions)
+ filtered_sessions = list()
+ for session in tqdm(sessions):
+ if session_lengths[session["session"]] >= min_session_len:
+ session["events"] = list(filter(lambda event: aid_support[event["aid"]] >= min_aid_support, session["events"]))
+ if len(session["events"]) > 0:
+ filtered_sessions.append(session)
+ return filtered_sessions
+
+
+def apply_session_filtering(sessions, min_session_len=2, min_aid_support=5):
+ sessions = filter_short_sessions(sessions, min_session_len)
+ sessions = filter_low_aid_support(sessions, min_aid_support)
+ return filter_short_sessions(sessions, min_session_len)
+
+
+def train_test_split(sessions, dataset_name, split_seconds, split_idx):
+ max_date = max([session["events"][0]["ts"] for session in sessions])
+ if dataset_name == "diginetica":
+ max_date = datetime.fromtimestamp(int(max_date)).strftime('%Y-%m-%d')
+ max_date = time.mktime(time.strptime(max_date, '%Y-%m-%d'))
+ splitdate = max_date - split_seconds
+ train_sessions = filter(lambda session: session["events"][split_idx]["ts"] < splitdate, sessions)
+ test_sessions = filter(lambda session: session["events"][split_idx]["ts"] >= splitdate, sessions)
+ return (list(train_sessions), list(test_sessions))
+
+
+def filter_test_aids(train_sessions, test_sessions):
+ train_aids = [event["aid"] for session in train_sessions for event in session["events"]]
+ test_aids = [event["aid"] for session in test_sessions for event in session["events"]]
+ aids_to_remove = set(test_aids).difference(set(train_aids))
+ for session in test_sessions:
+ session["events"] = [event for event in session["events"] if not event["aid"] in aids_to_remove]
+ return (test_sessions, train_aids)
+
+
+def create_aid_to_idx(train_aids):
+ aid_to_idx = dict()
+ aid_counter = 1
+ for aid in tqdm(train_aids):
+ if not aid in aid_to_idx:
+ aid_to_idx[aid] = aid_counter
+ aid_counter += 1
+ return aid_to_idx
+
+
+def remap_indices(sessions, aid_to_idx):
+ num_events = 0
+ num_sessions = 0
+ for session in tqdm(sessions):
+ for event in session["events"]:
+ event["aid"] = aid_to_idx[event["aid"]]
+ num_events += 1
+ num_sessions += 1
+ return sessions, num_sessions, num_events
+
+
+def write_file(sessions, filename):
+ with open(filename, "w") as f:
+ for s in tqdm(sessions):
+ f.write(json.dumps(s) + "\n")
+
+
+def write_stats(num_items, num_train_sessions, num_train_events, num_test_sessions=None, num_test_events=None, filename=None):
+ stats = {
+ "train": {
+ "num_sessions": num_train_sessions,
+ "num_events": num_train_events
+ },
+ "num_items": num_items,
+ "test": {
+ "num_sessions": num_test_sessions,
+ "num_events": num_test_events
+ }
+ }
+ with open(filename, "w") as f:
+ f.write(json.dumps(stats))
+
+
+def run_preprocessing(config, data_dir):
+ dataset_name = config["dataset_name"]
+ events = read_file(config["data_file"], header=config["header"])
+ log.info(f"Read {len(events)} events from {config['data_file']}")
+
+ log.info("Creating sessions...")
+ sessions = create_sessions(events, dataset_name)
+ log.info(f"Created {len(sessions)} sessions for {dataset_name}")
+
+ log.info("Filtering sessions...")
+ sessions = apply_session_filtering(sessions)
+ log.info(f"Remaining sessions after filtering: {len(sessions)}")
+
+ log.info("Splitting sessions into train and test...")
+ train_sessions, test_sessions = train_test_split(sessions, dataset_name, config["split_seconds"], config["split_idx"])
+ log.info(f"Split sessions into {len(train_sessions)} train and {len(test_sessions)} test sessions")
+ test_sessions, train_aids = filter_test_aids(train_sessions, test_sessions)
+ test_sessions = filter_short_sessions(test_sessions)
+ log.info(f"Remaining test sessions after filtering: {len(test_sessions)}")
+
+ log.info("Creating item indices...")
+ aid_to_idx = create_aid_to_idx(train_aids)
+ log.info(f"Created {len(aid_to_idx)} item indices")
+
+ log.info("Remapping item indices...")
+ train_sessions, num_train_sessions, num_train_events = remap_indices(train_sessions, aid_to_idx)
+ test_sessions, num_test_sessions, num_test_events = remap_indices(test_sessions, aid_to_idx)
+
+ log.info("Sorting sessions")
+ train_sessions = sort_sessions(train_sessions)
+ test_sessions = sort_sessions(test_sessions)
+
+ output_dir = data_dir / dataset_name
+ output_dir.mkdir(parents=True, exist_ok=True)
+ log.info(f"Writing sessions to {output_dir}")
+ write_file(train_sessions, output_dir / f"{dataset_name}_train.jsonl")
+ write_file(test_sessions, output_dir / f"{dataset_name}_test.jsonl")
+
+ stats_file = output_dir / f"{dataset_name}_stats.json"
+ log.info(f"Writing stats to {stats_file}")
+ write_stats(len(set(train_aids)), num_train_sessions, num_train_events, num_test_sessions, num_test_events, stats_file)
+
+
+def filter_non_clicks(in_file, out_file):
+ num_sessions = 0
+ num_events = 0
+ items = set()
+ log.info(f"Filtering non-clicks from {in_file} to {out_file}")
+ with open(in_file, "r") as read_file:
+ with open(out_file, "w") as write_file:
+ for line in read_file:
+ session = json.loads(line)
+ session["events"] = list(filter(lambda d: d['type'] == "clicks", session["events"]))
+ session["events"] = increment_aids(session["events"])
+ num_sessions += 1
+ num_events += len(session["events"])
+ items.update([event["aid"] for event in session["events"]])
+ write_file.write(json.dumps(session, separators=(',', ':')) + "\n")
+ if num_sessions % 1000000 == 0:
+ log.info(f"Processed {num_sessions} sessions")
+ return num_sessions, num_events, len(items)
+
+
+def increment_aids(events):
+ for event in events:
+ event["aid"] = event["aid"] + 1
+ return events
+
+
+def run_preprocessing_otto(data_dir):
+ num_train_sessions, num_train_events, num_items = filter_non_clicks(f"{data_dir}/otto/otto-recsys-train.jsonl",
+ f"{data_dir}/otto/otto_train.jsonl")
+ num_test_sessions, num_test_events, _ = filter_non_clicks(f"{data_dir}/otto/otto-recsys-test.jsonl",
+ f"{data_dir}/otto/otto_test.jsonl")
+ stats_file = f"{data_dir}/otto/otto_stats.json"
+ log.info(f"Writing stats to {stats_file}")
+ write_stats(num_items, num_train_sessions, num_train_events, num_test_sessions, num_test_events, stats_file)
+
+
+class DatasetConf(Enum):
+ YOOCHOOSE = 'yoochoose'
+ DIGINETICA = 'diginetica'
+ OTTO = 'otto'
+ ALL = 'all'
+
+ def __str__(self):
+ return self.value
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset", type=DatasetConf, default=DatasetConf.ALL)
+ parser.add_argument("--data_dir", type=str, default="datasets")
+
+ args = parser.parse_args()
+ data_dir = Path(args.data_dir)
+
+ log.basicConfig(level=log.INFO)
+ log.info(f"Running preprocessing for {args.dataset} dataset")
+
+ yoochoose_conf = {
+ "dataset_name": "yoochoose",
+ "data_file": data_dir / "yoochoose" / "yoochoose-clicks.dat",
+ "header": False,
+ "split_seconds": 86400 * 1, # 1 day (for testing)
+ "split_idx": -1 # use last session timestamp for split
+ }
+
+ diginetica_conf = {
+ "dataset_name": "diginetica",
+ "data_file": data_dir / "diginetica" / "train-item-views.csv",
+ "header": True,
+ "split_seconds": 86400 * 7, # 7 days (for testing)
+ "split_idx": 0 # use first session timestamp for split
+ }
+
+ if args.dataset == DatasetConf.YOOCHOOSE:
+ run_preprocessing(yoochoose_conf, data_dir)
+ elif args.dataset == DatasetConf.DIGINETICA:
+ run_preprocessing(diginetica_conf, data_dir)
+ elif args.dataset == DatasetConf.OTTO:
+ run_preprocessing_otto(data_dir)
+ elif args.dataset == DatasetConf.ALL:
+ run_preprocessing(yoochoose_conf, data_dir)
+ run_preprocessing(diginetica_conf, data_dir)
+ run_preprocessing_otto(data_dir)
+
+ log.info("All done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/sasrec/__init__.py b/src/sasrec/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/sasrec/dataset.py b/src/sasrec/dataset.py
new file mode 100644
index 0000000..2017c4b
--- /dev/null
+++ b/src/sasrec/dataset.py
@@ -0,0 +1,103 @@
+import json
+
+import numpy as np
+import torch
+from torch.utils.data.dataset import Dataset
+
+from src.shared.sample import (sample_in_batch_negatives, sample_uniform,
+ sample_uniform_negatives_with_shape)
+from src.shared.utils import get_offsets
+
+
+class SasRecDataset(Dataset):
+
+ def __init__(self,
+ sessions_path,
+ total_sessions,
+ num_items,
+ max_seqlen,
+ num_uniform_negatives=1,
+ num_in_batch_negatives=0,
+ reject_uniform_session_items=False,
+ reject_in_batch_items=True,
+ sampling_style="eventwise",
+ shuffling_style="no_shuffling"
+ ):
+ self.session_path = sessions_path
+ self.total_sessions = total_sessions
+ self.num_items = num_items
+ self.max_seqlen = max_seqlen
+ self.shuffling_style = shuffling_style
+ self.num_uniform_negatives = num_uniform_negatives
+ self.num_in_batch_negatives = num_in_batch_negatives
+ self.reject_uniform_session_items = reject_uniform_session_items
+ self.reject_in_batch_items = reject_in_batch_items
+ self.sampling_style = sampling_style
+ self.line_offsets = get_offsets(sessions_path)
+
+ assert self.sampling_style in {"eventwise", "sessionwise", "batchwise"}
+ assert len(self.line_offsets) == self.total_sessions, f"{len(self.line_offsets)} != {self.total_sessions}"
+
+ def __len__(self):
+ return self.total_sessions
+
+ def __getitem__(self, idx):
+ with open(self.session_path, "rt") as f:
+
+ if self.shuffling_style=="shuffle_with_replacement":
+ idx = np.random.randint(0,self.total_sessions)
+
+ f.seek(self.line_offsets[idx])
+ line = f.readline()
+ session = json.loads(line)
+ session = session["events"]
+
+ assert sorted(session, key=lambda d: d["ts"]) == session
+
+ clicks = [int(event["aid"]) for event in session if event["type"] == "clicks"]
+
+ clicks = clicks[-(self.max_seqlen + 1):]
+ session_len = min(len(clicks) - 1, self.max_seqlen)
+ labels = clicks[1:]
+ clicks = clicks[:-1]
+ negatives = sample_uniform_negatives_with_shape(clicks, self.num_items, session_len, self.num_uniform_negatives, self.sampling_style, self.reject_uniform_session_items)
+
+ return {'clicks': clicks, 'labels': labels, 'session_len': session_len, "uniform_negatives": negatives.tolist()}
+
+
+ def dynamic_collate(self, batch):
+ batch_clicks = list()
+ batch_mask = list()
+ batch_labels = list()
+ batch_session_len = list()
+ batch_positives = list()
+ max_len = self.max_seqlen
+ batch_uniform_negatives = list()
+ in_batch_negatives = list()
+
+ for item in batch:
+ session_len = item["session_len"]
+ batch_clicks.append((max_len - session_len) * [0] + item["clicks"])
+ batch_mask.append((max_len - session_len) * [0.] + session_len * [1.])
+ batch_labels.append((max_len - session_len) * [0] + item["labels"])
+ batch_session_len.append(session_len)
+ batch_positives.extend(item["clicks"])
+
+ if self.sampling_style=="eventwise":
+ batch_uniform_negatives.append((max_len - session_len) * [[0]*self.num_uniform_negatives] + item["uniform_negatives"])
+ elif self.sampling_style=="sessionwise":
+ batch_uniform_negatives.append(item["uniform_negatives"])
+
+ if self.sampling_style=="batchwise":
+ batch_uniform_negatives = sample_uniform(self.num_items, [self.num_uniform_negatives], set(batch_positives), self.reject_in_batch_items)
+
+ in_batch_negatives = sample_in_batch_negatives(batch_positives, self.num_in_batch_negatives, batch_session_len, self.reject_in_batch_items)
+
+ return {
+ 'clicks': torch.tensor(batch_clicks, dtype=torch.long),
+ 'labels': torch.tensor(batch_labels, dtype=torch.long),
+ 'mask': torch.tensor(batch_mask, dtype=torch.float),
+ 'session_len': torch.tensor(batch_session_len, dtype=torch.long),
+ 'in_batch_negatives': torch.tensor(in_batch_negatives, dtype=torch.long),
+ 'uniform_negatives': torch.tensor(batch_uniform_negatives, dtype=torch.long)
+ }
diff --git a/src/sasrec/model.py b/src/sasrec/model.py
new file mode 100644
index 0000000..91762b7
--- /dev/null
+++ b/src/sasrec/model.py
@@ -0,0 +1,211 @@
+from functools import partial
+from pathlib import Path
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+from torch import concat, diag, logical_and, logical_or, nn, tensor, tile
+from torch.nn import Dropout
+
+from src.shared.evaluate import validate_batch_per_timestamp
+from src.shared.logits_computation import multiply_head_with_embedding
+from src.shared.loss import (bce_loss, bpr_max_loss, calc_loss,
+ sampled_softmax_loss)
+
+
+class DynamicPositionEmbedding(torch.nn.Module):
+
+ def __init__(self, max_len, dimension):
+ super(DynamicPositionEmbedding, self).__init__()
+ self.max_len = max_len
+ self.embedding = nn.Embedding(max_len, dimension)
+ self.pos_indices = torch.arange(0, self.max_len, dtype=torch.int)
+ self.register_buffer('pos_indices_const', self.pos_indices)
+
+ def forward(self, x, device='cpu'):
+ seq_len = x.shape[1]
+ return self.embedding(self.pos_indices_const[-seq_len:]) + x
+
+
+class SASRec(pl.LightningModule):
+
+ def __init__(self,
+ hidden_size,
+ dropout_rate,
+ max_len,
+ num_items,
+ batch_size,
+ sampling_style,
+ topk_sampling=False,
+ topk_sampling_k=1000,
+ learning_rate=0.001,
+ num_layers=2,
+ loss='bce',
+ bpr_penalty=None,
+ optimizer='adam',
+ output_bias=False,
+ share_embeddings=True,
+ final_activation=False):
+ super(SASRec, self).__init__()
+ self.learning_rate = learning_rate
+ self.hidden_size = hidden_size
+ self.dropout_rate = dropout_rate
+ self.num_items = num_items
+ self.batch_size = batch_size
+ self.num_layers = num_layers
+ self.max_len = max_len
+ self.output_bias = output_bias
+ self.share_embeddings = share_embeddings
+ self.future_mask = torch.triu(torch.ones(max_len, max_len) * float('-inf'), diagonal=1)
+ self.register_buffer('future_mask_const', self.future_mask)
+ self.register_buffer('seq_diag_const', ~diag(torch.ones(max_len, dtype=torch.bool)))
+ self.register_buffer('bias_ones', torch.ones([self.batch_size, self.max_len, 1]))
+ if output_bias and share_embeddings:
+ self.item_embedding = nn.Embedding(num_items + 1, hidden_size + 1, padding_idx=0)
+ else:
+ self.item_embedding = nn.Embedding(num_items + 1, hidden_size, padding_idx=0)
+ self.positional_embedding_layer = DynamicPositionEmbedding(max_len, hidden_size)
+
+ torch.nn.init.xavier_uniform_(self.item_embedding.weight.data)
+ torch.nn.init.xavier_uniform_(self.positional_embedding_layer.embedding.weight.data)
+
+ if share_embeddings:
+ self.output_embedding = self.item_embedding
+ elif (not share_embeddings) and output_bias:
+ self.output_embedding = nn.Embedding(num_items + 1, hidden_size + 1, padding_idx=0)
+ else:
+ self.output_embedding = nn.Embedding(num_items + 1, hidden_size, padding_idx=0)
+
+ self.norm = nn.LayerNorm([hidden_size])
+ self.input_dropout = Dropout(dropout_rate)
+ encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size,
+ nhead=1,
+ dim_feedforward=hidden_size,
+ dropout=dropout_rate,
+ batch_first=True,
+ norm_first=True)
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.num_layers, norm=self.norm)
+ self.merge_attn_mask = True
+ if final_activation:
+ self.final_activation = nn.ELU(0.5)
+ else:
+ self.final_activation = nn.Identity()
+
+ self.loss_fn = loss
+ if self.loss_fn == 'bce':
+ self.loss = bce_loss
+ elif self.loss_fn == 'ssm':
+ self.loss = sampled_softmax_loss
+ elif self.loss_fn == 'bpr-max':
+ if bpr_penalty is not None:
+ self.loss = partial(bpr_max_loss, bpr_penalty)
+ else:
+ raise ValueError('bpr_penalty must be provided for bpr_max loss')
+ else:
+ raise ValueError('Loss function not supported')
+
+ self.sampling_style = sampling_style
+ self.topk_sampling = topk_sampling
+ self.topk_sampling_k = topk_sampling_k
+ self.optimizer = optimizer
+ self.save_hyperparameters()
+
+ def merge_attn_masks(self, padding_mask):
+ batch_size = padding_mask.shape[0]
+ seq_len = padding_mask.shape[1]
+
+ if not self.merge_attn_mask:
+ return self.future_mask_const[:seq_len, :seq_len]
+
+ padding_mask_broadcast = ~padding_mask.bool().unsqueeze(1)
+ future_masks = tile(self.future_mask_const[:seq_len, :seq_len], (batch_size, 1, 1))
+ merged_masks = logical_or(padding_mask_broadcast, future_masks)
+ # Always allow self-attention to prevent NaN loss
+ # See: https://github.com/pytorch/pytorch/issues/41508
+ diag_masks = tile(self.seq_diag_const[:seq_len, :seq_len], (batch_size, 1, 1))
+ return logical_and(diag_masks, merged_masks)
+
+ def forward(self, item_indices, mask):
+ att_mask = self.merge_attn_masks(mask)
+ items = self.item_embedding(
+ item_indices)[:, :, :-1] if self.output_bias and self.share_embeddings else self.item_embedding(item_indices)
+ x = items * np.sqrt(self.hidden_size)
+ x = self.positional_embedding_layer(x)
+ x = self.encoder(self.input_dropout(x), att_mask)
+ return concat([x, self.bias_ones], dim=-1) if self.output_bias else x
+
+ def training_step(self, batch, _):
+ x_hat = self.forward(batch["clicks"], batch["mask"])
+ train_loss = calc_loss(self.loss, x_hat, batch["labels"], batch["uniform_negatives"], batch["in_batch_negatives"],
+ batch["mask"], self.output_embedding, self.sampling_style, self.final_activation,
+ self.topk_sampling, self.topk_sampling_k, self.device)
+ self.log("train_loss", train_loss)
+ return train_loss
+
+ def validation_step(self, batch, _batch_idx):
+ x_hat = self.forward(batch['clicks'], batch['mask'])
+ cut_offs = tensor([5, 10, 20], device=self.device)
+ recall, mrr = validate_batch_per_timestamp(batch, x_hat, self.output_embedding, cut_offs)
+ test_loss = calc_loss(self.loss, x_hat, batch["labels"], batch["uniform_negatives"], batch["in_batch_negatives"],
+ batch["mask"], self.output_embedding, self.sampling_style, self.final_activation,
+ self.topk_sampling, self.topk_sampling_k, self.device)
+ for i, k in enumerate(cut_offs.tolist()):
+ self.log(f'recall_cutoff_{k}', recall[i])
+ self.log(f'mrr_cutoff_{k}', mrr[i])
+ self.log('test_seq_len', x_hat.shape[1])
+ self.log('test_loss', test_loss)
+
+ def configure_optimizers(self):
+ if self.optimizer == 'adam':
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
+ elif self.optimizer == 'adagrad':
+ optimizer = torch.optim.Adagrad(self.parameters(), lr=self.learning_rate)
+ else:
+ raise ValueError('Optimizer not supported, please use adam or adagrad')
+ return optimizer
+
+ def export_topk_onnx(self, out_dir):
+ top_k_model = TopKModel(self)
+ top_k_model.export_onnx(out_dir)
+
+ def export(self, out_dir):
+ self.export_topk_onnx(out_dir)
+
+
+class TopKModel(pl.LightningModule):
+
+ def __init__(self, model: SASRec):
+ super(TopKModel, self).__init__()
+ self.model = model
+ # example input for self.forward(item_indices, k)
+ self.example_input_array = (torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), torch.tensor(10))
+
+ def forward(self, item_indices, k):
+ mask = torch.ones(item_indices.shape[0]).unsqueeze(0)
+ self.model.merge_attn_mask = False
+ x_hat = self.model.forward(item_indices.unsqueeze(0), mask)[:, -1]
+ logits = multiply_head_with_embedding(x_hat, self.model.item_embedding.weight)
+ logits[0][0] = -torch.inf # set score for padding item to -inf
+ scores, indices = torch.topk(logits, k)
+ return indices.squeeze(0), scores.squeeze(0)
+
+ def export_onnx(self, out_dir, verbose=True):
+ Path(out_dir).mkdir(parents=True, exist_ok=True)
+ self.to_onnx(f"{out_dir}/sasrec.onnx",
+ export_params=True,
+ opset_version=13,
+ verbose=verbose,
+ do_constant_folding=False,
+ input_names=["item_indices", "k"],
+ output_names=[f"indices", "scores"],
+ dynamic_axes={
+ 'item_indices': {
+ 0: 'sequence'
+ },
+ 'indices': {
+ 0: 'k'
+ },
+ 'scores': {
+ 0: 'k'
+ }
+ })
\ No newline at end of file
diff --git a/src/sasrec/train.py b/src/sasrec/train.py
new file mode 100644
index 0000000..d847b44
--- /dev/null
+++ b/src/sasrec/train.py
@@ -0,0 +1,84 @@
+import os
+from pytorch_lightning.trainer.trainer import Trainer
+from pytorch_lightning.callbacks import ModelCheckpoint
+from torch.utils.data import DataLoader
+
+from src.sasrec.model import SASRec
+from src.sasrec.dataset import SasRecDataset
+
+
+def train_sasrec(config, data_dir, train_stats, test_stats, num_items):
+ checkpoint_callback = ModelCheckpoint(save_top_k=1,
+ monitor='recall_cutoff_20',
+ mode='max',
+ filename=f'sasrec-{config["dataset"]}-' + '{epoch}-{recall_cutoff_20:.3f}')
+
+ trainer = Trainer(max_epochs=config["max_epochs"],
+ precision=16,
+ limit_val_batches=config["limit_val_batches"],
+ log_every_n_steps=1,
+ accelerator=config["accelerator"],
+ devices=1,
+ overfit_batches=config["overfit_batches"],
+ callbacks=[checkpoint_callback])
+
+ assert 0 <= config["num_batch_negatives"] < config['batch_size']
+
+ train_set = SasRecDataset(f'{data_dir}/{config["dataset"]}/{config["dataset"]}_train.jsonl',
+ train_stats["num_sessions"],
+ num_items=num_items,
+ max_seqlen=config["max_session_length"],
+ num_in_batch_negatives=config["num_batch_negatives"],
+ num_uniform_negatives=config["num_uniform_negatives"],
+ reject_uniform_session_items=config["reject_uniform_session_items"],
+ reject_in_batch_items=config["reject_in_batch_items"],
+ sampling_style=config["sampling_style"],
+ shuffling_style=config["shuffling_style"])
+
+ test_set = SasRecDataset(f'{data_dir}/{config["dataset"]}/{config["dataset"]}_test.jsonl',
+ test_stats["num_sessions"],
+ num_items=num_items,
+ max_seqlen=config["max_session_length"],
+ num_in_batch_negatives=config["num_batch_negatives"],
+ num_uniform_negatives=config["num_uniform_negatives"],
+ reject_uniform_session_items=config["reject_uniform_session_items"],
+ reject_in_batch_items=config["reject_in_batch_items"],
+ sampling_style=config["sampling_style"],
+ shuffling_style="no_shuffling")
+
+ shuffle = True if config["shuffling_style"] == "shuffle_without_replacement" else False
+ train_loader = DataLoader(train_set,
+ drop_last=True,
+ batch_size=config["batch_size"],
+ shuffle=shuffle,
+ pin_memory=True,
+ persistent_workers=True,
+ num_workers=os.cpu_count(),
+ collate_fn=train_set.dynamic_collate)
+
+ test_loader = DataLoader(test_set,
+ drop_last=True,
+ batch_size=config["batch_size"],
+ shuffle=False,
+ pin_memory=True,
+ persistent_workers=True,
+ num_workers=os.cpu_count(),
+ collate_fn=test_set.dynamic_collate)
+
+ model = SASRec(hidden_size=config["hidden_size"],
+ dropout_rate=config["dropout"],
+ max_len=config["max_session_length"],
+ num_items=num_items,
+ batch_size=config["batch_size"],
+ sampling_style=config["sampling_style"],
+ topk_sampling=config.get("topk_sampling", False),
+ topk_sampling_k=config.get("topk_sampling_k", 1000),
+ learning_rate=config["lr"],
+ num_layers=config["num_layers"],
+ loss=config["loss"],
+ bpr_penalty=config["bpr_penalty"],
+ optimizer=config["optimizer"],
+ output_bias=config["output_bias"],
+ share_embeddings=config["share_embeddings"])
+
+ return trainer, model, train_loader, test_loader
diff --git a/src/shared/__init__.py b/src/shared/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/shared/evaluate.py b/src/shared/evaluate.py
new file mode 100644
index 0000000..ed49028
--- /dev/null
+++ b/src/shared/evaluate.py
@@ -0,0 +1,47 @@
+from torch import cumsum, flip, inf, max, stack, sum, topk, where
+
+from src.shared.logits_computation import multiply_head_with_embedding
+
+
+def calculate_ranks(logits, labels, cutoffs):
+ num_logits = logits.shape[-1]
+ k = min(num_logits, max(cutoffs).item())
+ _, indices = topk(logits, k=k, dim=-1)
+ indices = flip(indices, dims=[-1])
+ hits = indices == labels.unsqueeze(dim=-1)
+ ranks = sum(cumsum(hits, -1), -1) - 1.
+ ranks[ranks == -1] = float('inf')
+ return ranks
+
+
+def pointwise_mrr(ranks, cutoffs, mask):
+ res = where(ranks < cutoffs.unsqueeze(-1).unsqueeze(-1), ranks, float('inf'))
+ return (1 / (res + 1)) * mask
+
+
+def pointwise_recall(ranks, cutoffs, mask):
+ res = ranks < cutoffs.unsqueeze(-1).unsqueeze(-1)
+ return res.float() * mask
+
+
+def mean_metric(pointwise_metric, mask):
+ hits = sum(pointwise_metric, dim=(2, 1))
+ return hits / sum(mask).clamp(0.0000005)
+
+
+def validate_batch_per_timestamp(batch, x_hat, output_embedding, cut_offs):
+ recalls = []
+ mrrs = []
+ for t in range(x_hat.shape[1]):
+ mask = batch['mask'][:, t]
+ positives = batch['labels'][:, t]
+ logits = multiply_head_with_embedding(x_hat[:, t], output_embedding.weight)
+ logits[:, 0] = -inf # set score for padding item to -inf
+ ranks = calculate_ranks(logits, positives, cut_offs)
+ pw_rec = pointwise_recall(ranks, cut_offs, mask)
+ recalls.append(pw_rec.squeeze(dim=1))
+ pw_mrr = pointwise_mrr(ranks, cut_offs, mask)
+ mrrs.append(pw_mrr.squeeze(dim=1))
+ pw_rec = stack(recalls, dim=2)
+ pw_mrr = stack(mrrs, dim=2)
+ return mean_metric(pw_rec, batch["mask"]), mean_metric(pw_mrr, batch["mask"])
diff --git a/src/shared/logits_computation.py b/src/shared/logits_computation.py
new file mode 100644
index 0000000..980bd8c
--- /dev/null
+++ b/src/shared/logits_computation.py
@@ -0,0 +1,20 @@
+from torch import concat
+
+
+def multiply_head_with_embedding(prediction_head, embeddings):
+ return prediction_head.matmul(embeddings.transpose(-1, -2))
+
+
+def lookup_and_multiply(prediction_head, positives, uniform_negatives, in_batch_negatives, embedding_layer, sampling_style):
+ positive_logits = multiply_head_with_embedding(prediction_head.unsqueeze(-2),
+ embedding_layer(positives).unsqueeze(-2)).squeeze(-1)
+
+ if sampling_style == "eventwise":
+ uniform_negative_logits = multiply_head_with_embedding(prediction_head.unsqueeze(-2),
+ embedding_layer(uniform_negatives)).squeeze(-2)
+ else:
+ uniform_negative_logits = multiply_head_with_embedding(prediction_head, embedding_layer(uniform_negatives))
+
+ in_batch_negative_logits = multiply_head_with_embedding(prediction_head, embedding_layer(in_batch_negatives))
+ negative_logits = concat([uniform_negative_logits, in_batch_negative_logits], dim=-1)
+ return positive_logits, negative_logits
diff --git a/src/shared/loss.py b/src/shared/loss.py
new file mode 100644
index 0000000..82a15f7
--- /dev/null
+++ b/src/shared/loss.py
@@ -0,0 +1,69 @@
+import torch
+from torch import cat, exp, log, sigmoid, softmax, sum, tensor
+from torch.nn import CrossEntropyLoss
+
+from src.shared.logits_computation import lookup_and_multiply
+
+ce_loss = CrossEntropyLoss(reduction="none")
+
+
+def _elementwise_sampled_softmax_loss(pos_logits, neg_logits, mask, target):
+ sm_logits = cat((pos_logits, neg_logits), dim=-1)
+ shape = sm_logits.shape
+ return ce_loss(sm_logits.reshape([-1, shape[-1]]), target).reshape([shape[0], shape[1]]) * mask
+
+
+def sampled_softmax_loss(pos_logits, neg_logits, mask, device="cpu"):
+ target = tensor([0], device=device).tile(mask.numel())
+ elementwise_ssm_loss = _elementwise_sampled_softmax_loss(pos_logits, neg_logits, mask, target)
+ return sum(elementwise_ssm_loss) / sum(mask)
+
+
+def bce_loss(pos_logits, neg_logits, mask, epsilon=1e-10, device="cpu"):
+ loss = log(1. + exp(-pos_logits) + epsilon) + log(1. + exp(neg_logits) + epsilon).mean(-1, keepdim=True)
+ return (loss * mask.unsqueeze(-1)).sum() / mask.sum()
+
+
+def _diff_logits(pos_logits, neg_logits):
+ return (pos_logits - neg_logits)
+
+
+def _elementwise_bpr_max_loss_per_negative(pos_logits, neg_logits):
+ logits_diff = sigmoid(_diff_logits(pos_logits, neg_logits))
+ s_j = softmax(neg_logits - torch.max(neg_logits, dim=-1)[0].unsqueeze(-1), dim=-1)
+ return s_j * logits_diff
+
+
+def _bpr_max_loss_unregulized(pos_logits, neg_logits, mask):
+ bpr_max_loss_per_element = -log(sum(_elementwise_bpr_max_loss_per_negative(pos_logits, neg_logits), dim=-1))
+ return bpr_max_loss_per_element, sum(bpr_max_loss_per_element * mask) / sum(mask)
+
+
+def _bpr_max_loss_regularization(neg_logits, penalty, mask):
+ regularization = penalty * sum(softmax(neg_logits, dim=-1) * neg_logits * neg_logits, dim=-1)
+ return sum(regularization * mask) / sum(mask)
+
+
+def bpr_max_loss(penalty, pos_logits, neg_logits, mask, device="cpu"):
+ _, unregulized_bpr_max_loss = _bpr_max_loss_unregulized(pos_logits, neg_logits, mask)
+ return unregulized_bpr_max_loss + _bpr_max_loss_regularization(neg_logits, penalty, mask)
+
+
+def calc_loss(loss_fn,
+ x_hat,
+ labels,
+ uniform_negatives,
+ in_batch_negatives,
+ mask,
+ embeddings,
+ sampling_style,
+ final_activation,
+ topk_sampling=False,
+ topk_sampling_k=1000,
+ device="cpu"):
+ pos_logits, neg_logits = lookup_and_multiply(x_hat, labels, uniform_negatives, in_batch_negatives, embeddings,
+ sampling_style)
+ if topk_sampling:
+ neg_logits, _ = torch.topk(neg_logits, k=topk_sampling_k, dim=-1)
+ pos_scores, neg_scores = final_activation(pos_logits), final_activation(neg_logits)
+ return loss_fn(pos_scores, neg_scores, mask, device=device)
diff --git a/src/shared/sample.py b/src/shared/sample.py
new file mode 100644
index 0000000..8f81ad0
--- /dev/null
+++ b/src/shared/sample.py
@@ -0,0 +1,55 @@
+import itertools
+from random import sample
+
+import numpy as np
+
+
+def _uniform_negatives(num_items, shape):
+ return np.random.randint(1, num_items+1, shape)
+
+def _uniform_negatives_session_rejected(num_items, shape, in_session_items):
+ negatives = []
+ for _ in range(np.prod(shape)):
+ negative = np.random.randint(1, num_items+1)
+ while negative in in_session_items:
+ negative = np.random.randint(1, num_items+1)
+ negatives.append(negative)
+ return np.array(negatives).reshape(shape)
+
+def _infer_shape(session_len, num_uniform_negatives, sampling_style):
+ if sampling_style=="eventwise":
+ return [session_len, num_uniform_negatives]
+ elif sampling_style=="sessionwise":
+ return [num_uniform_negatives]
+ else:
+ return []
+
+def sample_uniform(num_items, shape, in_session_items, reject_session_items):
+ if reject_session_items:
+ return _uniform_negatives_session_rejected(num_items, shape, in_session_items)
+ else:
+ return _uniform_negatives(num_items, shape)
+
+def sample_uniform_negatives_with_shape(clicks, num_items, session_len, num_uniform_negatives, sampling_style, reject_session_items):
+ in_session_items = set(clicks)
+ shape = _infer_shape(session_len, num_uniform_negatives, sampling_style)
+ if shape:
+ negatives = sample_uniform(num_items, shape, in_session_items, reject_session_items)
+ else:
+ negatives = np.array([])
+ return negatives
+
+
+def sample_in_batch_negatives(batch_positives, num_in_batch_negatives, batch_session_len, reject_session_items):
+ in_batch_negatives = []
+ positive_indices = itertools.accumulate(batch_session_len)
+ positive_indices = [0] + [p for p in positive_indices]
+ if reject_session_items:
+ for i in range(len(positive_indices[:-1])):
+ candidate_positives = batch_positives[:positive_indices[i]] + batch_positives[
+ positive_indices[i + 1]:]
+ in_batch_negatives.append(sample(candidate_positives, num_in_batch_negatives))
+ else:
+ for i in range(len(batch_session_len)):
+ in_batch_negatives.append(sample(batch_positives, num_in_batch_negatives))
+ return in_batch_negatives
\ No newline at end of file
diff --git a/src/shared/utils.py b/src/shared/utils.py
new file mode 100644
index 0000000..bba842a
--- /dev/null
+++ b/src/shared/utils.py
@@ -0,0 +1,11 @@
+def get_offsets(sessions_path):
+ line_offsets = []
+ with open(sessions_path, "rt") as f:
+ offset = 0
+ for line_idx, line in enumerate(f):
+ line_len = len(line)
+ line_offsets.append((line_len, line_idx, offset))
+ offset += line_len
+ line_offsets = [offset for _, _, offset in line_offsets]
+ return line_offsets
+
diff --git a/test/__init__.py b/test/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/test/resources/expected_filtered_sessions.jsonl b/test/resources/expected_filtered_sessions.jsonl
new file mode 100644
index 0000000..e013556
--- /dev/null
+++ b/test/resources/expected_filtered_sessions.jsonl
@@ -0,0 +1,5 @@
+{"session":12899779,"events":[{"aid":59626,"ts":1661724000278,"type":"clicks"},{"aid":875855,"ts":1661724026702,"type":"clicks"}]}
+{"session":12899780,"events":[{"aid":1142001,"ts":1661724000378,"type":"clicks"},{"aid":582733,"ts":1661724058352,"type":"clicks"},{"aid":973454,"ts":1661724109199,"type":"clicks"},{"aid":736516,"ts":1661724136868,"type":"clicks"},{"aid":1142001,"ts":1661724155248,"type":"clicks"},{"aid":260306,"ts":1661724221170,"type":"clicks"}]}
+{"session":12899781,"events":[{"aid":141737,"ts":1661724000559,"type":"clicks"},{"aid":199009,"ts":1661724022851,"type":"clicks"},{"aid":57316,"ts":1661724170835,"type":"clicks"},{"aid":194068,"ts":1661724246188,"type":"clicks"},{"aid":199009,"ts":1661780623778,"type":"clicks"},{"aid":199009,"ts":1661781274081,"type":"clicks"},{"aid":199009,"ts":1661804151788,"type":"clicks"},{"aid":199009,"ts":1662060028567,"type":"clicks"},{"aid":199009,"ts":1662060064706,"type":"clicks"},{"aid":918668,"ts":1662060160406,"type":"clicks"}]}
+{"session":12899782,"events":[{"aid":1669403,"ts":1661724000568,"type":"clicks"},{"aid":1494781,"ts":1661724163530,"type":"clicks"},{"aid":1494781,"ts":1661724190624,"type":"clicks"},{"aid":1494781,"ts":1661724203140,"type":"clicks"},{"aid":1674682,"ts":1661724816749,"type":"clicks"},{"aid":602723,"ts":1661724885670,"type":"clicks"},{"aid":1596099,"ts":1661725306961,"type":"clicks"},{"aid":45035,"ts":1661725434870,"type":"clicks"},{"aid":603160,"ts":1661725567598,"type":"clicks"},{"aid":413963,"ts":1661765601645,"type":"clicks"},{"aid":413963,"ts":1661765683962,"type":"clicks"},{"aid":779478,"ts":1661765990636,"type":"clicks"},{"aid":1037538,"ts":1661766045371,"type":"clicks"},{"aid":779478,"ts":1661766058833,"type":"clicks"},{"aid":476064,"ts":1661766167646,"type":"clicks"},{"aid":779478,"ts":1661766181059,"type":"clicks"},{"aid":562754,"ts":1661785953423,"type":"clicks"},{"aid":779478,"ts":1661785991473,"type":"clicks"},{"aid":562754,"ts":1661786014432,"type":"clicks"},{"aid":779478,"ts":1661786017266,"type":"clicks"},{"aid":476064,"ts":1661786024954,"type":"clicks"},{"aid":975117,"ts":1661786059083,"type":"clicks"},{"aid":595995,"ts":1661786240200,"type":"clicks"},{"aid":595995,"ts":1661786284439,"type":"clicks"},{"aid":1299063,"ts":1661786321468,"type":"clicks"},{"aid":1352726,"ts":1661802752201,"type":"clicks"},{"aid":1344774,"ts":1661802822218,"type":"clicks"},{"aid":794260,"ts":1661802866448,"type":"clicks"},{"aid":363337,"ts":1661802911345,"type":"clicks"},{"aid":829181,"ts":1661802976266,"type":"clicks"},{"aid":654810,"ts":1661803139163,"type":"clicks"},{"aid":723957,"ts":1661803162990,"type":"clicks"},{"aid":476064,"ts":1661803267739,"type":"clicks"},{"aid":975117,"ts":1661803282620,"type":"clicks"},{"aid":406002,"ts":1661803489097,"type":"clicks"},{"aid":889672,"ts":1661803500037,"type":"clicks"},{"aid":834355,"ts":1661803530264,"type":"clicks"},{"aid":889672,"ts":1661803545858,"type":"clicks"},{"aid":1099391,"ts":1661803555960,"type":"clicks"},{"aid":987400,"ts":1661803571124,"type":"clicks"},{"aid":638411,"ts":1661803600761,"type":"clicks"},{"aid":1072928,"ts":1661803621612,"type":"clicks"},{"aid":530900,"ts":1661803640817,"type":"clicks"},{"aid":229749,"ts":1661803651343,"type":"clicks"},{"aid":229749,"ts":1661803658101,"type":"clicks"},{"aid":740495,"ts":1661803681480,"type":"clicks"},{"aid":650187,"ts":1661971384912,"type":"clicks"},{"aid":1015404,"ts":1661982328030,"type":"clicks"},{"aid":1738522,"ts":1661982361511,"type":"clicks"}]}
+{"session":12899783,"events":[{"aid":255298,"ts":1661724000572,"type":"clicks"},{"aid":1114790,"ts":1661724004924,"type":"clicks"},{"aid":255298,"ts":1661724010953,"type":"clicks"},{"aid":300128,"ts":1661730417845,"type":"clicks"},{"aid":198386,"ts":1661730433368,"type":"clicks"},{"aid":300128,"ts":1661730438179,"type":"clicks"},{"aid":1729554,"ts":1661730496587,"type":"clicks"},{"aid":1216821,"ts":1661779640356,"type":"clicks"},{"aid":1754420,"ts":1661779655973,"type":"clicks"},{"aid":607639,"ts":1661779678936,"type":"clicks"},{"aid":1817896,"ts":1662041140398,"type":"clicks"},{"aid":455320,"ts":1662041200238,"type":"clicks"},{"aid":738412,"ts":1662041220812,"type":"clicks"},{"aid":946356,"ts":1662044060720,"type":"clicks"},{"aid":1351808,"ts":1662275228606,"type":"clicks"},{"aid":593244,"ts":1662275278107,"type":"clicks"},{"aid":1047552,"ts":1662275301331,"type":"clicks"},{"aid":13500,"ts":1662275325348,"type":"clicks"},{"aid":919342,"ts":1662276516147,"type":"clicks"},{"aid":332548,"ts":1662276539862,"type":"clicks"},{"aid":742845,"ts":1662279417428,"type":"clicks"}]}
diff --git a/test/resources/in_batch_negatives.jsonl b/test/resources/in_batch_negatives.jsonl
new file mode 100644
index 0000000..c465e75
--- /dev/null
+++ b/test/resources/in_batch_negatives.jsonl
@@ -0,0 +1,5 @@
+{"session": "0", "events": [{"ts": 1396314000.124, "type": "clicks", "aid": "1"}, {"ts": 1396314193.688, "type": "clicks", "aid": "1"}]}
+{"session": "1", "events": [{"ts": 1396314000.567, "type": "clicks", "aid": "2"}, {"ts": 1396314048.736, "type": "clicks", "aid": "2"}, {"ts": 1396314156.728, "type": "clicks", "aid": "2"}, {"ts": 1396314507.853, "type": "clicks", "aid": "2"}, {"ts": 1396314599.204, "type": "clicks", "aid": "2"}]}
+{"session": "10", "events": [{"ts": 1396314066.597, "type": "clicks", "aid": "3"}, {"ts": 1396314082.885, "type": "clicks", "aid": "3"}, {"ts": 1396314179.271, "type": "clicks", "aid": "3"}]}
+{"session": "100", "events": [{"ts": 1396315341.544, "type": "clicks", "aid": "4"}, {"ts": 1396315343.869, "type": "clicks", "aid": "4"}]}
+{"session": "1000", "events": [{"ts": 1396321206.746, "type": "clicks", "aid": "5"}, {"ts": 1396321242.609, "type": "clicks", "aid": "5"}, {"ts": 1396321247.929, "type": "clicks", "aid": "5"}]}
\ No newline at end of file
diff --git a/test/resources/test_diginetica-train-item-views.csv b/test/resources/test_diginetica-train-item-views.csv
new file mode 100644
index 0000000..77812ab
--- /dev/null
+++ b/test/resources/test_diginetica-train-item-views.csv
@@ -0,0 +1,21 @@
+session_id;user_id;item_id;timeframe;eventdate
+1;NA;1;594633;2023-06-12
+1;NA;2;1076889;2023-06-12
+1;NA;3;2109380;2023-06-12
+1;NA;4;164229;2023-06-12
+1;NA;5;1194078;2023-06-12
+1;NA;6;1127104;2023-06-12
+1;NA;7;358364;2023-06-12
+1;NA;8;3116072;2023-06-12
+1;NA;9;4136193;2023-06-12
+1;NA;10;9157201;2023-06-12
+2;NA;11;8161440;2023-06-12
+2;NA;12;3112140;2023-06-12
+2;NA;13;52254;2023-06-12
+2;NA;14;1123077;2023-06-12
+2;NA;15;3115861;2023-06-12
+2;NA;16;777297;2023-06-12
+2;NA;17;142019;2023-06-12
+2;NA;18;367911;2023-06-12
+2;NA;18;756206;2023-06-12
+2;NA;19;1080523;2023-06-12
diff --git a/test/resources/test_yoochoose-clicks.dat b/test/resources/test_yoochoose-clicks.dat
new file mode 100644
index 0000000..89a421b
--- /dev/null
+++ b/test/resources/test_yoochoose-clicks.dat
@@ -0,0 +1,19 @@
+1,2023-06-19T10:51:09.323Z,1,0
+1,2023-06-19T10:54:10.602Z,2,0
+1,2023-06-19T10:54:47.916Z,3,0
+1,2023-06-19T10:57:00.507Z,4,0
+2,2023-06-19T13:56:38.567Z,5,0
+2,2023-06-19T13:57:19.576Z,5,0
+2,2023-06-19T13:58:37.940Z,6,0
+2,2023-06-19T13:59:51.172Z,7,0
+2,2023-06-19T14:00:38.847Z,8,0
+2,2023-06-19T14:02:37.298Z,9,0
+3,2023-06-14T13:17:47.628Z,10,0
+3,2023-06-14T13:26:02.547Z,11,0
+3,2023-06-14T13:30:13.274Z,12,0
+4,2023-06-19T12:09:11.712Z,13,0
+4,2023-06-19T12:26:26.220Z,14,0
+6,2023-06-18T16:58:20.986Z,15,0
+6,2023-06-18T17:02:27.116Z,16,0
+7,2023-06-14T06:38:53.244Z,17,0
+7,2023-06-14T06:39:06.059Z,18,0
diff --git a/test/resources/train.jsonl b/test/resources/train.jsonl
new file mode 100644
index 0000000..9d7aeb4
--- /dev/null
+++ b/test/resources/train.jsonl
@@ -0,0 +1,10 @@
+{"session": "404391", "events": [{"aid": 33838, "ts": 1464127201.187, "type": "clicks"}, {"aid": 33838, "ts": 1464127201.522, "type": "clicks"}, {"aid": 4759, "ts": 1464127218.472, "type": "clicks"}, {"aid": 15406, "ts": 1464127243.334, "type": "clicks"}, {"aid": 12887, "ts": 1464127245.905, "type": "clicks"}, {"aid": 27601, "ts": 1464127251.938, "type": "clicks"}, {"aid": 15406, "ts": 1464127265.936, "type": "clicks"}, {"aid": 14564, "ts": 1464127406.279, "type": "clicks"}]}
+{"session": "308074", "events": [{"aid": 36617, "ts": 1464127203.563, "type": "clicks"}, {"aid": 34257, "ts": 1464127284.097, "type": "clicks"}]}
+{"session": "38422", "events": [{"aid": 31292, "ts": 1464127203.968, "type": "clicks"}, {"aid": 18083, "ts": 1464127243.181, "type": "clicks"}, {"aid": 12957, "ts": 1464127265.175, "type": "clicks"}]}
+{"session": "38847", "events": [{"aid": 14138, "ts": 1464127204.089, "type": "clicks"}, {"aid": 8977, "ts": 1464127821.903, "type": "clicks"}]}
+{"session": "158982", "events": [{"aid": 30011, "ts": 1464127204.21, "type": "clicks"}, {"aid": 2750, "ts": 1464127221.63, "type": "clicks"}, {"aid": 22418, "ts": 1464127230.282, "type": "clicks"}]}
+{"session": "255636", "events": [{"aid": 232, "ts": 1464127204.638, "type": "clicks"}, {"aid": 38279, "ts": 1464127257.395, "type": "clicks"}, {"aid": 30473, "ts": 1464127772.71, "type": "clicks"}]}
+{"session": "127854", "events": [{"aid": 11009, "ts": 1464127253.553, "type": "clicks"}, {"aid": 37445, "ts": 1464127532.556, "type": "clicks"}, {"aid": 11481, "ts": 1464127564.49, "type": "clicks"}, {"aid": 14647, "ts": 1464127766.499, "type": "clicks"}, {"aid": 22219, "ts": 1464127786.091, "type": "clicks"}, {"aid": 25188, "ts": 1464127930.796, "type": "clicks"}, {"aid": 11118, "ts": 1464127938.364, "type": "clicks"}, {"aid": 39930, "ts": 1464128008.524, "type": "clicks"}, {"aid": 537, "ts": 1464128106.51, "type": "clicks"}, {"aid": 3957, "ts": 1464128212.899, "type": "clicks"}, {"aid": 29957, "ts": 1464128380.66, "type": "clicks"}]}
+{"session": "78260", "events": [{"aid": 7946, "ts": 1464127204.774, "type": "clicks"}, {"aid": 25587, "ts": 1464127643.251, "type": "clicks"}]}
+{"session": "444838", "events": [{"aid": 8507, "ts": 1464127204.893, "type": "clicks"}, {"aid": 31649, "ts": 1464127306.19, "type": "clicks"}, {"aid": 16618, "ts": 1464127425.898, "type": "clicks"}, {"aid": 399, "ts": 1464127472.577, "type": "clicks"}]}
+{"session": "38247", "events": [{"aid": 12916, "ts": 1464127204.968, "type": "clicks"}, {"aid": 19221, "ts": 1464127272.281, "type": "clicks"}, {"aid": 39839, "ts": 1464127425.703, "type": "clicks"}, {"aid": 686, "ts": 1464127574.944, "type": "clicks"}, {"aid": 4869, "ts": 1464127592.098, "type": "clicks"}, {"aid": 686, "ts": 1464127609.094, "type": "clicks"}, {"aid": 6658, "ts": 1464127661.027, "type": "clicks"}, {"aid": 686, "ts": 1464127674.951, "type": "clicks"}, {"aid": 4869, "ts": 1464127683.549, "type": "clicks"}, {"aid": 27148, "ts": 1464127712.127, "type": "clicks"}]}
diff --git a/test/resources/unfiltered_sessions.jsonl b/test/resources/unfiltered_sessions.jsonl
new file mode 100644
index 0000000..167453c
--- /dev/null
+++ b/test/resources/unfiltered_sessions.jsonl
@@ -0,0 +1,5 @@
+{"session":12899779,"events":[{"aid":59625,"ts":1661724000278,"type":"clicks"},{"aid":875854,"ts":1661724026702,"type":"clicks"}]}
+{"session":12899780,"events":[{"aid":1142000,"ts":1661724000378,"type":"clicks"},{"aid":582732,"ts":1661724058352,"type":"clicks"},{"aid":973453,"ts":1661724109199,"type":"clicks"},{"aid":736515,"ts":1661724136868,"type":"clicks"},{"aid":1142000,"ts":1661724155248,"type":"clicks"},{"aid":260305,"ts":1661724221170,"type":"clicks"}]}
+{"session":12899781,"events":[{"aid":141736,"ts":1661724000559,"type":"clicks"},{"aid":199008,"ts":1661724022851,"type":"clicks"},{"aid":57315,"ts":1661724170835,"type":"clicks"},{"aid":194067,"ts":1661724246188,"type":"clicks"},{"aid":199008,"ts":1661780623778,"type":"clicks"},{"aid":199008,"ts":1661781274081,"type":"clicks"},{"aid":199008,"ts":1661781409993,"type":"carts"},{"aid":199008,"ts":1661804151788,"type":"clicks"},{"aid":199008,"ts":1662060028567,"type":"clicks"},{"aid":199008,"ts":1662060064706,"type":"clicks"},{"aid":918667,"ts":1662060160406,"type":"clicks"},{"aid":918667,"ts":1662060261372,"type":"carts"}]}
+{"session":12899782,"events":[{"aid":1669402,"ts":1661724000568,"type":"clicks"},{"aid":1494780,"ts":1661724163530,"type":"clicks"},{"aid":1494780,"ts":1661724190624,"type":"clicks"},{"aid":1494780,"ts":1661724203140,"type":"clicks"},{"aid":1494780,"ts":1661724244341,"type":"carts"},{"aid":1674681,"ts":1661724816749,"type":"clicks"},{"aid":602722,"ts":1661724885670,"type":"clicks"},{"aid":1596098,"ts":1661725306961,"type":"clicks"},{"aid":45034,"ts":1661725434870,"type":"clicks"},{"aid":603159,"ts":1661725567598,"type":"clicks"},{"aid":413962,"ts":1661765601645,"type":"clicks"},{"aid":413962,"ts":1661765608861,"type":"carts"},{"aid":413962,"ts":1661765683962,"type":"clicks"},{"aid":779477,"ts":1661765990636,"type":"clicks"},{"aid":1037537,"ts":1661766045371,"type":"clicks"},{"aid":779477,"ts":1661766058833,"type":"clicks"},{"aid":779477,"ts":1661766162910,"type":"carts"},{"aid":476063,"ts":1661766167646,"type":"clicks"},{"aid":562753,"ts":1661766178974,"type":"carts"},{"aid":779477,"ts":1661766181059,"type":"clicks"},{"aid":562753,"ts":1661785953423,"type":"clicks"},{"aid":476063,"ts":1661785989572,"type":"carts"},{"aid":779477,"ts":1661785991473,"type":"clicks"},{"aid":779477,"ts":1661786009326,"type":"carts"},{"aid":562753,"ts":1661786014432,"type":"clicks"},{"aid":779477,"ts":1661786017266,"type":"clicks"},{"aid":476063,"ts":1661786024954,"type":"clicks"},{"aid":975116,"ts":1661786059083,"type":"clicks"},{"aid":975116,"ts":1661786078505,"type":"carts"},{"aid":595994,"ts":1661786240200,"type":"clicks"},{"aid":595994,"ts":1661786254403,"type":"carts"},{"aid":595994,"ts":1661786284439,"type":"clicks"},{"aid":1299062,"ts":1661786321468,"type":"clicks"},{"aid":1352725,"ts":1661802752201,"type":"clicks"},{"aid":1344773,"ts":1661802822218,"type":"clicks"},{"aid":1344773,"ts":1661802836762,"type":"carts"},{"aid":794259,"ts":1661802866448,"type":"clicks"},{"aid":363336,"ts":1661802911345,"type":"clicks"},{"aid":829180,"ts":1661802976266,"type":"clicks"},{"aid":1711180,"ts":1661803002384,"type":"carts"},{"aid":127404,"ts":1661803019431,"type":"carts"},{"aid":654809,"ts":1661803139163,"type":"clicks"},{"aid":723956,"ts":1661803162990,"type":"clicks"},{"aid":476063,"ts":1661803267739,"type":"clicks"},{"aid":975116,"ts":1661803282620,"type":"clicks"},{"aid":406001,"ts":1661803489097,"type":"clicks"},{"aid":889671,"ts":1661803500037,"type":"clicks"},{"aid":889671,"ts":1661803510829,"type":"carts"},{"aid":834354,"ts":1661803530264,"type":"clicks"},{"aid":834354,"ts":1661803539038,"type":"carts"},{"aid":889671,"ts":1661803545858,"type":"clicks"},{"aid":1099390,"ts":1661803555960,"type":"clicks"},{"aid":987399,"ts":1661803571124,"type":"clicks"},{"aid":987399,"ts":1661803578438,"type":"carts"},{"aid":638410,"ts":1661803600761,"type":"clicks"},{"aid":1072927,"ts":1661803621612,"type":"clicks"},{"aid":530899,"ts":1661803640817,"type":"clicks"},{"aid":229748,"ts":1661803651343,"type":"clicks"},{"aid":229748,"ts":1661803658101,"type":"clicks"},{"aid":740494,"ts":1661803681480,"type":"clicks"},{"aid":740494,"ts":1661803697626,"type":"carts"},{"aid":834354,"ts":1661803710262,"type":"carts"},{"aid":1669402,"ts":1661803953178,"type":"orders"},{"aid":829180,"ts":1661803953178,"type":"orders"},{"aid":1696036,"ts":1661803953178,"type":"orders"},{"aid":479970,"ts":1661803953178,"type":"orders"},{"aid":834354,"ts":1661803953178,"type":"orders"},{"aid":1033148,"ts":1661803953178,"type":"orders"},{"aid":595994,"ts":1661803953178,"type":"orders"},{"aid":1007613,"ts":1661803953178,"type":"orders"},{"aid":987399,"ts":1661803953178,"type":"orders"},{"aid":740494,"ts":1661803953178,"type":"orders"},{"aid":127404,"ts":1661803953178,"type":"orders"},{"aid":1711180,"ts":1661803953178,"type":"orders"},{"aid":650186,"ts":1661971384912,"type":"clicks"},{"aid":1015403,"ts":1661982328030,"type":"clicks"},{"aid":1738521,"ts":1661982361511,"type":"clicks"}]}
+{"session":12899783,"events":[{"aid":255297,"ts":1661724000572,"type":"clicks"},{"aid":1114789,"ts":1661724004924,"type":"clicks"},{"aid":255297,"ts":1661724010953,"type":"clicks"},{"aid":300127,"ts":1661730417845,"type":"clicks"},{"aid":198385,"ts":1661730433368,"type":"clicks"},{"aid":300127,"ts":1661730438179,"type":"clicks"},{"aid":1729553,"ts":1661730496587,"type":"clicks"},{"aid":1216820,"ts":1661779640356,"type":"clicks"},{"aid":1754419,"ts":1661779655973,"type":"clicks"},{"aid":607638,"ts":1661779678936,"type":"clicks"},{"aid":1817895,"ts":1662041140398,"type":"clicks"},{"aid":455319,"ts":1662041200238,"type":"clicks"},{"aid":738411,"ts":1662041220812,"type":"clicks"},{"aid":946355,"ts":1662044060720,"type":"clicks"},{"aid":1351807,"ts":1662275228606,"type":"clicks"},{"aid":593243,"ts":1662275278107,"type":"clicks"},{"aid":1047551,"ts":1662275301331,"type":"clicks"},{"aid":13499,"ts":1662275325348,"type":"clicks"},{"aid":919341,"ts":1662276516147,"type":"clicks"},{"aid":332547,"ts":1662276539862,"type":"clicks"},{"aid":742844,"ts":1662279417428,"type":"clicks"}]}
diff --git a/test/test_evaluate.py b/test/test_evaluate.py
new file mode 100644
index 0000000..ad58228
--- /dev/null
+++ b/test/test_evaluate.py
@@ -0,0 +1,157 @@
+import torch
+from torch import allclose, equal, tensor
+
+from src.shared.evaluate import (calculate_ranks, mean_metric, pointwise_mrr,
+ pointwise_recall,
+ validate_batch_per_timestamp)
+
+
+def test_pointwise_recall():
+ ranks = tensor([[1., 0., float('inf')],
+ [0., 1., 1.],
+ [2., 2., 1.]])
+
+ mask = tensor([[1., 1., 1.],
+ [1., 1., 0.],
+ [0., 0., 0.]])
+
+ cutoffs = tensor([1, 3]).int()
+
+ expected = tensor([
+ [
+ [0., 1., 0.], # k=1
+ [1., 0., 0.], # k=1
+ [0., 0., 0.] # k=1
+ ],
+ [
+ [1., 1., 0.], # k=3
+ [1., 1., 0.], # k=3
+ [0., 0., 0.] # k=3
+ ]
+ ])
+ assert allclose(pointwise_recall(ranks, cutoffs, mask), expected)
+
+
+def test_pointwise_recall_per_timestamp():
+ logits = tensor([[5., 6., 7., 1.],
+ [5., 6., 7., 1.],
+ [5., 6., 7., 1.]])
+
+ labels = tensor([1, 2, 0])
+
+ mask = tensor([1., 1., 0.])
+
+ cutoffs = tensor([1, 3]).int()
+
+ expected = tensor([[[0., 1., 0.]], # k=1
+ [[1., 1., 0.]]]) # k=3
+ ranks = calculate_ranks(logits, labels, cutoffs)
+ assert allclose(pointwise_recall(ranks, cutoffs, mask), expected)
+
+
+def test_mean_recall():
+ recall_matrix = tensor([
+ [
+ [0., 0., 1.], # k=1
+ [0., 0., 0.], # k=1
+ ],
+ [
+ [0., 0., 0.], # k=3
+ [1., 1., 0.], # k=3
+ ]
+ ])
+
+ mask = tensor([[1., 1., 1.],
+ [1., 1., 0.]])
+
+ expected = tensor([
+ ((0 + 0 + 1) + (0 + 0 + 0)) / 5, # k=1
+ ((0 + 0 + 0) + (0 + 1 + 1)) / 5 # k=3
+ ])
+ assert allclose(mean_metric(recall_matrix, mask), expected)
+
+
+def test_calculate_ranks():
+ logits = tensor([[[5., 6., 7., 1.], [1., 2., 3., 0.], [4., 5., 1., 0.]],
+ [[5., 6., 7., 1.], [1., 2., 3., 0.], [4., 5., 1., 0.]],
+ [[5., 6., 7., 1.], [1., 2., 3., 0.], [4., 5., 1., 0.]]])
+
+ labels = tensor([[1, 2, 3],
+ [2, 1, 0],
+ [0, 0, 0]])
+
+ cutoffs = tensor([1, 3]).int()
+
+ expected_ranks = tensor([[1., 0., float('inf')],
+ [0., 1., 1.],
+ [2., 2., 1.]])
+
+ assert equal(calculate_ranks(logits, labels, cutoffs), expected_ranks)
+
+
+def test_pointwise_mrr():
+ ranks = tensor([[1., 0., float('inf')],
+ [0., 1., 1.],
+ [2., 2., 1.]])
+
+ mask = tensor([[1., 1., 1.],
+ [1., 1., 0.],
+ [0., 0., 0.]])
+
+ cutoffs = tensor([1, 3]).int()
+
+ expected = tensor([[[0., 1., 0.],
+ [1., 0., 0.],
+ [0., 0., 0.]],
+
+ [[.5, 1., 0.],
+ [1., .5, 0.],
+ [0., 0., 0.]]])
+ assert equal(pointwise_mrr(ranks, cutoffs, mask), expected)
+
+
+def test_mean_mrr():
+ mrr_matrix = tensor([[[0., 1., 0.],
+ [1., 0., 0.],
+ [0., 0., 0.]],
+
+ [[.5, 1., 0.],
+ [1., .5, 0.],
+ [0., 0., 0.]]])
+
+ mask = tensor([[1., 1., 1.],
+ [1., 1., 0.],
+ [0., 0., 0.]])
+
+ expected = tensor([
+ ((0 + 1 + 0) + (1 + 0 + 0) + (0 + 0 + 0)) / 5, # k=1
+ ((.5 + 1 + 0) + (1 + .5 + 0) + (0 + 0 + 0)) / 5 # k=3
+ ])
+
+ assert equal(mean_metric(mrr_matrix, mask), expected)
+
+def test_validate_batch_per_timestamp():
+ x_hat = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4]],
+ [[0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4]]])
+ output_embedding = torch.nn.Embedding(100, 4)
+
+ batch = {
+ 'labels': torch.tensor([[1, 2, 0],
+ [1, 2, 0]]),
+ 'mask': torch.tensor([[0., 1., 1.],
+ [1., 1., 1.]])
+ }
+ cut_offs = torch.tensor([1, 3]).int()
+
+ recalls, mrrs = validate_batch_per_timestamp(batch, x_hat, output_embedding, cut_offs)
+
+ assert recalls.shape == torch.Size([2])
+ assert mrrs.shape == torch.Size([2])
+ assert torch.greater_equal(recalls, 0).all()
+ assert torch.less_equal(recalls, 1).all()
+ assert torch.greater_equal(mrrs, 0).all()
+ assert torch.less_equal(mrrs, 1).all()
diff --git a/test/test_gru_dataset.py b/test/test_gru_dataset.py
new file mode 100644
index 0000000..ce2013d
--- /dev/null
+++ b/test/test_gru_dataset.py
@@ -0,0 +1,103 @@
+import torch
+from torch import equal, tensor
+from torch.utils.data.dataloader import DataLoader
+
+from src.gru4rec.dataset import (Gru4RecDataset, get_inactive_buffer_sessions,
+ label_session)
+
+
+def test_label_session():
+ session = [
+ {'aid': 33838, 'ts': 1464127201.522, 'type': 'clicks'},
+ {'aid': 4759, 'ts': 1464127218.472, 'type': 'clicks'},
+ {'aid': 15406, 'ts': 1464127243.334, 'type': 'clicks'},
+ {'aid': 12887, 'ts': 1464127245.905, 'type': 'clicks'},
+ {'aid': 27601, 'ts': 1464127251.938, 'type': 'clicks'},
+ {'aid': 15406, 'ts': 1464127265.936, 'type': 'clicks'},
+ {'aid': 14564, 'ts': 1464127406.279, 'type': 'clicks'}
+ ]
+
+ expected = [
+ {'aid': 33838, 'ts': 1464127201.522, 'type': 'clicks', 'label': 4759},
+ {'aid': 4759, 'ts': 1464127218.472, 'type': 'clicks', 'label': 15406},
+ {'aid': 15406, 'ts': 1464127243.334, 'type': 'clicks', 'label': 12887},
+ {'aid': 12887, 'ts': 1464127245.905, 'type': 'clicks', 'label': 27601},
+ {'aid': 27601, 'ts': 1464127251.938, 'type': 'clicks', 'label': 15406},
+ {'aid': 15406, 'ts': 1464127265.936, 'type': 'clicks', 'label': 14564}
+ ]
+
+ assert label_session(session) == expected
+
+
+def test_get_inactive_buffer_sessions():
+ labeled_session_buffer = [
+ [],
+ [{'aid': 33838, 'ts': 1464127201.522, 'type': 'clicks', 'label': 4759}, {'aid': 4759, 'ts': 1464127218.472, 'type': 'clicks', 'label': 15406}],
+ []
+ ]
+ expected = [0, 2]
+
+ assert get_inactive_buffer_sessions(labeled_session_buffer) == expected
+
+
+def test_dataset():
+ session_path = "test/resources/train.jsonl"
+ dataset = Gru4RecDataset(session_path, total_sessions=10, num_items=40_000, max_seqlen=6, batch_size=3, shuffling_style="no_shuffling", sampling_style="eventwise", num_uniform_negatives=5, reject_uniform_session_items=True)
+
+ expected_first_batch = {'clicks': tensor([33838, 36617, 31292]), 'labels': tensor([[4759], [34257], [18083]]), 'keep_state': tensor([[0.], [0.], [0.]])}
+ expected_second_batch = {'clicks': tensor([4759, 14138, 18083]), 'labels': tensor([[15406], [8977], [12957]]), 'keep_state': tensor([[1.], [0.], [1.]])}
+ first_batch = next(dataset.__iter__())
+
+ assert equal(first_batch['clicks'], expected_first_batch['clicks'])
+ assert equal(first_batch['labels'], expected_first_batch['labels'])
+ assert equal(first_batch['keep_state'], expected_first_batch['keep_state'])
+ assert first_batch['uniform_negatives'].shape == torch.Size([3, 5])
+ assert first_batch['in_batch_negatives'].shape == torch.Size([3, 2])
+
+ dataset.sampling_style = 'sessionwise'
+ second_batch = next(dataset.__iter__())
+ assert equal(second_batch['clicks'], expected_second_batch['clicks'])
+ assert equal(second_batch['labels'], expected_second_batch['labels'])
+ assert equal(second_batch['keep_state'], expected_second_batch['keep_state'])
+ assert second_batch['uniform_negatives'].shape == torch.Size([3, 5])
+ assert second_batch['in_batch_negatives'].shape == torch.Size([3, 2])
+
+ dataset.sampling_style = 'batchwise'
+ third_batch = next(dataset.__iter__())
+ assert third_batch['uniform_negatives'].shape == torch.Size([1, 5])
+ assert third_batch['in_batch_negatives'].shape == torch.Size([3, 2])
+
+
+def test_datalaoder():
+ session_path = "test/resources/train.jsonl"
+ dataset = Gru4RecDataset(session_path, total_sessions=10, num_items=40_000, max_seqlen=6, batch_size=3, shuffling_style="no_shuffling", sampling_style="eventwise")
+ dataloader = DataLoader(dataset,
+ batch_size=1,
+ shuffle=False,
+ drop_last=True,
+ collate_fn=dataset.dynamic_collate)
+
+ expected = {'clicks': tensor([ 4869, 39930, 16618]), 'labels': tensor([[686], [537], [399]]), 'keep_state': tensor([[1.], [1.], [1.]])}
+
+ for batch in dataloader:
+ last_batch = batch
+ assert True
+ assert equal(last_batch['clicks'], expected['clicks'])
+ assert equal(last_batch['labels'], expected['labels'])
+ assert equal(last_batch['keep_state'], expected['keep_state'])
+ assert len(last_batch['uniform_negatives'].tolist()) == 3
+ assert len(last_batch['in_batch_negatives'].tolist()) == 3
+
+
+def test_datalaoder_batchsize_too_large():
+ session_path = "test/resources/train.jsonl"
+ dataset = Gru4RecDataset(session_path, total_sessions=10, num_items=40_000, max_seqlen=6, batch_size=11, shuffling_style="no_shuffling")
+ dataloader = DataLoader(dataset,
+ batch_size=1,
+ shuffle=False,
+ collate_fn=dataset.dynamic_collate)
+
+ for batch in dataloader:
+ assert False
+ assert len(dataloader.dataset.labeled_session_buffer) == 11
+ assert dataloader.dataset.labeled_session_buffer[-1] == []
\ No newline at end of file
diff --git a/test/test_gru_model.py b/test/test_gru_model.py
new file mode 100644
index 0000000..15452b9
--- /dev/null
+++ b/test/test_gru_model.py
@@ -0,0 +1,97 @@
+import torch
+from torch import allclose, equal, tensor
+
+from src.gru4rec.model import GRU4REC, clean_state
+
+batch = {
+ 'clicks': tensor([1, 2]),
+ 'labels': tensor([[2], [3]]),
+ 'in_batch_negatives': tensor([
+ [[5, 6]],
+ [[6, 4]]
+ ]),
+ 'uniform_negatives': tensor([
+ [[5,6,7]],
+ [[4,5,6]]
+ ]),
+ 'keep_state': tensor([
+ [1.], [1.]
+ ]),
+ 'mask': tensor([
+ [1., 1.],
+ ]),
+ }
+
+def test_clean_state():
+ curr_state = torch.ones(2, 3, 4)
+ keep_state = tensor([[1.], [0.], [1.]])
+
+ expected = tensor([
+ [[1.,1.,1.,1.],[0.,0.,0.,0.],[1.,1.,1.,1.]],
+ [[1.,1.,1.,1.],[0.,0.,0.,0.],[1.,1.,1.,1.]]
+ ])
+ assert equal(clean_state(curr_state, keep_state), expected)
+
+
+def test_gru4Rec():
+ model = GRU4REC(num_items=40_000,hidden_size=10,num_layers=2,batch_size=3, dropout_rate=0.)
+ click_indices = tensor([33838, 33838, 33838])
+ in_state = torch.ones(2, 3, 10) # num_layer, batch_size, hidden_dim
+ keep_state = tensor([[1.], [0.], [1.]])
+
+ gru_output, out_state = model.forward(click_indices, in_state, keep_state)
+
+ assert gru_output.shape == torch.Size([3, 1, 10])
+ assert out_state.shape == torch.Size([2, 3, 10])
+ assert not allclose(gru_output[0], gru_output[1])
+ assert allclose(gru_output[0], gru_output[2])
+
+
+def test_gru4Re_with_output_bias():
+ model = GRU4REC(num_items=40_000,hidden_size=10,num_layers=2,batch_size=3, dropout_rate=0., output_bias=True)
+ click_indices = tensor([33838, 33838, 33838])
+ in_state = torch.ones(2, 3, 10) # num_layer, batch_size, hidden_dim
+ keep_state = tensor([[1.], [0.], [1.]])
+
+ gru_output, out_state = model.forward(click_indices, in_state, keep_state)
+
+ assert gru_output.shape == torch.Size([3, 1, 11])
+ assert out_state.shape == torch.Size([2, 3, 10])
+ assert not allclose(gru_output[0], gru_output[1])
+ assert allclose(gru_output[0], gru_output[2])
+
+
+def test_training_step_shared_no_bias():
+ model = GRU4REC(num_items=40_000,hidden_size=10,num_layers=2,batch_size=2, dropout_rate=0.)
+
+ loss = model.training_step(batch, None)
+ assert loss.shape == torch.Size([])
+ assert not allclose(model.current_state, torch.zeros(2,2,10))
+
+
+def test_training_step_not_shared_no_bias():
+ model = GRU4REC(num_items=40_000,hidden_size=10,num_layers=2,batch_size=2, dropout_rate=0., output_bias=False, share_embeddings=False)
+
+ loss = model.training_step(batch, None)
+ assert loss.shape == torch.Size([])
+ assert not allclose(model.current_state, torch.zeros(2,2,10))
+
+def test_training_step_not_shared_bias():
+ model = GRU4REC(num_items=40_000,hidden_size=10,num_layers=2,batch_size=2, dropout_rate=0., output_bias=True, share_embeddings=False)
+
+ loss = model.training_step(batch, None)
+ assert loss.shape == torch.Size([])
+ assert not allclose(model.current_state, torch.zeros(2,2,10))
+
+
+def test_training_step_not_shared_bias():
+ model = GRU4REC(num_items=40_000,hidden_size=10,num_layers=2,batch_size=2, dropout_rate=0., output_bias=True, share_embeddings=False, original_gru=True)
+
+ loss = model.training_step(batch, None)
+ assert loss.shape == torch.Size([])
+ assert not allclose(model.current_state, torch.zeros(2,2,10))
+
+def test_validation_step():
+ model = GRU4REC(num_items=40_000,hidden_size=10,num_layers=2,batch_size=2, dropout_rate=0.)
+ model.validation_step(batch, None)
+ assert True
diff --git a/test/test_logits_computation.py b/test/test_logits_computation.py
new file mode 100644
index 0000000..59f0cfc
--- /dev/null
+++ b/test/test_logits_computation.py
@@ -0,0 +1,87 @@
+import torch
+from torch import equal, tensor
+from torch.nn import Embedding
+
+from src.shared.logits_computation import (lookup_and_multiply,
+ multiply_head_with_embedding)
+
+
+def test_multiply_head_with_embedding_batchwise():
+ transformer_head = tensor([[[1.,0.],[0., 1.],[0.,0.]], [[1.,1.], [2., 0.],[0.,2.]]]) # [2,3,2] (batch_size, sequence_length, embedding_size)
+ batchwise_negative_embedding = tensor([[1.,0.], [1., 1.]]) # [2,2] (negative_samples, embedding_size)
+ expected_batchwise_multiplication = tensor([[[1. , 1.], [0., 1.], [0.,0.]], [[1., 2.], [2. , 2.], [0., 2.]]]) # [2,3,2] (batch_size, sequence_length, negative_samples)
+ assert equal(multiply_head_with_embedding(transformer_head,batchwise_negative_embedding), expected_batchwise_multiplication)
+
+
+def test_multiply_head_with_embedding_sessionwise():
+ transformer_head = tensor([[[1.,0.],[0., 1.],[0.,0.]], [[1.,1.], [2., 0.],[0.,2.]]]) # [2,3,2] (batch_size, sequence_length, embedding_size)
+ sessionwise_negative_embedding = tensor([[[1.,0.], [1., 1.]],[[1.,1.], [2., -1.]]]) # [2,2,2] (batch_size, negative_samples, embedding_size)
+ expected_sessionwise_multiplication = tensor([[[1. , 1.], [0., 1.], [0.,0.]], [[2., 1.], [2. , 4.], [2., -2.]]]) # [2,3,2] (batch_size, sequence_length, negative_samples)
+ assert equal(multiply_head_with_embedding(transformer_head,sessionwise_negative_embedding), expected_sessionwise_multiplication)
+
+
+def test_multiply_head_with_embedding_eventwise():
+ transformer_head = tensor([[[1.,0.],[0., 1.],[0.,0.]], [[1.,1.], [2., 0.],[0.,2.]]]).unsqueeze(-2) # [2,3,1,2] (batch_size, sequence_length, 1, embedding_size)
+ eventwise_negative_embedding = tensor([[[[1.,0.], [1., 1.]],[[0., 1.], [1., 0.]], [[0.,0.], [0., 0.]]],[[[1.,1.], [2., 2.]], [[2., 0.], [3., 0.]],[[0.,1.], [0., 3.]]]]) # [2,3,2,2] (batch_size, sequence_length, negative_samples, embedding_size)
+ expected_eventwise_multiplication = tensor([[[1. , 1.], [1., 0.], [0.,0.]], [[2., 4.], [4. , 6.], [2., 6.]]]) # [2,3,2] (batch_size, sequence_length, negative_samples)
+ assert equal(multiply_head_with_embedding(transformer_head,eventwise_negative_embedding).squeeze(-2), expected_eventwise_multiplication)
+
+
+def test_multiply_head_with_embedding_positives():
+ transformer_head = tensor([[[1.,0.],[0., 1.],[0.,0.]], [[1.,1.], [2., 0.],[0.,2.]]]).unsqueeze(-2) # [2,3,1,2] (batch_size, sequence_length, embedding_size)
+ eventwise_positive_embedding = tensor([[[1.,0.], [0., 1.], [0.,0.]],[[1.,1.], [2., 0.], [0.,1.]]]).unsqueeze(-2) # [2,3,1,2] (batch_size, sequence_length, embedding_size)
+ expected_positive_multiplication = tensor([[1., 1., 0.,], [2., 4., 2.]]) # [2,3] (batch_size, sequence_length)
+ assert equal(multiply_head_with_embedding(transformer_head,eventwise_positive_embedding).squeeze(-1).squeeze(-1), expected_positive_multiplication)
+
+
+def test_lookup_and_multiply_eventwise():
+ transformer_head = tensor([[[1.,0.],[0., 1.],[0.,0.]], [[1.,1.], [2., 0.],[0.,2.]]])
+ embedding_layer = Embedding(20, 2, padding_idx=0)
+ positives = tensor([[2,5,6],[7,9,8]], dtype=torch.long)
+ eventwise_uniform_negatives = tensor([[[1, 2],[3, 4], [5, 6]],[[7, 8], [9, 10],[11, 12]]], dtype=torch.long)
+ in_batch_negatives = tensor([[1, 3, 4],[6, 5, 8]], dtype=torch.long)
+
+ expected_pos_logits_shape = [2,3,1]
+ expected_neg_logits_shape = [2,3,5]
+ actual_pos_logits, actual_neg_logits = lookup_and_multiply(transformer_head, positives, eventwise_uniform_negatives, in_batch_negatives, embedding_layer, 'eventwise')
+
+ assert actual_pos_logits.shape == torch.Size(expected_pos_logits_shape)
+ assert actual_neg_logits.shape == torch.Size(expected_neg_logits_shape)
+
+
+def test_lookup_and_multiply_no_uniform_negatives():
+ transformer_head = tensor([[[1.,0.],[0., 1.],[0.,0.]], [[1.,1.], [2., 0.],[0.,2.]]])
+ embedding_layer = Embedding(20, 2, padding_idx=0)
+ positives = tensor([[2,5,6],[7,9,8]], dtype=torch.long) #(batch_size, seqlen)
+ elementwise_uniform_negatives = tensor([[[],[], []],[[], [],[]]], dtype=torch.long)
+ in_batch_negatives = tensor([[1, 3, 4],[6, 5, 8]], dtype=torch.long)
+
+ expected_pos_logits_shape = [2,3,1]
+ expected_neg_logits_shape = [2,3,3]
+ actual_pos_logits, actual_neg_logits = lookup_and_multiply(transformer_head, positives, elementwise_uniform_negatives, in_batch_negatives, embedding_layer, 'eventwise')
+
+ assert actual_pos_logits.shape == torch.Size(expected_pos_logits_shape)
+ assert actual_neg_logits.shape == torch.Size(expected_neg_logits_shape)
+
+
+def test_lookup_and_multiply_no_in_batch_negatives():
+ transformer_head = tensor([[[1.,0.],[0., 1.],[0.,0.]], [[1.,1.], [2., 0.],[0.,2.]]])
+ embedding_layer = Embedding(20, 2, padding_idx=0)
+ positives = tensor([[2,5,6],[7,9,8]], dtype=torch.long)
+ elementwise_uniform_negatives = tensor([[[1, 2],[3, 4], [5, 6]],[[7, 8], [9, 10],[11, 12]]], dtype=torch.long)
+ in_batch_negatives = tensor([[],[]], dtype=torch.long)
+
+ expected_pos_logits_shape = [2,3,1]
+ expected_neg_logits_shape = [2,3,2]
+ actual_pos_logits, actual_neg_logits = lookup_and_multiply(transformer_head, positives, elementwise_uniform_negatives, in_batch_negatives, embedding_layer, 'eventwise')
+
+ assert actual_pos_logits.shape == torch.Size(expected_pos_logits_shape)
+ assert actual_neg_logits.shape == torch.Size(expected_neg_logits_shape)
+
+
+def test_multiply_transformerhead_with_candidates_per_timestamp():
+ transformer_head = tensor([[1.,0.], [1.,1.]])
+ positive_embedding = tensor([[.2, .4],[.1, .2]])
+ expected_multiplication = tensor([[.2, .1], [.6, .3]])
+
+ assert equal(multiply_head_with_embedding(transformer_head, positive_embedding), expected_multiplication)
diff --git a/test/test_loss.py b/test/test_loss.py
new file mode 100644
index 0000000..97d6926
--- /dev/null
+++ b/test/test_loss.py
@@ -0,0 +1,99 @@
+import torch
+from torch import sigmoid, softmax, tensor
+
+from src.shared.loss import (_bpr_max_loss_regularization,
+ _bpr_max_loss_unregulized, _diff_logits,
+ _elementwise_bpr_max_loss_per_negative,
+ _elementwise_sampled_softmax_loss, bce_loss,
+ bpr_max_loss, sampled_softmax_loss)
+
+
+def test_elementwise_sampled_softmax_loss():
+ pos_logits = tensor([[[1], [2], [3], [4], [4]], [[0], [3], [2], [1], [4]]], dtype=torch.float)
+ neg_logits = tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [6, 3, 8], [6, 3, 8]], [
+ [0, 0, 0], [9, 8, 7], [6, 5, 4], [3, 2, 1], [6, 3, 8]]], dtype=torch.float)
+ mask = tensor([[1., 1., 1., 1., 1.], [0., 1., 1., 1., 1.]])
+ target = tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
+ expected_elementwise_loss = tensor([[2.4938, 4.4197, 6.4093, 4.1488, 4.1488],[0.0000, 6.4093, 4.4197, 2.4938, 4.1488]])
+ assert torch.allclose(expected_elementwise_loss, _elementwise_sampled_softmax_loss(pos_logits, neg_logits, mask, target), atol=1e-5)
+
+
+def test_sampled_softmax_loss():
+ pos_logits = tensor([[[1], [2], [3], [4], [4]], [[0], [3], [2], [1], [4]]], dtype=torch.float)
+ neg_logits = tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [6, 3, 8], [6, 3, 8]], [
+ [0, 0, 0], [9, 8, 7], [6, 5, 4], [3, 2, 1], [6, 3, 8]]], dtype=torch.float)
+ mask = tensor([[1., 1., 1., 1., 1.], [0., 1., 1., 1., 1.]])
+ expected_loss = tensor(39.0920) / 9.
+ assert torch.allclose(expected_loss, sampled_softmax_loss(pos_logits, neg_logits, mask))
+
+
+def test_binary_cross_entropy_loss():
+ mask = tensor([[1., 1.], [1., 0.]])
+ positive_logits = tensor([[[1.], [-2.]],[[3.], [4.]]])
+ negative_logits = tensor([[[1.1, 1.2], [1.2, 2.1]],[[-1., 4.], [1.,- 4.]]])
+ assert torch.allclose(tensor(2.6397), bce_loss(positive_logits, negative_logits, mask), atol=0.001)
+
+
+def test_difference_positive_and_negative_logits():
+ pos_logits = tensor([[[1], [2], [3], [4], [4]],
+ [[0], [3], [2], [1], [4]]], dtype=torch.float)
+ neg_logits = tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [6, 3, 8], [6, 3, 8]],
+ [[0, 0, 0], [9, 8, 7], [6, 5, 4], [3, 2, 1], [6, 3, 8]]], dtype=torch.float)
+
+ expected_diff = tensor([[[0, -1, -2], [-2, -3, -4], [-4, -5, -6], [-2, 1, -4], [-2, 1, -4]],
+ [[0, 0, 0], [-6, -5, -4], [-4, -3, -2], [-2, -1, 0], [-2, 1, -4]]], dtype=torch.float)
+
+ assert torch.equal(_diff_logits(pos_logits, neg_logits), expected_diff)
+
+
+def test_elementwise_bpr_max_loss_per_negative():
+ pos_logits = tensor([[[1], [2], [3], [4], [4]],
+ [[0], [3], [2], [1], [4]]], dtype=torch.float)
+ neg_logits = tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [6, 3, 8], [6, 3, 8]],
+ [[0, 0, 0], [9, 8, 7], [6, 5, 4], [3, 2, 1], [6, 3, 8]]], dtype=torch.float)
+ expected = tensor([[[0.0450, 0.0658, 0.0793], [0.0107, 0.0116, 0.0120], [0.0016, 0.0016, 0.0016], [0.0141, 0.0043, 0.0157], [0.0141, 0.0043, 0.0157]],
+ [[0.1667, 0.1667, 0.1667], [0.0016, 0.0016, 0.0016], [0.0120, 0.0116, 0.0107], [0.0793, 0.0658, 0.0450], [0.0141, 0.0043, 0.0157]]])
+ actual = _elementwise_bpr_max_loss_per_negative(pos_logits, neg_logits)
+
+ assert torch.allclose(actual, expected, atol=0.0001)
+
+
+def test_elementwise_bpr_max_loss():
+ pos_logits = tensor([[[1], [2], [3], [4], [4]],
+ [[0], [3], [2], [1], [4]]], dtype=torch.float)
+ neg_logits = tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [6, 3, 8], [6, 3, 8]],
+ [[0, 0, 0], [9, 8, 7], [6, 5, 4], [3, 2, 1], [6, 3, 8]]], dtype=torch.float)
+ mask = tensor([[1., 1., 1., 1., 1.],
+ [0., 1., 1., 1., 1.]])
+
+ expected_unmasked = tensor([[1.6602, 3.3726, 5.3391, 3.3785, 3.3785],
+ [0.6929, 5.3391, 3.3726, 1.6602, 3.3785]], dtype=torch.float)
+ actual_bpr_max_loss_unregulized_unmasked, actual_bpr_max_loss_unregulized = _bpr_max_loss_unregulized(pos_logits, neg_logits, mask)
+
+ assert torch.allclose(actual_bpr_max_loss_unregulized_unmasked, expected_unmasked, atol=0.1)
+ assert torch.allclose(actual_bpr_max_loss_unregulized, tensor(30.8793 / 9), atol=0.01)
+
+
+def test_bpr_max_loss_regularization():
+ penalty = 1.
+ neg_logits = tensor([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5]],
+ [[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9], [10, 10, 10]]], dtype=torch.float)
+ mask = tensor([[1., 1., 1., 1., 1.],
+ [0., 1., 1., 1., 1.]])
+
+ expected_regularization = tensor(349 / 9)
+ actual_regularization = _bpr_max_loss_regularization(neg_logits, penalty, mask)
+
+ assert torch.allclose(actual_regularization, expected_regularization)
+
+
+def test_bpr_max_loss():
+ pos_logits = tensor([[[1], [2], [3], [4], [4]],
+ [[0], [3], [2], [1], [4]]], dtype=torch.float)
+ neg_logits = tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [6, 3, 8], [6, 3, 8]],
+ [[0, 0, 0], [9, 8, 7], [6, 5, 4], [3, 2, 1], [6, 3, 8]]], dtype=torch.float)
+ mask = tensor([[1., 1., 1., 1., 1.],
+ [0., 1., 1., 1., 1.]])
+ penalty = 0.
+
+ assert torch.allclose(bpr_max_loss(penalty, pos_logits, neg_logits, mask), tensor(30.8793 / 9), atol=0.01)
\ No newline at end of file
diff --git a/test/test_preprocessing.py b/test/test_preprocessing.py
new file mode 100644
index 0000000..927be8a
--- /dev/null
+++ b/test/test_preprocessing.py
@@ -0,0 +1,27 @@
+from src.preprocessing import filter_non_clicks, increment_aids, sort_events, create_sessions
+import os
+from filecmp import cmp
+
+
+os.makedirs("test/resources/out", exist_ok=True)
+
+
+def test_increment_aids():
+ events = [{"aid":0,"ts":1,"type":"clicks"},{"aid":1,"ts":1,"type":"clicks"},{"aid":2,"ts":1,"type":"clicks"},{"aid":1,"ts":1,"type":"clicks"}]
+ expected_events = [{"aid":1,"ts":1,"type":"clicks"},{"aid":2,"ts":1,"type":"clicks"},{"aid":3,"ts":1,"type":"clicks"},{"aid":2,"ts":1,"type":"clicks"}]
+ assert expected_events == increment_aids(events)
+
+
+def test_filter_non_clicks():
+ num_sessions, num_events, num_items = filter_non_clicks("test/resources/unfiltered_sessions.jsonl", "test/resources/out/filtered_sessions.jsonl")
+ assert 5 == num_sessions
+ assert 88 == num_events
+ assert 66 == num_items
+ assert cmp("test/resources/expected_filtered_sessions.jsonl", "test/resources/out/filtered_sessions.jsonl")
+ os.remove("test/resources/out/filtered_sessions.jsonl")
+
+
+def test_sort_events():
+ events = [{"aid":1,"ts":5,"type":"clicks"},{"aid":2,"ts":1,"type":"clicks"},{"aid":3,"ts":3,"type":"clicks"},{"aid":2,"ts":0,"type":"clicks"}]
+ expected_events = [{"aid":2,"ts":0,"type":"clicks"},{"aid":2,"ts":1,"type":"clicks"},{"aid":3,"ts":3,"type":"clicks"},{"aid":1,"ts":5,"type":"clicks"}]
+ assert expected_events == sort_events(events)
diff --git a/test/test_sample.py b/test/test_sample.py
new file mode 100644
index 0000000..f8a5fdf
--- /dev/null
+++ b/test/test_sample.py
@@ -0,0 +1,87 @@
+import itertools
+
+import numpy as np
+
+from src.shared.sample import (_infer_shape, _uniform_negatives,
+ _uniform_negatives_session_rejected,
+ sample_in_batch_negatives, sample_uniform,
+ sample_uniform_negatives_with_shape)
+
+
+def test_uniform_negatives():
+ num_items = 10
+ shape = [5,2]
+ negatives = _uniform_negatives(num_items=num_items, shape=shape)
+ assert negatives.shape == (5,2)
+ assert set(list(itertools.chain(*negatives))).difference(set(range(1,11))) == set([])
+
+
+def test_uniform_negatives_with_0():
+ pass
+
+def test_uniform_negatives_session_rejected():
+ num_items = 10
+ shape = [5,2]
+ in_session_items = [1,5,10]
+
+ negatives = _uniform_negatives_session_rejected(num_items=num_items, shape=shape, in_session_items=in_session_items)
+
+ assert negatives.shape == (5,2)
+ assert set(in_session_items).intersection(set(list(itertools.chain(*negatives.tolist())))) == set([])
+
+def test_infer_shape():
+ session_len = 5
+ num_uniform_negatives = 2
+ shape_eventwise = _infer_shape(session_len=session_len, num_uniform_negatives=num_uniform_negatives, sampling_style="eventwise")
+ shape_sessionwise = _infer_shape(session_len=session_len, num_uniform_negatives=num_uniform_negatives, sampling_style="sessionwise")
+ shape_batchwise = _infer_shape(session_len=session_len, num_uniform_negatives=num_uniform_negatives, sampling_style="batchwise")
+
+ assert shape_eventwise==[5,2]
+ assert shape_sessionwise==[2,]
+ assert shape_batchwise==[]
+
+def test_sample_uniform():
+ num_items = 10
+ shape = [6,2]
+ clicks = [7,4,3]
+ with_rejection = sample_uniform(num_items=num_items, shape=shape, in_session_items=clicks, reject_session_items=True)
+ without_rejection = sample_uniform(num_items=num_items, shape=shape, in_session_items=clicks, reject_session_items=False)
+
+ assert with_rejection.shape == (6,2)
+ assert without_rejection.shape == (6,2)
+
+ for element in with_rejection.tolist():
+ assert set(element).isdisjoint(set(clicks))
+
+ for element in without_rejection.tolist():
+ assert set(element).issubset(set(range(1,11)))
+
+def test_sample_uniform_negatives_with_shape():
+ clicks = [7,4,3]
+ num_items = 10
+ session_len = 12
+ num_uniform_negatives = 3
+ elementwise_negatives = sample_uniform_negatives_with_shape(clicks=clicks, num_items=num_items, session_len=session_len, num_uniform_negatives=num_uniform_negatives, sampling_style="eventwise", reject_session_items=False)
+ sessionwise_negatives = sample_uniform_negatives_with_shape(clicks=clicks, num_items=num_items, session_len=session_len, num_uniform_negatives=num_uniform_negatives, sampling_style="sessionwise", reject_session_items=False)
+ batchwise_negatives = sample_uniform_negatives_with_shape(clicks=clicks, num_items=num_items, session_len=session_len, num_uniform_negatives=num_uniform_negatives, sampling_style="batchwise", reject_session_items=False)
+
+ assert elementwise_negatives.shape == (12,3)
+ assert sessionwise_negatives.shape == (3,)
+ assert batchwise_negatives.shape == (0,)
+
+def test_sample_in_batch_negatives():
+ batch_positives = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ num_in_batch_negatives = 2
+ batch_session_len = [5,2,3]
+
+ without_same_session_negatives = sample_in_batch_negatives(batch_positives=batch_positives, num_in_batch_negatives=num_in_batch_negatives, batch_session_len=batch_session_len, reject_session_items=True)
+ with_same_session_negatives = sample_in_batch_negatives(batch_positives=batch_positives, num_in_batch_negatives=num_in_batch_negatives, batch_session_len=batch_session_len, reject_session_items=False)
+
+ assert np.array(with_same_session_negatives).shape == (3, 2)
+ for element in with_same_session_negatives:
+ assert set(element).issubset(set(range(1,11)))
+
+ assert np.array(without_same_session_negatives).shape == (3, 2)
+ assert set(without_same_session_negatives[0]).issubset(set([6, 7, 8, 9, 10]))
+ assert set(without_same_session_negatives[1]).issubset(set([1, 2, 3, 4, 5, 8, 9, 10]))
+ assert set(without_same_session_negatives[2]).issubset(set([1, 2, 3, 4, 5, 6, 7]))
\ No newline at end of file
diff --git a/test/test_sas_dataset.py b/test/test_sas_dataset.py
new file mode 100644
index 0000000..6364aec
--- /dev/null
+++ b/test/test_sas_dataset.py
@@ -0,0 +1,141 @@
+import numpy as np
+from torch import allclose, tensor
+from torch.utils.data.dataloader import DataLoader
+
+from src.sasrec.dataset import SasRecDataset
+
+
+def test_dataset():
+ session_path = "test/resources/train.jsonl"
+ dataset = SasRecDataset(session_path, total_sessions=10, num_items=40_000, max_seqlen=6, shuffling_style="no_shuffling", num_uniform_negatives=3, num_in_batch_negatives=0, reject_uniform_session_items=False, sampling_style="eventwise")
+
+ expected_first_session = {
+ 'clicks': [33838, 4759, 15406, 12887, 27601, 15406],
+ 'labels': [4759, 15406, 12887, 27601, 15406, 14564],
+ 'session_len': 6
+ }
+
+ expected_second_session = {'clicks': [36617], 'labels': [34257], 'session_len': 1}
+
+ expected_third_session = {
+ 'clicks': [31292, 18083],
+ 'labels': [18083, 12957],
+ 'session_len': 2
+ }
+
+ expected_fourth_session = {
+ 'clicks': [14138],
+ 'labels': [8977],
+ 'session_len': 1
+ }
+
+ first_session = dataset.__getitem__(0)
+ second_session = dataset.__getitem__(1)
+ third_session = dataset.__getitem__(2)
+ fourth_session = dataset.__getitem__(3)
+
+ assert first_session['clicks'] == expected_first_session['clicks']
+ assert first_session['labels'] == expected_first_session['labels']
+ assert first_session['session_len'] == expected_first_session['session_len']
+ assert np.array(first_session['uniform_negatives']).shape == (6, 3)
+
+ assert second_session['clicks'] == expected_second_session['clicks']
+ assert second_session['labels'] == expected_second_session['labels']
+ assert second_session['session_len'] == expected_second_session['session_len']
+ assert np.array(second_session['uniform_negatives']).shape == (1, 3)
+
+ assert third_session['clicks'] == expected_third_session['clicks']
+ assert third_session['labels'] == expected_third_session['labels']
+ assert third_session['session_len'] == expected_third_session['session_len']
+ assert np.array(third_session['uniform_negatives']).shape == (2, 3)
+
+ assert fourth_session['clicks'] == expected_fourth_session['clicks']
+ assert fourth_session['labels'] == expected_fourth_session['labels']
+ assert fourth_session['session_len'] == expected_fourth_session['session_len']
+ assert np.array(fourth_session['uniform_negatives']).shape == (1, 3)
+
+
+def test_datalaoder():
+ session_path = "test/resources/train.jsonl"
+ dataset = SasRecDataset(sessions_path=session_path, total_sessions=10, num_items=40_000, max_seqlen=6, shuffling_style="no_shuffling", num_uniform_negatives=3, num_in_batch_negatives=2, reject_uniform_session_items=True, sampling_style="eventwise")
+ dataloader = DataLoader(dataset,
+ batch_size=3,
+ shuffle=False,
+ collate_fn=dataset.dynamic_collate)
+
+
+
+ expected_first_batch = {
+ 'clicks': tensor([
+ [33838, 4759, 15406, 12887, 27601, 15406],
+ [0, 0, 0, 0, 0, 36617],
+ [0, 0, 0, 0, 31292, 18083]]),
+ 'labels': tensor([
+ [4759, 15406, 12887, 27601, 15406, 14564],
+ [0, 0, 0, 0, 0, 34257],
+ [0, 0, 0, 0, 18083, 12957]]),
+ 'mask': tensor([
+ [1., 1., 1., 1., 1., 1.],
+ [0., 0., 0., 0., 0., 1.],
+ [0., 0., 0., 0., 1., 1.],]),
+ 'session_len': tensor([6, 1, 2]),
+ }
+
+ for batch in dataloader:
+ assert allclose(batch['clicks'],
+ expected_first_batch['clicks'])
+ assert allclose(batch['labels'],
+ expected_first_batch['labels'])
+ assert allclose(batch['mask'],
+ expected_first_batch['mask'])
+ assert allclose(batch['session_len'],
+ expected_first_batch['session_len'])
+ assert batch['in_batch_negatives'].shape == (3,2)
+ assert batch['uniform_negatives'].shape == (3,6,3)
+ assert set(batch['in_batch_negatives'].tolist()[0]).issubset([36617, 31292, 18083])
+ assert set(batch['in_batch_negatives'].tolist()[1]).issubset([33838, 4759, 15406, 12887, 27601, 15406, 31292, 18083])
+ assert set(batch['in_batch_negatives'].tolist()[2]).issubset([33838, 4759, 15406, 12887, 27601, 15406, 36617])
+ break
+
+ dataset.sampling_style="sessionwise"
+ batch = next(iter(dataloader))
+ assert batch['uniform_negatives'].shape == (3,3)
+
+ dataset.sampling_style="batchwise"
+ batch = next(iter(dataloader))
+ assert batch['uniform_negatives'].shape == (3,)
+
+
+def test_datalaoder_no_uniform_negatives():
+ session_path = "test/resources/train.jsonl"
+ dataset = SasRecDataset(sessions_path=session_path, total_sessions=10, num_items=40_000, max_seqlen=6, shuffling_style="no_shuffling", num_uniform_negatives=0, num_in_batch_negatives=2, reject_uniform_session_items=True, sampling_style="eventwise")
+ dataloader = DataLoader(dataset,
+ batch_size=3,
+ shuffle=False,
+ collate_fn=dataset.dynamic_collate)
+
+ for batch in dataloader:
+ assert batch['uniform_negatives'].shape == (3,6,0)
+ break
+
+ dataset.sampling_style="sessionwise"
+ batch = next(iter(dataloader))
+ assert batch['uniform_negatives'].shape == (3,0)
+
+ dataset.sampling_style="batchwise"
+ batch = next(iter(dataloader))
+ assert batch['uniform_negatives'].shape == (0,)
+
+
+def test_datalaoder_no_in_batch_negatives():
+ session_path = "test/resources/train.jsonl"
+ dataset = SasRecDataset(sessions_path=session_path, total_sessions=10, num_items=40_000, max_seqlen=6, shuffling_style="no_shuffling", num_uniform_negatives=3, num_in_batch_negatives=0, reject_uniform_session_items=True, sampling_style="eventwise")
+ dataloader = DataLoader(dataset,
+ batch_size=3,
+ shuffle=False,
+ collate_fn=dataset.dynamic_collate)
+
+
+ for batch in dataloader:
+ assert batch['in_batch_negatives'].shape == (3,0)
+ break
diff --git a/test/test_sas_model.py b/test/test_sas_model.py
new file mode 100644
index 0000000..4da8dfe
--- /dev/null
+++ b/test/test_sas_model.py
@@ -0,0 +1,125 @@
+import torch
+from torch import tensor
+
+from src.sasrec.model import SASRec
+
+batch = {
+ 'clicks': tensor([
+ [1, 2, 3, 4],
+ [0, 0, 0, 2],
+ [0, 0, 5, 6]]),
+ 'labels': tensor([
+ [2, 3, 4, 5],
+ [0, 0, 0, 3],
+ [0, 0, 6, 7]]),
+ 'in_batch_negatives': tensor([
+ [5, 6],
+ [6, 4],
+ [1, 2]
+ ]),
+ 'uniform_negatives': tensor([
+ [[5,6,7],[5,6,7],[5,6,7],[5,6,7]],
+ [[4,5,6],[4,5,6],[4,5,6],[4,5,6]],
+ [[3,4,9],[3,4,9],[3,4,9],[3,4,9]]
+ ]),
+ 'mask': tensor([
+ [1., 1., 1., 1.],
+ [0., 0., 0., 1.],
+ [0., 0., 1., 1.],]),
+ 'session_len': tensor([4, 1, 2]),
+}
+
+def test_forward():
+ model = SASRec(
+ hidden_size=8,
+ dropout_rate=0.,
+ max_len=3,
+ num_items=16,
+ learning_rate=0.01,
+ batch_size=2,
+ sampling_style='eventwise')
+
+ item_indices = tensor([[2,5,6],[0,9,8]], dtype=torch.long)
+ mask = tensor([[1.,1.,1.], [0.,1.,1.]], dtype=torch.float)
+
+ actual_shape = model.forward(item_indices, mask).shape
+ expected_shape = torch.Size([2, 3, 8])
+
+ assert actual_shape == expected_shape
+
+def test_forward_with_output_bias():
+ model = SASRec(
+ hidden_size=8,
+ dropout_rate=0.,
+ max_len=3,
+ num_items=16,
+ learning_rate=0.01,
+ batch_size=2,
+ output_bias=True,
+ sampling_style='eventwise')
+
+ item_indices = tensor([[2,5,6],[0,9,8]], dtype=torch.long)
+ mask = tensor([[1.,1.,1.], [0.,1.,1.]], dtype=torch.float)
+
+ actual_shape = model.forward(item_indices, mask).shape
+ expected_shape = torch.Size([2, 3, 9])
+
+ assert actual_shape == expected_shape
+
+
+
+def test_training_step():
+ model = SASRec(
+ hidden_size=8,
+ dropout_rate=0.,
+ max_len=4,
+ num_items=16,
+ learning_rate=0.01,
+ batch_size=2,
+ sampling_style='eventwise')
+
+ loss = model.training_step(batch, None)
+ assert loss.shape == torch.Size([])
+
+
+def test_training_step_not_shared_output_bias():
+ model = SASRec(
+ hidden_size=8,
+ dropout_rate=0.,
+ max_len=4,
+ num_items=16,
+ learning_rate=0.01,
+ batch_size=3,
+ output_bias=True,
+ share_embeddings=False,
+ sampling_style='eventwise')
+
+ loss = model.training_step(batch, None)
+ assert loss.shape == torch.Size([])
+
+def test_training_step_not_shared_output_no_output_bias():
+ model = SASRec(
+ hidden_size=8,
+ dropout_rate=0.,
+ max_len=4,
+ num_items=16,
+ learning_rate=0.01,
+ batch_size=3,
+ output_bias=False,
+ share_embeddings=False,
+ sampling_style='eventwise')
+
+ loss = model.training_step(batch, None)
+ assert loss.shape == torch.Size([])
+
+def test_validation_step():
+ model = SASRec(
+ hidden_size=8,
+ dropout_rate=0.,
+ max_len=4,
+ num_items=16,
+ learning_rate=0.01,
+ batch_size=3,
+ sampling_style='eventwise')
+
+ model.validation_step(batch, None)
\ No newline at end of file