diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index b6db61a3c..147672284 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -156,6 +156,11 @@ jobs: #---------------------------------------------- - name: Black run: poetry run black --check src + #---------------------------------------------- + # pylint the code + #---------------------------------------------- + - name: Pylint + run: poetry run pylint --rcfile=pylintrc src check-types: runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..888c02f2e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-added-large-files + +- repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + exclude: '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|thrift_api)/' \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0cb258769..ccd0b2566 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -145,9 +145,11 @@ The `PySQLStagingIngestionTestSuite` namespace requires a cluster running DBR ve The suites marked `[not documented]` require additional configuration which will be documented at a later time. -### Code formatting +### Code formatting and linting -This project uses [Black](https://pypi.org/project/black/). +This project uses [Black](https://pypi.org/project/black/) for code formatting and [Pylint](https://pylint.org/) for linting. + +#### Black ``` poetry run python3 -m black src --check @@ -157,6 +159,25 @@ Remove the `--check` flag to write reformatted files to disk. To simplify reviews you can format your changes in a separate commit. +#### Pylint + +``` +poetry run pylint --rcfile=pylintrc src +``` + +#### Pre-commit hooks + +We use [pre-commit](https://pre-commit.com/) to automatically run Black and other checks before each commit. + +To set up pre-commit hooks: + +```bash +# Set up the git hooks +poetry run pre-commit install +``` + +This will set up the hooks defined in `.pre-commit-config.yaml` to run automatically on each commit. + ### Change a pinned dependency version Modify the dependency specification (syntax can be found [here](https://python-poetry.org/docs/dependency-specification/)) in `pyproject.toml` and run one of the following in your terminal: diff --git a/poetry.lock b/poetry.lock index b68d1a3fb..25a261cb2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -63,6 +63,18 @@ files = [ {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, ] +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "3.4.1" @@ -209,6 +221,18 @@ files = [ graph = ["objgraph (>=1.7.2)"] profile = ["gprof2dot (>=2022.7.29)"] +[[package]] +name = "distlib" +version = "0.3.9" +description = "Distribution utilities" +optional = false +python-versions = "*" +groups = ["dev"] +files = [ + {file = "distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87"}, + {file = "distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403"}, +] + [[package]] name = "et-xmlfile" version = "2.0.0" @@ -237,6 +261,74 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "filelock" +version = "3.16.1" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version < \"3.10\"" +files = [ + {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, + {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] +typing = ["typing-extensions (>=4.12.2) ; python_version < \"3.11\""] + +[[package]] +name = "filelock" +version = "3.18.0" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de"}, + {file = "filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"] +typing = ["typing-extensions (>=4.12.2) ; python_version < \"3.11\""] + +[[package]] +name = "identify" +version = "2.6.1" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version < \"3.10\"" +files = [ + {file = "identify-2.6.1-py2.py3-none-any.whl", hash = "sha256:53863bcac7caf8d2ed85bd20312ea5dcfc22226800f6d6881f232d861db5a8f0"}, + {file = "identify-2.6.1.tar.gz", hash = "sha256:91478c5fb7c3aac5ff7bf9b4344f803843dc586832d5f110d672b19aa1984c98"}, +] + +[package.extras] +license = ["ukkonen"] + +[[package]] +name = "identify" +version = "2.6.12" +description = "File identification library for Python" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "identify-2.6.12-py2.py3-none-any.whl", hash = "sha256:ad9672d5a72e0d2ff7c5c8809b62dfa60458626352fb0eb7b55e69bdc45334a2"}, + {file = "identify-2.6.12.tar.gz", hash = "sha256:d8de45749f1efb108badef65ee8386f0f7bb19a7f26185f74de6367bffbaf0e6"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.10" @@ -414,6 +506,18 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["dev"] +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "numpy" version = "1.24.4" @@ -761,6 +865,46 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pre-commit" +version = "3.5.0" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version < \"3.10\"" +files = [ + {file = "pre_commit-3.5.0-py2.py3-none-any.whl", hash = "sha256:841dc9aef25daba9a0238cd27984041fa0467b4199fc4852e27950664919f660"}, + {file = "pre_commit-3.5.0.tar.gz", hash = "sha256:5804465c675b659b0862f07907f96295d490822a450c4c40e747d0b1c6ebcb32"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + +[[package]] +name = "pre-commit" +version = "3.8.0" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "pre_commit-3.8.0-py2.py3-none-any.whl", hash = "sha256:9a90a53bf82fdd8778d58085faf8d83df56e40dfe18f45b19446e26bf1b3a63f"}, + {file = "pre_commit-3.8.0.tar.gz", hash = "sha256:8bb6494d4a20423842e198980c9ecf9f96607a07ea29549e180eef9ae80fe7af"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "pyarrow" version = "17.0.0" @@ -1020,6 +1164,69 @@ files = [ {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, ] +[[package]] +name = "pyyaml" +version = "6.0.2" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, +] + [[package]] name = "requests" version = "2.32.3" @@ -1170,10 +1377,31 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "virtualenv" +version = "20.31.2" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "virtualenv-20.31.2-py3-none-any.whl", hash = "sha256:36efd0d9650ee985f0cad72065001e66d49a6f24eb44d98980f630686243cf11"}, + {file = "virtualenv-20.31.2.tar.gz", hash = "sha256:e10c0a9d02835e592521be48b332b6caee6887f332c111aa79a09b9e79efc2af"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"GraalVM\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] + [extras] pyarrow = ["pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "0305d9a30397e4baa3d02d0a920989a901ba08749b93bd1c433886f151ed2cdc" +content-hash = "6d06d9516129c459ebbd7bc22538e0e3061dc25035347405ea5308a360c6857a" diff --git a/pylintrc b/pylintrc new file mode 100644 index 000000000..1a339592c --- /dev/null +++ b/pylintrc @@ -0,0 +1,22 @@ +[MASTER] +ignore=thrift_api + +[MESSAGES CONTROL] +disable=too-many-arguments, + too-many-locals, + too-many-public-methods, + too-many-branches, + too-many-statements, + fixme, + missing-docstring, + line-too-long, + too-few-public-methods, + too-many-instance-attributes, + too-many-lines + +[FORMAT] +max-line-length=100 + +[REPORTS] +output-format=text +reports=yes diff --git a/pyproject.toml b/pyproject.toml index 9b862d7ac..0a36ee54d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ mypy = "^1.10.1" pylint = ">=2.12.0" black = "^22.3.0" pytest-dotenv = "^0.5.2" +pre-commit = {version = "^3.5.0"} numpy = [ { version = ">=1.16.6", python = ">=3.8,<3.11" }, { version = ">=1.23.4", python = ">=3.11" }, @@ -55,6 +56,7 @@ ignore_missing_imports = "true" exclude = ['ttypes\.py$', 'TCLIService\.py$'] [tool.black] +line-length = 100 exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|thrift_api)/' [tool.pytest.ini_options] diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index d3af2f5c8..1118be1d6 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -38,8 +38,7 @@ def filter(self, record): ) else: record.args = tuple( - (self.redact(arg) if isinstance(arg, str) else arg) - for arg in record.args + (self.redact(arg) if isinstance(arg, str) else arg) for arg in record.args ) return True diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 26c1f3708..ba5cf2141 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -68,9 +68,7 @@ def __init__( try: idp_endpoint = get_oauth_endpoints(hostname, auth_type == "azure-oauth") if not idp_endpoint: - raise NotImplementedError( - f"OAuth is not supported for host ${hostname}" - ) + raise NotImplementedError(f"OAuth is not supported for host ${hostname}") # Convert to the corresponding scopes in the corresponding IdP cloud_scopes = idp_endpoint.get_scopes_mapping(scopes) @@ -179,9 +177,7 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider): AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/" DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token" - DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = ( - "X-Databricks-Azure-Workspace-Resource-Id" - ) + DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = "X-Databricks-Azure-Workspace-Resource-Id" def __init__( self, @@ -195,9 +191,7 @@ def __init__( self.azure_client_id = azure_client_id self.azure_client_secret = azure_client_secret self.azure_workspace_resource_id = azure_workspace_resource_id - self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host( - hostname - ) + self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(hostname) def auth_type(self) -> str: return AuthType.AZURE_SP_M2M.value @@ -211,9 +205,7 @@ def get_token_source(self, resource: str) -> RefreshableTokenSource: ) def __call__(self, *args, **kwargs) -> HeaderFactory: - inner = self.get_token_source( - resource=get_effective_azure_login_app_id(self.hostname) - ) + inner = self.get_token_source(resource=get_effective_azure_login_app_id(self.hostname)) cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE) def header_factory() -> Dict[str, str]: diff --git a/src/databricks/sql/auth/endpoint.py b/src/databricks/sql/auth/endpoint.py index 5cb26ae3e..85530192b 100644 --- a/src/databricks/sql/auth/endpoint.py +++ b/src/databricks/sql/auth/endpoint.py @@ -58,9 +58,7 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]: def is_supported_databricks_oauth_host(hostname: str) -> bool: host = hostname.lower().replace("https://", "").split("/")[0] - domains = ( - DATABRICKS_AWS_DOMAINS + DATABRICKS_GCP_DOMAINS + DATABRICKS_OAUTH_AZURE_DOMAINS - ) + domains = DATABRICKS_AWS_DOMAINS + DATABRICKS_GCP_DOMAINS + DATABRICKS_OAUTH_AZURE_DOMAINS return any(e for e in domains if host.endswith(e)) @@ -106,7 +104,9 @@ def get_authorization_url(self, hostname: str): return f"{get_databricks_oidc_url(hostname)}/oauth2/v2.0/authorize" def get_openid_config_url(self, hostname: str): - return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration" + return ( + "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration" + ) class InHouseOAuthEndpointCollection(OAuthEndpointCollection): @@ -123,9 +123,7 @@ def get_openid_config_url(self, hostname: str): return f"{idp_url}/.well-known/oauth-authorization-server" -def get_oauth_endpoints( - hostname: str, use_azure_auth: bool -) -> Optional[OAuthEndpointCollection]: +def get_oauth_endpoints(hostname: str, use_azure_auth: bool) -> Optional[OAuthEndpointCollection]: cloud = infer_cloud_from_host(hostname) if cloud in [CloudType.AWS, CloudType.GCP]: diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index aa3184d88..e84442e32 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -41,9 +41,7 @@ def __init__(self, access_token: str, token_type: str, refresh_token: str): def is_expired(self) -> bool: try: - decoded_token = jwt.decode( - self.access_token, options={"verify_signature": False} - ) + decoded_token = jwt.decode(self.access_token, options={"verify_signature": False}) exp_time = decoded_token.get("exp") current_time = time.time() buffer_time = 30 # 30 seconds buffer @@ -134,9 +132,7 @@ def __fetch_well_known_config(self, hostname: str): def __get_challenge(): verifier_string = OAuthManager.__token_urlsafe(32) digest = hashlib.sha256(verifier_string.encode("UTF-8")).digest() - challenge_string = ( - base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "") - ) + challenge_string = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "") return verifier_string, challenge_string def __get_authorization_code(self, client, auth_url, scope, state, challenge): @@ -158,9 +154,7 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge): logger.info(f"Opening {auth_req_uri}") webbrowser.open_new(auth_req_uri) - logger.info( - f"Listening for OAuth authorization callback at {redirect_url}" - ) + logger.info(f"Listening for OAuth authorization callback at {redirect_url}") httpd.handle_request() self.redirect_port = port break @@ -182,9 +176,7 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge): raise RuntimeError(msg) # This is a kludge because the parsing library expects https callbacks # We should probably set it up using https - full_redirect_url = ( - f"https://localhost:{self.redirect_port}/{handler.request_path}" - ) + full_redirect_url = f"https://localhost:{self.redirect_port}/{handler.request_path}" try: authorization_code_response = client.parse_request_uri_response( full_redirect_url, state=state @@ -197,9 +189,7 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge): def __send_auth_code_token_request( self, client, token_request_url, redirect_url, code, verifier ): - token_request_body = client.prepare_request_body( - code=code, redirect_uri=redirect_url - ) + token_request_body = client.prepare_request_body(code=code, redirect_uri=redirect_url) data = f"{token_request_body}&code_verifier={verifier}" return self.__send_token_request(token_request_url, data) @@ -227,15 +217,11 @@ def __send_refresh_token_request(self, hostname, refresh_token): def __get_tokens_from_response(oauth_response): access_token = oauth_response["access_token"] refresh_token = ( - oauth_response["refresh_token"] - if "refresh_token" in oauth_response - else None + oauth_response["refresh_token"] if "refresh_token" in oauth_response else None ) return access_token, refresh_token - def check_and_refresh_access_token( - self, hostname: str, access_token: str, refresh_token: str - ): + def check_and_refresh_access_token(self, hostname: str, access_token: str, refresh_token: str): now = datetime.now(tz=timezone.utc) # If we can't decode an expiration time, this will be expired by default. expiration_time = now @@ -246,9 +232,7 @@ def check_and_refresh_access_token( # an unnecessary signature verification. access_token_payload = access_token.split(".")[1] # add padding - access_token_payload = access_token_payload + "=" * ( - -len(access_token_payload) % 4 - ) + access_token_payload = access_token_payload + "=" * (-len(access_token_payload) % 4) decoded = json.loads(base64.standard_b64decode(access_token_payload)) expiration_time = datetime.fromtimestamp(decoded["exp"], tz=timezone.utc) except Exception as e: @@ -265,13 +249,9 @@ def check_and_refresh_access_token( raise RuntimeError(msg) # Try to refresh using the refresh token - logger.debug( - f"Attempting to refresh OAuth access token that expired on {expiration_time}" - ) + logger.debug(f"Attempting to refresh OAuth access token that expired on {expiration_time}") oauth_response = self.__send_refresh_token_request(hostname, refresh_token) - fresh_access_token, fresh_refresh_token = self.__get_tokens_from_response( - oauth_response - ) + fresh_access_token, fresh_refresh_token = self.__get_tokens_from_response(oauth_response) return fresh_access_token, fresh_refresh_token, True def get_tokens(self, hostname: str, scope=None): @@ -285,9 +265,7 @@ def get_tokens(self, hostname: str, scope=None): client = oauthlib.oauth2.WebApplicationClient(self.client_id) try: - auth_response = self.__get_authorization_code( - client, auth_url, scope, state, challenge - ) + auth_response = self.__get_authorization_code(client, auth_url, scope, state, challenge) except OAuth2Error as e: msg = f"OAuth Authorization Error: {e.description}" logger.error(msg) @@ -359,6 +337,4 @@ def refresh(self) -> Token: oauth_response.refresh_token, ) else: - raise Exception( - f"Failed to get token: {response.status_code} {response.text}" - ) + raise Exception(f"Failed to get token: {response.status_code} {response.text}") diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 432ac687d..7d08aaaf1 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -167,9 +167,7 @@ def __private_init__( new_object.command_type = command_type return new_object - def new( - self, **urllib3_incremented_counters: typing.Any - ) -> "DatabricksRetryPolicy": + def new(self, **urllib3_incremented_counters: typing.Any) -> "DatabricksRetryPolicy": """This method is responsible for passing the entire Retry state to its next iteration. urllib3 calls Retry.new() between successive requests as part of its `.increment()` method @@ -435,9 +433,7 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: "Failed requests are retried by default per configured DatabricksRetryPolicy", ) - def is_retry( - self, method: str, status_code: int, has_retry_after: bool = False - ) -> bool: + def is_retry(self, method: str, status_code: int, has_retry_after: bool = False) -> bool: """ Called by urllib3 when determining whether or not to retry diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index f0daae162..bdc1e187a 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -199,9 +199,7 @@ def flush(self): self.headers = self.__resp.headers logger.info( - "HTTP Response with status code {}, message: {}".format( - self.code, self.message - ) + "HTTP Response with status code {}, message: {}".format(self.code, self.message) ) @staticmethod diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index b4cd78cf8..99fbf91f8 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -242,15 +242,11 @@ def read(self) -> Optional[OAuthToken]: self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) - auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) + auth_provider = get_python_sql_connector_auth_provider(server_hostname, **kwargs) self.server_telemetry_enabled = True self.client_telemetry_enabled = kwargs.get("enable_telemetry", False) - self.telemetry_enabled = ( - self.client_telemetry_enabled and self.server_telemetry_enabled - ) + self.telemetry_enabled = self.client_telemetry_enabled and self.server_telemetry_enabled user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: @@ -262,9 +258,7 @@ def read(self) -> Optional[OAuthToken]: ) if user_agent_entry: - useragent_header = "{}/{} ({})".format( - USER_AGENT_NAME, __version__, user_agent_entry - ) + useragent_header = "{}/{} ({})".format(USER_AGENT_NAME, __version__, user_agent_entry) else: useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) @@ -272,9 +266,7 @@ def read(self) -> Optional[OAuthToken]: self._ssl_options = SSLOptions( # Double negation is generally a bad thing, but we have to keep backward compatibility - tls_verify=not kwargs.get( - "_tls_no_verify", False - ), # by default - verify cert and host + tls_verify=not kwargs.get("_tls_no_verify", False), # by default - verify cert and host tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), tls_client_cert_file=kwargs.get("_tls_client_cert_file"), @@ -369,8 +361,7 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): if self.open: logger.debug( - "Closing unclosed connection for session " - "{}".format(self.get_session_id_hex()) + "Closing unclosed connection for session " "{}".format(self.get_session_id_hex()) ) try: self._close(close_cursors=False) @@ -453,13 +444,9 @@ def _close(self, close_cursors=True) -> None: logger.info("Session was closed by a prior request") except DatabaseError as e: if "Invalid SessionHandle" in str(e): - logger.warning( - f"Attempted to close session that was already closed: {e}" - ) + logger.warning(f"Attempted to close session that was already closed: {e}") else: - logger.warning( - f"Attempt to close session raised an exception at the server: {e}" - ) + logger.warning(f"Attempt to close session raised an exception at the server: {e}") except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") @@ -553,9 +540,7 @@ def _all_dbsql_parameters_are_named(self, params: List[TDbsqlParameter]) -> bool """Return True if all members of the list have a non-null .name attribute""" return all([i.name is not None for i in params]) - def _normalize_tparametersequence( - self, params: TParameterSequence - ) -> List[TDbsqlParameter]: + def _normalize_tparametersequence(self, params: TParameterSequence) -> List[TDbsqlParameter]: """Retains the same order as the input list.""" output: List[TDbsqlParameter] = [] @@ -567,12 +552,9 @@ def _normalize_tparametersequence( return output - def _normalize_tparameterdict( - self, params: TParameterDict - ) -> List[TDbsqlParameter]: + def _normalize_tparameterdict(self, params: TParameterDict) -> List[TDbsqlParameter]: return [ - dbsql_parameter_from_primitive(value=value, name=name) - for name, value in params.items() + dbsql_parameter_from_primitive(value=value, name=name) for name, value in params.items() ] def _normalize_tparametercollection( @@ -645,8 +627,7 @@ def _prepare_native_parameters( stmt = stmt output = [ - p.as_tspark_param(named=param_structure == ParameterStructure.NAMED) - for p in params + p.as_tspark_param(named=param_structure == ParameterStructure.NAMED) for p in params ] return stmt, output @@ -665,9 +646,7 @@ def _check_not_closed(self): session_id_hex=self.connection.get_session_id_hex(), ) - def _handle_staging_operation( - self, staging_allowed_local_path: Union[None, str, List[str]] - ): + def _handle_staging_operation(self, staging_allowed_local_path: Union[None, str, List[str]]): """Fetch the HTTP request instruction from a staging ingestion command and call the designated handler. @@ -685,9 +664,7 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) - abs_staging_allowed_local_paths = [ - os.path.abspath(i) for i in _staging_allowed_local_paths - ] + abs_staging_allowed_local_paths = [os.path.abspath(i) for i in _staging_allowed_local_paths] assert self.active_result_set is not None row = self.active_result_set.fetchone() @@ -716,9 +693,7 @@ def _handle_staging_operation( ) # May be real headers, or could be json string - headers = ( - json.loads(row.headers) if isinstance(row.headers, str) else row.headers - ) + headers = json.loads(row.headers) if isinstance(row.headers, str) else row.headers handler_args = { "presigned_url": row.presignedUrl, @@ -815,9 +790,7 @@ def _handle_staging_get( fp.write(r.content) @log_latency(StatementType.SQL) - def _handle_staging_remove( - self, presigned_url: str, headers: Optional[dict] = None - ): + def _handle_staging_remove(self, presigned_url: str, headers: Optional[dict] = None): """Make an HTTP DELETE request to the presigned_url""" r = requests.delete(url=presigned_url, headers=headers) @@ -866,9 +839,7 @@ def execute( :returns self """ - logger.debug( - "Cursor.execute(operation=%s, parameters=%s)", operation, parameters - ) + logger.debug("Cursor.execute(operation=%s, parameters=%s)", operation, parameters) param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: @@ -1007,9 +978,7 @@ def get_async_execution_result(self): operation_state = self.get_query_state() if operation_state == ttypes.TOperationState.FINISHED_STATE: - execute_response = self.thrift_backend.get_execution_result( - self.active_op_handle, self - ) + execute_response = self.thrift_backend.get_execution_result(self.active_op_handle, self) self.active_result_set = ResultSet( self.connection, execute_response, @@ -1259,8 +1228,7 @@ def cancel(self) -> None: self.thrift_backend.cancel_command(self.active_op_handle) else: logger.warning( - "Attempting to cancel a command, but there is no " - "currently executing command" + "Attempting to cancel a command, but there is no " "currently executing command" ) def close(self) -> None: @@ -1406,9 +1374,7 @@ def _convert_arrow_table(self, table): ResultRow = Row(*column_names) if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] + return [ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns())] # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types @@ -1455,11 +1421,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): + while n_remaining_rows > 0 and not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) results = pyarrow.concat_tables([results, partial_results]) @@ -1480,8 +1442,7 @@ def merge_columnar(self, result1, result2): raise ValueError("The columns in the results don't match") merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) + result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns) ] return ColumnTable(merged_result, result1.column_names) @@ -1497,11 +1458,7 @@ def fetchmany_columnar(self, size: int): n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows - ): + while n_remaining_rows > 0 and not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) results = self.merge_columnar(results, partial_results) @@ -1518,9 +1475,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): + if isinstance(results, ColumnTable) and isinstance(partial_results, ColumnTable): results = self.merge_columnar(results, partial_results) else: results = pyarrow.concat_tables([results, partial_results]) @@ -1529,10 +1484,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table # Valid only for metadata commands result set if isinstance(results, ColumnTable) and pyarrow: - data = { - name: col - for name, col in zip(results.column_names, results.column_table) - } + data = {name: col for name, col in zip(results.column_names, results.column_table)} return pyarrow.Table.from_pydict(data) return results diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 7e96cd323..182db2468 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -41,9 +41,7 @@ def __init__( self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) self._ssl_options = ssl_options - def get_next_downloaded_file( - self, next_row_offset: int - ) -> Union[DownloadedFile, None]: + def get_next_downloaded_file(self, next_row_offset: int) -> Union[DownloadedFile, None]: """ Get next file that starts at given offset. @@ -90,9 +88,7 @@ def _schedule_downloads(self): len(self._pending_links) > 0 ): link = self._pending_links.pop(0) - logger.debug( - "- start: {}, row count: {}".format(link.startRowOffset, link.rowCount) - ) + logger.debug("- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)) handler = ResultSetDownloadHandler( settings=self._downloadable_result_settings, link=link, diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 228e07d6c..4c4c3f854 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -86,9 +86,7 @@ def run(self) -> DownloadedFile: ) # Check if link is already expired or is expiring - ResultSetDownloadHandler._validate_link( - self.link, self.settings.link_expiry_buffer_secs - ) + ResultSetDownloadHandler._validate_link(self.link, self.settings.link_expiry_buffer_secs) session = requests.Session() session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) @@ -145,10 +143,7 @@ def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): but may expire before the file has fully downloaded. """ current_time = int(time.time()) - if ( - link.expiryTime <= current_time - or link.expiryTime - current_time <= expiry_buffer_secs - ): + if link.expiryTime <= current_time or link.expiryTime - current_time <= expiry_buffer_secs: raise Error("CloudFetch link has expired") @staticmethod diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 65235f630..8df013d5e 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -13,18 +13,14 @@ class Error(Exception): `context`: Optional extra context about the error. MUST be JSON serializable """ - def __init__( - self, message=None, context=None, session_id_hex=None, *args, **kwargs - ): + def __init__(self, message=None, context=None, session_id_hex=None, *args, **kwargs): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} error_name = self.__class__.__name__ if session_id_hex: - telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) + telemetry_client = TelemetryClientFactory.get_telemetry_client(session_id_hex) telemetry_client.export_failure_log(error_name, self.message) def __str__(self): diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/src/databricks/sql/experimental/oauth_persistence.py index 13a966126..d0849efa9 100644 --- a/src/databricks/sql/experimental/oauth_persistence.py +++ b/src/databricks/sql/experimental/oauth_persistence.py @@ -74,8 +74,6 @@ def read(self, hostname: str) -> Optional[OAuthToken]: ) logger.error(msg) raise Exception(msg) - return OAuthToken( - token_as_json["access_token"], token_as_json["refresh_token"] - ) + return OAuthToken(token_as_json["access_token"], token_as_json["refresh_token"]) except Exception as e: return None diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 0b0c564da..b9effc5ab 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -73,10 +73,7 @@ def get_execution_result(self) -> ExecutionResultFormat: return ExecutionResultFormat.FORMAT_UNSPECIFIED def get_retry_count(self) -> int: - if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy - ): + if hasattr(self.thrift_backend, "retry_policy") and self.thrift_backend.retry_policy: return len(self.thrift_backend.retry_policy.history) return 0 @@ -110,10 +107,7 @@ def get_execution_result(self) -> ExecutionResultFormat: return ExecutionResultFormat.FORMAT_UNSPECIFIED def get_retry_count(self) -> int: - if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy - ): + if hasattr(self.thrift_backend, "retry_policy") and self.thrift_backend.retry_policy: return len(self.thrift_backend.retry_policy.history) return 0 @@ -208,9 +202,7 @@ def _safe_call(func_to_call): retry_count=_safe_call(extractor.get_retry_count), ) - telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) + telemetry_client = TelemetryClientFactory.get_telemetry_client(session_id_hex) telemetry_client.export_latency_log( latency_ms=duration_ms, sql_execution_event=sql_exec_event, diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 5eb8c6ed0..99b70dbb3 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -160,9 +160,7 @@ def _export_event(self, event): with self._lock: self._events_batch.append(event) if len(self._events_batch) >= self._batch_size: - logger.debug( - "Batch size limit reached (%s), flushing events", self._batch_size - ) + logger.debug("Batch size limit reached (%s), flushing events", self._batch_size) self._flush() def _flush(self): @@ -314,9 +312,7 @@ class TelemetryClientFactory: It uses a thread pool to handle asynchronous operations. """ - _clients: Dict[ - str, BaseTelemetryClient - ] = {} # Map of session_id_hex -> BaseTelemetryClient + _clients: Dict[str, BaseTelemetryClient] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.RLock() # Thread safety for factory operations @@ -330,14 +326,10 @@ def _initialize(cls): if not cls._initialized: cls._clients = {} - cls._executor = ThreadPoolExecutor( - max_workers=10 - ) # Thread pool for async operations + cls._executor = ThreadPoolExecutor(max_workers=10) # Thread pool for async operations cls._install_exception_hook() cls._initialized = True - logger.debug( - "TelemetryClientFactory initialized with thread pool (max_workers=10)" - ) + logger.debug("TelemetryClientFactory initialized with thread pool (max_workers=10)") @classmethod def _install_exception_hook(cls): @@ -380,9 +372,7 @@ def initialize_telemetry_client( session_id_hex, ) if telemetry_enabled: - TelemetryClientFactory._clients[ - session_id_hex - ] = TelemetryClient( + TelemetryClientFactory._clients[session_id_hex] = TelemetryClient( telemetry_enabled=telemetry_enabled, session_id_hex=session_id_hex, auth_provider=auth_provider, @@ -390,9 +380,7 @@ def initialize_telemetry_client( executor=TelemetryClientFactory._executor, ) else: - TelemetryClientFactory._clients[ - session_id_hex - ] = NoopTelemetryClient() + TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail @@ -401,9 +389,7 @@ def initialize_telemetry_client( @staticmethod def get_telemetry_client(session_id_hex): """Get the telemetry client for a specific connection""" - return TelemetryClientFactory._clients.get( - session_id_hex, NoopTelemetryClient() - ) + return TelemetryClientFactory._clients.get(session_id_hex, NoopTelemetryClient()) @staticmethod def close(session_id_hex): @@ -411,20 +397,14 @@ def close(session_id_hex): with TelemetryClientFactory._lock: if ( - telemetry_client := TelemetryClientFactory._clients.pop( - session_id_hex, None - ) + telemetry_client := TelemetryClientFactory._clients.pop(session_id_hex, None) ) is not None: - logger.debug( - "Removing telemetry client for connection %s", session_id_hex - ) + logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() # Shutdown executor if no more clients if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: - logger.debug( - "No more telemetry clients, shutting down thread pool executor" - ) + logger.debug("No more telemetry clients, shutting down thread pool executor") try: TelemetryClientFactory._executor.shutdown(wait=True) except Exception as e: diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 78683ac31..292f647a1 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -152,13 +152,9 @@ def __init__( self.staging_allowed_local_path = staging_allowed_local_path self._initialize_retry_args(kwargs) - self._use_arrow_native_complex_types = kwargs.get( - "_use_arrow_native_complex_types", True - ) + self._use_arrow_native_complex_types = kwargs.get("_use_arrow_native_complex_types", True) self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True) - self._use_arrow_native_timestamps = kwargs.get( - "_use_arrow_native_timestamps", True - ) + self._use_arrow_native_timestamps = kwargs.get("_use_arrow_native_timestamps", True) # Cloud fetch self.max_download_threads = kwargs.get("max_download_threads", 10) @@ -233,15 +229,11 @@ def _initialize_retry_args(self, kwargs): given_or_default = type_(kwargs.get(key, default)) bound = _bound(min, max, given_or_default) setattr(self, key, bound) - logger.debug( - "retry parameter: {} given_or_default {}".format(key, given_or_default) - ) + logger.debug("retry parameter: {} given_or_default {}".format(key, given_or_default)) if bound != given_or_default: logger.warning( "Override out of policy retry parameter: " - + "{} given {}, restricted to {}".format( - key, given_or_default, bound - ) + + "{} given {}, restricted to {}".format(key, given_or_default, bound) ) # Fail on retry delay min > max; consider later adding fail on min > duration? @@ -272,9 +264,7 @@ def _extract_error_message_from_headers(headers): if THRIFT_ERROR_MESSAGE_HEADER in headers: err_msg = headers[THRIFT_ERROR_MESSAGE_HEADER] if DATABRICKS_ERROR_OR_REDIRECT_HEADER in headers: - if ( - err_msg - ): # We don't expect both to be set, but log both here just in case + if err_msg: # We don't expect both to be set, but log both here just in case err_msg = "Thriftserver error: {}, Databricks error: {}".format( err_msg, headers[DATABRICKS_ERROR_OR_REDIRECT_HEADER] ) @@ -294,10 +284,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): max_attempts = self._retry_stop_after_attempts_count max_duration_s = self._retry_stop_after_attempts_duration - if ( - error_info.retry_delay is not None - and elapsed + error_info.retry_delay > max_duration_s - ): + if error_info.retry_delay is not None and elapsed + error_info.retry_delay > max_duration_s: no_retry_reason = NoRetryReason.OUT_OF_TIME elif error_info.retry_delay is not None and attempt >= max_attempts: no_retry_reason = NoRetryReason.OUT_OF_ATTEMPTS @@ -393,9 +380,7 @@ def attempt_request(attempt): response = method(request) # We need to call type(response) here because thrift doesn't implement __name__ attributes for thrift responses - logger.debug( - "Received response: {}()".format(type(response).__name__) - ) + logger.debug("Received response: {}()".format(type(response).__name__)) unsafe_logger.debug("Received response: {}".format(response)) return response @@ -512,10 +497,7 @@ def _check_initial_namespace(self, catalog, schema, response): if not (catalog or schema): return - if ( - response.serverProtocolVersion - < ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4 - ): + if response.serverProtocolVersion < ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4: raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", @@ -532,10 +514,7 @@ def _check_initial_namespace(self, catalog, schema, response): def _check_session_configuration(self, session_configuration): # This client expects timetampsAsString to be false, so we do not allow users to modify that - if ( - session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() - != "false" - ): + if session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() != "false": raise Error( "Invalid session configuration: {} cannot be changed " "while using the Databricks SQL connector, it must be false not {}".format( @@ -548,18 +527,14 @@ def _check_session_configuration(self, session_configuration): def open_session(self, session_configuration, catalog, schema): try: self._transport.open() - session_configuration = { - k: str(v) for (k, v) in (session_configuration or {}).items() - } + session_configuration = {k: str(v) for (k, v) in (session_configuration or {}).items()} self._check_session_configuration(session_configuration) # We want to receive proper Timestamp arrow types. # We set it also in confOverlay in TExecuteStatementReq on a per query basic, # but it doesn't hurt to also set for the whole session. session_configuration[TIMESTAMP_AS_STRING_CONFIG] = "false" if catalog or schema: - initial_namespace = ttypes.TNamespace( - catalogName=catalog, schemaName=schema - ) + initial_namespace = ttypes.TNamespace(catalogName=catalog, schemaName=schema) else: initial_namespace = None @@ -574,9 +549,7 @@ def open_session(self, session_configuration, catalog, schema): self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) self._session_id_hex = ( - self.handle_to_hex_id(response.sessionHandle) - if response.sessionHandle - else None + self.handle_to_hex_id(response.sessionHandle) if response.sessionHandle else None ) return response except: @@ -590,9 +563,7 @@ def close_session(self, session_handle) -> None: finally: self._transport.close() - def _check_command_not_in_error_or_closed_state( - self, op_handle, get_operations_resp - ): + def _check_command_not_in_error_or_closed_state(self, op_handle, get_operations_resp): if get_operations_resp.operationState == ttypes.TOperationState.ERROR_STATE: if get_operations_resp.displayMessage: raise ServerOperationError( @@ -619,10 +590,7 @@ def _check_command_not_in_error_or_closed_state( "Command {} unexpectedly closed server side".format( op_handle and self.guid_to_hex_id(op_handle.operationId.guid) ), - { - "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid) - }, + {"operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid)}, session_id_hex=self._session_id_hex, ) @@ -732,8 +700,7 @@ def _col_to_description(col, session_id_hex=None): @staticmethod def _hive_schema_to_description(t_table_schema, session_id_hex=None): return [ - ThriftBackend._col_to_description(col, session_id_hex) - for col in t_table_schema.columns + ThriftBackend._col_to_description(col, session_id_hex) for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -878,8 +845,7 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): op_handle, initial_operation_status_resp ) operation_state = ( - initial_operation_status_resp - and initial_operation_status_resp.operationState + initial_operation_status_resp and initial_operation_status_resp.operationState ) while not operation_state or operation_state in [ ttypes.TOperationState.RUNNING_STATE, @@ -983,9 +949,7 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): req = ttypes.TGetCatalogsReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), ) resp = self.make_request(self._client.GetCatalogs, req) return self._handle_execute_response(resp, cursor) @@ -1003,9 +967,7 @@ def get_schemas( req = ttypes.TGetSchemasReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), catalogName=catalog_name, schemaName=schema_name, ) @@ -1027,9 +989,7 @@ def get_tables( req = ttypes.TGetTablesReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), catalogName=catalog_name, schemaName=schema_name, tableName=table_name, @@ -1053,9 +1013,7 @@ def get_columns( req = ttypes.TGetColumnsReq( sessionHandle=session_handle, - getDirectResults=ttypes.TSparkGetDirectResults( - maxRows=max_rows, maxBytes=max_bytes - ), + getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes), catalogName=catalog_name, schemaName=schema_name, tableName=table_name, @@ -1135,9 +1093,7 @@ def close_command(self, op_handle): def cancel_command(self, active_op_handle): logger.debug( - "Cancelling command {}".format( - self.guid_to_hex_id(active_op_handle.operationId.guid) - ) + "Cancelling command {}".format(self.guid_to_hex_id(active_op_handle.operationId.guid)) ) req = ttypes.TCancelOperationReq(active_op_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 233808777..2238d5943 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -81,9 +81,7 @@ def build_queue( arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes ) - converted_arrow_table = convert_decimals_in_arrow_table( - arrow_table, description - ) + converted_arrow_table = convert_decimals_in_arrow_table(arrow_table, description) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: column_table, column_names = convert_column_based_set_to_column_table( @@ -135,10 +133,7 @@ def slice(self, curr_index, length): return ColumnTable(sliced_column_table, self.column_names) def __eq__(self, other): - return ( - self.column_table == other.column_table - and self.column_names == other.column_names - ) + return self.column_table == other.column_table and self.column_names == other.column_names class ColumnQueue(ResultSetQueue): @@ -155,9 +150,7 @@ def next_n_rows(self, num_rows): return slice def remaining_rows(self): - slice = self.column_table.slice( - self.cur_row_index, self.n_valid_rows - self.cur_row_index - ) + slice = self.column_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index) self.cur_row_index += slice.num_rows return slice @@ -193,9 +186,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": return slice def remaining_rows(self) -> "pyarrow.Table": - slice = self.arrow_table.slice( - self.cur_row_index, self.n_valid_rows - self.cur_row_index - ) + slice = self.arrow_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index) self.cur_row_index += slice.num_rows return slice @@ -310,14 +301,10 @@ def remaining_rows(self) -> "pyarrow.Table": def _create_next_table(self) -> Union["pyarrow.Table", None]: logger.debug( - "CloudFetchQueue: Trying to get downloaded file for row {}".format( - self.start_row_index - ) + "CloudFetchQueue: Trying to get downloaded file for row {}".format(self.start_row_index) ) # Create next table by retrieving the logical next downloaded file, or return None to signal end of queue - downloaded_file = self.download_manager.get_next_downloaded_file( - self.start_row_index - ) + downloaded_file = self.download_manager.get_next_downloaded_file(self.start_row_index) if not downloaded_file: logger.debug( "CloudFetchQueue: Cannot find downloaded file for row {}".format( @@ -383,9 +370,7 @@ class NoRetryReason(Enum): class RequestErrorInfo( - namedtuple( - "RequestErrorInfo_", "error error_message retry_delay http_code method request" - ) + namedtuple("RequestErrorInfo_", "error error_message retry_delay http_code method request") ): @property def request_session_id(self): @@ -415,9 +400,7 @@ def full_info_logging_context( ] ) - log_base_data_dict["no-retry-reason"] = ( - no_retry_reason and no_retry_reason.value - ) + log_base_data_dict["no-retry-reason"] = no_retry_reason and no_retry_reason.value log_base_data_dict["bounded-retry-delay"] = self.retry_delay log_base_data_dict["attempt"] = "{}/{}".format(attempt, max_attempts) log_base_data_dict["elapsed-seconds"] = "{}/{}".format(elapsed, max_duration) @@ -451,9 +434,7 @@ def escape_args(self, parameters): elif isinstance(parameters, (list, tuple)): return tuple(self.escape_item(x) for x in parameters) else: - raise exc.ProgrammingError( - "Unsupported param format: {}".format(parameters) - ) + raise exc.ProgrammingError("Unsupported param format: {}".format(parameters)) def escape_number(self, item): return item @@ -537,9 +518,7 @@ def _may_contain_inline_positional_markers(operation: str) -> bool: return interpolated != operation -def _interpolate_named_markers( - operation: str, parameters: List[TDbsqlParameter] -) -> str: +def _interpolate_named_markers(operation: str, parameters: List[TDbsqlParameter]) -> str: """Replace all instances of `%(param)s` in `operation` with `:param`. If `operation` contains no instances of `%(param)s` then the input string is returned unchanged. @@ -590,9 +569,8 @@ def transform_paramstyle( str """ output = operation - if ( - param_structure == ParameterStructure.POSITIONAL - and _may_contain_inline_positional_markers(operation) + if param_structure == ParameterStructure.POSITIONAL and _may_contain_inline_positional_markers( + operation ): logger.warning( "It looks like this query may contain un-named query markers like `%s`" @@ -605,9 +583,7 @@ def transform_paramstyle( return output -def create_arrow_table_from_arrow_file( - file_bytes: bytes, description -) -> "pyarrow.Table": +def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table": arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) @@ -625,11 +601,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema n_rows = 0 for arrow_batch in arrow_batches: n_rows += arrow_batch.rowCount - ba += ( - lz4.frame.decompress(arrow_batch.batch) - if lz4_compressed - else arrow_batch.batch - ) + ba += lz4.frame.decompress(arrow_batch.batch) if lz4_compressed else arrow_batch.batch arrow_table = pyarrow.ipc.open_stream(ba).read_all() return arrow_table, n_rows @@ -667,17 +639,13 @@ def convert_to_assigned_datatypes_in_column_table(column_table, description): converted_column_table = [] for i, col in enumerate(column_table): if description[i][1] == "decimal": - converted_column_table.append( - tuple(v if v is None else Decimal(v) for v in col) - ) + converted_column_table.append(tuple(v if v is None else Decimal(v) for v in col)) elif description[i][1] == "date": converted_column_table.append( tuple(v if v is None else datetime.date.fromisoformat(v) for v in col) ) elif description[i][1] == "timestamp": - converted_column_table.append( - tuple((v if v is None else parser.parse(v)) for v in col) - ) + converted_column_table.append(tuple((v if v is None else parser.parse(v)) for v in col)) else: converted_column_table.append(col) diff --git a/tests/e2e/common/core_tests.py b/tests/e2e/common/core_tests.py index 3f0fdc05d..e89289efc 100644 --- a/tests/e2e/common/core_tests.py +++ b/tests/e2e/common/core_tests.py @@ -4,18 +4,15 @@ TypeFailure = namedtuple( "TypeFailure", - "query,columnType,resultType,resultValue," - "actualValue,actualType,description,conf", + "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf", ) ResultFailure = namedtuple( "ResultFailure", - "query,columnType,resultType,resultValue," - "actualValue,actualType,description,conf", + "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf", ) ExecFailure = namedtuple( "ExecFailure", - "query,columnType,resultType,resultValue," - "actualValue,actualType,description,conf,error", + "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf,error", ) @@ -61,9 +58,7 @@ def run_tests_on_queries(self, default_conf): for query, columnType, rowValueType, answer in self.range_queries: with self.cursor(default_conf) as cursor: failures.extend( - self.run_query( - cursor, query, columnType, rowValueType, answer, default_conf - ) + self.run_query(cursor, query, columnType, rowValueType, answer, default_conf) ) failures.extend( self.run_range_query( @@ -74,9 +69,7 @@ def run_tests_on_queries(self, default_conf): for query, columnType, rowValueType, answer in self.queries: with self.cursor(default_conf) as cursor: failures.extend( - self.run_query( - cursor, query, columnType, rowValueType, answer, default_conf - ) + self.run_query(cursor, query, columnType, rowValueType, answer, default_conf) ) if failures: @@ -91,9 +84,7 @@ def run_query(self, cursor, query, columnType, rowValueType, answer, conf): try: cursor.execute(full_query) (result,) = cursor.fetchone() - if not all( - cursor.description[0][1] == type for type in expected_column_types - ): + if not all(cursor.description[0][1] == type for type in expected_column_types): return [ TypeFailure( full_query, @@ -159,10 +150,7 @@ def run_range_query(self, cursor, query, columnType, rowValueType, expected, con if len(rows) <= 0: break for index, (result, id) in enumerate(rows): - if not all( - cursor.description[0][1] == type - for type in expected_column_types - ): + if not all(cursor.description[0][1] == type for type in expected_column_types): return [ TypeFailure( full_query, @@ -175,10 +163,7 @@ def run_range_query(self, cursor, query, columnType, rowValueType, expected, con conf, ) ] - if ( - self.validate_row_value_type - and type(result) is not rowValueType - ): + if self.validate_row_value_type and type(result) is not rowValueType: return [ TypeFailure( full_query, diff --git a/tests/e2e/common/decimal_tests.py b/tests/e2e/common/decimal_tests.py index 0029f30cb..cf94a46dc 100644 --- a/tests/e2e/common/decimal_tests.py +++ b/tests/e2e/common/decimal_tests.py @@ -38,9 +38,7 @@ class DecimalTestsMixin: ), ] - @pytest.mark.parametrize( - "decimal, expected_value, expected_type", decimal_and_expected_results - ) + @pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results) def test_decimals(self, decimal, expected_value, expected_type): with self.cursor({}) as cursor: query = "SELECT CAST ({})".format(decimal) @@ -54,9 +52,7 @@ def test_decimals(self, decimal, expected_value, expected_type): ) def test_multi_decimals(self, decimals, expected_values, expected_type): with self.cursor({}) as cursor: - union_str = " UNION ".join( - ["(SELECT CAST ({}))".format(dec) for dec in decimals] - ) + union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals]) query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str) cursor.execute(query) diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index 1181ef154..b96e45b4b 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -36,9 +36,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): num_fetches = max(math.ceil(n / 10000), 1) latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 print( - "Fetched {} rows with an avg latency of {} per fetch, ".format( - n, latency_ms - ) + "Fetched {} rows with an avg latency of {} per fetch, ".format(n, latency_ms) + "assuming 10K fetch size." ) @@ -57,14 +55,10 @@ def test_query_with_large_wide_result_set(self): cursor.connection.lz4_compression = lz4_compression uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) cursor.execute( - "SELECT id, {uuids} FROM RANGE({rows})".format( - uuids=uuids, rows=rows - ) + "SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows) ) assert lz4_compression == cursor.active_result_set.lz4_compressed - for row_id, row in enumerate( - self.fetch_rows(cursor, rows, fetchmany_size) - ): + for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): assert row[0] == row_id # Verify no rows are dropped in the middle. assert len(row[1]) == 36 diff --git a/tests/e2e/common/predicates.py b/tests/e2e/common/predicates.py index 61de69fd3..bdafd0c2c 100644 --- a/tests/e2e/common/predicates.py +++ b/tests/e2e/common/predicates.py @@ -97,9 +97,7 @@ def test_some_pyhive_v1_stuff(): def validate_version(version): v = parse_version(str(version)) # assert that we get a PEP-440 Version back -- LegacyVersion doesn't have major/minor. - assert hasattr(v, "major"), ( - 'Module has incompatible "Legacy" version: ' + version - ) + assert hasattr(v, "major"), 'Module has incompatible "Legacy" version: ' + version return (v.major, v.minor, v.micro) mod_version = validate_version(module.__version__) diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b5d01a45d..870997b1c 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -99,9 +99,7 @@ def mock_sequential_server_responses(responses: List[dict]): # Each resp should have these members: for resp in responses: - _mock = MagicMock( - headers=resp["headers"], msg=resp["headers"], status=resp["status"] - ) + _mock = MagicMock(headers=resp["headers"], msg=resp["headers"], status=resp["status"]) _mock.get_redirect_location.return_value = ( False if resp["redirect_location"] is None else resp["redirect_location"] ) @@ -180,9 +178,7 @@ def test_retry_exponential_backoff(self): retry_policy["_retry_delay_min"] = 1 time_start = time.time() - with mocked_server_response( - status=429, headers={"Retry-After": "8"} - ) as mock_obj: + with mocked_server_response(status=429, headers={"Retry-After": "8"}) as mock_obj: with pytest.raises(RequestError) as cm: with self.connection(extra_params=retry_policy) as conn: pass @@ -262,9 +258,7 @@ def test_retry_dangerous_codes(self): assert isinstance(cm.value.args[1], UnsafeToRetryError) # Prove that these codes are retried if forced by the user - with self.connection( - extra_params={**self._retry_policy, **additional_settings} - ) as conn: + with self.connection(extra_params={**self._retry_policy, **additional_settings}) as conn: with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: with mocked_server_response(status=dangerous_code): @@ -334,9 +328,7 @@ def test_retry_abort_close_operation_on_404(self, caplog): curs.execute("SELECT 1") with mock_sequential_server_responses(responses): curs.close() - assert ( - "Operation was canceled by a prior request" in caplog.text - ) + assert "Operation was canceled by a prior request" in caplog.text def test_retry_max_redirects_raises_too_many_redirects_exception(self): """GIVEN the connector is configured with a custom max_redirects @@ -347,9 +339,7 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): max_redirects, expected_call_count = 1, 2 # Code 302 is a redirect - with mocked_server_response( - status=302, redirect_location="/foo.bar" - ) as mock_obj: + with mocked_server_response(status=302, redirect_location="/foo.bar") as mock_obj: with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ @@ -371,9 +361,7 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self): _stop_after_attempts_count is enforced. """ # Code 302 is a redirect - with mocked_server_response( - status=302, redirect_location="/foo.bar/" - ) as mock_obj: + with mocked_server_response(status=302, redirect_location="/foo.bar/") as mock_obj: with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ @@ -399,9 +387,7 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): with pytest.raises(RequestError) as cm: with mock_sequential_server_responses(responses): - with self.connection( - extra_params={**self._retry_policy, **additional_settings} - ): + with self.connection(extra_params={**self._retry_policy, **additional_settings}): pass # The error should be the result of the 500, not because of too many requests. @@ -421,12 +407,8 @@ def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog) assert "it will have no affect!" in caplog.text def test_retry_legacy_behavior_warns_user(self, caplog): - with self.connection( - extra_params={**self._retry_policy, "_enable_v3_retries": False} - ): - assert ( - "Legacy retry behavior is enabled for this connection." in caplog.text - ) + with self.connection(extra_params={**self._retry_policy, "_enable_v3_retries": False}): + assert "Legacy retry behavior is enabled for this connection." in caplog.text def test_403_not_retried(self): """GIVEN the server returns a code 403 diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 825f830f3..e63a06d58 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -41,9 +41,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): with open(fh, "wb") as fp: fp.write(original_text) - with self.connection( - extra_params={"staging_allowed_local_path": temp_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' OVERWRITE" @@ -53,9 +51,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): new_fh, new_temp_path = tempfile.mkstemp() - with self.connection( - extra_params={"staging_allowed_local_path": new_temp_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": new_temp_path}) as conn: cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -75,19 +71,17 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail - with pytest.raises( - Error, match="Staging operation over HTTP was unsuccessful: 404" - ): + with pytest.raises(Error, match="Staging operation over HTTP was unsuccessful: 404"): cursor = conn.cursor() - query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" + query = ( + f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" + ) cursor.execute(query) os.remove(temp_path) os.remove(new_temp_path) - def test_staging_ingestion_put_fails_without_staging_allowed_local_path( - self, ingestion_user - ): + def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self, ingestion_user): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path """ @@ -99,9 +93,7 @@ def test_staging_ingestion_put_fails_without_staging_allowed_local_path( with open(fh, "wb") as fp: fp.write(original_text) - with pytest.raises( - Error, match="You must provide at least one staging_allowed_local_path" - ): + with pytest.raises(Error, match="You must provide at least one staging_allowed_local_path"): with self.connection() as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" @@ -127,16 +119,12 @@ def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_p Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): - with self.connection( - extra_params={"staging_allowed_local_path": base_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set( - self, ingestion_user - ): + def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set(self, ingestion_user): """PUT a file into the staging location twice. First command should succeed. Second should fail.""" fh, temp_path = tempfile.mkstemp() @@ -147,22 +135,16 @@ def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set( fp.write(original_text) def perform_put(): - with self.connection( - extra_params={"staging_allowed_local_path": temp_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" cursor.execute(query) def perform_remove(): try: - remove_query = ( - f"REMOVE 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" - ) + remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" - with self.connection( - extra_params={"staging_allowed_local_path": "/"} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: cursor = conn.cursor() cursor.execute(remove_query) except Exception: @@ -196,9 +178,7 @@ def test_staging_ingestion_fails_to_modify_another_staging_user(self): fp.write(original_text) def perform_put(): - with self.connection( - extra_params={"staging_allowed_local_path": temp_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv' OVERWRITE" cursor.execute(query) @@ -206,16 +186,12 @@ def perform_put(): def perform_remove(): remove_query = f"REMOVE 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv'" - with self.connection( - extra_params={"staging_allowed_local_path": "/"} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: cursor = conn.cursor() cursor.execute(remove_query) def perform_get(): - with self.connection( - extra_params={"staging_allowed_local_path": temp_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() query = f"GET 'stage://tmp/{some_other_user}/tmp/11/15/file1.csv' TO '{temp_path}'" cursor.execute(query) @@ -256,9 +232,7 @@ def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowe query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_empty_local_path_fails_to_parse_at_server( - self, ingestion_user - ): + def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self, ingestion_user): staging_allowed_local_path = "/var/www/html" target_file = "" @@ -270,9 +244,7 @@ def test_staging_ingestion_empty_local_path_fails_to_parse_at_server( query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_invalid_staging_path_fails_at_server( - self, ingestion_user - ): + def test_staging_ingestion_invalid_staging_path_fails_at_server(self, ingestion_user): staging_allowed_local_path = "/var/www/html" target_file = "index.html" @@ -306,9 +278,7 @@ def generate_file_and_path_and_queries(): original_text = "hello world!".encode("utf-8") fp.write(original_text) put_query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv' OVERWRITE" - remove_query = ( - f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv'" - ) + remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv'" return fh, temp_path, put_query, remove_query ( diff --git a/tests/e2e/common/timestamp_tests.py b/tests/e2e/common/timestamp_tests.py index 70ded7d00..5f7cd0281 100644 --- a/tests/e2e/common/timestamp_tests.py +++ b/tests/e2e/common/timestamp_tests.py @@ -48,24 +48,18 @@ def assertTimestampsEqual(self, result, expected): def multi_query(self, n_rows=10): row_sql = "SELECT " + ", ".join( - [ - "TIMESTAMP('{}')".format(ts) - for (ts, _) in self.timestamp_and_expected_results - ] + ["TIMESTAMP('{}')".format(ts) for (ts, _) in self.timestamp_and_expected_results] ) query = " UNION ALL ".join([row_sql for _ in range(n_rows)]) expected_matrix = [ - [dt for (_, dt) in self.timestamp_and_expected_results] - for _ in range(n_rows) + [dt for (_, dt) in self.timestamp_and_expected_results] for _ in range(n_rows) ] return query, expected_matrix def test_timestamps(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: for timestamp, expected in self.timestamp_and_expected_results: - cursor.execute( - "SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp) - ) + cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) result = cursor.fetchone()[0] self.assertTimestampsEqual(result, expected) diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 72e2f5020..cccb03a84 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -40,21 +40,19 @@ def test_uc_volume_life_cycle(self, catalog, schema): with open(fh, "wb") as fp: fp.write(original_text) - with self.connection( - extra_params={"staging_allowed_local_path": temp_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() - query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" + query = ( + f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" + ) cursor.execute(query) # GET should succeed new_fh, new_temp_path = tempfile.mkstemp() - with self.connection( - extra_params={"staging_allowed_local_path": new_temp_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": new_temp_path}) as conn: cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -74,9 +72,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail - with pytest.raises( - Error, match="Staging operation over HTTP was unsuccessful: 404" - ): + with pytest.raises(Error, match="Staging operation over HTTP was unsuccessful: 404"): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -84,9 +80,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): os.remove(temp_path) os.remove(new_temp_path) - def test_uc_volume_put_fails_without_staging_allowed_local_path( - self, catalog, schema - ): + def test_uc_volume_put_fails_without_staging_allowed_local_path(self, catalog, schema): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path """ @@ -98,9 +92,7 @@ def test_uc_volume_put_fails_without_staging_allowed_local_path( with open(fh, "wb") as fp: fp.write(original_text) - with pytest.raises( - Error, match="You must provide at least one staging_allowed_local_path" - ): + with pytest.raises(Error, match="You must provide at least one staging_allowed_local_path"): with self.connection() as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" @@ -126,16 +118,12 @@ def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path( Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): - with self.connection( - extra_params={"staging_allowed_local_path": base_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set( - self, catalog, schema - ): + def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set(self, catalog, schema): """PUT a file into the staging location twice. First command should succeed. Second should fail.""" fh, temp_path = tempfile.mkstemp() @@ -146,22 +134,16 @@ def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set( fp.write(original_text) def perform_put(): - with self.connection( - extra_params={"staging_allowed_local_path": temp_path} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" cursor.execute(query) def perform_remove(): try: - remove_query = ( - f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" - ) + remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" - with self.connection( - extra_params={"staging_allowed_local_path": "/"} - ) as conn: + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: cursor = conn.cursor() cursor.execute(remove_query) except Exception: @@ -230,9 +212,7 @@ def test_uc_volume_invalid_volume_path_fails_at_server(self, catalog, schema): query = f"PUT '{target_file}' INTO '/Volumes/RANDOMSTRINGOFCHARACTERS/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_supports_multiple_staging_allowed_local_path_values( - self, catalog, schema - ): + def test_uc_volume_supports_multiple_staging_allowed_local_path_values(self, catalog, schema): """staging_allowed_local_path may be either a path-like object or a list of path-like objects. This test confirms that two configured base paths: @@ -252,9 +232,7 @@ def generate_file_and_path_and_queries(): original_text = "hello world!".encode("utf-8") fp.write(original_text) put_query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv' OVERWRITE" - remove_query = ( - f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv'" - ) + remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv'" return fh, temp_path, put_query, remove_query ( diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index c8a3a0781..678562e2d 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -77,9 +77,7 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): ) def test_read_complex_types_as_string(self, field, table_fixture): """Confirms the return type of a complex type that is returned as a string""" - with self.cursor( - extra_params={"_use_arrow_native_complex_types": False} - ) as cursor: + with self.cursor(extra_params={"_use_arrow_native_complex_types": False}) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 042fcc10a..60c872a47 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -61,9 +61,7 @@ # manually decorate DecimalTestsMixin to need arrow support for name in loader.getTestCaseNames(DecimalTestsMixin, "test_"): fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( - fn - ) + decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")(fn) setattr(DecimalTestsMixin, name, decorated) @@ -74,9 +72,7 @@ class PySQLPytestTestCase: error_type = Error conf_to_disable_rate_limit_retries = {"_retry_stop_after_attempts_count": 1} - conf_to_disable_temporarily_unavailable_retries = { - "_retry_stop_after_attempts_count": 1 - } + conf_to_disable_temporarily_unavailable_retries = {"_retry_stop_after_attempts_count": 1} arraysize = 1000 buffer_size_bytes = 104857600 POLLING_INTERVAL = 2 @@ -114,9 +110,7 @@ def connection(self, extra_params=()): @contextmanager def cursor(self, extra_params=()): with self.connection(extra_params) as conn: - cursor = conn.cursor( - arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes - ) + cursor = conn.cursor(arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes) try: yield cursor finally: @@ -478,21 +472,15 @@ def test_escape_single_quotes(self): table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) # Test escape syntax directly cursor.execute( - "CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format( - table_name - ) - ) - cursor.execute( - "SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name) + "CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format(table_name) ) + cursor.execute("SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name)) rows = cursor.fetchall() assert rows[0]["col_1"] == "you're" # Test escape syntax in parameter cursor.execute( - "SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format( - table_name, table_name - ), + "SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), parameters={"var": "you're"}, ) rows = cursor.fetchall() @@ -521,9 +509,7 @@ def test_get_catalogs(self): cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description - assert catalogs_desc == [ - ("TABLE_CAT", "string", None, None, None, None, None) - ] + assert catalogs_desc == [("TABLE_CAT", "string", None, None, None, None, None)] @skipUnless(pysql_supports_arrow(), "arrow test need arrow support") def test_get_arrow(self): @@ -684,9 +670,7 @@ def test_socket_timeout_user_defined(self): def test_ssp_passthrough(self): for enable_ansi in (True, False): - with self.cursor( - {"session_configuration": {"ansi_mode": enable_ansi}} - ) as cursor: + with self.cursor({"session_configuration": {"ansi_mode": enable_ansi}}) as cursor: cursor.execute("SET ansi_mode") assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)] @@ -694,9 +678,7 @@ def test_ssp_passthrough(self): def test_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: for timestamp, expected in self.timestamp_and_expected_results: - cursor.execute( - "SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp) - ) + cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) arrow_table = cursor.fetchmany_arrow(1) if self.should_add_timezone(): ts_type = pyarrow.timestamp("us", tz="Etc/UTC") @@ -707,9 +689,7 @@ def test_timestamps_arrow(self): # To work consistently across different local timezones, we specify the timezone # of the expected result to # be UTC (what it should be by default on the server) - aware_timestamp = expected and expected.replace( - tzinfo=datetime.timezone.utc - ) + aware_timestamp = expected and expected.replace(tzinfo=datetime.timezone.utc) assert result_value == ( aware_timestamp and aware_timestamp.timestamp() * 1000000 ), "timestamp {} did not match {}".format(timestamp, expected) @@ -719,16 +699,14 @@ def test_multi_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: query, expected = self.multi_query() expected = [ - [self.maybe_add_timezone_to_timestamp(ts) for ts in row] - for row in expected + [self.maybe_add_timezone_to_timestamp(ts) for ts in row] for row in expected ] cursor.execute(query) table = cursor.fetchall_arrow() # Transpose columnar result to list of rows list_of_cols = [c.to_pylist() for c in table] result = [ - [col[row_index] for col in list_of_cols] - for row_index in range(table.num_rows) + [col[row_index] for col in list_of_cols] for row_index in range(table.num_rows) ] assert result == expected @@ -745,9 +723,7 @@ def test_timezone_with_timestamp(self): cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") arrow_result_table = cursor.fetchmany_arrow(1) - arrow_result_value = ( - arrow_result_table.column(0).combine_chunks()[0].value - ) + arrow_result_value = arrow_result_table.column(0).combine_chunks()[0].value ts_type = pyarrow.timestamp("us", tz="Europe/Amsterdam") assert arrow_result_table.field(0).type == ts_type @@ -818,9 +794,7 @@ class HTTP429Suite(Client429ResponseMixin, PySQLPytestTestCase): class HTTP503Suite(Client503ResponseMixin, PySQLPytestTestCase): # 503Response suite gets custom error here vs PyODBC def test_retry_disabled(self): - self._test_retry_disabled_with_message( - "TEMPORARILY_UNAVAILABLE", OperationalError - ) + self._test_retry_disabled_with_message("TEMPORARILY_UNAVAILABLE", OperationalError) class TestPySQLUnityCatalogSuite(PySQLPytestTestCase): diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 79def9b72..05676e357 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -179,8 +179,12 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column) :paramstyle: This is a no-op but is included to make the test-code easier to read. """ - INSERT_QUERY = f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" - SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" + INSERT_QUERY = ( + f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" + ) + SELECT_QUERY = ( + f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" + ) DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table" with self.connection(extra_params={"use_inline_params": True}) as conn: @@ -278,9 +282,7 @@ def _parse_to_common_type(self, value): """ if value is None: return None - elif isinstance(value, (Sequence, np.ndarray)) and not isinstance( - value, (str, bytes) - ): + elif isinstance(value, (Sequence, np.ndarray)) and not isinstance(value, (str, bytes)): return tuple(value) elif isinstance(value, dict): return tuple(value.items()) @@ -307,8 +309,7 @@ def _recursive_compare(self, actual, expected): if len(actual_parsed) != len(expected_parsed): return False return all( - self._recursive_compare(o1, o2) - for o1, o2 in zip(actual_parsed, expected_parsed) + self._recursive_compare(o1, o2) for o1, o2 in zip(actual_parsed, expected_parsed) ) return actual_parsed == expected_parsed @@ -385,15 +386,11 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) If a user explicitly sets use_inline_params, don't warn them about it. """ - extra_args = ( - {"use_inline_params": use_inline_params} if use_inline_params else {} - ) + extra_args = {"use_inline_params": use_inline_params} if use_inline_params else {} with self.connection(extra_params=extra_args) as conn: with conn.cursor() as cursor: - with self.patch_server_supports_native_params( - supports_native_params=True - ): + with self.patch_server_supports_native_params(supports_native_params=True): cursor.execute("SELECT %(p)s", parameters={"p": 1}) if use_inline_params is True: assert ( @@ -535,9 +532,7 @@ def test_inline_ordinals_can_break_sql(self): query = "SELECT 'samsonite', %s WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] with self.cursor(extra_params={"use_inline_params": True}) as cursor: - with pytest.raises( - TypeError, match="not enough arguments for format string" - ): + with pytest.raises(TypeError, match="not enough arguments for format string"): cursor.execute(query, parameters=params) def test_inline_named_dont_break_sql(self): diff --git a/tests/unit/test_arrow_queue.py b/tests/unit/test_arrow_queue.py index 6c195bf10..bd63b8cb5 100644 --- a/tests/unit/test_arrow_queue.py +++ b/tests/unit/test_arrow_queue.py @@ -18,21 +18,13 @@ def make_arrow_table(batch): return pa.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) def test_fetchmany_respects_n_rows(self): - arrow_table = self.make_arrow_table( - [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] - ) + arrow_table = self.make_arrow_table([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]) aq = ArrowQueue(arrow_table, 3) - self.assertEqual( - aq.next_n_rows(2), self.make_arrow_table([[0, 1, 2], [3, 4, 5]]) - ) + self.assertEqual(aq.next_n_rows(2), self.make_arrow_table([[0, 1, 2], [3, 4, 5]])) self.assertEqual(aq.next_n_rows(2), self.make_arrow_table([[6, 7, 8]])) def test_fetch_remaining_rows_respects_n_rows(self): - arrow_table = self.make_arrow_table( - [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] - ) + arrow_table = self.make_arrow_table([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]) aq = ArrowQueue(arrow_table, 3) self.assertEqual(aq.next_n_rows(1), self.make_arrow_table([[0, 1, 2]])) - self.assertEqual( - aq.remaining_rows(), self.make_arrow_table([[3, 4, 5], [6, 7, 8]]) - ) + self.assertEqual(aq.remaining_rows(), self.make_arrow_table([[3, 4, 5], [6, 7, 8]])) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 8bf914708..e6193006b 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -114,9 +114,7 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh): ) self.assertEqual(auth_provider.oauth_manager.port_range, [8020]) self.assertEqual(auth_provider.oauth_manager.client_id, client_id) - self.assertEqual( - oauth_persistence.read(host).refresh_token, refresh_token - ) + self.assertEqual(oauth_persistence.read(host).refresh_token, refresh_token) mock_get_tokens.assert_called_with(hostname=host, scope=expected_scopes) headers = {} @@ -184,9 +182,7 @@ def test_get_python_sql_connector_basic_auth(self): } with self.assertRaises(ValueError) as e: get_python_sql_connector_auth_provider("foo.cloud.databricks.com", **kwargs) - self.assertIn( - "Username/password authentication is no longer supported", str(e.exception) - ) + self.assertIn("Username/password authentication is no longer supported", str(e.exception)) @patch.object(DatabricksOAuthProvider, "_initial_get_token") def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): @@ -229,9 +225,7 @@ def token_source(self): client_secret="client_secret", ) - def test_no_token_refresh__when_token_is_not_expired( - self, token_source, indefinite_token - ): + def test_no_token_refresh__when_token_is_not_expired(self, token_source, indefinite_token): with patch.object(token_source, "refresh") as mock_get_token: mock_get_token.return_value = indefinite_token @@ -285,16 +279,11 @@ def test_provider_credentials(self, credential_provider): test_token = Token("access_token", "Bearer", "refresh_token") - with patch.object( - credential_provider, "get_token_source" - ) as mock_get_token_source: + with patch.object(credential_provider, "get_token_source") as mock_get_token_source: mock_get_token_source.return_value = MagicMock() mock_get_token_source.return_value.get_token.return_value = test_token headers = credential_provider()() assert headers["Authorization"] == f"Bearer {test_token.access_token}" - assert ( - headers["X-Databricks-Azure-SP-Management-Token"] - == test_token.access_token - ) + assert headers["X-Databricks-Azure-SP-Management-Token"] == test_token.access_token diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 44c84d790..3d681b1f0 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -185,9 +185,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): with self.subTest(closed=closed): # Set initial state based on whether the command is already closed initial_state = ( - TOperationState.FINISHED_STATE - if not closed - else TOperationState.CLOSED_STATE + TOperationState.FINISHED_STATE if not closed else TOperationState.CLOSED_STATE ) # Mock the execute response with controlled state @@ -209,9 +207,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): cursor = connection.cursor() # Mock execute_command to return our execute response - cursor.thrift_backend.execute_command = Mock( - return_value=mock_execute_response - ) + cursor.thrift_backend.execute_command = Mock(return_value=mock_execute_response) # Execute a command cursor.execute("SELECT 1") @@ -227,10 +223,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): assert active_result_set.has_been_closed_server_side is True # 2. op_state should always be CLOSED after close() - assert ( - active_result_set.op_state - == connection.thrift_backend.CLOSED_OP_STATE - ) + assert active_result_set.op_state == connection.thrift_backend.CLOSED_OP_STATE # 3. Backend close_command should be called appropriately if not closed: @@ -254,9 +247,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) - def test_arraysize_buffer_size_passthrough( - self, mock_cursor_class, mock_client_class - ): + def test_arraysize_buffer_size_passthrough(self, mock_cursor_class, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) connection.cursor(arraysize=999, buffer_size_bytes=1234) kwargs = mock_cursor_class.call_args[1] @@ -289,9 +280,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_thrift_backend = Mock() mock_results = Mock() mock_connection.open = True - result_set = client.ResultSet( - mock_connection, mock_results_response, mock_thrift_backend - ) + result_set = client.ResultSet(mock_connection, mock_results_response, mock_thrift_backend) result_set.results = mock_results result_set.close() @@ -302,16 +291,12 @@ def test_closing_result_set_hard_closes_commands(self): mock_results.close.assert_called_once() @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executing_multiple_commands_uses_the_most_recent_command( - self, mock_result_set_class - ): + def test_executing_multiple_commands_uses_the_most_recent_command(self, mock_result_set_class): mock_result_sets = [Mock(), Mock()] mock_result_set_class.side_effect = mock_result_sets - cursor = client.Cursor( - connection=Mock(), thrift_backend=ThriftBackendMockFactory.new() - ) + cursor = client.Cursor(connection=Mock(), thrift_backend=ThriftBackendMockFactory.new()) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -462,13 +447,9 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect( - _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS - ) + databricks.sql.connect(_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS) - self.assertEqual( - mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 - ) + self.assertEqual(mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): @@ -500,15 +481,9 @@ def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem) + self.assertEqual(mock_client_class.return_value.open_session.call_args[0][1], mock_cat) + self.assertEqual(mock_client_class.return_value.open_session.call_args[0][2], mock_schem) def test_execute_parameter_passthrough(self): mock_thrift_backend = ThriftBackendMockFactory.new() @@ -624,9 +599,7 @@ def test_disable_pandas_respected(self, mock_thrift_backend_class): mock_aq = Mock() mock_aq.remaining_rows.return_value = mock_table mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.execute_command.return_value.has_been_closed_server_side = ( - True - ) + mock_thrift_backend.execute_command.return_value.has_been_closed_server_side = True mock_con = Mock() mock_con.disable_pandas = True @@ -771,9 +744,7 @@ def mock_close_normal(): except Exception as e: self.fail(f"Connection close should handle exceptions: {e}") - self.assertEqual( - cursors_closed, [1, 2], "Both cursors should have close called" - ) + self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called") if __name__ == "__main__": diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 7dec4e680..3e120a689 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -20,9 +20,7 @@ def create_result_link( row_count: int = 8000, bytes_num: int = 20971520, ): - return TSparkArrowResultLink( - file_link, None, start_row_offset, row_count, bytes_num - ) + return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) def create_result_links(self, num_files: int, start_row_offset: int = 0): result_links = [] @@ -188,10 +186,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): assert result.num_rows == 7 assert queue.table_row_index == 3 assert ( - result - == pyarrow.concat_tables( - [self.make_arrow_table(), self.make_arrow_table()] - )[:7] + result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] ) @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") @@ -288,9 +283,7 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): assert result == self.make_arrow_table() @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") - def test_remaining_rows_multiple_tables_fully_returned( - self, mock_create_next_table - ): + def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [ self.make_arrow_table(), self.make_arrow_table(), @@ -312,10 +305,7 @@ def test_remaining_rows_multiple_tables_fully_returned( assert mock_create_next_table.call_count == 3 assert result.num_rows == 5 assert ( - result - == pyarrow.concat_tables( - [self.make_arrow_table(), self.make_arrow_table()] - )[3:] + result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[3:] ) @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) diff --git a/tests/unit/test_column_queue.py b/tests/unit/test_column_queue.py index 234af88ee..130b589b2 100644 --- a/tests/unit/test_column_queue.py +++ b/tests/unit/test_column_queue.py @@ -8,18 +8,14 @@ def make_column_table(table): return ColumnTable(table, [f"col_{i}" for i in range(n_cols)]) def test_fetchmany_respects_n_rows(self): - column_table = self.make_column_table( - [[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]] - ) + column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) column_queue = ColumnQueue(column_table) assert column_queue.next_n_rows(2) == column_table.slice(0, 2) assert column_queue.next_n_rows(2) == column_table.slice(2, 2) def test_fetch_remaining_rows_respects_n_rows(self): - column_table = self.make_column_table( - [[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]] - ) + column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) column_queue = ColumnQueue(column_table) assert column_queue.next_n_rows(2) == column_table.slice(0, 2) diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 64edbdebe..61519247d 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -11,9 +11,7 @@ class DownloadManagerTests(unittest.TestCase): Unit tests for checking download manager logic. """ - def create_download_manager( - self, links, max_download_threads=10, lz4_compressed=True - ): + def create_download_manager(self, links, max_download_threads=10, lz4_compressed=True): return download_manager.ResultFileDownloadManager( links, max_download_threads, @@ -28,9 +26,7 @@ def create_result_link( row_count: int = 8000, bytes_num: int = 20971520, ): - return TSparkArrowResultLink( - file_link, None, start_row_offset, row_count, bytes_num - ) + return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) def create_result_links(self, num_files: int, start_row_offset: int = 0): result_links = [] @@ -63,9 +59,7 @@ def test_add_file_links_success(self): def test_schedule_downloads(self, mock_submit): max_download_threads = 4 links = self.create_result_links(num_files=10) - manager = self.create_download_manager( - links, max_download_threads=max_download_threads - ) + manager = self.create_download_manager(links, max_download_threads=max_download_threads) manager._schedule_downloads() assert mock_submit.call_count == max_download_threads diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 2a3b715b5..962237c60 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -26,9 +26,7 @@ def test_run_link_expired(self, mock_time): result_link = Mock() # Already expired result_link.expiryTime = 999 - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(Error) as context: d.run() @@ -42,9 +40,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): result_link = Mock() # Within the expiry buffer time result_link.expiryTime = 1004 - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(Error) as context: d.run() @@ -62,9 +58,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): settings.use_proxy = False result_link = Mock(expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() self.assertTrue("404" in str(context.exception)) @@ -81,9 +75,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -104,9 +96,7 @@ def test_run_compressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -114,29 +104,21 @@ def test_run_compressed_successful(self, mock_time, mock_session): @patch("requests.Session.get", side_effect=ConnectionError("foo")) @patch("time.time", return_value=1000) def test_download_connection_error(self, mock_time, mock_session): - settings = Mock( - link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True - ) + settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) result_link = Mock(bytesNum=100, expiryTime=1001) mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(ConnectionError): d.run() @patch("requests.Session.get", side_effect=TimeoutError("foo")) @patch("time.time", return_value=1000) def test_download_timeout(self, mock_time, mock_session): - settings = Mock( - link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True - ) + settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) result_link = Mock(bytesNum=100, expiryTime=1001) mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() - ) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_endpoint.py b/tests/unit/test_endpoint.py index 1f7d7cddd..5a864dfb8 100644 --- a/tests/unit/test_endpoint.py +++ b/tests/unit/test_endpoint.py @@ -29,16 +29,12 @@ def test_infer_cloud_from_host(self): for expected_type, host in param_list: with self.subTest(expected_type or "None", expected_type=expected_type): self.assertEqual(infer_cloud_from_host(host), expected_type) - self.assertEqual( - infer_cloud_from_host(f"https://{host}/to/path"), expected_type - ) + self.assertEqual(infer_cloud_from_host(f"https://{host}/to/path"), expected_type) def test_oauth_endpoint(self): scopes = ["offline_access", "sql", "admin"] scopes2 = ["sql", "admin"] - azure_scope = ( - f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation" - ) + azure_scope = f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation" param_list = [ ( @@ -99,12 +95,8 @@ def test_oauth_endpoint(self): ) in param_list: with self.subTest(cloud_type): endpoint = get_oauth_endpoints(host, use_azure_auth) - self.assertEqual( - endpoint.get_authorization_url(host), expected_auth_url - ) - self.assertEqual( - endpoint.get_openid_config_url(host), expected_config_url - ) + self.assertEqual(endpoint.get_authorization_url(host), expected_auth_url) + self.assertEqual(endpoint.get_openid_config_url(host), expected_config_url) self.assertEqual(endpoint.get_scopes_mapping(scopes), expected_scopes) self.assertEqual(endpoint.get_scopes_mapping(scopes2), expected_scope2) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 71766f2cb..94e7bd473 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -22,9 +22,7 @@ def make_arrow_table(batch): n_cols = len(batch[0]) if batch else 0 schema = pa.schema({"col%s" % i: pa.uint32() for i in range(n_cols)}) cols = [[batch[row][col] for row in range(len(batch))] for col in range(n_cols)] - return schema, pa.Table.from_pydict( - dict(zip(schema.names, cols)), schema=schema - ) + return schema, pa.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) @staticmethod def make_arrow_queue(batch): @@ -54,8 +52,7 @@ def make_dummy_result_set_from_initial_results(initial_results): ) num_cols = len(initial_results[0]) if initial_results else 0 rs.description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) + (f"col{col_id}", "integer", None, None, None, None, None) for col_id in range(num_cols) ] return rs @@ -115,39 +112,29 @@ def test_fetchmany_with_initial_results(self): [2], [3], ] # This is a list of rows, each row with 1 col - dummy_result_set = self.make_dummy_result_set_from_initial_results( - initial_results_1 - ) + dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) self.assertEqualRowValues(dummy_result_set.fetchmany(3), [[1], [2], [3]]) # Fetch in small amounts initial_results_2 = [[1], [2], [3], [4]] - dummy_result_set = self.make_dummy_result_set_from_initial_results( - initial_results_2 - ) + dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_2) self.assertEqualRowValues(dummy_result_set.fetchmany(1), [[1]]) self.assertEqualRowValues(dummy_result_set.fetchmany(2), [[2], [3]]) self.assertEqualRowValues(dummy_result_set.fetchmany(1), [[4]]) # Fetch too many initial_results_3 = [[2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results( - initial_results_3 - ) + dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_3) self.assertEqualRowValues(dummy_result_set.fetchmany(5), [[2], [3]]) # Empty results initial_results_4 = [[]] - dummy_result_set = self.make_dummy_result_set_from_initial_results( - initial_results_4 - ) + dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_4) self.assertEqualRowValues(dummy_result_set.fetchmany(0), []) def test_fetch_many_without_initial_results(self): # Fetch all in one go; single batch - batch_list_1 = [ - [[1], [2], [3]] - ] # This is a list of one batch of rows, each row with 1 col + batch_list_1 = [[[1], [2], [3]]] # This is a list of one batch of rows, each row with 1 col dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_1) self.assertEqualRowValues(dummy_result_set.fetchmany(3), [[1], [2], [3]]) @@ -177,9 +164,7 @@ def test_fetch_many_without_initial_results(self): # Fetch too many; multiple batches batch_list_6 = [[[1]], [[2], [3], [4]], [[5], [6]]] dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_6) - self.assertEqualRowValues( - dummy_result_set.fetchmany(100), [[1], [2], [3], [4], [5], [6]] - ) + self.assertEqualRowValues(dummy_result_set.fetchmany(100), [[1], [2], [3], [4], [5], [6]]) # Fetch 0; 1 empty batch batch_list_7 = [[]] @@ -193,25 +178,19 @@ def test_fetch_many_without_initial_results(self): def test_fetchall_with_initial_results(self): initial_results_1 = [[1], [2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results( - initial_results_1 - ) + dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) self.assertEqualRowValues(dummy_result_set.fetchall(), [[1], [2], [3]]) def test_fetchall_without_initial_results(self): # Fetch all, single batch - batch_list_1 = [ - [[1], [2], [3]] - ] # This is a list of one batch of rows, each row with 1 col + batch_list_1 = [[[1], [2], [3]]] # This is a list of one batch of rows, each row with 1 col dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_1) self.assertEqualRowValues(dummy_result_set.fetchall(), [[1], [2], [3]]) # Fetch all, multiple batches batch_list_2 = [[[1], [2]], [[3]], [[4], [5], [6]]] dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2) - self.assertEqualRowValues( - dummy_result_set.fetchall(), [[1], [2], [3], [4], [5], [6]] - ) + self.assertEqualRowValues(dummy_result_set.fetchall(), [[1], [2], [3], [4], [5], [6]]) batch_list_3 = [[]] dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_3) @@ -219,16 +198,12 @@ def test_fetchall_without_initial_results(self): def test_fetchmany_fetchall_with_initial_results(self): initial_results_1 = [[1], [2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results( - initial_results_1 - ) + dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) self.assertEqualRowValues(dummy_result_set.fetchmany(2), [[1], [2]]) self.assertEqualRowValues(dummy_result_set.fetchall(), [[3]]) def test_fetchmany_fetchall_without_initial_results(self): - batch_list_1 = [ - [[1], [2], [3]] - ] # This is a list of one batch of rows, each row with 1 col + batch_list_1 = [[[1], [2], [3]]] # This is a list of one batch of rows, each row with 1 col dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_1) self.assertEqualRowValues(dummy_result_set.fetchmany(2), [[1], [2]]) self.assertEqualRowValues(dummy_result_set.fetchall(), [[3]]) @@ -240,9 +215,7 @@ def test_fetchmany_fetchall_without_initial_results(self): def test_fetchone_with_initial_results(self): initial_results_1 = [[1], [2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results( - initial_results_1 - ) + dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) self.assertSequenceEqual(dummy_result_set.fetchone(), [1]) self.assertSequenceEqual(dummy_result_set.fetchone(), [2]) self.assertSequenceEqual(dummy_result_set.fetchone(), [3]) diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 552872221..e282ef768 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -48,9 +48,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): ] return rs - @pytest.mark.skip( - reason="Test has not been updated for latest connector API (June 2022)" - ) + @pytest.mark.skip(reason="Test has not been updated for latest connector API (June 2022)") def test_benchmark_fetchall(self): print("preparing dummy arrow table") arrow_table = FetchBenchmarkTests.make_arrow_table(10, 25000) @@ -60,9 +58,7 @@ def test_benchmark_fetchall(self): start_time = time.time() count = 0 while time.time() < start_time + benchmark_seconds: - dummy_result_set = self.make_dummy_result_set_from_initial_results( - arrow_table - ) + dummy_result_set = self.make_dummy_result_set_from_initial_results(arrow_table) res = dummy_result_set.fetchall() for _ in res: pass diff --git a/tests/unit/test_param_escaper.py b/tests/unit/test_param_escaper.py index 9b6b9c246..4f3cf09b3 100644 --- a/tests/unit/test_param_escaper.py +++ b/tests/unit/test_param_escaper.py @@ -47,29 +47,23 @@ def test_escape_string_that_includes_special_characters(self): # Testing for the presence of these characters: '"/\😂 assert ( - pe.escape_string("his name was 'robert palmer'") - == r"'his name was \'robert palmer\''" + pe.escape_string("his name was 'robert palmer'") == r"'his name was \'robert palmer\''" ) # These tests represent the same user input in the several ways it can be written in Python # Each argument to `escape_string` evaluates to the same bytes. But Python lets us write it differently. assert ( - pe.escape_string('his name was "robert palmer"') - == "'his name was \"robert palmer\"'" + pe.escape_string('his name was "robert palmer"') == "'his name was \"robert palmer\"'" ) assert ( - pe.escape_string('his name was "robert palmer"') - == "'his name was \"robert palmer\"'" + pe.escape_string('his name was "robert palmer"') == "'his name was \"robert palmer\"'" ) assert ( pe.escape_string("his name was {}".format('"robert palmer"')) == "'his name was \"robert palmer\"'" ) - assert ( - pe.escape_string("his name was robert / palmer") - == r"'his name was robert / palmer'" - ) + assert pe.escape_string("his name was robert / palmer") == r"'his name was robert / palmer'" # If you need to include a single backslash, use an r-string to prevent Python from raising a # DeprecationWarning for an invalid escape sequence @@ -78,18 +72,14 @@ def test_escape_string_that_includes_special_characters(self): == r"'his name was robert \\/ palmer'" ) assert ( - pe.escape_string("his name was robert \\ palmer") - == r"'his name was robert \\ palmer'" + pe.escape_string("his name was robert \\ palmer") == r"'his name was robert \\ palmer'" ) assert ( pe.escape_string("his name was robert \\\\ palmer") == r"'his name was robert \\\\ palmer'" ) - assert ( - pe.escape_string("his name was robert palmer 😂") - == r"'his name was robert palmer 😂'" - ) + assert pe.escape_string("his name was robert palmer 😂") == r"'his name was robert palmer 😂'" # Adding the test from PR #56 to prove escape behaviour @@ -235,15 +225,10 @@ class TestInlineToNativeTransformer(object): ), ), ) - def test_transformer( - self, label: str, query: str, params: Dict[str, Any], expected: str - ): + def test_transformer(self, label: str, query: str, params: Dict[str, Any], expected: str): _params = [ - dbsql_parameter_from_primitive(value=value, name=name) - for name, value in params.items() + dbsql_parameter_from_primitive(value=value, name=name) for name, value in params.items() ] - output = transform_paramstyle( - query, _params, param_structure=ParameterStructure.NAMED - ) + output = transform_paramstyle(query, _params, param_structure=ParameterStructure.NAMED) assert output == expected diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index 249730789..6349856ba 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -141,9 +141,7 @@ class TestDbsqlParameter: (ArrayParameter, Primitive.ARRAY, "ARRAY"), ), ) - def test_cast_expression( - self, _type: TDbsqlParameter, prim: Primitive, expect_cast_expr: str - ): + def test_cast_expression(self, _type: TDbsqlParameter, prim: Primitive, expect_cast_expr: str): p = _type(prim.value) assert p._cast_expr() == expect_cast_expr @@ -201,13 +199,9 @@ def test_tspark_param_value(self, t: TDbsqlParameter, prim): type="MAP", value=None, arguments=[ - TSparkParameterValueArg( - type="STRING", value="a", arguments=None - ), + TSparkParameterValueArg(type="STRING", value="a", arguments=None), TSparkParameterValueArg(type="INT", value="1", arguments=None), - TSparkParameterValueArg( - type="STRING", value="b", arguments=None - ), + TSparkParameterValueArg(type="STRING", value="b", arguments=None), TSparkParameterValueArg(type="INT", value="2", arguments=None), ], ), @@ -225,36 +219,20 @@ def test_tspark_param_value(self, t: TDbsqlParameter, prim): type="MAP", value=None, arguments=[ - TSparkParameterValueArg( - type="STRING", value="a", arguments=None - ), - TSparkParameterValueArg( - type="INT", value="1", arguments=None - ), - TSparkParameterValueArg( - type="STRING", value="b", arguments=None - ), - TSparkParameterValueArg( - type="INT", value="2", arguments=None - ), + TSparkParameterValueArg(type="STRING", value="a", arguments=None), + TSparkParameterValueArg(type="INT", value="1", arguments=None), + TSparkParameterValueArg(type="STRING", value="b", arguments=None), + TSparkParameterValueArg(type="INT", value="2", arguments=None), ], ), TSparkParameterValueArg( type="MAP", value=None, arguments=[ - TSparkParameterValueArg( - type="STRING", value="c", arguments=None - ), - TSparkParameterValueArg( - type="INT", value="3", arguments=None - ), - TSparkParameterValueArg( - type="STRING", value="d", arguments=None - ), - TSparkParameterValueArg( - type="INT", value="4", arguments=None - ), + TSparkParameterValueArg(type="STRING", value="c", arguments=None), + TSparkParameterValueArg(type="INT", value="3", arguments=None), + TSparkParameterValueArg(type="STRING", value="d", arguments=None), + TSparkParameterValueArg(type="INT", value="4", arguments=None), ], ), ], diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 897a1d111..cf8992fbf 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -35,9 +35,7 @@ def test_sleep__no_retry_after(self, t_mock, retry_policy, error_history): retry_policy.sleep(HTTPResponse(status=503)) expected_backoff_time = max( - self.calculate_backoff_time( - 0, retry_policy.delay_min, retry_policy.delay_max - ), + self.calculate_backoff_time(0, retry_policy.delay_min, retry_policy.delay_max), retry_policy.delay_max, ) t_mock.assert_called_with(expected_backoff_time) @@ -66,9 +64,7 @@ def test_sleep__no_retry_after_header__multiple_retries(self, t_mock, retry_poli ) # Asserts if the sleep value was called in the expected order - t_mock.assert_has_calls( - [call(expected_time) for expected_time in expected_backoff_times] - ) + t_mock.assert_has_calls([call(expected_time) for expected_time in expected_backoff_times]) @patch("time.sleep") def test_excessive_retry_attempts_error(self, t_mock, retry_policy): diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index f57f75562..a3cf9940e 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -8,7 +8,7 @@ NoopTelemetryClient, TelemetryClientFactory, TelemetryHelper, - BaseTelemetryClient + BaseTelemetryClient, ) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( @@ -24,7 +24,7 @@ def mock_telemetry_client(): session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() - + return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -43,7 +43,7 @@ def test_noop_client_behavior(self): client1 = NoopTelemetryClient() client2 = NoopTelemetryClient() assert client1 is client2 - + # Test that all methods can be called without exceptions client1.export_initial_telemetry_log(MagicMock(), "test-agent") client1.export_failure_log("TestError", "Test message") @@ -58,61 +58,61 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): """Test the complete event batching and flushing flow.""" client = mock_telemetry_client client._batch_size = 3 # Small batch for testing - + # Mock the network call - with patch.object(client, '_send_telemetry') as mock_send: + with patch.object(client, "_send_telemetry") as mock_send: # Add events one by one - should not flush yet client._export_event("event1") client._export_event("event2") mock_send.assert_not_called() assert len(client._events_batch) == 2 - + # Third event should trigger flush client._export_event("event3") mock_send.assert_called_once() assert len(client._events_batch) == 0 # Batch cleared after flush - - @patch('requests.post') + + @patch("requests.post") def test_network_request_flow(self, mock_post, mock_telemetry_client): """Test the complete network request flow with authentication.""" mock_post.return_value.status_code = 200 client = mock_telemetry_client - + # Create mock events mock_events = [MagicMock() for _ in range(2)] for i, event in enumerate(mock_events): event.to_json.return_value = f'{{"event": "{i}"}}' - + # Send telemetry client._send_telemetry(mock_events) - + # Verify request was submitted to executor client._executor.submit.assert_called_once() args, kwargs = client._executor.submit.call_args - + # Verify correct function and URL assert args[0] == requests.post - assert args[1] == 'https://test-host.com/telemetry-ext' - assert kwargs['headers']['Authorization'] == 'Bearer test-token' - + assert args[1] == "https://test-host.com/telemetry-ext" + assert kwargs["headers"]["Authorization"] == "Bearer test-token" + # Verify request body structure - request_data = kwargs['data'] + request_data = kwargs["data"] assert '"uploadTime"' in request_data assert '"protoLogs"' in request_data def test_telemetry_logging_flows(self, mock_telemetry_client): """Test all telemetry logging methods work end-to-end.""" client = mock_telemetry_client - - with patch.object(client, '_export_event') as mock_export: + + with patch.object(client, "_export_event") as mock_export: # Test initial log client.export_initial_telemetry_log(MagicMock(), "test-agent") assert mock_export.call_count == 1 - + # Test failure log client.export_failure_log("TestError", "Error message") assert mock_export.call_count == 2 - + # Test latency log client.export_latency_log(150, "EXECUTE_STATEMENT", "stmt-123") assert mock_export.call_count == 3 @@ -120,14 +120,14 @@ def test_telemetry_logging_flows(self, mock_telemetry_client): def test_error_handling_resilience(self, mock_telemetry_client): """Test that telemetry errors don't break the client.""" client = mock_telemetry_client - + # Test that exceptions in telemetry don't propagate - with patch.object(client, '_export_event', side_effect=Exception("Test error")): + with patch.object(client, "_export_event", side_effect=Exception("Test error")): # These should not raise exceptions client.export_initial_telemetry_log(MagicMock(), "test-agent") client.export_failure_log("TestError", "Error message") client.export_latency_log(100, "EXECUTE_STATEMENT", "stmt-123") - + # Test executor submission failure client._executor.submit.side_effect = Exception("Thread pool error") client._send_telemetry([MagicMock()]) # Should not raise @@ -140,7 +140,7 @@ def test_system_configuration_caching(self): """Test that system configuration is cached and contains expected data.""" config1 = TelemetryHelper.get_driver_system_configuration() config2 = TelemetryHelper.get_driver_system_configuration() - + # Should be cached (same instance) assert config1 is config2 @@ -153,7 +153,7 @@ def test_auth_mechanism_detection(self): (MagicMock(), AuthMech.OTHER), # Unknown provider (None, None), ] - + for provider, expected in test_cases: assert TelemetryHelper.get_auth_mechanism(provider) == expected @@ -164,18 +164,21 @@ def test_auth_flow_detection(self): oauth_with_tokens._access_token = "test-access-token" oauth_with_tokens._refresh_token = "test-refresh-token" assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH - + # Test OAuth with browser-based auth oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) oauth_with_browser._access_token = None oauth_with_browser._refresh_token = None oauth_with_browser.oauth_manager = MagicMock() - assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION - + assert ( + TelemetryHelper.get_auth_flow(oauth_with_browser) + == AuthFlow.BROWSER_BASED_AUTHENTICATION + ) + # Test non-OAuth provider pat_auth = AccessTokenAuthProvider("test-token") assert TelemetryHelper.get_auth_flow(pat_auth) is None - + # Test None auth provider assert TelemetryHelper.get_auth_flow(None) is None @@ -202,24 +205,24 @@ def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") - + # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, - host_url="test-host.com" + host_url="test-host.com", ) - + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex - + # Close client - with patch.object(client, 'close') as mock_close: + with patch.object(client, "close") as mock_close: TelemetryClientFactory.close(session_id_hex) mock_close.assert_called_once() - + # Should get NoopTelemetryClient after close client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) @@ -227,31 +230,33 @@ def test_client_lifecycle_flow(self): def test_disabled_telemetry_flow(self): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" - + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, session_id_hex=session_id_hex, auth_provider=None, - host_url="test-host.com" + host_url="test-host.com", ) - + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" - + # Simulate initialization error - with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', - side_effect=Exception("Init error")): + with patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient", + side_effect=Exception("Init error"), + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id, auth_provider=AccessTokenAuthProvider("token"), - host_url="test-host.com" + host_url="test-host.com", ) - + # Should fall back to NoopTelemetryClient client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) @@ -260,25 +265,25 @@ def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" - + # Initialize multiple clients for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session, auth_provider=AccessTokenAuthProvider("token"), - host_url="test-host.com" + host_url="test-host.com", ) - + # Factory should be initialized assert TelemetryClientFactory._initialized is True assert TelemetryClientFactory._executor is not None - + # Close first client - factory should stay initialized TelemetryClientFactory.close(session1) assert TelemetryClientFactory._initialized is True - + # Close second client - factory should shut down TelemetryClientFactory.close(session2) assert TelemetryClientFactory._initialized is False - assert TelemetryClientFactory._executor is None \ No newline at end of file + assert TelemetryClientFactory._executor is None diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 458ea9a82..7866da93d 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -86,9 +86,7 @@ def test_make_request_checks_thrift_status_code(self): def _make_type_desc(self, type): return ttypes.TTypeDesc( - types=[ - ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type)) - ] + types=[ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type))] ) def _make_fake_thrift_backend(self): @@ -159,9 +157,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) - self.assertIn( - "expected server to use a protocol version", str(cm.exception) - ) + self.assertIn("expected server to use a protocol version", str(cm.exception)) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @@ -190,9 +186,7 @@ def test_headers_are_set(self, t_http_client_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - t_http_client_class.return_value.setCustomHeaders.assert_called_with( - {"header": "value"} - ) + t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) def test_proxy_headers_are_set(self): @@ -212,9 +206,7 @@ def test_proxy_headers_are_set(self): @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.types.create_default_context") - def test_tls_cert_args_are_propagated( - self, mock_create_default_context, t_http_client_class - ): + def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class): mock_cert_key_file = Mock() mock_cert_key_password = Mock() mock_trusted_ca_file = Mock() @@ -245,9 +237,7 @@ def test_tls_cert_args_are_propagated( ) self.assertTrue(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual( - t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options - ) + self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) @patch("databricks.sql.types.create_default_context") def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context): @@ -286,9 +276,7 @@ def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context self.assertEqual(conn_pool.ca_certs, mock_ssl_options.tls_trusted_ca_file) self.assertEqual(conn_pool.cert_file, mock_ssl_options.tls_client_cert_file) self.assertEqual(conn_pool.key_file, mock_ssl_options.tls_client_cert_key_file) - self.assertEqual( - conn_pool.key_password, mock_ssl_options.tls_client_cert_key_password - ) + self.assertEqual(conn_pool.key_password, mock_ssl_options.tls_client_cert_key_password) def test_tls_no_verify_is_respected_by_http_client(self): from databricks.sql.auth.thrift_http_client import THttpClient @@ -308,9 +296,7 @@ def test_tls_no_verify_is_respected_by_http_client(self): @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.types.create_default_context") - def test_tls_no_verify_is_respected( - self, mock_create_default_context, t_http_client_class - ): + def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class): mock_ssl_options = SSLOptions(tls_verify=False) mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() @@ -326,9 +312,7 @@ def test_tls_no_verify_is_respected( self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_NONE) - self.assertEqual( - t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options - ) + self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.types.create_default_context") @@ -350,9 +334,7 @@ def test_tls_verify_hostname_is_respected( self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual( - t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options - ) + self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): @@ -410,9 +392,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): ssl_options=SSLOptions(), _socket_timeout=129, ) - self.assertEqual( - t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000 - ) + self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000) ThriftBackend( "hostname", 123, @@ -431,9 +411,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - self.assertEqual( - t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000 - ) + self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000) ThriftBackend( "hostname", 123, @@ -443,9 +421,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): ssl_options=SSLOptions(), _socket_timeout=None, ) - self.assertEqual( - t_http_client_class.return_value.setTimeout.call_args[0][0], None - ) + self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None) def test_non_primitive_types_raise_error(self): columns = [ @@ -457,9 +433,7 @@ def test_non_primitive_types_raise_error(self): columnName="column 2", typeDesc=ttypes.TTypeDesc( types=[ - ttypes.TTypeEntry( - userDefinedTypeEntry=ttypes.TUserDefinedTypeEntry("foo") - ) + ttypes.TTypeEntry(userDefinedTypeEntry=ttypes.TUserDefinedTypeEntry("foo")) ] ), ), @@ -516,12 +490,8 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): type=ttypes.TTypeId.DECIMAL_TYPE, typeQualifiers=ttypes.TTypeQualifiers( qualifiers={ - "precision": ttypes.TTypeQualifierValue( - i32Value=10 - ), - "scale": ttypes.TTypeQualifierValue( - i32Value=100 - ), + "precision": ttypes.TTypeQualifierValue(i32Value=10), + "scale": ttypes.TTypeQualifierValue(i32Value=100), } ), ) @@ -602,12 +572,8 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertIn("some information about the error", str(cm.exception)) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) - def test_handle_execute_response_sets_compression_in_direct_results( - self, build_queue - ): + @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + def test_handle_execute_response_sets_compression_in_direct_results(self, build_queue): for resp_type in self.execute_response_types: lz4Compressed = Mock() resultSet = MagicMock() @@ -637,15 +603,11 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( - t_execute_resp, Mock() - ) + execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_checks_operation_state_in_polls( - self, tcli_service_class - ): + def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value error_resp = ttypes.TGetOperationStatusResp( @@ -661,9 +623,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( for op_state_resp, exec_resp_type in itertools.product( [error_resp, closed_resp], self.execute_response_types ): - with self.subTest( - op_state_resp=op_state_resp, exec_resp_type=exec_resp_type - ): + with self.subTest(op_state_resp=op_state_resp, exec_resp_type=exec_resp_type): tcli_service_instance = tcli_service_class.return_value t_execute_resp = exec_resp_type( status=self.okay_status, @@ -705,9 +665,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): directResults=None, operationHandle=self.operation_handle, ) - tcli_service_instance.GetOperationStatus.return_value = ( - t_get_operation_status_resp - ) + tcli_service_instance.GetOperationStatus.return_value = t_get_operation_status_resp tcli_service_instance.ExecuteStatement.return_value = t_execute_resp thrift_backend = ThriftBackend( @@ -769,9 +727,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resp_1 = resp_type( status=self.okay_status, directResults=ttypes.TSparkDirectResults( - operationStatus=ttypes.TGetOperationStatusResp( - status=self.bad_status - ), + operationStatus=ttypes.TGetOperationStatusResp(status=self.bad_status), resultSetMetadata=None, resultSet=None, closeOperation=None, @@ -782,9 +738,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): status=self.okay_status, directResults=ttypes.TSparkDirectResults( operationStatus=None, - resultSetMetadata=ttypes.TGetResultSetMetadataResp( - status=self.bad_status - ), + resultSetMetadata=ttypes.TGetResultSetMetadataResp(status=self.bad_status), resultSet=None, closeOperation=None, ), @@ -826,9 +780,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): self.assertIn("this is a bad error", str(cm.exception)) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class - ): + def test_handle_execute_response_can_handle_without_direct_results(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value for resp_type in self.execute_response_types: @@ -855,9 +807,7 @@ def test_handle_execute_response_can_handle_without_direct_results( operationState=ttypes.TOperationState.FINISHED_STATE, ) - tcli_service_instance.GetResultSetMetadata.return_value = ( - self.metadata_resp - ) + tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp tcli_service_instance.GetOperationStatus.side_effect = [ op_state_1, op_state_2, @@ -936,13 +886,9 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - tcli_service_instance.GetResultSetMetadata.return_value = ( - t_get_result_set_metadata_resp - ) + tcli_service_instance.GetResultSetMetadata.return_value = t_get_result_set_metadata_resp thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( - t_execute_resp, Mock() - ) + execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @@ -973,9 +919,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): thrift_backend._hive_schema_to_arrow_schema.call_args[0][0], ) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) + @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue @@ -1006,20 +950,14 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( operationHandle=self.operation_handle, ) - tcli_service_instance.GetResultSetMetadata.return_value = ( - self.metadata_resp - ) + tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) + execute_response = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual(has_more_rows, execute_response.has_more_rows) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) + @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue @@ -1054,12 +992,8 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( ) tcli_service_instance.FetchResults.return_value = fetch_results_resp - tcli_service_instance.GetOperationStatus.return_value = ( - operation_status_resp - ) - tcli_service_instance.GetResultSetMetadata.return_value = ( - self.metadata_resp - ) + tcli_service_instance.GetOperationStatus.return_value = operation_status_resp + tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -1086,8 +1020,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): startRowOffset=0, rows=[], arrowBatches=[ - ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) - for _ in range(10) + ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) for _ in range(10) ], ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( @@ -1129,9 +1062,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class - ): + def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response @@ -1153,14 +1084,10 @@ def test_execute_statement_calls_client_and_handle_execute_response( self.assertEqual(req.getDirectResults, get_direct_results) self.assertEqual(req.statement, "foo") # Check response handling - thrift_backend._handle_execute_response.assert_called_with( - response, cursor_mock - ) + thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class - ): + def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response @@ -1181,14 +1108,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) self.assertEqual(req.getDirectResults, get_direct_results) # Check response handling - thrift_backend._handle_execute_response.assert_called_with( - response, cursor_mock - ) + thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class - ): + def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response @@ -1218,14 +1141,10 @@ def test_get_schemas_calls_client_and_handle_execute_response( self.assertEqual(req.catalogName, "catalog_pattern") self.assertEqual(req.schemaName, "schema_pattern") # Check response handling - thrift_backend._handle_execute_response.assert_called_with( - response, cursor_mock - ) + thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class - ): + def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response @@ -1259,14 +1178,10 @@ def test_get_tables_calls_client_and_handle_execute_response( self.assertEqual(req.tableName, "table_pattern") self.assertEqual(req.tableTypes, ["type1", "type2"]) # Check response handling - thrift_backend._handle_execute_response.assert_called_with( - response, cursor_mock - ) + thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class - ): + def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response @@ -1300,9 +1215,7 @@ def test_get_columns_calls_client_and_handle_execute_response( self.assertEqual(req.tableName, "table_pattern") self.assertEqual(req.columnName, "column_pattern") # Check response handling - thrift_backend._handle_execute_response.assert_called_with( - response, cursor_mock - ) + thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_open_session_user_provided_session_id_optional(self, tcli_service_class): @@ -1355,9 +1268,7 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_non_arrow_non_column_based_set_triggers_exception( - self, tcli_service_class - ): + def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1386,9 +1297,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( with self.assertRaises(OperationalError) as cm: thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) - self.assertIn( - "Expected results to be in Arrow or column based format", str(cm.exception) - ) + self.assertIn("Expected results to be in Arrow or column based format", str(cm.exception)) def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() @@ -1426,23 +1335,17 @@ def test_create_arrow_table_calls_correct_conversion_method( description = Mock() t_col_set = ttypes.TRowSet(columns=cols) - thrift_backend._create_arrow_table( - t_col_set, lz4_compressed, schema, description - ) + thrift_backend._create_arrow_table(t_col_set, lz4_compressed, schema, description) convert_arrow_mock.assert_not_called() convert_col_mock.assert_called_once_with(cols, description) t_arrow_set = ttypes.TRowSet(arrowBatches=arrow_batches) thrift_backend._create_arrow_table(t_arrow_set, lz4_compressed, schema, Mock()) - convert_arrow_mock.assert_called_once_with( - arrow_batches, lz4_compressed, schema - ) + convert_arrow_mock.assert_called_once_with(arrow_batches, lz4_compressed, schema) @patch("lz4.frame.decompress") @patch("pyarrow.ipc.open_stream") - def test_convert_arrow_based_set_to_arrow_table( - self, open_stream_mock, lz4_decompress_mock - ): + def test_convert_arrow_based_set_to_arrow_table(self, open_stream_mock, lz4_decompress_mock): thrift_backend = ThriftBackend( "foobar", 443, @@ -1482,23 +1385,15 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self): t_cols = [ ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))), ttypes.TColumn( - stringVal=ttypes.TStringColumn( - values=["s1", "s2", "s3"], nulls=bytes(1) - ) - ), - ttypes.TColumn( - doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1)) + stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1)) ), + ttypes.TColumn(doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1))), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn( - values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1) - ) + binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1)) ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( - t_cols, description - ) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) self.assertEqual(n_rows, 3) # Check schema, column names and types @@ -1523,29 +1418,19 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self): description = [(name,) for name in field_names] t_cols = [ + ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes([1]))), ttypes.TColumn( - i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes([1])) - ), - ttypes.TColumn( - stringVal=ttypes.TStringColumn( - values=["s1", "s2", "s3"], nulls=bytes([2]) - ) + stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes([2])) ), ttypes.TColumn( - doubleVal=ttypes.TDoubleColumn( - values=[1.15, 2.2, 3.3], nulls=bytes([4]) - ) + doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes([4])) ), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn( - values=[b"\x11", b"\x22", b"\x33"], nulls=bytes([3]) - ) + binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes([3])) ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( - t_cols, description - ) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) self.assertEqual(n_rows, 3) # Check data @@ -1561,23 +1446,15 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): t_cols = [ ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))), ttypes.TColumn( - stringVal=ttypes.TStringColumn( - values=["s1", "s2", "s3"], nulls=bytes(1) - ) + stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1)) ), + ttypes.TColumn(doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1))), ttypes.TColumn( - doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1)) - ), - ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn( - values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1) - ) + binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1)) ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( - t_cols, description - ) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) self.assertEqual(n_rows, 3) # Check schema, column names and types @@ -1623,12 +1500,8 @@ def test_handle_execute_response_sets_active_op_handle(self): self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch( - "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" - ) - @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory - ) + @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") + @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_GetOperationStatus( self, mock_retry_policy, mock_GetOperationStatus, t_transport_class ): @@ -1640,9 +1513,7 @@ def test_make_request_will_retry_GetOperationStatus( this_gos_name = "GetOperationStatus" mock_GetOperationStatus.__name__ = this_gos_name - mock_GetOperationStatus.side_effect = OSError( - errno.ETIMEDOUT, "Connection timed out" - ) + mock_GetOperationStatus.side_effect = OSError(errno.ETIMEDOUT, "Connection timed out") protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(t_transport_class) client = Client(protocol) @@ -1671,18 +1542,12 @@ def test_make_request_will_retry_GetOperationStatus( self.assertEqual( NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"] ) - self.assertEqual( - f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"] - ) + self.assertEqual(f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"]) # Unusual OSError code - mock_GetOperationStatus.side_effect = OSError( - errno.EEXIST, "File does not exist" - ) + mock_GetOperationStatus.side_effect = OSError(errno.EEXIST, "File does not exist") - with self.assertLogs( - "databricks.sql.thrift_backend", level=logging.WARNING - ) as cm: + with self.assertLogs("databricks.sql.thrift_backend", level=logging.WARNING) as cm: with self.assertRaises(RequestError): thrift_backend.make_request(client.GetOperationStatus, req) @@ -1698,12 +1563,8 @@ def test_make_request_will_retry_GetOperationStatus( cm.output[0], ) - @patch( - "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" - ) - @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory - ) + @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") + @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos ): @@ -1748,14 +1609,10 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( self.assertEqual( NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"] ) - self.assertEqual( - f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"] - ) + self.assertEqual(f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"]) @patch("thrift.transport.THttpClient.THttpClient") - def test_make_request_wont_retry_if_error_code_not_429_or_503( - self, t_transport_class - ): + def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_class): t_transport_instance = t_transport_class.return_value t_transport_instance.code = 430 t_transport_instance.headers = {"Retry-After": "1"} @@ -1778,9 +1635,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( self.assertIn("This method fails", str(cm.exception.message_with_context())) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory - ) + @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self, mock_retry_policy, t_transport_class ): @@ -1812,9 +1667,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self.assertEqual(mock_method.call_count, 14) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - def test_make_request_will_read_error_message_headers_if_set( - self, t_transport_class - ): + def test_make_request_will_read_error_message_headers_if_set(self, t_transport_class): t_transport_instance = t_transport_class.return_value mock_method = Mock() mock_method.__name__ = "method name" @@ -1861,18 +1714,12 @@ def make_table_and_desc( ): int_col = [int_constant for _ in range(height)] decimal_col = [decimal_constant for _ in range(height)] - data = OrderedDict( - {"col{}".format(i): int_col for i in range(width - n_decimal_cols)} - ) - decimals = OrderedDict( - {"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)} - ) + data = OrderedDict({"col{}".format(i): int_col for i in range(width - n_decimal_cols)}) + decimals = OrderedDict({"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)}) data.update(decimals) int_desc = [("", "int")] * (width - n_decimal_cols) - decimal_desc = [ - ("", "decimal", None, None, precision, scale, None) - ] * n_decimal_cols + decimal_desc = [("", "decimal", None, None, precision, scale, None)] * n_decimal_cols description = int_desc + decimal_desc table = pyarrow.Table.from_pydict(data) @@ -1911,30 +1758,20 @@ def test_arrow_decimal_conversion(self): else: self.assertEqual( decimal_converted_table.field(i).type, - pyarrow.decimal128( - precision=precision, scale=scale - ), + pyarrow.decimal128(precision=precision, scale=scale), ) int_col = [int_constant for _ in range(height)] decimal_col = [Decimal(decimal_constant) for _ in range(height)] expected_result = OrderedDict( - { - "col{}".format(i): int_col - for i in range(width - n_decimal_cols) - } + {"col{}".format(i): int_col for i in range(width - n_decimal_cols)} ) decimals = OrderedDict( - { - "col_dec{}".format(i): decimal_col - for i in range(n_decimal_cols) - } + {"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)} ) expected_result.update(decimals) - self.assertEqual( - decimal_converted_table.to_pydict(), expected_result - ) + self.assertEqual(decimal_converted_table.to_pydict(), expected_result) @patch("thrift.transport.THttpClient.THttpClient") def test_retry_args_passthrough(self, mock_http_client): @@ -1967,8 +1804,7 @@ def test_retry_args_bounding(self, mock_http_client): for i in range(2): retry_delay_args = { - k: v[i][0] - for (k, v) in retry_delay_test_args_and_expected_values.items() + k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } backend = ThriftBackend( "foobar", @@ -1980,8 +1816,7 @@ def test_retry_args_bounding(self, mock_http_client): **retry_delay_args, ) retry_delay_expected_vals = { - k: v[i][1] - for (k, v) in retry_delay_test_args_and_expected_values.items() + k: v[i][1] for (k, v) in retry_delay_test_args_and_expected_values.items() } for arg, val in retry_delay_expected_vals.items(): self.assertEqual(getattr(backend, arg), val) @@ -2060,16 +1895,12 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): backend.open_session({}, cat, schem) - open_session_req = tcli_client_class.return_value.OpenSession.call_args[ - 0 - ][0] + open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertEqual(open_session_req.initialNamespace.catalogName, cat) self.assertEqual(open_session_req.initialNamespace.schemaName, schem) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_can_use_multiple_catalogs_is_set_in_open_session_req( - self, tcli_client_class - ): + def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp @@ -2087,9 +1918,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req( self.assertTrue(open_session_req.canUseMultipleCatalogs) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( - self, tcli_client_class - ): + def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value backend = ThriftBackend( @@ -2182,9 +2011,7 @@ def test_execute_command_sets_complex_type_fields_correctly( **complex_arg_types, ) thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) - t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ - 0 - ][0] + t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[0][0] # If the value is unset, the native type should default to True self.assertEqual( t_execute_statement_req.useArrowNativeTypes.timestampAsArrow, @@ -2198,9 +2025,7 @@ def test_execute_command_sets_complex_type_fields_correctly( t_execute_statement_req.useArrowNativeTypes.complexTypesAsArrow, complex_arg_types.get("_use_arrow_native_complex_types", True), ) - self.assertFalse( - t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow - ) + self.assertFalse(t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow) if __name__ == "__main__": diff --git a/tests/unit/test_thrift_field_ids.py b/tests/unit/test_thrift_field_ids.py index a4bba439d..4b850db55 100644 --- a/tests/unit/test_thrift_field_ids.py +++ b/tests/unit/test_thrift_field_ids.py @@ -31,11 +31,7 @@ def test_all_thrift_field_ids_are_within_allowed_range(self): # Get all classes from the ttypes module for name, obj in inspect.getmembers(ttypes): - if ( - inspect.isclass(obj) - and hasattr(obj, "thrift_spec") - and obj.thrift_spec is not None - ): + if inspect.isclass(obj) and hasattr(obj, "thrift_spec") and obj.thrift_spec is not None: self._check_class_field_ids(obj, name, violations)