diff --git a/.tool-versions b/.tool-versions index 4a678234c..bff6ea103 100644 --- a/.tool-versions +++ b/.tool-versions @@ -1,2 +1,2 @@ python 3.9.11 -poetry 1.4.2 +poetry 1.5.1 diff --git a/poetry.lock b/poetry.lock index 099d907d8..01ed34e65 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "absl-py" version = "1.4.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -16,7 +15,6 @@ files = [ name = "alembic" version = "1.11.1" description = "A database migration tool for SQLAlchemy." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -38,7 +36,6 @@ tz = ["python-dateutil"] name = "anyio" version = "3.7.1" description = "High level compatibility layer for multiple asynchronous event loop implementations" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -60,7 +57,6 @@ trio = ["trio (<0.22)"] name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" -category = "dev" optional = false python-versions = "*" files = [ @@ -72,7 +68,6 @@ files = [ name = "astunparse" version = "1.6.3" description = "An AST unparser for Python" -category = "dev" optional = false python-versions = "*" files = [ @@ -88,7 +83,6 @@ wheel = ">=0.23.0,<1.0" name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -107,7 +101,6 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" -category = "dev" optional = false python-versions = "*" files = [ @@ -119,7 +112,6 @@ files = [ name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" -category = "dev" optional = false python-versions = ">=3.6.0" files = [ @@ -138,7 +130,6 @@ lxml = ["lxml"] name = "black" version = "22.12.0" description = "The uncompromising code formatter." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -174,7 +165,6 @@ uvloop = ["uvloop (>=0.15.2)"] name = "blake3" version = "0.3.3" description = "Python bindings for the Rust blake3 crate" -category = "main" optional = false python-versions = "*" files = [ @@ -209,7 +199,6 @@ files = [ name = "bleach" version = "6.0.0" description = "An easy safelist-based HTML-sanitizing tool." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -228,7 +217,6 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] name = "blinker" version = "1.6.2" description = "Fast, simple object-to-object and broadcast signaling" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -240,7 +228,6 @@ files = [ name = "boto3" version = "1.28.2" description = "The AWS SDK for Python" -category = "main" optional = false python-versions = ">= 3.7" files = [ @@ -260,7 +247,6 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.31.2" description = "Low-level, data-driven core of boto 3." -category = "main" optional = false python-versions = ">= 3.7" files = [ @@ -280,7 +266,6 @@ crt = ["awscrt (==0.16.9)"] name = "cachetools" version = "5.3.1" description = "Extensible memoizing collections and decorators" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -292,7 +277,6 @@ files = [ name = "certifi" version = "2023.5.7" description = "Python package for providing Mozilla's CA Bundle." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -304,7 +288,6 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." -category = "dev" optional = false python-versions = "*" files = [ @@ -381,7 +364,6 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." -category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -393,7 +375,6 @@ files = [ name = "charset-normalizer" version = "3.2.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -478,7 +459,6 @@ files = [ name = "click" version = "8.1.4" description = "Composable command line interface toolkit" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -493,7 +473,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "cloudpickle" version = "2.2.1" description = "Extended pickling support for Python objects" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -505,7 +484,6 @@ files = [ name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -517,7 +495,6 @@ files = [ name = "comm" version = "0.1.3" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -537,7 +514,6 @@ typing = ["mypy (>=0.990)"] name = "coverage" version = "6.5.0" description = "Code coverage measurement for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -603,7 +579,6 @@ toml = ["tomli"] name = "databricks-cli" version = "0.17.7" description = "A command line interface for Databricks" -category = "dev" optional = false python-versions = "*" files = [ @@ -624,7 +599,6 @@ urllib3 = ">=1.26.7,<2.0.0" name = "debugpy" version = "1.6.7" description = "An implementation of the Debug Adapter Protocol for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -652,7 +626,6 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -664,7 +637,6 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -676,7 +648,6 @@ files = [ name = "distlib" version = "0.3.6" description = "Distribution utilities" -category = "dev" optional = false python-versions = "*" files = [ @@ -688,7 +659,6 @@ files = [ name = "docker" version = "6.1.3" description = "A Python library for the Docker Engine API." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -710,7 +680,6 @@ ssh = ["paramiko (>=2.4.3)"] name = "dockerfile" version = "3.2.0" description = "Parse a dockerfile into a high-level representation using the official go parser." -category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -724,7 +693,6 @@ files = [ name = "entrypoints" version = "0.4" description = "Discover and load entry points from installed packages." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -736,7 +704,6 @@ files = [ name = "exceptiongroup" version = "1.1.2" description = "Backport of PEP 654 (exception groups)" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -751,7 +718,6 @@ test = ["pytest (>=6)"] name = "fastapi" version = "0.95.2" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -773,7 +739,6 @@ test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6 name = "fastjsonschema" version = "2.17.1" description = "Fastest Python implementation of JSON schema" -category = "dev" optional = false python-versions = "*" files = [ @@ -788,7 +753,6 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.12.2" description = "A platform independent file lock." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -804,7 +768,6 @@ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "p name = "flake8" version = "4.0.1" description = "the modular source code checker: pep8 pyflakes and co" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -821,7 +784,6 @@ pyflakes = ">=2.4.0,<2.5.0" name = "flask" version = "2.3.2" description = "A simple framework for building complex web applications." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -845,7 +807,6 @@ dotenv = ["python-dotenv"] name = "flatbuffers" version = "23.5.26" description = "The FlatBuffers serialization format for Python" -category = "dev" optional = false python-versions = "*" files = [ @@ -857,7 +818,6 @@ files = [ name = "fsspec" version = "2023.6.0" description = "File-system specification" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -893,7 +853,6 @@ tqdm = ["tqdm"] name = "gast" version = "0.4.0" description = "Python AST that abstracts the underlying Python version" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -905,7 +864,6 @@ files = [ name = "gitdb" version = "4.0.10" description = "Git Object Database" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -920,7 +878,6 @@ smmap = ">=3.0.1,<6" name = "gitpython" version = "3.1.32" description = "GitPython is a Python library used to interact with Git repositories" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -935,7 +892,6 @@ gitdb = ">=4.0.1,<5" name = "google-auth" version = "2.21.0" description = "Google Authentication Library" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -961,7 +917,6 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"] name = "google-auth-oauthlib" version = "1.0.0" description = "Google Authentication Library" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -980,7 +935,6 @@ tool = ["click (>=6.0.0)"] name = "google-pasta" version = "0.2.0" description = "pasta is an AST-based Python refactoring library" -category = "dev" optional = false python-versions = "*" files = [ @@ -996,7 +950,6 @@ six = "*" name = "greenlet" version = "2.0.2" description = "Lightweight in-process concurrent programming" -category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*" files = [ @@ -1070,7 +1023,6 @@ test = ["objgraph", "psutil"] name = "grpcio" version = "1.56.0" description = "HTTP/2-based RPC framework" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1128,7 +1080,6 @@ protobuf = ["grpcio-tools (>=1.56.0)"] name = "gunicorn" version = "20.1.0" description = "WSGI HTTP Server for UNIX" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1149,7 +1100,6 @@ tornado = ["tornado (>=0.2)"] name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1161,7 +1111,6 @@ files = [ name = "h5py" version = "3.9.0" description = "Read and write HDF5 files from Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1195,7 +1144,6 @@ numpy = ">=1.17.3" name = "httpcore" version = "0.17.3" description = "A minimal low-level HTTP client." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1207,17 +1155,16 @@ files = [ anyio = ">=3.0,<5.0" certifi = "*" h11 = ">=0.13,<0.15" -sniffio = ">=1.0.0,<2.0.0" +sniffio = "==1.*" [package.extras] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (>=1.0.0,<2.0.0)"] +socks = ["socksio (==1.*)"] [[package]] name = "httpx" version = "0.24.1" description = "The next generation HTTP client." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1227,24 +1174,23 @@ files = [ [package.dependencies] certifi = "*" -click = {version = ">=8.0.0,<9.0.0", optional = true, markers = "extra == \"cli\""} +click = {version = "==8.*", optional = true, markers = "extra == \"cli\""} httpcore = ">=0.15.0,<0.18.0" idna = "*" -pygments = {version = ">=2.0.0,<3.0.0", optional = true, markers = "extra == \"cli\""} +pygments = {version = "==2.*", optional = true, markers = "extra == \"cli\""} rich = {version = ">=10,<14", optional = true, markers = "extra == \"cli\""} sniffio = "*" [package.extras] brotli = ["brotli", "brotlicffi"] -cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (>=1.0.0,<2.0.0)"] +socks = ["socksio (==1.*)"] [[package]] name = "huggingface-hub" version = "0.16.4" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -1277,7 +1223,6 @@ typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "typ name = "identify" version = "2.5.24" description = "File identification library for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1292,7 +1237,6 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" -category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1304,7 +1248,6 @@ files = [ name = "importlib-metadata" version = "5.2.0" description = "Read metadata from Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1324,7 +1267,6 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "6.0.0" description = "Read resources from Python packages" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1343,7 +1285,6 @@ testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1351,11 +1292,28 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "inquirerpy" +version = "0.3.4" +description = "Python port of Inquirer.js (A collection of common interactive command-line user interfaces)" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "InquirerPy-0.3.4-py3-none-any.whl", hash = "sha256:c65fdfbac1fa00e3ee4fb10679f4d3ed7a012abf4833910e63c295827fe2a7d4"}, + {file = "InquirerPy-0.3.4.tar.gz", hash = "sha256:89d2ada0111f337483cb41ae31073108b2ec1e618a49d7110b0d7ade89fc197e"}, +] + +[package.dependencies] +pfzy = ">=0.3.1,<0.4.0" +prompt-toolkit = ">=3.0.1,<4.0.0" + +[package.extras] +docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"] + [[package]] name = "ipdb" version = "0.13.13" description = "IPython-enabled pdb" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1372,7 +1330,6 @@ tomli = {version = "*", markers = "python_version > \"3.6\" and python_version < name = "ipykernel" version = "6.24.0" description = "IPython Kernel for Jupyter" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1386,7 +1343,7 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" @@ -1406,7 +1363,6 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipython" version = "7.34.0" description = "IPython: Productive Interactive Computing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1443,7 +1399,6 @@ test = ["ipykernel", "nbformat", "nose (>=0.10.1)", "numpy (>=1.17)", "pygments" name = "isort" version = "5.12.0" description = "A Python utility / library to sort Python imports." -category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -1461,7 +1416,6 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] name = "itsdangerous" version = "2.1.2" description = "Safely pass data to untrusted environments and back." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1473,7 +1427,6 @@ files = [ name = "jax" version = "0.4.13" description = "Differentiate, compile, and transform Numpy code." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1504,7 +1457,6 @@ tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] name = "jedi" version = "0.18.2" description = "An autocompletion tool for Python that can be used for text editors." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1524,7 +1476,6 @@ testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1542,7 +1493,6 @@ i18n = ["Babel (>=2.7)"] name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1554,7 +1504,6 @@ files = [ name = "joblib" version = "1.3.1" description = "Lightweight pipelining with Python functions" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1566,7 +1515,6 @@ files = [ name = "jsonschema" version = "4.18.0" description = "An implementation of JSON Schema validation for Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1590,7 +1538,6 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jsonschema-specifications" version = "2023.6.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1606,7 +1553,6 @@ referencing = ">=0.28.0" name = "jupyter-client" version = "8.3.0" description = "Jupyter protocol implementation and client libraries" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1616,7 +1562,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -1630,7 +1576,6 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-core" version = "5.3.1" description = "Jupyter core package. A base package on which Jupyter projects rely." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1651,7 +1596,6 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1663,7 +1607,6 @@ files = [ name = "keras" version = "2.12.0" description = "Deep learning for humans." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1674,7 +1617,6 @@ files = [ name = "libclang" version = "16.0.0" description = "Clang Python Bindings, mirrored from the official LLVM repo: https://github.com/llvm/llvm-project/tree/main/clang/bindings/python, to make the installation process easier." -category = "dev" optional = false python-versions = "*" files = [ @@ -1692,7 +1634,6 @@ files = [ name = "lightgbm" version = "3.3.5" description = "LightGBM Python Package" -category = "dev" optional = false python-versions = "*" files = [ @@ -1716,7 +1657,6 @@ dask = ["dask[array] (>=2.0.0)", "dask[dataframe] (>=2.0.0)", "dask[distributed] name = "mako" version = "1.2.4" description = "A super-fast templating language that borrows the best ideas from the existing templating languages." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1736,7 +1676,6 @@ testing = ["pytest"] name = "markdown" version = "3.4.3" description = "Python implementation of John Gruber's Markdown." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1754,7 +1693,6 @@ testing = ["coverage", "pyyaml"] name = "markdown-it-py" version = "3.0.0" description = "Python port of markdown-it. Markdown parsing, done right!" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1779,7 +1717,6 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1839,7 +1776,6 @@ files = [ name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1854,7 +1790,6 @@ traitlets = "*" name = "mccabe" version = "0.6.1" description = "McCabe checker, plugin for flake8" -category = "dev" optional = false python-versions = "*" files = [ @@ -1866,7 +1801,6 @@ files = [ name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1878,7 +1812,6 @@ files = [ name = "mistune" version = "3.0.1" description = "A sane and fast Markdown parser with useful plugins and renderers" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1890,7 +1823,6 @@ files = [ name = "ml-dtypes" version = "0.2.0" description = "" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1927,7 +1859,6 @@ dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] name = "mlflow" version = "1.30.1" description = "MLflow: A Platform for ML Development and Productionization" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1970,7 +1901,6 @@ sqlserver = ["mlflow-dbstore"] name = "msgpack" version = "1.0.5" description = "MessagePack serializer" -category = "main" optional = false python-versions = "*" files = [ @@ -2043,7 +1973,6 @@ files = [ name = "msgpack-numpy" version = "0.4.8" description = "Numpy data serialization using msgpack" -category = "main" optional = false python-versions = "*" files = [ @@ -2059,7 +1988,6 @@ numpy = ">=1.9.0" name = "mypy" version = "1.4.1" description = "Optional static typing for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2106,7 +2034,6 @@ reports = ["lxml"] name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2118,7 +2045,6 @@ files = [ name = "nbclient" version = "0.8.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." -category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -2128,7 +2054,7 @@ files = [ [package.dependencies] jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" nbformat = ">=5.1" traitlets = ">=5.4" @@ -2141,7 +2067,6 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= name = "nbconvert" version = "7.6.0" description = "Converting Jupyter Notebooks" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2180,7 +2105,6 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.9.1" description = "The Jupyter Notebook format" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2202,7 +2126,6 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2214,7 +2137,6 @@ files = [ name = "nodeenv" version = "1.8.0" description = "Node.js virtual environment builder" -category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -2229,7 +2151,6 @@ setuptools = "*" name = "numpy" version = "1.23.5" description = "NumPy is the fundamental package for array computing with Python." -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2267,7 +2188,6 @@ files = [ name = "nvidia-cublas-cu11" version = "11.10.3.66" description = "CUBLAS native runtime libraries" -category = "dev" optional = false python-versions = ">=3" files = [ @@ -2283,7 +2203,6 @@ wheel = "*" name = "nvidia-cuda-nvrtc-cu11" version = "11.7.99" description = "NVRTC native runtime libraries" -category = "dev" optional = false python-versions = ">=3" files = [ @@ -2300,7 +2219,6 @@ wheel = "*" name = "nvidia-cuda-runtime-cu11" version = "11.7.99" description = "CUDA Runtime native Libraries" -category = "dev" optional = false python-versions = ">=3" files = [ @@ -2316,7 +2234,6 @@ wheel = "*" name = "nvidia-cudnn-cu11" version = "8.5.0.96" description = "cuDNN runtime libraries" -category = "dev" optional = false python-versions = ">=3" files = [ @@ -2332,7 +2249,6 @@ wheel = "*" name = "oauthlib" version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2349,7 +2265,6 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] name = "opt-einsum" version = "3.3.0" description = "Optimizing numpys einsum function" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2368,7 +2283,6 @@ tests = ["pytest", "pytest-cov", "pytest-pep8"] name = "packaging" version = "20.9" description = "Core utilities for Python packages" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2383,7 +2297,6 @@ pyparsing = ">=2.0.2" name = "pandas" version = "1.5.2" description = "Powerful data structures for data analysis, time series, and statistics" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2432,7 +2345,6 @@ test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2444,7 +2356,6 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2460,7 +2371,6 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathspec" version = "0.11.1" description = "Utility library for gitignore style pattern matching of file paths." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2472,7 +2382,6 @@ files = [ name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." -category = "dev" optional = false python-versions = "*" files = [ @@ -2483,11 +2392,24 @@ files = [ [package.dependencies] ptyprocess = ">=0.5" +[[package]] +name = "pfzy" +version = "0.3.4" +description = "Python port of the fzy fuzzy string matching algorithm" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "pfzy-0.3.4-py3-none-any.whl", hash = "sha256:5f50d5b2b3207fa72e7ec0ef08372ef652685470974a107d0d4999fc5a903a96"}, + {file = "pfzy-0.3.4.tar.gz", hash = "sha256:717ea765dd10b63618e7298b2d98efd819e0b30cd5905c9707223dceeb94b3f1"}, +] + +[package.extras] +docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"] + [[package]] name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" -category = "dev" optional = false python-versions = "*" files = [ @@ -2499,7 +2421,6 @@ files = [ name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2511,7 +2432,6 @@ files = [ name = "platformdirs" version = "3.8.1" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2527,7 +2447,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- name = "pluggy" version = "1.2.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2543,7 +2462,6 @@ testing = ["pytest", "pytest-benchmark"] name = "pre-commit" version = "2.21.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2562,7 +2480,6 @@ virtualenv = ">=20.10.0" name = "prometheus-client" version = "0.17.1" description = "Python client for the Prometheus monitoring system." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2577,7 +2494,6 @@ twisted = ["twisted"] name = "prometheus-flask-exporter" version = "0.22.4" description = "Prometheus metrics exporter for Flask" -category = "dev" optional = false python-versions = "*" files = [ @@ -2593,7 +2509,6 @@ prometheus-client = "*" name = "prompt-toolkit" version = "3.0.39" description = "Library for building powerful interactive command lines in Python" -category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -2608,7 +2523,6 @@ wcwidth = "*" name = "protobuf" version = "4.23.4" description = "" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2631,7 +2545,6 @@ files = [ name = "psutil" version = "5.9.5" description = "Cross-platform lib for process and system monitoring in Python." -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2658,7 +2571,6 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" -category = "dev" optional = false python-versions = "*" files = [ @@ -2670,7 +2582,6 @@ files = [ name = "pyasn1" version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -2682,7 +2593,6 @@ files = [ name = "pyasn1-modules" version = "0.3.0" description = "A collection of ASN.1-based protocols modules" -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -2697,7 +2607,6 @@ pyasn1 = ">=0.4.6,<0.6.0" name = "pycodestyle" version = "2.8.0" description = "Python style guide checker" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -2709,7 +2618,6 @@ files = [ name = "pycparser" version = "2.21" description = "C parser in Python" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2721,7 +2629,6 @@ files = [ name = "pydantic" version = "1.10.11" description = "Data validation and settings management using python type hints" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2774,7 +2681,6 @@ email = ["email-validator (>=1.0.3)"] name = "pyflakes" version = "2.4.0" description = "passive checker of Python programs" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2786,7 +2692,6 @@ files = [ name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2801,7 +2706,6 @@ plugins = ["importlib-metadata"] name = "pyjwt" version = "2.7.0" description = "JSON Web Token implementation in Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2819,7 +2723,6 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] name = "pyparsing" version = "3.1.0" description = "pyparsing module - Classes and methods to define and execute parsing grammars" -category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -2834,7 +2737,6 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pytest" version = "7.2.0" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2858,7 +2760,6 @@ testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2. name = "pytest-cov" version = "3.0.0" description = "Pytest plugin for measuring coverage." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2877,7 +2778,6 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "pytest-split" version = "0.8.1" description = "Pytest plugin which splits the test suite to equally sized sub suites based on test execution time." -category = "dev" optional = false python-versions = ">=3.7.1,<4.0" files = [ @@ -2892,7 +2792,6 @@ pytest = ">=5,<8" name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -2907,7 +2806,6 @@ six = ">=1.5" name = "python-json-logger" version = "2.0.7" description = "A python library adding a json log formatter" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2919,7 +2817,6 @@ files = [ name = "python-on-whales" version = "0.46.0" description = "A Docker client for Python, designed to be fun and intuitive!" -category = "main" optional = false python-versions = ">=3.7, <4" files = [ @@ -2938,7 +2835,6 @@ typing-extensions = "*" name = "pytz" version = "2022.7.1" description = "World timezone definitions, modern and historical" -category = "dev" optional = false python-versions = "*" files = [ @@ -2950,7 +2846,6 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" -category = "dev" optional = false python-versions = "*" files = [ @@ -2974,7 +2869,6 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3024,7 +2918,6 @@ files = [ name = "pyzmq" version = "25.1.0" description = "Python bindings for 0MQ" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3114,7 +3007,6 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "querystring-parser" version = "1.2.4" description = "QueryString parser for Python/Django that correctly handles nested dictionaries" -category = "dev" optional = false python-versions = "*" files = [ @@ -3129,7 +3021,6 @@ six = "*" name = "referencing" version = "0.29.1" description = "JSON Referencing + Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3145,7 +3036,6 @@ rpds-py = ">=0.7.0" name = "regex" version = "2023.6.3" description = "Alternative regular expression module, to replace re." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3243,7 +3133,6 @@ files = [ name = "requests" version = "2.31.0" description = "Python HTTP for Humans." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3265,7 +3154,6 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "requests-oauthlib" version = "1.3.1" description = "OAuthlib authentication support for Requests." -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3284,7 +3172,6 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"] name = "rich" version = "13.4.2" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3300,11 +3187,28 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rich-click" +version = "1.6.1" +description = "Format click help output nicely with rich" +optional = false +python-versions = ">=3.7" +files = [ + {file = "rich-click-1.6.1.tar.gz", hash = "sha256:f8ff96693ec6e261d1544e9f7d9a5811c5ef5d74c8adb4978430fc0dac16777e"}, + {file = "rich_click-1.6.1-py3-none-any.whl", hash = "sha256:0fcf4d1a09029d79322dd814ab0b2e66ac183633037561881d45abae8a161d95"}, +] + +[package.dependencies] +click = ">=7" +rich = ">=10.7.0" + +[package.extras] +dev = ["pre-commit"] + [[package]] name = "rpds-py" version = "0.8.10" description = "Python bindings to Rust's persistent data structures (rpds)" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3411,7 +3315,6 @@ files = [ name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" -category = "dev" optional = false python-versions = ">=3.6,<4" files = [ @@ -3426,7 +3329,6 @@ pyasn1 = ">=0.1.3" name = "s3transfer" version = "0.6.1" description = "An Amazon S3 Transfer Manager" -category = "main" optional = false python-versions = ">= 3.7" files = [ @@ -3444,7 +3346,6 @@ crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] name = "safetensors" version = "0.3.1" description = "Fast and Safe Tensor serialization" -category = "dev" optional = false python-versions = "*" files = [ @@ -3505,7 +3406,6 @@ torch = ["torch (>=1.10)"] name = "scikit-learn" version = "1.0.2" description = "A set of python modules for machine learning and data mining" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3559,7 +3459,6 @@ tests = ["black (>=21.6b0)", "flake8 (>=3.8.2)", "matplotlib (>=2.2.3)", "mypy ( name = "scipy" version = "1.10.1" description = "Fundamental algorithms for scientific computing in Python" -category = "dev" optional = false python-versions = "<3.12,>=3.8" files = [ @@ -3598,7 +3497,6 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo name = "setuptools" version = "68.0.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3615,7 +3513,6 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( name = "single-source" version = "0.3.0" description = "Access to the project version in Python code for PEP 621-style projects" -category = "main" optional = false python-versions = ">=3.6,<4.0" files = [ @@ -3627,7 +3524,6 @@ files = [ name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -3639,7 +3535,6 @@ files = [ name = "smmap" version = "5.0.0" description = "A pure Python implementation of a sliding window memory map manager" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3651,7 +3546,6 @@ files = [ name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3663,7 +3557,6 @@ files = [ name = "soupsieve" version = "2.4.1" description = "A modern CSS selector implementation for Beautiful Soup." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3675,7 +3568,6 @@ files = [ name = "sqlalchemy" version = "1.4.49" description = "Database Abstraction Library" -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3720,7 +3612,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and platform_machine == \"aarch64\" or python_version >= \"3\" and platform_machine == \"ppc64le\" or python_version >= \"3\" and platform_machine == \"x86_64\" or python_version >= \"3\" and platform_machine == \"amd64\" or python_version >= \"3\" and platform_machine == \"AMD64\" or python_version >= \"3\" and platform_machine == \"win32\" or python_version >= \"3\" and platform_machine == \"WIN32\""} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"} [package.extras] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] @@ -3747,7 +3639,6 @@ sqlcipher = ["sqlcipher3-binary"] name = "sqlparse" version = "0.4.4" description = "A non-validating SQL parser." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -3764,7 +3655,6 @@ test = ["pytest", "pytest-cov"] name = "starlette" version = "0.27.0" description = "The little ASGI library that shines." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3783,7 +3673,6 @@ full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyam name = "tabulate" version = "0.9.0" description = "Pretty-print tabular data" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3798,7 +3687,6 @@ widechars = ["wcwidth"] name = "tenacity" version = "8.2.2" description = "Retry code until it succeeds" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3813,7 +3701,6 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"] name = "tensorboard" version = "2.12.3" description = "TensorBoard lets you watch Tensors Flow" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3838,7 +3725,6 @@ wheel = ">=0.26" name = "tensorboard-data-server" version = "0.7.1" description = "Fast data loading for TensorBoard" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3851,7 +3737,6 @@ files = [ name = "tensorflow" version = "2.12.0" description = "TensorFlow is an open source machine learning framework for everyone." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3901,7 +3786,6 @@ wrapt = ">=1.11.0,<1.15" name = "tensorflow-estimator" version = "2.12.0" description = "TensorFlow Estimator." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3912,7 +3796,6 @@ files = [ name = "tensorflow-hub" version = "0.12.0" description = "TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models." -category = "dev" optional = false python-versions = "*" files = [ @@ -3931,7 +3814,6 @@ make-nearest-neighbour-index = ["annoy", "apache-beam"] name = "tensorflow-io-gcs-filesystem" version = "0.32.0" description = "TensorFlow IO" -category = "dev" optional = false python-versions = ">=3.7, <3.12" files = [ @@ -3962,7 +3844,6 @@ tensorflow-rocm = ["tensorflow-rocm (>=2.12.0,<2.13.0)"] name = "tensorflow-macos" version = "2.12.0" description = "TensorFlow is an open source machine learning framework for everyone." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4004,7 +3885,6 @@ wrapt = ">=1.11.0,<1.15" name = "termcolor" version = "2.3.0" description = "ANSI color formatting for output in terminal" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4019,7 +3899,6 @@ tests = ["pytest", "pytest-cov"] name = "threadpoolctl" version = "3.1.0" description = "threadpoolctl" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4031,7 +3910,6 @@ files = [ name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4050,7 +3928,6 @@ test = ["flake8", "isort", "pytest"] name = "tokenizers" version = "0.13.3" description = "Fast and Customizable Tokenizers" -category = "dev" optional = false python-versions = "*" files = [ @@ -4105,7 +3982,6 @@ testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4117,7 +3993,6 @@ files = [ name = "torch" version = "1.13.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -4158,7 +4033,6 @@ opt-einsum = ["opt-einsum (>=3.3)"] name = "tornado" version = "6.3.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." -category = "dev" optional = false python-versions = ">= 3.8" files = [ @@ -4179,7 +4053,6 @@ files = [ name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4200,7 +4073,6 @@ telegram = ["requests"] name = "traitlets" version = "5.9.0" description = "Traitlets Python configuration system" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4216,7 +4088,6 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] name = "transformers" version = "4.30.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" -category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -4286,7 +4157,6 @@ vision = ["Pillow"] name = "typer" version = "0.9.0" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4308,7 +4178,6 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6. name = "typing-extensions" version = "4.7.1" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4320,7 +4189,6 @@ files = [ name = "urllib3" version = "1.26.16" description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -4337,7 +4205,6 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "uvicorn" version = "0.21.1" description = "The lightning-fast ASGI server." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4356,7 +4223,6 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", name = "virtualenv" version = "20.23.1" description = "Virtual Python Environment builder" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4377,7 +4243,6 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess name = "waitress" version = "2.1.2" description = "Waitress WSGI server" -category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -4393,7 +4258,6 @@ testing = ["coverage (>=5.0)", "pytest", "pytest-cover"] name = "watchfiles" version = "0.19.0" description = "Simple, modern and high performance file watching and code reload in python." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4428,7 +4292,6 @@ anyio = ">=3.0.0" name = "wcwidth" version = "0.2.6" description = "Measures the displayed width of unicode strings in a terminal" -category = "dev" optional = false python-versions = "*" files = [ @@ -4440,7 +4303,6 @@ files = [ name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" -category = "dev" optional = false python-versions = "*" files = [ @@ -4452,7 +4314,6 @@ files = [ name = "websocket-client" version = "1.6.1" description = "WebSocket client for Python with low level API options" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4469,7 +4330,6 @@ test = ["websockets"] name = "werkzeug" version = "2.3.6" description = "The comprehensive WSGI web application library." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4487,7 +4347,6 @@ watchdog = ["watchdog (>=2.3)"] name = "wheel" version = "0.40.0" description = "A built-package format for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4502,7 +4361,6 @@ test = ["pytest (>=6.0.0)"] name = "wrapt" version = "1.14.1" description = "Module for decorators, wrappers and monkey patching." -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -4576,7 +4434,6 @@ files = [ name = "xgboost" version = "1.7.6" description = "XGBoost Python Package" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4604,7 +4461,6 @@ scikit-learn = ["scikit-learn"] name = "zipp" version = "3.16.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4619,4 +4475,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "ad117168ccf349648a7835583c94a57c88434b44431beb13769938211bd00c42" +content-hash = "57da8ec16d6514ec6974048d8d8aa7ead0b8d89ecd7f66407bfcd7473809a053" diff --git a/pyproject.toml b/pyproject.toml index f921c6901..3421151de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.5.0" +version = "0.5.1" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" @@ -39,6 +39,8 @@ boto3 = "^1.26.157" rich = "^13.4.2" watchfiles = "^0.19.0" huggingface_hub = "^0.16.4" +rich-click = "^1.6.1" +inquirerpy = "^0.3.4" [tool.poetry.group.builder.dependencies] python = ">=3.8,<3.12" @@ -82,7 +84,7 @@ ipykernel = "^6.16.0" dockerfile = "^3.2.0" [tool.poetry.scripts] -truss = 'truss.cli:cli_group' +truss = 'truss.cli:truss_cli' [tool.poetry.group.dev.dependencies] mlflow = "^1.29.0" diff --git a/truss/cli.py b/truss/cli.py index 1ce224c66..93f26353d 100644 --- a/truss/cli.py +++ b/truss/cli.py @@ -1,18 +1,40 @@ import json import logging import os +import sys from functools import wraps from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, Optional, Union -import click +import rich +import rich_click as click import truss -import yaml +from truss.remote.remote_cli import inquire_model_name, inquire_remote_name from truss.remote.remote_factory import RemoteFactory logging.basicConfig(level=logging.INFO) +click.rich_click.COMMAND_GROUPS = { + "truss": [ + { + "name": "Main usage", + "commands": ["init", "push", "watch", "predict"], + "table_styles": { # type: ignore + "row_styles": ["green"], + }, + }, + { + "name": "Advanced Usage", + "commands": ["image", "container", "cleanup"], + "table_styles": { # type: ignore + "row_styles": ["yellow"], + }, + }, + ] +} + + def echo_output(f: Callable[..., object]): @wraps(f) def wrapper(*args, **kwargs): @@ -39,23 +61,26 @@ def print_help() -> None: @click.group(name="truss", invoke_without_command=True) # type: ignore @click.pass_context -@click.option( - "-v", - "--version", - is_flag=True, - show_default=False, - default=False, - help="Show Truss package version.", -) -def cli_group(ctx, version) -> None: +@click.version_option(truss.version()) +def truss_cli(ctx) -> None: + """truss: The simplest way to serve models in production""" if not ctx.invoked_subcommand: - if version: - click.echo(truss.version()) - else: - click.echo(ctx.get_help()) + click.echo(ctx.get_help()) + +@click.group() +def container(): + """Subcommands for truss container""" + pass -@cli_group.command() + +@click.group() +def image(): + """Subcommands for truss image""" + pass + + +@truss_cli.command() @click.argument("target_directory", required=True) @click.option( "-s", @@ -75,7 +100,7 @@ def cli_group(ctx, version) -> None: ) @error_handling def init(target_directory, skip_confirm, trainable) -> None: - """Initializes an empty Truss directory. + """Create a new truss. TARGET_DIRECTORY: A Truss is created in this directory """ @@ -85,7 +110,7 @@ def init(target_directory, skip_confirm, trainable) -> None: click.echo(f"Truss was created in {tr_path}") -@cli_group.command() +@image.command() # type: ignore @click.argument("build_dir") @click.argument("target_directory", required=False) @error_handling @@ -101,12 +126,12 @@ def build_context(build_dir, target_directory: str) -> None: tr.docker_build_setup(build_dir=Path(build_dir)) -@cli_group.command() # type: ignore +@image.command() # type: ignore @click.argument("target_directory", required=False) @click.argument("build_dir", required=False) @error_handling @click.option("--tag", help="Docker image tag") -def build_image(target_directory: str, build_dir: Path, tag) -> None: +def build(target_directory: str, build_dir: Path, tag) -> None: """ Builds the docker image for a Truss. @@ -120,7 +145,7 @@ def build_image(target_directory: str, build_dir: Path, tag) -> None: tr.build_serving_docker_image(build_dir=build_dir, tag=tag) -@cli_group.command() +@image.command() # type: ignore @click.argument("target_directory", required=False) @click.argument("build_dir", required=False) @click.option("--tag", help="Docker build image tag") @@ -129,7 +154,7 @@ def build_image(target_directory: str, build_dir: Path, tag) -> None: "--attach", is_flag=True, default=False, help="Flag for attaching the process" ) @error_handling -def run_image(target_directory: str, build_dir: Path, tag, port, attach) -> None: +def run(target_directory: str, build_dir: Path, tag, port, attach) -> None: """ Runs the docker image for a Truss. @@ -148,12 +173,12 @@ def run_image(target_directory: str, build_dir: Path, tag, port, attach) -> None tr.docker_run(build_dir=build_dir, tag=tag, local_port=port, detach=not attach) -@cli_group.command() +@truss_cli.command() @click.argument("target_directory", required=False, default=os.getcwd()) @click.option( "--remote", type=str, - required=True, + required=False, help="Name of the remote in .trussrc to patch changes to", ) @error_handling @@ -162,24 +187,31 @@ def watch( remote: str, ) -> None: """ - Watches local truss directory for changes and sends patch requests to remote development truss + Seamless remote development with truss TARGET_DIRECTORY: A Truss directory. If none, use current directory. """ # TODO: ensure that provider support draft + if not remote: + remote = inquire_remote_name(RemoteFactory.get_available_config_names()) + remote_provider = RemoteFactory.create(remote=remote) tr = _get_truss_from_directory(target_directory=target_directory) model_name = tr.spec.config.model_name if not model_name: - raise ValueError("'NoneType' model_name value provided in config.yaml") + rich.print( + "🧐 NoneType model_name provided in config.yaml. " + "Please check that you have the correct model name in your config file." + ) + sys.exit(1) - click.echo(f"Watching for changes to truss at: {target_directory} ...") + rich.print(f"👀 Watching for changes to truss at '{target_directory}' ...") remote_provider.sync_truss_to_dev_version_by_name(model_name, target_directory) # type: ignore -@cli_group.command() +@truss_cli.command() @click.option("--target_directory", required=False, help="Directory of truss") @click.option( "--request", @@ -218,7 +250,7 @@ def predict( request_file, ): """ - Invokes the packaged model, either locally or in a Docker container. + Invokes the packaged model TARGET_DIRECTORY: A Truss directory. If none, use current directory. @@ -251,54 +283,12 @@ def predict( ) -@cli_group.command() -@click.option("--target_directory", required=False, help="Directory of truss") -@click.option( - "--build-dir", - type=click.Path(exists=True), - required=False, - help="Directory where context is built", -) -@click.option("--tag", help="Docker build image tag") -@click.option( - "--var", - multiple=True, - help="""Training variables in key=value form where value is string. - For more complex values use vars_yaml_file""", -) -@click.option( - "--vars_yaml_file", - required=False, - help="Training variables from a yaml file", -) -@click.option( - "--local", - is_flag=True, - default=False, - help="Flag to run training locally (not on docker)", -) -@error_handling -@echo_output -def train(target_directory: str, build_dir, tag, var: List[str], vars_yaml_file, local): - """Runs prediction for a Truss in a docker image or locally""" - tr = _get_truss_from_directory(target_directory=target_directory) - if vars_yaml_file is not None: - with Path(vars_yaml_file).open() as vars_file: - variables = yaml.safe_load(vars_file) - else: - variables = _variables_dict_from_option(var) - if local: - return tr.local_train(variables=variables) - - return tr.docker_train(build_dir=build_dir, tag=tag, variables=variables) - - -@cli_group.command() +@truss_cli.command() @click.argument("target_directory", required=False, default=os.getcwd()) @click.option( "--remote", type=str, - required=True, + required=False, help="Name of the remote in .trussrc to push to", ) @click.option("--model-name", type=str, required=False, help="Name of the model") @@ -323,16 +313,17 @@ def push( TARGET_DIRECTORY: A Truss directory. If none, use current directory. """ + if not remote: + remote = inquire_remote_name(RemoteFactory.get_available_config_names()) + remote_provider = RemoteFactory.create(remote=remote) tr = _get_truss_from_directory(target_directory=target_directory) # Push model_name = model_name or tr.spec.config.model_name - if model_name is None: - raise ValueError( - "Model name must be provided either as a flag or in the Truss config" - ) + if not model_name: + model_name = inquire_model_name() # Write model name to config if it's not already there if model_name != tr.spec.config.model_name: @@ -345,41 +336,10 @@ def push( click.echo(f"Model {model_name} was successfully pushed.") -@cli_group.command() +@container.command() # type: ignore @click.argument("target_directory", required=False) -@click.option("--name", type=str, required=False, help="Name of example to run") -@click.option( - "--local", is_flag=True, default=False, help="Flag to run prediction locally" -) @error_handling -@echo_output -def run_example(target_directory: str, name, local): - """ - Runs examples specified in the Truss, over docker. - - TARGET_DIRECTORY: A Truss directory. If none, use current directory. - """ - tr = _get_truss_from_directory(target_directory=target_directory) - predict_fn = tr.docker_predict - if local: - predict_fn = tr.server_predict - - if name is not None: - example = tr.example(name) - click.echo(f"Running example: {name}") - return predict_fn(example.input) - else: - example_outputs = [] - for example in tr.examples(): - click.echo(f"Running example: {example.name}") - example_outputs.append(predict_fn(example.input)) - return example_outputs - - -@cli_group.command() -@click.argument("target_directory", required=False) -@error_handling -def get_container_logs(target_directory) -> None: +def logs(target_directory) -> None: """ Get logs in a container is running for a truss @@ -391,7 +351,7 @@ def get_container_logs(target_directory) -> None: click.echo(log) -@cli_group.command() # type: ignore +@container.command() # type: ignore @click.argument("target_directory", required=False) def kill(target_directory: str) -> None: """ @@ -403,13 +363,13 @@ def kill(target_directory: str) -> None: tr.kill_container() -@cli_group.command() +@container.command() # type: ignore def kill_all() -> None: "Kills all truss containers that are not manually persisted" truss.kill_all() -@cli_group.command() +@truss_cli.command() @error_handling def cleanup() -> None: """ @@ -429,19 +389,8 @@ def _get_truss_from_directory(target_directory: Optional[str] = None): return truss.load(target_directory) -def _variables_dict_from_option(variables_list: List[str]) -> dict: - vars_dict = {} - for var in variables_list: - first_equals_pos = var.find("=") - if first_equals_pos == -1: - raise ValueError( - f"Training variable expected in `key=value` from but found `{var}`", - ) - var_name = var[:first_equals_pos] - var_value = var[first_equals_pos + 1 :] - vars_dict[var_name] = var_value - return vars_dict - +truss_cli.add_command(container) +truss_cli.add_command(image) if __name__ == "__main__": - cli_group() + truss_cli() diff --git a/truss/contexts/local_loader/truss_file_syncer.py b/truss/contexts/local_loader/truss_file_syncer.py index 910ef95c2..fa241eebc 100644 --- a/truss/contexts/local_loader/truss_file_syncer.py +++ b/truss/contexts/local_loader/truss_file_syncer.py @@ -26,6 +26,9 @@ def run(self) -> None: """Watch for files in background and apply appropriate patches.""" from watchfiles import watch + # disable watchfiles logger + logging.getLogger("watchfiles.main").disabled = True + for _ in watch( self.watch_path, watch_filter=self.watch_filter, raise_interrupt=False ): diff --git a/truss/patch/calc_patch.py b/truss/patch/calc_patch.py index 359ed497c..c300b8a89 100644 --- a/truss/patch/calc_patch.py +++ b/truss/patch/calc_patch.py @@ -87,6 +87,7 @@ def _under_unsupported_patch_dir(path: str) -> bool: patches = [] for path in changed_paths["removed"]: if _strictly_under(path, [model_module_path]): + logger.info(f"Created patch to remove model code file: {path}") patches.append( Patch( type=PatchType.MODEL_CODE, @@ -101,6 +102,7 @@ def _under_unsupported_patch_dir(path: str) -> bool: logger.info(f"Patching not supported for removing {path}") return None elif _strictly_under(path, [bundled_packages_path]): + logger.info(f"Created patch to remove package file: {path}") patches.append( Patch( type=PatchType.PACKAGE, @@ -122,6 +124,9 @@ def _under_unsupported_patch_dir(path: str) -> bool: # TODO(pankaj) Add support for empty directories, skip them for now. if not full_path.is_file(): continue + logger.info( + f"Created patch to {action.value.lower()} model code file: {path}" + ) patches.append( Patch( type=PatchType.MODEL_CODE, @@ -138,12 +143,14 @@ def _under_unsupported_patch_dir(path: str) -> bool: yaml.safe_load(previous_truss_signature.config) ) config_patches = calc_config_patches(prev_config, new_config) + if config_patches: + logger.info(f"Created patch to {action.value.lower()} config") patches.extend(config_patches) elif _strictly_under(path, [bundled_packages_path]): full_path = truss_dir / path if not full_path.is_file(): continue - + logger.info(f"Created patch to {action.value.lower()} package file: {path}") patches.append( Patch( type=PatchType.PACKAGE, @@ -223,13 +230,16 @@ def calc_config_patches( Returns None if patch cannot be calculated. Empty list means no relevant differences found. """ - - config_patches = _calc_general_config_patches(prev_config, new_config) - python_requirement_patches = _calc_python_requirements_patches( - prev_config, new_config - ) - system_package_patches = _calc_system_packages_patches(prev_config, new_config) - return [*config_patches, *python_requirement_patches, *system_package_patches] + try: + config_patches = _calc_general_config_patches(prev_config, new_config) + python_requirement_patches = _calc_python_requirements_patches( + prev_config, new_config + ) + system_package_patches = _calc_system_packages_patches(prev_config, new_config) + return [*config_patches, *python_requirement_patches, *system_package_patches] + except Exception as e: + logger.error(f"Failed to calculate config patch with exception: {e}") + raise def _calc_general_config_patches( diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index c2ebbb528..8a3ca97f1 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -1,4 +1,5 @@ import logging +from typing import Optional import requests from truss.remote.baseten.auth import AuthService @@ -64,10 +65,18 @@ def create_model_from_truss( semver_bump, client_version, is_trusted=False, + model_id: Optional[str] = None, ): + if model_id: + mutation = "create_model_version_from_truss" + first_arg = f'model_id: "{model_id}"' + else: + mutation = "create_model_from_truss" + first_arg = f'name: "{model_name}"' + query_string = f""" mutation {{ - create_model_from_truss(name: "{model_name}", + {mutation}({first_arg}, s3_key: "{s3_key}", config: "{config}", semver_bump: "{semver_bump}", @@ -81,7 +90,7 @@ def create_model_from_truss( }} """ resp = self._post_graphql_query(query_string) - return resp["data"]["create_model_from_truss"] + return resp["data"][mutation] def create_development_model_from_truss( self, diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index bcfa5335f..4ec4ebc76 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -def exists_model(api: BasetenApi, model_name: str) -> bool: +def exists_model(api: BasetenApi, model_name: str) -> Optional[str]: """ Check if a model with the given name exists in the Baseten remote. @@ -19,13 +19,13 @@ def exists_model(api: BasetenApi, model_name: str) -> bool: model_name: Name of the model to check for existence Returns: - True if the model exists, False otherwise + model_id if present, otherwise None """ models = api.models() for model in models["models"]: if model["name"] == model_name: - return True - return False + return model["id"] + return None def get_dev_version_info(api: BasetenApi, model_name: str) -> dict: @@ -85,6 +85,7 @@ def create_truss_service( semver_bump: Optional[str] = "MINOR", is_trusted: Optional[bool] = False, is_draft: Optional[bool] = False, + model_id: Optional[str] = None, ) -> Tuple[str, str]: """ Create a model in the Baseten remote. @@ -116,6 +117,7 @@ def create_truss_service( semver_bump, f"truss=={truss.version()}", is_trusted, + model_id, ) return (model_version_json["id"], model_version_json["version_id"]) diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 3de687b4c..1c29ba4d2 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -4,7 +4,6 @@ import yaml from truss.contexts.local_loader.truss_file_syncer import TrussFilesSyncer from truss.local.local_config_handler import LocalConfigHandler -from truss.patch.constants import PATCHABLE_STATUSES from truss.remote.baseten.api import BasetenApi from truss.remote.baseten.auth import AuthService from truss.remote.baseten.core import ( @@ -33,8 +32,7 @@ def push(self, truss_handle: TrussHandle, model_name: str, publish: bool = True) if model_name.isspace(): raise ValueError("Model name cannot be empty") - if exists_model(self._api, model_name): - raise ValueError(f"Model with name {model_name} already exists") + model_id = exists_model(self._api, model_name) gathered_truss = TrussHandle(truss_handle.gather()) encoded_config_str = base64_encoded_json_str( @@ -43,12 +41,14 @@ def push(self, truss_handle: TrussHandle, model_name: str, publish: bool = True) temp_file = archive_truss(gathered_truss) s3_key = upload_truss(self._api, temp_file) + model_id, model_version_id = create_truss_service( api=self._api, model_name=model_name, s3_key=s3_key, config=encoded_config_str, is_draft=not publish, + model_id=model_id, ) return BasetenService( @@ -60,7 +60,11 @@ def push(self, truss_handle: TrussHandle, model_name: str, publish: bool = True) truss_handle=truss_handle, ) - def sync_truss_to_dev_version_by_name(self, model_name: str, target_directory: str): + def sync_truss_to_dev_version_by_name( + self, + model_name: str, + target_directory: str, + ): # verify that development deployment exists for given model name _ = get_dev_version_info( self._api, model_name # pylint: disable=protected-access @@ -76,29 +80,32 @@ def sync_truss_to_dev_version_by_name(self, model_name: str, target_directory: s while True: pass - def patch(self, watch_path: Path, logger: logging.Logger): + def patch( + self, + watch_path: Path, + logger: logging.Logger, + ): try: truss_handle = TrussHandle(watch_path) except yaml.parser.ParserError: - logger.error("Unable to parse config file.") + logger.error("Unable to parse config file") return model_name = truss_handle.spec.config.model_name dev_version = get_dev_version_info(self._api, model_name) # type: ignore truss_hash = dev_version.get("truss_hash", None) truss_signature = dev_version.get("truss_signature", None) LocalConfigHandler.add_signature(truss_hash, truss_signature) - patch_request = truss_handle.calc_patch(truss_hash) + try: + patch_request = truss_handle.calc_patch(truss_hash) + except Exception: + logger.error("Failed to calculate patch") + return if patch_request: if ( patch_request.prev_hash == patch_request.next_hash or len(patch_request.patch_ops) == 0 ): logger.info("No changes observed, skipping deployment") - model_deployment_status = dev_version.get( - "current_model_deployment_status", None - ).get("status", None) - if model_deployment_status not in PATCHABLE_STATUSES: - logger.info(f"Model {model_name} is not ready for patching") resp = self._api.patch_draft_truss(model_name, patch_request) if not resp["succeeded"]: needs_full_deploy = resp.get("needs_full_deploy", None) diff --git a/truss/remote/remote_cli.py b/truss/remote/remote_cli.py new file mode 100644 index 000000000..ffc803032 --- /dev/null +++ b/truss/remote/remote_cli.py @@ -0,0 +1,50 @@ +from typing import List + +import rich +from InquirerPy import inquirer +from truss.remote.remote_factory import RemoteFactory +from truss.remote.truss_remote import RemoteConfig + + +def inquire_remote_config() -> RemoteConfig: + # TODO(bola): extract questions from remote + rich.print("💻 Let's add a Baseten remote!") + remote_url = inquirer.text( + message="🌐 Baseten remote url:", + default="https://app.baseten.co", + qmark="", + ).execute() + api_key = inquirer.secret( + message="🤫 Quiety paste your API_KEY:", + qmark="", + ).execute() + return RemoteConfig( + name="baseten", + configs={ + "remote_provider": "baseten", + "api_key": api_key, + "remote_url": remote_url, + }, + ) + + +def inquire_remote_name(available_remotes: List[str]) -> str: + if len(available_remotes) > 1: + remote = inquirer.select( + "🎮 Which remote do you want to push to?", + qmark="", + choices=available_remotes, + ).execute() + return remote + elif len(available_remotes) == 1: + return available_remotes[0] + remote_config = inquire_remote_config() + RemoteFactory.update_remote_config(remote_config) + return remote_config.name + + +def inquire_model_name() -> str: + return inquirer.text( + "📦 Name this model:", + qmark="", + ).execute() diff --git a/truss/remote/remote_factory.py b/truss/remote/remote_factory.py index ad7fe0767..2010d18dd 100644 --- a/truss/remote/remote_factory.py +++ b/truss/remote/remote_factory.py @@ -1,10 +1,25 @@ -import configparser import inspect +from configparser import DEFAULTSECT, SafeConfigParser +from functools import partial +from operator import is_not from pathlib import Path -from typing import Dict, Type +from typing import Dict, List, Type from truss.remote.baseten import BasetenRemote -from truss.remote.truss_remote import TrussRemote +from truss.remote.truss_remote import RemoteConfig, TrussRemote + +USER_TRUSSRC_PATH = Path("~/.trussrc").expanduser() + + +def load_config() -> SafeConfigParser: + config = SafeConfigParser() + config.read(USER_TRUSSRC_PATH) + return config + + +def update_config(config: SafeConfigParser): + with open(USER_TRUSSRC_PATH, "w") as configfile: + config.write(configfile) class RemoteFactory: @@ -15,22 +30,36 @@ class RemoteFactory: REGISTRY: Dict[str, Type[TrussRemote]] = {"baseten": BasetenRemote} @staticmethod - def load_remote_config(remote_name: str) -> Dict: + def get_available_config_names() -> List[str]: + if not USER_TRUSSRC_PATH.exists(): + return [] + + config = load_config() + return list(filter(partial(is_not, DEFAULTSECT), config.keys())) + + @staticmethod + def load_remote_config(remote_name: str) -> RemoteConfig: """ Load and validate a remote config from the .trussrc file """ - config_path = Path("~/.trussrc").expanduser() - - if not config_path.exists(): - raise FileNotFoundError(f"No .trussrc file found at {config_path}") + if not USER_TRUSSRC_PATH.exists(): + raise FileNotFoundError("No ~/.trussrc file found.") - config = configparser.ConfigParser() - config.read(config_path) + config = load_config() if remote_name not in config: - raise ValueError(f"Service provider {remote_name} not found in .trussrc") + raise ValueError(f"Service provider {remote_name} not found in ~/.trussrc") - return dict(config[remote_name]) + return RemoteConfig(name=remote_name, configs=dict(config[remote_name])) + + @staticmethod + def update_remote_config(remote_config: RemoteConfig): + """ + Load and validate a remote config from the .trussrc file + """ + config = load_config() + config[remote_config.name] = remote_config.configs + update_config(config) @staticmethod def validate_remote_config(remote_config: Dict, remote_name: str): @@ -76,17 +105,17 @@ def required_params(remote: Type[TrussRemote]) -> set: @classmethod def create(cls, remote: str) -> TrussRemote: remote_config = cls.load_remote_config(remote) - cls.validate_remote_config(remote_config, remote) + configs = remote_config.configs + cls.validate_remote_config(configs, remote) - remote_class = cls.REGISTRY[remote_config.pop("remote_provider")] + remote_class = cls.REGISTRY[configs.pop("remote_provider")] remote_params = { - param: remote_config.get(param) - for param in cls.required_params(remote_class) + param: configs.get(param) for param in cls.required_params(remote_class) } # Add any additional params provided by the user in their .trussrc - additional_params = set(remote_config.keys()) - set(remote_params.keys()) + additional_params = set(configs.keys()) - set(remote_params.keys()) for param in additional_params: - remote_params[param] = remote_config.get(param) + remote_params[param] = configs.get(param) return remote_class(**remote_params) # type: ignore diff --git a/truss/remote/truss_remote.py b/truss/remote/truss_remote.py index 5a7f8e479..7ec258b44 100644 --- a/truss/remote/truss_remote.py +++ b/truss/remote/truss_remote.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Dict, Optional import requests @@ -186,3 +187,11 @@ def authenticate(self, **kwargs): **kwargs: Additional keyword arguments for the authentication operation. """ pass + + +@dataclass +class RemoteConfig: + """Class to hold configs for various remotes""" + + name: str + configs: Dict = field(default_factory=dict) diff --git a/truss/templates/server/common/truss_server.py b/truss/templates/server/common/truss_server.py index 2595d3661..cbd626121 100644 --- a/truss/templates/server/common/truss_server.py +++ b/truss/templates/server/common/truss_server.py @@ -10,7 +10,7 @@ import time from collections.abc import Generator from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import AsyncGenerator, Dict, List, Optional, Union import common.errors as errors import shared.util as utils @@ -107,15 +107,15 @@ async def invocations_ready(self) -> Dict[str, Union[str, bool]]: return {} - def invocations( + async def invocations( self, request: Request, body_raw: bytes = Depends(parse_body) ) -> Response: """ This method provides compatibility with Sagemaker hosting for the 'invocations' endpoint. """ - return self.predict(self._model.name, request, body_raw) + return await self.predict(self._model.name, request, body_raw) - def predict( + async def predict( self, model_name: str, request: Request, body_raw: bytes = Depends(parse_body) ) -> Response: """ @@ -132,16 +132,14 @@ def predict( body = json.loads(body_raw) # calls ModelWrapper.__call__, which runs validate, preprocess, predict, and postprocess - response: Union[Dict, Generator] = asyncio.run( - model( - body, - headers=utils.transform_keys(request.headers, lambda key: key.lower()), - ) + response: Union[Dict, Generator] = await model( + body, + headers=utils.transform_keys(request.headers, lambda key: key.lower()), ) # In the case that the model returns a Generator object, return a # StreamingResponse instead. - if isinstance(response, Generator): + if isinstance(response, AsyncGenerator): # media_type in StreamingResponse sets the Content-Type header return StreamingResponse(response, media_type="application/octet-stream") diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 733385a7e..d72449e4d 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -1,3 +1,4 @@ +import asyncio import importlib import inspect import logging @@ -8,11 +9,10 @@ from collections.abc import Generator from enum import Enum from pathlib import Path -from queue import Queue from threading import Lock, Thread -from typing import Any, Dict, Optional, Union +from typing import Any, AsyncGenerator, Dict, Optional, Set, Union -from anyio import to_thread +from anyio import Semaphore, to_thread from common.patches import apply_patches from common.retry import retry from shared.secrets_resolver import SecretsResolver @@ -21,6 +21,7 @@ NUM_LOAD_RETRIES = int(os.environ.get("NUM_LOAD_RETRIES_TRUSS", "3")) STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS = 60 +DEFAULT_PREDICT_CONCURRENCY = 1 class ModelWrapper: @@ -36,8 +37,13 @@ def __init__(self, config: Dict): self.name = MODEL_BASENAME self.ready = False self._load_lock = Lock() - self._predict_lock = Lock() self._status = ModelWrapper.Status.NOT_READY + self._predict_semaphore = Semaphore( + self._config.get("runtime", {}).get( + "predict_concurrency", DEFAULT_PREDICT_CONCURRENCY + ) + ) + self._background_tasks: Set[asyncio.Task] = set() def load(self) -> bool: if self.ready: @@ -123,34 +129,87 @@ def try_load(self): gap_seconds=1.0, ) - def preprocess( + async def preprocess( self, payload: Any, headers: Optional[Dict[str, str]] = None, ) -> Any: if not hasattr(self._model, "preprocess"): return payload - return self._model.preprocess(payload) # type: ignore - def postprocess( + if inspect.iscoroutinefunction(self._model.preprocess): + return await self._model.preprocess(payload) + else: + return await to_thread.run_sync(self._model.preprocess, payload) + + def _predict_sync_with_error_handling(self, payload): + try: + return self._model.predict(payload) + except Exception: + logging.exception("Exception while running predict") + return {"error": {"traceback": traceback.format_exc()}} + + async def _predict_async_with_error_handling(self, payload): + try: + return await self._model.predict(payload) + except Exception: + logging.exception("Exception while running predict") + return {"error": {"traceback": traceback.format_exc()}} + + async def predict( + self, + payload: Any, + headers: Optional[Dict[str, str]] = None, + ) -> Any: + # It's possible for the user's predict function to be a: + # 1. Generator function (function that returns a generator) + # 2. Async generator (function that returns async generator) + # In these cases, just return the generator or async generator, + # as we will be propagating these up. No need for await at this point. + # 3. Coroutine -- in this case, await the predict function as it is async + # 4. Normal function -- in this case, offload to a separate thread to prevent + # blocking the main event loop + if inspect.isasyncgenfunction( + self._model.predict + ) or inspect.isgeneratorfunction(self._model.predict): + return self._model.predict(payload) + + if inspect.iscoroutinefunction(self._model.predict): + return await self._predict_async_with_error_handling(payload) + + return await to_thread.run_sync(self._predict_sync_with_error_handling, payload) + + async def postprocess( self, response: Any, headers: Optional[Dict[str, str]] = None, ) -> Any: + # Similar to the predict function, it is possible for postprocess + # to return either a generator or async generator, in which case + # just return the generator. + # + # It can also return a coroutine or just be a function, in which + # case either await, or offload to a thread respectively. if not hasattr(self._model, "postprocess"): return response - return self._model.postprocess(response) # type: ignore - def predict( - self, - payload: Any, - headers: Optional[Dict[str, str]] = None, - ) -> Any: - try: - return self._model.predict(payload) # type: ignore - except Exception: - logging.exception("Exception while running predict") - return {"error": {"traceback": traceback.format_exc()}} + if inspect.isasyncgenfunction( + self._model.postprocess + ) or inspect.isgeneratorfunction(self._model.postprocess): + return self._model.postprocess(response, headers) + + if inspect.iscoroutinefunction(self._model.postprocess): + return await self._model.postprocess(response) + + return await to_thread.run_sync(self._model.postprocess, response) + + async def write_response_to_queue( + self, queue: asyncio.Queue, generator: AsyncGenerator + ): + async for chunk in generator: + await queue.put(ResponseChunk(chunk)) + + await queue.put(None) async def __call__( self, body: Any, headers: Optional[Dict[str, str]] = None @@ -166,57 +225,49 @@ async def __call__( Generator: In case of streaming response """ - payload = ( - await self.preprocess(body, headers) - if inspect.iscoroutinefunction(self.preprocess) - else self.preprocess(body, headers) - ) + payload = await self.preprocess(body, headers) - return await to_thread.run_sync(self._predict_and_post, payload, headers) + async with self._predict_semaphore: + response = await self.predict(payload, headers) - def _predict_and_post( - self, - payload: Any, - headers: Optional[Dict[str, str]] = None, - ) -> Any: - self._predict_lock.acquire() - defer_lock_release = False - try: - response = self.predict(payload, headers) - response = self.postprocess(response, headers) - if not isinstance(response, Generator): - return response + processed_response = await self.postprocess(response) - # Generator response - if headers and headers.get("accept") == "application/json": - return _convert_streamed_response_to_string(response) + # Streaming cases + if inspect.isgenerator(response) or inspect.isasyncgen(response): + async_generator = _force_async_generator(response) - # Reaching here means streaming response, and need to defer releasing lock - defer_lock_release = True - finally: - if not defer_lock_release: - self._predict_lock.release() - - # Streaming response - response_queue: Queue = Queue() - - def queue_response_chunks(): - # In a background thread, write the response chunks to a queue. - # In the main thread, read data from the queue until a "None" - # is written. This allows to us to use the predict lock only - # around the actual predict, and does not create a dependency - # on the client reading the entire response before releasing - # the lock. - try: - for chunk in response: - response_queue.put(ResponseChunk(chunk)) - response_queue.put(None) - finally: - self._predict_lock.release() - - response_generate_thread = Thread(target=queue_response_chunks) - response_generate_thread.start() - return _response_generator(response_queue) + if headers and headers.get("accept") == "application/json": + # In the case of a streaming response, consume stream + # if the http accept header is set, and json is requested. + return await _convert_streamed_response_to_string(async_generator) + + # To ensure that a partial read from a client does not cause the semaphore + # to stay claimed, we immediately write all of the data from the stream to a + # queue. We then return a new generator that reads from the queue, and then + # exit the semaphore block. + response_queue: asyncio.Queue = asyncio.Queue() + + # This task will be triggered and run in the background. + task = asyncio.create_task( + self.write_response_to_queue(response_queue, async_generator) + ) + self._background_tasks.add(task) + + task.add_done_callback(self._background_tasks.discard) + + async def _response_generator(): + while True: + chunk = await asyncio.wait_for( + response_queue.get(), + timeout=STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS, + ) + if chunk is None: + return + yield chunk.value + + return _response_generator() + + return processed_response class ResponseChunk: @@ -224,21 +275,34 @@ def __init__(self, value): self.value = value -def _response_generator(queue: Queue): +async def _convert_streamed_response_to_string(response: AsyncGenerator): + return "".join([str(chunk) async for chunk in response]) + + +def _force_async_generator(gen: Union[Generator, AsyncGenerator]) -> AsyncGenerator: """ - When returning the stream result, simply read from the response queue until a `None` - is reached. + Takes a generator, and converts it into an async generator if it is not already. """ - while True: - chunk = queue.get(timeout=STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS) - if chunk is None: - return - else: - yield chunk.value + if inspect.isasyncgen(gen): + return gen - -def _convert_streamed_response_to_string(response: Any): - return "".join([str(chunk) for chunk in list(response)]) + async def _convert_generator_to_async(): + """ + Runs each iteration of the generator in an offloaded thread, to ensure + the main loop is not blocked, and yield to create an async generator. + """ + FINAL_GENERATOR_VALUE = object() + while True: + # Note that this is the equivalent of running: + # next(gen, FINAL_GENERATOR_VALUE) on a separate thread, + # ensuring that if there is anything blocking in the generator, + # it does not block the main loop. + chunk = await to_thread.run_sync(next, gen, FINAL_GENERATOR_VALUE) + if chunk == FINAL_GENERATOR_VALUE: + break + yield chunk + + return _convert_generator_to_async() def _signature_accepts_keyword_arg(signature: inspect.Signature, kwarg: str) -> bool: diff --git a/truss/test_data/test_async_truss/config.yaml b/truss/test_data/test_async_truss/config.yaml new file mode 100644 index 000000000..55179ebcf --- /dev/null +++ b/truss/test_data/test_async_truss/config.yaml @@ -0,0 +1,36 @@ +apply_library_patches: true +bundled_packages_dir: packages +data_dir: data +description: null +environment_variables: {} +examples_filename: examples.yaml +external_package_dirs: [] +input_type: Any +live_reload: false +model_class_filename: model.py +model_class_name: Model +model_framework: custom +model_metadata: {} +model_module_dir: model +model_name: null +model_type: custom +python_version: py39 +requirements: [] +resources: + accelerator: null + cpu: 500m + memory: 512Mi + use_gpu: false +secrets: {} +spec_version: '2.0' +system_packages: [] +train: + resources: + accelerator: null + cpu: 500m + memory: 512Mi + use_gpu: false + training_class_filename: train.py + training_class_name: Train + training_module_dir: train + variables: {} diff --git a/truss/test_data/test_async_truss/model/model.py b/truss/test_data/test_async_truss/model/model.py new file mode 100644 index 000000000..d0a49955d --- /dev/null +++ b/truss/test_data/test_async_truss/model/model.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, List + + +class Model: + def __init__(self, **kwargs) -> None: + self._data_dir = kwargs["data_dir"] + self._config = kwargs["config"] + self._secrets = kwargs["secrets"] + self._model = None + + def load(self): + # Load model here and assign to self._model. + pass + + async def preprocess(self, model_input: Dict): + return {"preprocess_value": "value", **model_input} + + async def postprocess(self, response: Dict): + return {"postprocess_value": "value", **response} + + async def predict(self, model_input: Any) -> Dict[str, List]: + return model_input diff --git a/truss/test_data/test_concurrency_truss/config.yaml b/truss/test_data/test_concurrency_truss/config.yaml new file mode 100644 index 000000000..593cbc50d --- /dev/null +++ b/truss/test_data/test_concurrency_truss/config.yaml @@ -0,0 +1,2 @@ +runtime: + predict_concurrency: 2 diff --git a/truss/test_data/test_concurrency_truss/model/model.py b/truss/test_data/test_concurrency_truss/model/model.py new file mode 100644 index 000000000..62b132aa2 --- /dev/null +++ b/truss/test_data/test_concurrency_truss/model/model.py @@ -0,0 +1,20 @@ +import time +from typing import Any, Dict, List + + +class Model: + def __init__(self, **kwargs) -> None: + self._data_dir = kwargs["data_dir"] + self._config = kwargs["config"] + self._secrets = kwargs["secrets"] + self._model = None + + def load(self): + # Load model here and assign to self._model. + print("hello") + pass + + def predict(self, model_input: Any) -> Dict[str, List]: + # Invoke model on model_input and calculate predictions here. + time.sleep(2) + return model_input diff --git a/truss/test_data/test_streaming_async_generator_truss/config.yaml b/truss/test_data/test_streaming_async_generator_truss/config.yaml new file mode 100644 index 000000000..55179ebcf --- /dev/null +++ b/truss/test_data/test_streaming_async_generator_truss/config.yaml @@ -0,0 +1,36 @@ +apply_library_patches: true +bundled_packages_dir: packages +data_dir: data +description: null +environment_variables: {} +examples_filename: examples.yaml +external_package_dirs: [] +input_type: Any +live_reload: false +model_class_filename: model.py +model_class_name: Model +model_framework: custom +model_metadata: {} +model_module_dir: model +model_name: null +model_type: custom +python_version: py39 +requirements: [] +resources: + accelerator: null + cpu: 500m + memory: 512Mi + use_gpu: false +secrets: {} +spec_version: '2.0' +system_packages: [] +train: + resources: + accelerator: null + cpu: 500m + memory: 512Mi + use_gpu: false + training_class_filename: train.py + training_class_name: Train + training_module_dir: train + variables: {} diff --git a/truss/test_data/test_streaming_async_generator_truss/model/model.py b/truss/test_data/test_streaming_async_generator_truss/model/model.py new file mode 100644 index 000000000..92a53f8a2 --- /dev/null +++ b/truss/test_data/test_streaming_async_generator_truss/model/model.py @@ -0,0 +1,18 @@ +from typing import Any, Dict, List + + +class Model: + def __init__(self, **kwargs) -> None: + self._data_dir = kwargs["data_dir"] + self._config = kwargs["config"] + self._secrets = kwargs["secrets"] + self._model = None + + def load(self): + # Load model here and assign to self._model. + pass + + async def predict(self, model_input: Any) -> Dict[str, List]: + # Invoke model on model_input and calculate predictions here. + for i in range(5): + yield str(i) diff --git a/truss/test_data/test_streaming_truss/config.yaml b/truss/test_data/test_streaming_truss/config.yaml index 55179ebcf..7a9032937 100644 --- a/truss/test_data/test_streaming_truss/config.yaml +++ b/truss/test_data/test_streaming_truss/config.yaml @@ -21,6 +21,8 @@ resources: cpu: 500m memory: 512Mi use_gpu: false +runtime: + predict_concurrency: 1 secrets: {} spec_version: '2.0' system_packages: [] diff --git a/truss/tests/remote/test_remote_factory.py b/truss/tests/remote/test_remote_factory.py index f552e0240..bd575b148 100644 --- a/truss/tests/remote/test_remote_factory.py +++ b/truss/tests/remote/test_remote_factory.py @@ -2,7 +2,7 @@ import pytest from truss.remote.remote_factory import RemoteFactory -from truss.remote.truss_remote import TrussRemote +from truss.remote.truss_remote import RemoteConfig, TrussRemote SAMPLE_CONFIG = {"api_key": "test_key", "remote_url": "http://test.com"} @@ -38,11 +38,17 @@ def push(self): def mock_service_config(): - return {"remote_provider": "test_remote", **SAMPLE_CONFIG} + return RemoteConfig( + name="mock-service", + configs={"remote_provider": "test_remote", **SAMPLE_CONFIG}, + ) def mock_incorrect_service_config(): - return {"remote_provider": "nonexistent_remote", **SAMPLE_CONFIG} + return RemoteConfig( + name="mock-incorrect-service", + configs={"remote_provider": "nonexistent_remote", **SAMPLE_CONFIG}, + ) @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) @@ -73,7 +79,8 @@ def test_create_no_service(mock_load_remote_config): @mock.patch("pathlib.Path.exists", return_value=True) def test_load_remote_config(mock_exists, mock_open): service = RemoteFactory.load_remote_config("test") - assert service == {"remote_provider": "test_remote", **SAMPLE_CONFIG} + assert service.name == "test" + assert service.configs == {"remote_provider": "test_remote", **SAMPLE_CONFIG} @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) @@ -104,9 +111,9 @@ def test_required_params(): ) @mock.patch("pathlib.Path.exists", return_value=True) def test_validate_remote_config_no_remote(mock_exists, mock_open): + service = RemoteFactory.load_remote_config("test") with pytest.raises(ValueError): - service = RemoteFactory.load_remote_config("test") - RemoteFactory.validate_remote_config(service, "test") + RemoteFactory.validate_remote_config(service.configs, "test") @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) @@ -115,6 +122,6 @@ def test_validate_remote_config_no_remote(mock_exists, mock_open): ) @mock.patch("pathlib.Path.exists", return_value=True) def test_load_remote_config_no_params(mock_exists, mock_open): + service = RemoteFactory.load_remote_config("test") with pytest.raises(ValueError): - service = RemoteFactory.load_remote_config("test") - RemoteFactory.validate_remote_config(service, "test") + RemoteFactory.validate_remote_config(service.configs, "test") diff --git a/truss/tests/test_config.py b/truss/tests/test_config.py index 33b3e3430..45a61b103 100644 --- a/truss/tests/test_config.py +++ b/truss/tests/test_config.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest import yaml from truss.truss_config import ( @@ -279,3 +281,39 @@ def test_huggingface_cache_multiple_models_mixed_revision(): assert new_config == config.to_dict(verbose=False) assert config.to_dict(verbose=True)["hf_cache"][0].get("revision") is None assert config.to_dict(verbose=True)["hf_cache"][1].get("revision") == "not-main2" + + +def test_empty_config(): + config = TrussConfig() + new_config = generate_default_config() + + assert new_config == config.to_dict(verbose=False) + + +def test_from_yaml(): + yaml_path = Path("test.yaml") + data = {"description": "this is a test"} + with yaml_path.open("w") as yaml_file: + yaml.safe_dump(data, yaml_file) + + result = TrussConfig.from_yaml(yaml_path) + + assert result.description == "this is a test" + + yaml_path.unlink() + + +def test_from_yaml_empty(): + yaml_path = Path("test.yaml") + data = {} + with yaml_path.open("w") as yaml_file: + yaml.safe_dump(data, yaml_file) + + result = TrussConfig.from_yaml(yaml_path) + + # test some attributes (should be default) + assert result.description is None + assert result.spec_version == "2.0" + assert result.bundled_packages_dir == "packages" + + yaml_path.unlink() diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index 43b89922a..c4ff9fdef 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -21,6 +21,26 @@ logger = logging.getLogger(__name__) +class PropagatingThread(Thread): + """ + PropagatingThread allows us to run threads and keep track of exceptions + thrown. + """ + + def run(self): + self.exc = None + try: + self.ret = self._target(*self._args, **self._kwargs) + except BaseException as e: + self.exc = e + + def join(self, timeout=None): + super(PropagatingThread, self).join(timeout) + if self.exc: + raise self.exc + return self.ret + + def test_pytorch_init_arg_validation( pytorch_model_with_init_args, pytorch_model_init_args ): @@ -137,6 +157,93 @@ def _test_invocations(expected_code): assert not _test_invocations(expected_code=200) +@pytest.mark.integration +def test_concurrency_truss(): + # Tests that concurrency limits work correctly + with ensure_kill_all(): + truss_root = Path(__file__).parent.parent.parent.resolve() / "truss" + + truss_dir = truss_root / "test_data" / "test_concurrency_truss" + + tr = TrussHandle(truss_dir) + + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + + # Each request takes 2 seconds, for this thread, we allow + # a concurrency of 2. This means the first two requests will + # succeed within the 2 seconds, and the third will fail, since + # it cannot start until the first two have completed. + def make_request(): + requests.post(full_url, json={}, timeout=3) + + successful_thread_1 = PropagatingThread(target=make_request) + successful_thread_2 = PropagatingThread(target=make_request) + failed_thread = PropagatingThread(target=make_request) + + successful_thread_1.start() + successful_thread_2.start() + # Ensure that the thread to fail starts a little after the others + time.sleep(0.2) + failed_thread.start() + + successful_thread_1.join() + successful_thread_2.join() + with pytest.raises(requests.exceptions.ReadTimeout): + failed_thread.join() + + +@pytest.mark.integration +def test_async_truss(): + with ensure_kill_all(): + truss_root = Path(__file__).parent.parent.parent.resolve() / "truss" + + truss_dir = truss_root / "test_data" / "test_async_truss" + + tr = TrussHandle(truss_dir) + + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + + response = requests.post(full_url, json={}) + assert response.json() == { + "preprocess_value": "value", + "postprocess_value": "value", + } + + +@pytest.mark.integration +def test_async_streaming(): + with ensure_kill_all(): + truss_root = Path(__file__).parent.parent.parent.resolve() / "truss" + + truss_dir = truss_root / "test_data" / "test_streaming_async_generator_truss" + + tr = TrussHandle(truss_dir) + + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + + response = requests.post(full_url, json={}, stream=True) + assert response.headers.get("transfer-encoding") == "chunked" + assert [ + byte_string.decode() for byte_string in list(response.iter_content()) + ] == ["0", "1", "2", "3", "4"] + + predict_non_stream_response = requests.post( + full_url, + json={}, + stream=True, + headers={"accept": "application/json"}, + ) + assert "transfer-encoding" not in predict_non_stream_response.headers + assert predict_non_stream_response.json() == "01234" + + @pytest.mark.integration def test_streaming_truss(): with ensure_kill_all(): diff --git a/truss/truss_config.py b/truss/truss_config.py index 825ed946d..50a38ce2b 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -27,6 +27,7 @@ DEFAULT_DATA_DIRECTORY = "data" DEFAULT_EXAMPLES_FILENAME = "examples.yaml" DEFAULT_SPEC_VERSION = "2.0" +DEFAULT_PREDICT_CONCURRENCY = 1 DEFAULT_CPU = "500m" DEFAULT_MEMORY = "512Mi" @@ -107,6 +108,24 @@ def to_list(self, verbose=False) -> List[Dict[str, str]]: return [model.to_dict(verbose=verbose) for model in self.models] +@dataclass +class Runtime: + predict_concurrency: int = DEFAULT_PREDICT_CONCURRENCY + + @staticmethod + def from_dict(d): + predict_concurrency = d.get("predict_concurrency", DEFAULT_PREDICT_CONCURRENCY) + + return Runtime( + predict_concurrency=predict_concurrency, + ) + + def to_dict(self): + return { + "predict_concurrency": self.predict_concurrency, + } + + @dataclass class Resources: cpu: str = DEFAULT_CPU @@ -282,6 +301,7 @@ class TrussConfig: system_packages: List[str] = field(default_factory=list) environment_variables: Dict[str, str] = field(default_factory=dict) resources: Resources = field(default_factory=Resources) + runtime: Runtime = field(default_factory=Runtime) python_version: str = DEFAULT_PYTHON_VERSION examples_filename: str = DEFAULT_EXAMPLES_FILENAME secrets: Dict[str, str] = field(default_factory=dict) @@ -308,9 +328,6 @@ def canonical_python_version(self) -> str: @staticmethod def from_dict(d): config = TrussConfig( - # Users that are calling `load` on an existing Truss - # should default to 1.0 whereas users creating a new Truss - # should default to 2.0. spec_version=d.get("spec_version", DEFAULT_SPEC_VERSION), model_type=d.get("model_type", DEFAULT_MODEL_TYPE), model_framework=ModelFrameworkType( @@ -328,6 +345,7 @@ def from_dict(d): system_packages=d.get("system_packages", []), environment_variables=d.get("environment_variables", {}), resources=Resources.from_dict(d.get("resources", {})), + runtime=Runtime.from_dict(d.get("runtime", {})), python_version=d.get("python_version", DEFAULT_PYTHON_VERSION), model_name=d.get("model_name", None), examples_filename=d.get("examples_filename", DEFAULT_EXAMPLES_FILENAME), @@ -352,7 +370,8 @@ def from_dict(d): @staticmethod def from_yaml(yaml_path: Path): with yaml_path.open() as yaml_file: - return TrussConfig.from_dict(yaml.safe_load(yaml_file)) + raw_data = yaml.safe_load(yaml_file) or {} + return TrussConfig.from_dict(raw_data) def write_to_yaml_file(self, path: Path, verbose: bool = True): with path.open("w") as config_file: @@ -372,6 +391,7 @@ def validate(self): DATACLASS_TO_REQ_KEYS_MAP = { Train: {"variables"}, Resources: {"accelerator", "cpu", "memory", "use_gpu"}, + Runtime: {"predict_concurrency"}, TrussConfig: { "environment_variables", "external_package_dirs",