Skip to content

Commit

Permalink
support auth key from content and from file (#40390)
Browse files Browse the repository at this point in the history
* support auth key from content and from file

* add logs

* fix static check
  • Loading branch information
uzhastik committed Jun 27, 2024
1 parent d0e4b8d commit 91e6e60
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 10 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/ydb/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ versions:
dependencies:
- apache-airflow>=2.7.0
- apache-airflow-providers-common-sql>=1.3.1
- ydb>=3.11.3
- ydb>=3.12.1

integrations:
- integration-name: YDB
Expand Down
15 changes: 11 additions & 4 deletions airflow/providers/ydb/utils/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
# under the License.
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

import ydb
import ydb.iam.auth as auth

from airflow.exceptions import AirflowException

if TYPE_CHECKING:
from airflow.models.connection import Connection

log = logging.getLogger(__name__)


def get_credentials_from_connection(
endpoint: str, database: str, connection: Connection, connection_extra: dict[str, Any] | None = None
Expand Down Expand Up @@ -54,23 +55,29 @@ def get_credentials_from_connection(
database=database,
)

log.info("using login as credentials")
return ydb.StaticCredentials(driver_config, user=connection.login, password=connection.password)

connection_extra = connection_extra or {}
token = connection_extra.get("token")
if token:
log.info("using token as credentials")
return ydb.AccessTokenCredentials(token)

service_account_json_path = connection_extra.get("service_account_json_path")
if service_account_json_path:
return auth.BaseJWTCredentials.from_file(auth.ServiceAccountCredentials, service_account_json_path)
log.info("using service_account_json_path as credentials")
return auth.ServiceAccountCredentials.from_file(service_account_json_path)

service_account_json = connection_extra.get("service_account_json")
if service_account_json:
raise AirflowException("service_account_json parameter is not supported yet")
log.info("using service_account_json as credentials")
return auth.ServiceAccountCredentials.from_content(service_account_json)

use_vm_metadata = connection_extra.get("use_vm_metadata", False)
if use_vm_metadata:
log.info("using vm metadata as credentials")
return auth.MetadataUrlCredentials()

log.info("using anonymous access")
return ydb.AnonymousCredentials()
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@
"deps": [
"apache-airflow-providers-common-sql>=1.3.1",
"apache-airflow>=2.7.0",
"ydb>=3.11.3"
"ydb>=3.12.1"
],
"devel-deps": [],
"plugins": [],
Expand Down
40 changes: 36 additions & 4 deletions tests/providers/ydb/utils/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ def test_vm_metadata_creds(mock):
mock.assert_called_once()


@patch("ydb.iam.auth.BaseJWTCredentials.from_content")
def test_service_account_json_creds(mock):
mock.return_value = MAGIC_CONST
c = Connection(conn_type="ydb", host="localhost")

credentials = get_credentials_from_connection(
TEST_ENDPOINT, TEST_DATABASE, c, {"service_account_json": "my_json"}
)
assert credentials == MAGIC_CONST
mock.assert_called_once()

assert mock.call_args.args == ("my_json",)


@patch("ydb.iam.auth.BaseJWTCredentials.from_file")
def test_service_account_json_path_creds(mock):
mock.return_value = MAGIC_CONST
Expand All @@ -79,8 +93,7 @@ def test_service_account_json_path_creds(mock):
assert credentials == MAGIC_CONST
mock.assert_called_once()

assert len(mock.call_args.args) == 2
assert mock.call_args.args[1] == "my_path"
assert mock.call_args.args == ("my_path",)


def test_creds_priority():
Expand All @@ -93,6 +106,7 @@ def test_creds_priority():
TEST_DATABASE,
c,
{
"service_account_json": "my_json",
"service_account_json_path": "my_path",
"use_vm_metadata": True,
"token": "my_token",
Expand All @@ -110,6 +124,7 @@ def test_creds_priority():
TEST_DATABASE,
c,
{
"service_account_json": "my_json",
"service_account_json_path": "my_path",
"use_vm_metadata": True,
"token": "my_token",
Expand All @@ -127,14 +142,31 @@ def test_creds_priority():
TEST_DATABASE,
c,
{
"service_account_json": "my_json",
"service_account_json_path": "my_path",
"use_vm_metadata": True,
},
)
assert credentials == MAGIC_CONST
mock.assert_called_once()

# 4. vm metadata
# 4. service account json
with patch("ydb.iam.auth.BaseJWTCredentials.from_content") as mock:
c = Connection(conn_type="ydb", host="localhost")
mock.return_value = MAGIC_CONST
credentials = get_credentials_from_connection(
TEST_ENDPOINT,
TEST_DATABASE,
c,
{
"service_account_json": "my_json",
"use_vm_metadata": True,
},
)
assert credentials == MAGIC_CONST
mock.assert_called_once()

# 5. vm metadata
with patch("ydb.iam.auth.MetadataUrlCredentials") as mock:
c = Connection(conn_type="ydb", host="localhost")
mock.return_value = MAGIC_CONST
Expand All @@ -149,7 +181,7 @@ def test_creds_priority():
assert credentials == MAGIC_CONST
mock.assert_called_once()

# 5. anonymous
# 6. anonymous
with patch("ydb.AnonymousCredentials") as mock:
c = Connection(conn_type="ydb", host="localhost")
mock.return_value = MAGIC_CONST
Expand Down

0 comments on commit 91e6e60

Please sign in to comment.