From 91e6e6055b3241aae7e1593bd9b855682c733e7d Mon Sep 17 00:00:00 2001 From: uzhastik Date: Thu, 27 Jun 2024 16:57:34 +0300 Subject: [PATCH] support auth key from content and from file (#40390) * support auth key from content and from file * add logs * fix static check --- airflow/providers/ydb/provider.yaml | 2 +- airflow/providers/ydb/utils/credentials.py | 15 +++++-- generated/provider_dependencies.json | 2 +- tests/providers/ydb/utils/test_credentials.py | 40 +++++++++++++++++-- 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/airflow/providers/ydb/provider.yaml b/airflow/providers/ydb/provider.yaml index 944df9204390..e6d21dcbb12a 100644 --- a/airflow/providers/ydb/provider.yaml +++ b/airflow/providers/ydb/provider.yaml @@ -30,7 +30,7 @@ versions: dependencies: - apache-airflow>=2.7.0 - apache-airflow-providers-common-sql>=1.3.1 - - ydb>=3.11.3 + - ydb>=3.12.1 integrations: - integration-name: YDB diff --git a/airflow/providers/ydb/utils/credentials.py b/airflow/providers/ydb/utils/credentials.py index 61e08ac10919..db468accf51f 100644 --- a/airflow/providers/ydb/utils/credentials.py +++ b/airflow/providers/ydb/utils/credentials.py @@ -16,16 +16,17 @@ # under the License. from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any import ydb import ydb.iam.auth as auth -from airflow.exceptions import AirflowException - if TYPE_CHECKING: from airflow.models.connection import Connection +log = logging.getLogger(__name__) + def get_credentials_from_connection( endpoint: str, database: str, connection: Connection, connection_extra: dict[str, Any] | None = None @@ -54,23 +55,29 @@ def get_credentials_from_connection( database=database, ) + log.info("using login as credentials") return ydb.StaticCredentials(driver_config, user=connection.login, password=connection.password) connection_extra = connection_extra or {} token = connection_extra.get("token") if token: + log.info("using token as credentials") return ydb.AccessTokenCredentials(token) service_account_json_path = connection_extra.get("service_account_json_path") if service_account_json_path: - return auth.BaseJWTCredentials.from_file(auth.ServiceAccountCredentials, service_account_json_path) + log.info("using service_account_json_path as credentials") + return auth.ServiceAccountCredentials.from_file(service_account_json_path) service_account_json = connection_extra.get("service_account_json") if service_account_json: - raise AirflowException("service_account_json parameter is not supported yet") + log.info("using service_account_json as credentials") + return auth.ServiceAccountCredentials.from_content(service_account_json) use_vm_metadata = connection_extra.get("use_vm_metadata", False) if use_vm_metadata: + log.info("using vm metadata as credentials") return auth.MetadataUrlCredentials() + log.info("using anonymous access") return ydb.AnonymousCredentials() diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index f4760e5a0aea..9f0e85f5cd49 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1343,7 +1343,7 @@ "deps": [ "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow>=2.7.0", - "ydb>=3.11.3" + "ydb>=3.12.1" ], "devel-deps": [], "plugins": [], diff --git a/tests/providers/ydb/utils/test_credentials.py b/tests/providers/ydb/utils/test_credentials.py index 864069ddbcc8..af0d58cddd19 100644 --- a/tests/providers/ydb/utils/test_credentials.py +++ b/tests/providers/ydb/utils/test_credentials.py @@ -68,6 +68,20 @@ def test_vm_metadata_creds(mock): mock.assert_called_once() +@patch("ydb.iam.auth.BaseJWTCredentials.from_content") +def test_service_account_json_creds(mock): + mock.return_value = MAGIC_CONST + c = Connection(conn_type="ydb", host="localhost") + + credentials = get_credentials_from_connection( + TEST_ENDPOINT, TEST_DATABASE, c, {"service_account_json": "my_json"} + ) + assert credentials == MAGIC_CONST + mock.assert_called_once() + + assert mock.call_args.args == ("my_json",) + + @patch("ydb.iam.auth.BaseJWTCredentials.from_file") def test_service_account_json_path_creds(mock): mock.return_value = MAGIC_CONST @@ -79,8 +93,7 @@ def test_service_account_json_path_creds(mock): assert credentials == MAGIC_CONST mock.assert_called_once() - assert len(mock.call_args.args) == 2 - assert mock.call_args.args[1] == "my_path" + assert mock.call_args.args == ("my_path",) def test_creds_priority(): @@ -93,6 +106,7 @@ def test_creds_priority(): TEST_DATABASE, c, { + "service_account_json": "my_json", "service_account_json_path": "my_path", "use_vm_metadata": True, "token": "my_token", @@ -110,6 +124,7 @@ def test_creds_priority(): TEST_DATABASE, c, { + "service_account_json": "my_json", "service_account_json_path": "my_path", "use_vm_metadata": True, "token": "my_token", @@ -127,6 +142,7 @@ def test_creds_priority(): TEST_DATABASE, c, { + "service_account_json": "my_json", "service_account_json_path": "my_path", "use_vm_metadata": True, }, @@ -134,7 +150,23 @@ def test_creds_priority(): assert credentials == MAGIC_CONST mock.assert_called_once() - # 4. vm metadata + # 4. service account json + with patch("ydb.iam.auth.BaseJWTCredentials.from_content") as mock: + c = Connection(conn_type="ydb", host="localhost") + mock.return_value = MAGIC_CONST + credentials = get_credentials_from_connection( + TEST_ENDPOINT, + TEST_DATABASE, + c, + { + "service_account_json": "my_json", + "use_vm_metadata": True, + }, + ) + assert credentials == MAGIC_CONST + mock.assert_called_once() + + # 5. vm metadata with patch("ydb.iam.auth.MetadataUrlCredentials") as mock: c = Connection(conn_type="ydb", host="localhost") mock.return_value = MAGIC_CONST @@ -149,7 +181,7 @@ def test_creds_priority(): assert credentials == MAGIC_CONST mock.assert_called_once() - # 5. anonymous + # 6. anonymous with patch("ydb.AnonymousCredentials") as mock: c = Connection(conn_type="ydb", host="localhost") mock.return_value = MAGIC_CONST