diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml new file mode 100644 index 00000000..74b93608 --- /dev/null +++ b/.github/workflows/token-federation-test.yml @@ -0,0 +1,78 @@ +name: Token Federation Test + +# Tests token federation functionality with GitHub Actions OIDC tokens +on: + # Manual trigger with required inputs + workflow_dispatch: + inputs: + databricks_host: + description: 'Databricks host URL (e.g., example.cloud.databricks.com)' + required: true + databricks_http_path: + description: 'Databricks HTTP path (e.g., /sql/1.0/warehouses/abc123)' + required: true + identity_federation_client_id: + description: 'Identity federation client ID' + required: true + + # Run on PRs that might affect token federation + pull_request: + branches: [main] + paths: + - 'src/databricks/sql/auth/**' + - 'examples/token_federation_*.py' + - 'tests/token_federation/**' + - '.github/workflows/token-federation-test.yml' + + # Run on push to main that affects token federation + push: + branches: [main] + paths: + - 'src/databricks/sql/auth/**' + - 'examples/token_federation_*.py' + - 'tests/token_federation/**' + - '.github/workflows/token-federation-test.yml' + +permissions: + id-token: write # Required for GitHub OIDC token + contents: read + +jobs: + test-token-federation: + name: Test Token Federation + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + cache: 'pip' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install pyarrow + + - name: Get GitHub OIDC token + id: get-id-token + uses: actions/github-script@v7 + with: + script: | + const token = await core.getIDToken('https://github.com/databricks') + core.setSecret(token) + core.setOutput('token', token) + + - name: Test token federation with GitHub OIDC token + env: + DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} + DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }} + IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} + OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} + run: python tests/token_federation/github_oidc_test.py diff --git a/poetry.lock b/poetry.lock index 1bc396c9..67880458 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "astroid" @@ -6,6 +6,7 @@ version = "3.2.4" description = "An abstract syntax tree for Python with inference support." optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "astroid-3.2.4-py3-none-any.whl", hash = "sha256:413658a61eeca6202a59231abb473f932038fbcbf1666587f66d482083413a25"}, {file = "astroid-3.2.4.tar.gz", hash = "sha256:0e14202810b30da1b735827f78f5157be2bbd4a7a59b7707ca0bfc2fb4c0063a"}, @@ -20,6 +21,7 @@ version = "22.12.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, @@ -55,6 +57,7 @@ version = "2025.1.31" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -66,6 +69,7 @@ version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -167,6 +171,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -181,6 +186,8 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["dev"] +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -192,6 +199,7 @@ version = "0.3.9" description = "serialize all of Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, @@ -207,6 +215,7 @@ version = "2.0.0" description = "An implementation of lxml.xmlfile for the standard library" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"}, {file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"}, @@ -218,6 +227,8 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -232,6 +243,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -246,6 +258,7 @@ version = "2.1.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -257,6 +270,7 @@ version = "5.13.2" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -271,6 +285,7 @@ version = "4.3.3" description = "LZ4 Bindings for Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "lz4-4.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b891880c187e96339474af2a3b2bfb11a8e4732ff5034be919aa9029484cd201"}, {file = "lz4-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:222a7e35137d7539c9c33bb53fcbb26510c5748779364014235afc62b0ec797f"}, @@ -321,6 +336,7 @@ version = "0.7.0" description = "McCabe checker, plugin for flake8" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, @@ -332,6 +348,7 @@ version = "1.14.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb"}, {file = "mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0"}, @@ -391,6 +408,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -402,6 +420,8 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] +markers = "python_version < \"3.10\"" files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -439,6 +459,8 @@ version = "2.2.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.10" +groups = ["main", "dev"] +markers = "python_version >= \"3.10\"" files = [ {file = "numpy-2.2.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8146f3550d627252269ac42ae660281d673eb6f8b32f113538e0cc2a9aed42b9"}, {file = "numpy-2.2.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e642d86b8f956098b564a45e6f6ce68a22c2c97a04f5acd3f221f57b8cb850ae"}, @@ -503,6 +525,7 @@ version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, @@ -519,6 +542,7 @@ version = "3.1.5" description = "A Python library to read/write Excel 2010 xlsx/xlsm files" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"}, {file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"}, @@ -533,6 +557,7 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -544,6 +569,8 @@ version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" files = [ {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, @@ -573,11 +600,7 @@ files = [ ] [package.dependencies] -numpy = [ - {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, -] +numpy = {version = ">=1.20.3", markers = "python_version < \"3.10\""} python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.1" @@ -611,6 +634,8 @@ version = "2.2.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" files = [ {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, @@ -657,7 +682,11 @@ files = [ ] [package.dependencies] -numpy = {version = ">=1.26.0", markers = "python_version >= \"3.12\""} +numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, +] python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.7" @@ -693,6 +722,7 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -704,6 +734,7 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -720,6 +751,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -735,6 +767,8 @@ version = "17.0.0" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, @@ -786,6 +820,8 @@ version = "19.0.1" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:fc28912a2dc924dddc2087679cc8b7263accc71b9ff025a1362b004711661a69"}, {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fca15aabbe9b8355800d923cc2e82c8ef514af321e18b437c3d782aa884eaeec"}, @@ -834,12 +870,51 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pyjwt" +version = "2.9.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" +files = [ + {file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"}, + {file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + +[[package]] +name = "pyjwt" +version = "2.10.1" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"}, + {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pylint" version = "3.2.7" description = "python code static checker" optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "pylint-3.2.7-py3-none-any.whl", hash = "sha256:02f4aedeac91be69fb3b4bea997ce580a4ac68ce58b89eaefeaf06749df73f4b"}, {file = "pylint-3.2.7.tar.gz", hash = "sha256:1b7a721b575eaeaa7d39db076b6e7743c993ea44f57979127c517c6c572c803e"}, @@ -851,7 +926,7 @@ colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=0.3.6", markers = "python_version == \"3.11\""}, ] isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" mccabe = ">=0.6,<0.8" @@ -870,6 +945,7 @@ version = "7.4.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, @@ -892,6 +968,7 @@ version = "0.5.2" description = "A py.test plugin that parses environment files before running tests" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"}, {file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"}, @@ -907,6 +984,7 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -921,6 +999,7 @@ version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, @@ -935,6 +1014,7 @@ version = "2025.2" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, @@ -946,6 +1026,7 @@ version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -967,6 +1048,7 @@ version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -978,6 +1060,7 @@ version = "0.20.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "thrift-0.20.0.tar.gz", hash = "sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba"}, ] @@ -996,6 +1079,8 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -1037,6 +1122,7 @@ version = "0.13.2" description = "Style preserving TOML library" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde"}, {file = "tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79"}, @@ -1048,6 +1134,7 @@ version = "4.13.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "typing_extensions-4.13.0-py3-none-any.whl", hash = "sha256:c8dd92cc0d6425a97c18fbb9d1954e5ff92c1ca881a309c45f06ebc0b79058e5"}, {file = "typing_extensions-4.13.0.tar.gz", hash = "sha256:0a4ac55a5820789d87e297727d229866c9650f6521b64206413c4fbada24d95b"}, @@ -1059,6 +1146,7 @@ version = "2025.2" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" +groups = ["main"] files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, @@ -1070,13 +1158,14 @@ version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -1085,6 +1174,6 @@ zstd = ["zstandard (>=0.18.0)"] pyarrow = ["pyarrow", "pyarrow"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "0bd6a6a019693a69a3da5ae312cea625ea73dfc5832b1e4051c7c7d1e76553d8" +content-hash = "aa36901ed7501adeeba5384352904ba06a34d298e400e926201e0fd57f6b6678" diff --git a/pyproject.toml b/pyproject.toml index 7b95a509..7d326b2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,11 +25,12 @@ pyarrow = [ { version = ">=18.0.0", python = ">=3.13", optional=true } ] python-dateutil = "^2.8.0" +PyJWT = ">=2.0.0" [tool.poetry.extras] pyarrow = ["pyarrow"] -[tool.poetry.dev-dependencies] +[tool.poetry.group.dev.dependencies] pytest = "^7.1.2" mypy = "^1.10.1" pylint = ">=2.12.0" diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 347934ee..c679879f 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -5,6 +5,7 @@ AuthProvider, AccessTokenAuthProvider, ExternalAuthProvider, + CredentialsProvider, DatabricksOAuthProvider, ) @@ -12,6 +13,9 @@ class AuthType(Enum): DATABRICKS_OAUTH = "databricks-oauth" AZURE_OAUTH = "azure-oauth" + # TODO: Token federation should be a feature that works with different auth types, + # not an auth type itself. This will be refactored in a future change. + TOKEN_FEDERATION = "token-federation" # other supported types (access_token) can be inferred # we can add more types as needed later @@ -29,6 +33,7 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, + identity_federation_client_id: Optional[str] = None, ): self.hostname = hostname self.access_token = access_token @@ -40,11 +45,44 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider + self.identity_federation_client_id = identity_federation_client_id def get_auth_provider(cfg: ClientContext): + # TODO: In a future refactoring, token federation should be a feature that wraps + # any auth provider, not a separate auth type. The code below treats it as an auth type + # for backward compatibility, but this approach will be revised. + if cfg.credentials_provider: + # If token federation is enabled and credentials provider is provided, + # wrap the credentials provider with DatabricksTokenFederationProvider + if cfg.auth_type == AuthType.TOKEN_FEDERATION.value: + from databricks.sql.auth.token_federation import ( + DatabricksTokenFederationProvider, + ) + + federation_provider = DatabricksTokenFederationProvider( + cfg.credentials_provider, + cfg.hostname, + cfg.identity_federation_client_id, + ) + return ExternalAuthProvider(federation_provider) + + # If not token federation, just use the credentials provider directly return ExternalAuthProvider(cfg.credentials_provider) + + # If we don't have a credentials provider but have token federation auth type with access token + if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: + # If only access_token is provided with token federation, use create_token_federation_provider + from databricks.sql.auth.token_federation import ( + create_token_federation_provider, + ) + + federation_provider = create_token_federation_provider( + cfg.access_token, cfg.hostname, cfg.identity_federation_client_id + ) + return ExternalAuthProvider(federation_provider) + if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: assert cfg.oauth_redirect_port_range is not None assert cfg.oauth_client_id is not None @@ -112,6 +150,10 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): "Please use OAuth or access token instead." ) + # TODO: Future refactoring needed: + # - Add a use_token_federation flag that can be combined with any auth type + # - Remove TOKEN_FEDERATION as an auth_type and properly handle the underlying auth type + # - Maintain backward compatibility during transition cfg = ClientContext( hostname=normalize_host_name(hostname), auth_type=auth_type, @@ -125,5 +167,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): else redirect_port_range, oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), + identity_federation_client_id=kwargs.get("identity_federation_client_id"), ) return get_auth_provider(cfg) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 64eb91bb..c425f088 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -26,10 +26,16 @@ class CredentialsProvider(abc.ABC): @abc.abstractmethod def auth_type(self) -> str: + """ + Returns the authentication type for this provider + """ ... @abc.abstractmethod def __call__(self, *args, **kwargs) -> HeaderFactory: + """ + Configure and return a HeaderFactory that provides authentication headers + """ ... diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py new file mode 100644 index 00000000..7f3f147d --- /dev/null +++ b/src/databricks/sql/auth/token_federation.py @@ -0,0 +1,474 @@ +import base64 +import json +import logging +from datetime import datetime, timezone, timedelta +from typing import Dict, Optional, Any, Tuple +from urllib.parse import urlparse + +import requests +from requests.exceptions import RequestException + +from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory +from databricks.sql.auth.endpoint import ( + get_oauth_endpoints, + infer_cloud_from_host, +) + +logger = logging.getLogger(__name__) + +TOKEN_EXCHANGE_PARAMS = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "scope": "sql", + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "return_original_token_if_authenticated": "true", +} + +TOKEN_REFRESH_BUFFER_SECONDS = 10 + + +class Token: + """Represents an OAuth token with expiry information.""" + + def __init__( + self, + access_token: str, + token_type: str, + refresh_token: str = "", + expiry: Optional[datetime] = None, + ): + self.access_token = access_token + self.token_type = token_type + self.refresh_token = refresh_token + + # Ensure expiry is timezone-aware + if expiry is None: + self.expiry = datetime.now(tz=timezone.utc) + elif expiry.tzinfo is None: + # Convert naive datetime to aware datetime + self.expiry = expiry.replace(tzinfo=timezone.utc) + else: + self.expiry = expiry + + def is_expired(self) -> bool: + """Check if the token is expired.""" + return datetime.now(tz=timezone.utc) >= self.expiry + + def needs_refresh(self) -> bool: + """Check if the token needs to be refreshed soon.""" + buffer_time = timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS) + return datetime.now(tz=timezone.utc) >= (self.expiry - buffer_time) + + def __str__(self) -> str: + return f"{self.token_type} {self.access_token}" + + +class DatabricksTokenFederationProvider(CredentialsProvider): + """ + Implementation of the Credential Provider that exchanges the third party access token + for a Databricks InHouse Token. This class exchanges the access token if the issued token + is not from the same host as the Databricks host. + """ + + def __init__( + self, + credentials_provider: CredentialsProvider, + hostname: str, + identity_federation_client_id: Optional[str] = None, + ): + """ + Initialize the token federation provider. + + Args: + credentials_provider: The underlying credentials provider + hostname: The Databricks hostname + identity_federation_client_id: Optional client ID for identity federation + """ + self.credentials_provider = credentials_provider + self.hostname = hostname + self.identity_federation_client_id = identity_federation_client_id + self.external_provider_headers: Dict[str, str] = {} + self.token_endpoint: Optional[str] = None + self.idp_endpoints = None + self.openid_config = None + self.last_exchanged_token: Optional[Token] = None + self.last_external_token: Optional[str] = None + + def auth_type(self) -> str: + """Return the auth type from the underlying credentials provider.""" + return self.credentials_provider.auth_type() + + @property + def host(self) -> str: + """ + Alias for hostname to maintain compatibility with code expecting a host attribute. + + Returns: + str: The hostname value + """ + return self.hostname + + def __call__(self, *args, **kwargs) -> HeaderFactory: + """ + Configure and return a HeaderFactory that provides authentication headers. + + This is called by the ExternalAuthProvider to get headers for authentication. + """ + # First call the underlying credentials provider to get its headers + header_factory = self.credentials_provider(*args, **kwargs) + + # Initialize OIDC discovery + self._init_oidc_discovery() + + def get_headers() -> Dict[str, str]: + # Get headers from the underlying provider + self.external_provider_headers = header_factory() + + # Extract the token from the headers + token_info = self._extract_token_info_from_header( + self.external_provider_headers + ) + token_type, access_token = token_info + + try: + # Check if we need to refresh the token + if ( + self.last_exchanged_token + and self.last_external_token == access_token + and self.last_exchanged_token.needs_refresh() + ): + # The token is approaching expiry, try to refresh + logger.info( + "Exchanged token approaching expiry, refreshing with fresh external token..." + ) + return self._refresh_token(access_token, token_type) + + # Parse the JWT to get claims + token_claims = self._parse_jwt_claims(access_token) + + # Check if token needs to be exchanged + if self._is_same_host(token_claims.get("iss", ""), self.hostname): + # Token is from the same host, no need to exchange + logger.debug("Token from same host, no exchange needed") + return self.external_provider_headers + else: + # Token is from a different host, need to exchange + logger.debug("Token from different host, attempting exchange") + return self._try_token_exchange_or_fallback( + access_token, token_type + ) + except Exception as e: + logger.error(f"Error processing token: {str(e)}") + # Fall back to original headers in case of error + return self.external_provider_headers + + return get_headers + + def _init_oidc_discovery(self): + """Initialize OIDC discovery to find token endpoint.""" + if self.token_endpoint is not None: + return + + try: + # Use the existing OIDC discovery mechanism + use_azure_auth = infer_cloud_from_host(self.hostname) == "azure" + self.idp_endpoints = get_oauth_endpoints(self.hostname, use_azure_auth) + + if self.idp_endpoints: + # Get the OpenID configuration URL + openid_config_url = self.idp_endpoints.get_openid_config_url( + self.hostname + ) + + # Fetch the OpenID configuration + response = requests.get(openid_config_url) + if response.status_code == 200: + self.openid_config = response.json() + # Extract token endpoint from OpenID config + self.token_endpoint = self.openid_config.get("token_endpoint") + logger.info(f"Discovered token endpoint: {self.token_endpoint}") + else: + logger.warning( + f"Failed to fetch OpenID configuration from {openid_config_url}: {response.status_code}" + ) + + # Fallback to default token endpoint if discovery fails + if not self.token_endpoint: + hostname = self._format_hostname(self.hostname) + self.token_endpoint = f"{hostname}oidc/v1/token" + logger.info(f"Using default token endpoint: {self.token_endpoint}") + except Exception as e: + logger.warning( + f"OIDC discovery failed: {str(e)}. Using default token endpoint." + ) + hostname = self._format_hostname(self.hostname) + self.token_endpoint = f"{hostname}oidc/v1/token" + logger.info( + f"Using default token endpoint after error: {self.token_endpoint}" + ) + + def _format_hostname(self, hostname: str) -> str: + """Format hostname to ensure it has proper https:// prefix and trailing slash.""" + if not hostname.startswith("https://"): + hostname = f"https://{hostname}" + if not hostname.endswith("/"): + hostname = f"{hostname}/" + return hostname + + def _extract_token_info_from_header( + self, headers: Dict[str, str] + ) -> Tuple[str, str]: + """Extract token type and token value from authorization header.""" + auth_header = headers.get("Authorization") + if not auth_header: + raise ValueError("No Authorization header found") + + parts = auth_header.split(" ", 1) + if len(parts) != 2: + raise ValueError(f"Invalid Authorization header format: {auth_header}") + + return parts[0], parts[1] + + def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: + """Parse JWT token claims without validation.""" + try: + # Split the token + parts = token.split(".") + if len(parts) != 3: + raise ValueError("Invalid JWT format") + + # Get the payload part (second part) + payload = parts[1] + + # Add padding if needed + padding = "=" * (4 - len(payload) % 4) + payload += padding + + # Decode and parse JSON + decoded = base64.b64decode(payload) + return json.loads(decoded) + except Exception as e: + logger.error(f"Failed to parse JWT: {str(e)}") + raise + + def _is_same_host(self, url1: str, url2: str) -> bool: + """Check if two URLs have the same host.""" + try: + host1 = urlparse(url1).netloc + host2 = urlparse(url2).netloc + # If host1 is empty, it's not a valid URL, so we return False + if not host1: + return False + return host1 == host2 + except Exception as e: + logger.error(f"Failed to parse URLs: {str(e)}") + return False + + def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: + """ + Attempt to refresh an expired token by first getting a fresh external token + and then exchanging it for a new Databricks token. + + Args: + access_token: The original external access token (will be replaced) + token_type: The token type (Bearer, etc.) + + Returns: + The headers with the fresh token + """ + try: + logger.info( + "Refreshing token using proactive approach (getting fresh external token first)" + ) + + # Get a fresh token from the underlying credentials provider + # instead of reusing the same access_token + fresh_headers = self.credentials_provider()() + + # Extract the fresh token from the headers + auth_header = fresh_headers.get("Authorization", "") + if not auth_header: + logger.error("No Authorization header in fresh headers") + return self.external_provider_headers + + parts = auth_header.split(" ", 1) + if len(parts) != 2: + logger.error(f"Invalid Authorization header format: {auth_header}") + return self.external_provider_headers + + fresh_token_type = parts[0] + fresh_access_token = parts[1] + + # Check if we got the same token back + if fresh_access_token == access_token: + logger.warning( + "Credentials provider returned the same token during refresh" + ) + + # Perform a new token exchange with the fresh token + refreshed_token = self._exchange_token(fresh_access_token) + + # Update the stored token + self.last_exchanged_token = refreshed_token + self.last_external_token = fresh_access_token + + # Create new headers with the refreshed token + headers = dict(fresh_headers) # Use the fresh headers as base + headers[ + "Authorization" + ] = f"{refreshed_token.token_type} {refreshed_token.access_token}" + + logger.info( + f"Successfully refreshed token, new expiry: {refreshed_token.expiry}" + ) + return headers + except Exception as e: + logger.error( + f"Token refresh failed, falling back to original token: {str(e)}" + ) + # If refresh fails, fall back to the original headers + return self.external_provider_headers + + def _try_token_exchange_or_fallback( + self, access_token: str, token_type: str + ) -> Dict[str, str]: + """Try to exchange the token or fall back to the original token.""" + try: + # Exchange the token + exchanged_token = self._exchange_token(access_token) + + # Store the exchanged token for potential refresh later + self.last_exchanged_token = exchanged_token + self.last_external_token = access_token + + # Create new headers with the exchanged token + headers = dict(self.external_provider_headers) + headers[ + "Authorization" + ] = f"{exchanged_token.token_type} {exchanged_token.access_token}" + return headers + except Exception as e: + logger.error( + f"Token exchange failed, falling back to using external token: {str(e)}" + ) + # Fall back to original headers + return self.external_provider_headers + + def _exchange_token(self, access_token: str) -> Token: + """ + Exchange an external token for a Databricks token. + + Args: + access_token: The external token to exchange + + Returns: + A Token object containing the exchanged token + """ + if not self.token_endpoint: + self._init_oidc_discovery() + + # Ensure token_endpoint is set + if not self.token_endpoint: + raise ValueError("Token endpoint could not be determined") + + # Create request parameters + params = dict(TOKEN_EXCHANGE_PARAMS) + params["subject_token"] = access_token + + # Add client ID if available + if self.identity_federation_client_id: + params["client_id"] = self.identity_federation_client_id + + # Set up headers + headers = {"Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded"} + + try: + # Make the token exchange request + response = requests.post(self.token_endpoint, data=params, headers=headers) + response.raise_for_status() + + # Parse the response + resp_data = response.json() + + # Create a token from the response + token = Token( + access_token=resp_data.get("access_token"), + token_type=resp_data.get("token_type", "Bearer"), + refresh_token=resp_data.get("refresh_token", ""), + ) + + # Set expiry time from the response's expires_in field if available + # This is the standard OAuth approach + if "expires_in" in resp_data and resp_data["expires_in"]: + try: + # Calculate expiry by adding expires_in seconds to current time + expires_in_seconds = int(resp_data["expires_in"]) + token.expiry = datetime.now(tz=timezone.utc) + timedelta( + seconds=expires_in_seconds + ) + logger.debug(f"Token expiry set from expires_in: {token.expiry}") + except (ValueError, TypeError) as e: + logger.warning( + f"Could not parse expires_in from response: {str(e)}" + ) + + # If expires_in wasn't available, try to parse expiry from the token JWT + if token.expiry == datetime.now(tz=timezone.utc): + try: + token_claims = self._parse_jwt_claims(token.access_token) + exp_time = token_claims.get("exp") + if exp_time: + token.expiry = datetime.fromtimestamp(exp_time, tz=timezone.utc) + logger.debug( + f"Token expiry set from JWT exp claim: {token.expiry}" + ) + except Exception as e: + logger.warning(f"Could not parse expiry from token: {str(e)}") + + return token + except RequestException as e: + logger.error(f"Failed to perform token exchange: {str(e)}") + raise ValueError(f"Request error during token exchange: {str(e)}") + + +class SimpleCredentialsProvider(CredentialsProvider): + """A simple credentials provider that returns fixed headers.""" + + def __init__( + self, token: str, token_type: str = "Bearer", auth_type_value: str = "token" + ): + self.token = token + self.token_type = token_type + self._auth_type = auth_type_value + + def auth_type(self) -> str: + return self._auth_type + + def __call__(self, *args, **kwargs) -> HeaderFactory: + def get_headers() -> Dict[str, str]: + return {"Authorization": f"{self.token_type} {self.token}"} + + return get_headers + + +def create_token_federation_provider( + token: str, + hostname: str, + identity_federation_client_id: Optional[str] = None, + token_type: str = "Bearer", +) -> DatabricksTokenFederationProvider: + """ + Create a token federation provider using a simple token. + + Args: + token: The token to use + hostname: The Databricks hostname + identity_federation_client_id: Optional client ID for identity federation + token_type: The token type (default: "Bearer") + + Returns: + A DatabricksTokenFederationProvider + """ + provider = SimpleCredentialsProvider(token, token_type) + return DatabricksTokenFederationProvider( + provider, hostname, identity_federation_client_id + ) diff --git a/tests/token_federation/github_oidc_test.py b/tests/token_federation/github_oidc_test.py new file mode 100755 index 00000000..79fc40b3 --- /dev/null +++ b/tests/token_federation/github_oidc_test.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 + +""" +Test script for Databricks SQL token federation with GitHub Actions OIDC tokens. + +This script tests the Databricks SQL connector with token federation +using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse, +runs a simple query, and shows the connected user. +""" + +import os +import sys +import json +import base64 +import logging +from databricks import sql + + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def decode_jwt(token): + """ + Decode and return the claims from a JWT token. + + Args: + token: The JWT token string + + Returns: + dict: The decoded token claims or None if decoding fails + """ + try: + parts = token.split(".") + if len(parts) != 3: + raise ValueError("Invalid JWT format") + + payload = parts[1] + # Add padding if needed + padding = '=' * (4 - len(payload) % 4) + payload += padding + + decoded = base64.b64decode(payload) + return json.loads(decoded) + except Exception as e: + logger.error(f"Failed to decode token: {str(e)}") + return None + + +def get_environment_variables(): + """ + Get required environment variables for the test. + + Returns: + tuple: (github_token, host, http_path, identity_federation_client_id) + + Raises: + SystemExit: If any required environment variable is missing + """ + github_token = os.environ.get("OIDC_TOKEN") + if not github_token: + logger.error("GitHub OIDC token not available") + sys.exit(1) + + host = os.environ.get("DATABRICKS_HOST_FOR_TF") + http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") + identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") + + if not host or not http_path: + logger.error("Missing Databricks connection parameters") + sys.exit(1) + + return github_token, host, http_path, identity_federation_client_id + + +def display_token_info(claims): + """Display token claims for debugging.""" + if not claims: + logger.warning("No token claims available to display") + return + + logger.info("=== GitHub OIDC Token Claims ===") + logger.info(f"Token issuer: {claims.get('iss')}") + logger.info(f"Token subject: {claims.get('sub')}") + logger.info(f"Token audience: {claims.get('aud')}") + logger.info(f"Token expiration: {claims.get('exp', 'unknown')}") + logger.info(f"Repository: {claims.get('repository', 'unknown')}") + logger.info(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") + logger.info(f"Event name: {claims.get('event_name', 'unknown')}") + logger.info("===============================") + + +def test_databricks_connection(host, http_path, github_token, identity_federation_client_id): + """ + Test connection to Databricks using token federation. + + Args: + host: Databricks host + http_path: Databricks HTTP path + github_token: GitHub OIDC token + identity_federation_client_id: Identity federation client ID + + Returns: + bool: True if the test is successful, False otherwise + """ + logger.info("=== Testing Connection via Connector ===") + logger.info(f"Connecting to Databricks at {host}{http_path}") + logger.info(f"Using client ID: {identity_federation_client_id}") + + connection_params = { + "server_hostname": host, + "http_path": http_path, + "access_token": github_token, + "auth_type": "token-federation", + "identity_federation_client_id": identity_federation_client_id, + } + + try: + with sql.connect(**connection_params) as connection: + logger.info("Connection established successfully") + + # Execute a simple query + cursor = connection.cursor() + cursor.execute("SELECT 1 + 1 as result") + result = cursor.fetchall() + logger.info(f"Query result: {result[0][0]}") + + # Show current user + cursor.execute("SELECT current_user() as user") + result = cursor.fetchall() + logger.info(f"Connected as user: {result[0][0]}") + + logger.info("Token federation test successful!") + return True + except Exception as e: + logger.error(f"Error connecting to Databricks: {str(e)}") + return False + + +def main(): + """Main entry point for the test script.""" + try: + # Get environment variables + github_token, host, http_path, identity_federation_client_id = get_environment_variables() + + # Display token claims + claims = decode_jwt(github_token) + display_token_info(claims) + + # Test Databricks connection + success = test_databricks_connection( + host, http_path, github_token, identity_federation_client_id + ) + + if not success: + logger.error("Token federation test failed") + sys.exit(1) + + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py new file mode 100644 index 00000000..78ffc9e2 --- /dev/null +++ b/tests/unit/test_token_federation.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 + +""" +Unit tests for token federation functionality in the Databricks SQL connector. +""" + +import unittest +from unittest.mock import patch, MagicMock +import json +from datetime import datetime, timezone, timedelta + +from databricks.sql.auth.token_federation import ( + Token, + DatabricksTokenFederationProvider, + SimpleCredentialsProvider, + create_token_federation_provider, + TOKEN_REFRESH_BUFFER_SECONDS, +) + + +class TestToken(unittest.TestCase): + """Tests for the Token class.""" + + def test_token_initialization(self): + """Test Token initialization.""" + token = Token("access_token_value", "Bearer", "refresh_token_value") + self.assertEqual(token.access_token, "access_token_value") + self.assertEqual(token.token_type, "Bearer") + self.assertEqual(token.refresh_token, "refresh_token_value") + + def test_token_is_expired(self): + """Test Token is_expired method.""" + # Token with expiry in the past + past = datetime.now(tz=timezone.utc) - timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=past) + self.assertTrue(token.is_expired()) + + # Token with expiry in the future + future = datetime.now(tz=timezone.utc) + timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=future) + self.assertFalse(token.is_expired()) + + def test_token_needs_refresh(self): + """Test Token needs_refresh method using actual TOKEN_REFRESH_BUFFER_SECONDS.""" + # Token with expiry in the past + past = datetime.now(tz=timezone.utc) - timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=past) + self.assertTrue(token.needs_refresh()) + + # Token with expiry in the near future (within refresh buffer) + near_future = datetime.now(tz=timezone.utc) + timedelta( + seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 + ) + token = Token("access_token", "Bearer", expiry=near_future) + self.assertTrue(token.needs_refresh()) + + # Token with expiry far in the future + far_future = datetime.now(tz=timezone.utc) + timedelta( + seconds=TOKEN_REFRESH_BUFFER_SECONDS + 10 + ) + token = Token("access_token", "Bearer", expiry=far_future) + self.assertFalse(token.needs_refresh()) + + +class TestSimpleCredentialsProvider(unittest.TestCase): + """Tests for the SimpleCredentialsProvider class.""" + + def test_simple_credentials_provider(self): + """Test SimpleCredentialsProvider.""" + provider = SimpleCredentialsProvider( + "token_value", "Bearer", "custom_auth_type" + ) + self.assertEqual(provider.auth_type(), "custom_auth_type") + + header_factory = provider() + headers = header_factory() + self.assertEqual(headers, {"Authorization": "Bearer token_value"}) + + +class TestTokenFederationProvider(unittest.TestCase): + """Tests for the DatabricksTokenFederationProvider class.""" + + def test_host_property(self): + """Test the host property of DatabricksTokenFederationProvider.""" + creds_provider = SimpleCredentialsProvider("token") + federation_provider = DatabricksTokenFederationProvider( + creds_provider, "example.com", "client_id" + ) + self.assertEqual(federation_provider.host, "example.com") + self.assertEqual(federation_provider.hostname, "example.com") + + @patch("databricks.sql.auth.token_federation.requests.get") + @patch("databricks.sql.auth.token_federation.get_oauth_endpoints") + def test_init_oidc_discovery(self, mock_get_endpoints, mock_requests_get): + """Test _init_oidc_discovery method.""" + # Mock the get_oauth_endpoints function + mock_endpoints = MagicMock() + mock_endpoints.get_openid_config_url.return_value = ( + "https://example.com/openid-config" + ) + mock_get_endpoints.return_value = mock_endpoints + + # Mock the requests.get response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "token_endpoint": "https://example.com/token" + } + mock_requests_get.return_value = mock_response + + # Create the provider + creds_provider = SimpleCredentialsProvider("token") + federation_provider = DatabricksTokenFederationProvider( + creds_provider, "example.com", "client_id" + ) + + # Call the method + federation_provider._init_oidc_discovery() + + # Check if the token endpoint was set correctly + self.assertEqual( + federation_provider.token_endpoint, "https://example.com/token" + ) + + # Test fallback when discovery fails + mock_requests_get.side_effect = Exception("Connection error") + federation_provider.token_endpoint = None + federation_provider._init_oidc_discovery() + self.assertEqual( + federation_provider.token_endpoint, "https://example.com/oidc/v1/token" + ) + + @patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims" + ) + @patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" + ) + @patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" + ) + def test_token_refresh( + self, mock_is_same_host, mock_exchange_token, mock_parse_jwt + ): + """Test token refresh functionality for approaching expiry.""" + # Set up mocks + mock_parse_jwt.return_value = { + "iss": "https://login.microsoftonline.com/tenant" + } + mock_is_same_host.return_value = False + + # Create the initial header factory + initial_headers = {"Authorization": "Bearer initial_token"} + initial_header_factory = MagicMock() + initial_header_factory.return_value = initial_headers + + # Create the fresh header factory for later use + fresh_headers = {"Authorization": "Bearer fresh_token"} + fresh_header_factory = MagicMock() + fresh_header_factory.return_value = fresh_headers + + # Create the credentials provider that will return the header factory + mock_creds_provider = MagicMock() + mock_creds_provider.return_value = initial_header_factory + + # Set up the token federation provider + federation_provider = DatabricksTokenFederationProvider( + mock_creds_provider, "example.com", "client_id" + ) + + # Mock the token exchange to return a known token + future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) + mock_exchange_token.return_value = Token( + "exchanged_token_1", "Bearer", expiry=future_time + ) + + # First call to get initial headers and token - this should trigger an exchange + headers_factory = federation_provider() + headers = headers_factory() + + # Verify the exchange happened with the initial token + mock_exchange_token.assert_called_with("initial_token") + self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1") + + # Reset the mocks to track the next call + mock_exchange_token.reset_mock() + + # Now simulate an approaching expiry + near_expiry = datetime.now(tz=timezone.utc) + timedelta( + seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 + ) + federation_provider.last_exchanged_token = Token( + "exchanged_token_1", "Bearer", expiry=near_expiry + ) + federation_provider.last_external_token = "initial_token" + + # For the refresh call, we need the credentials provider to return a fresh token + # Update the mock to return fresh_header_factory for the second call + mock_creds_provider.return_value = fresh_header_factory + + # Set up the mock to return a different token for the refresh + mock_exchange_token.return_value = Token( + "exchanged_token_2", "Bearer", expiry=future_time + ) + + # Make a second call which should trigger refresh + headers = headers_factory() + + # Verify the exchange was performed with the fresh token + mock_exchange_token.assert_called_once_with("fresh_token") + + # Verify the headers contain the new token + self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2") + + +class TestTokenFederationFactory(unittest.TestCase): + """Tests for the token federation factory function.""" + + def test_create_token_federation_provider(self): + """Test create_token_federation_provider function.""" + provider = create_token_federation_provider( + "token_value", "example.com", "client_id", "Bearer" + ) + + self.assertIsInstance(provider, DatabricksTokenFederationProvider) + self.assertEqual(provider.hostname, "example.com") + self.assertEqual(provider.identity_federation_client_id, "client_id") + + # Test that the underlying credentials provider was set up correctly + self.assertEqual(provider.credentials_provider.token, "token_value") + self.assertEqual(provider.credentials_provider.token_type, "Bearer") + + +if __name__ == "__main__": + unittest.main()