Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DatabricksHook #19835

Merged
merged 4 commits into from
Dec 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 73 additions & 64 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@

USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'}

RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']

# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token
# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints
AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com"
Expand All @@ -64,7 +66,9 @@
class RunState:
"""Utility class for the run state concept of Databricks runs."""

def __init__(self, life_cycle_state: str, result_state: str, state_message: str) -> None:
def __init__(
self, life_cycle_state: str, result_state: str = '', state_message: str = '', *args, **kwargs
) -> None:
self.life_cycle_state = life_cycle_state
self.result_state = result_state
self.state_message = state_message
Expand Down Expand Up @@ -131,7 +135,11 @@ def __init__(
) -> None:
super().__init__()
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = None
self.databricks_conn = self.get_connection(databricks_conn_id)
if 'host' in self.databricks_conn.extra_dejson:
self.host = self._parse_host(self.databricks_conn.extra_dejson['host'])
else:
self.host = self._parse_host(self.databricks_conn.host)
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError('Retry limit must be greater than equal to 1')
Expand Down Expand Up @@ -173,13 +181,11 @@ def _get_aad_token(self, resource: str) -> str:
:param resource: resource to issue token to
:return: AAD token, or raise an exception
"""
if resource in self.aad_tokens:
d = self.aad_tokens[resource]
now = int(time.time())
if d['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): # it expires in more than 2 minutes
return d['token']
self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...")
aad_token = self.aad_tokens.get(resource)
if aad_token and self._is_aad_token_valid(aad_token):
return aad_token['token']

self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')
attempt_num = 1
while True:
try:
Expand Down Expand Up @@ -235,21 +241,53 @@ def _get_aad_token(self, resource: str) -> str:
attempt_num += 1
sleep(self.retry_delay)

def _fill_aad_tokens(self, headers: dict) -> str:
def _get_aad_headers(self) -> dict:
"""
Fills headers if necessary (SPN is outside of the workspace) and generates AAD token
:param headers: dictionary with headers to fill-in
:return: AAD token
Fills AAD headers if necessary (SPN is outside of the workspace)
:return: dictionary with filled AAD headers
"""
# SP is outside of the workspace
headers = {}
if 'azure_resource_id' in self.databricks_conn.extra_dejson:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep this check inside the function, because it could be called by accident (in the future). maybe call it _fill_aad_headers_if_needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think if we call it _get_aad_headers(), which would return either empty dict or a filled dict? Also we won't need input arg headers in this case.

Then we could construct headers like:

aad_headers = self._get_aad_headers()
headers = {**USER_AGENT_HEADER.copy(), **aad_headers}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I thought something like this. it's easier to use because the logic of adding headers is incorporating inside function...

mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[
'azure_resource_id'
]
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
return headers

return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
@staticmethod
def _is_aad_token_valid(aad_token: dict) -> bool:
"""
Utility function to check AAD token hasn't expired yet
:param aad_token: dict with properties of AAD token
:type aad_token: dict
:return: true if token is valid, false otherwise
:rtype: bool
"""
now = int(time.time())
if aad_token['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME):
return True
return False

@staticmethod
def _check_azure_metadata_service() -> None:
"""
Check for Azure Metadata Service
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
"""
try:
jsn = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
).json()
if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")

def _do_api_call(self, endpoint_info, json):
"""
Expand All @@ -265,14 +303,10 @@ def _do_api_call(self, endpoint_info, json):
:rtype: dict
"""
method, endpoint = endpoint_info
url = f'https://{self.host}/{endpoint}'

self.databricks_conn = self.get_connection(self.databricks_conn_id)

headers = USER_AGENT_HEADER.copy()
if 'host' in self.databricks_conn.extra_dejson:
host = self._parse_host(self.databricks_conn.extra_dejson['host'])
else:
host = self.databricks_conn.host
aad_headers = self._get_aad_headers()
headers = {**USER_AGENT_HEADER.copy(), **aad_headers}

if 'token' in self.databricks_conn.extra_dejson:
self.log.info(
Expand All @@ -285,34 +319,16 @@ def _do_api_call(self, endpoint_info, json):
elif 'azure_tenant_id' in self.databricks_conn.extra_dejson:
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
raise AirflowException("Azure SPN credentials aren't provided")

self.log.info('Using AAD Token for SPN. ')
auth = _TokenAuth(self._fill_aad_tokens(headers))
self.log.info('Using AAD Token for SPN.')
auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE))
elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
self.log.info('Using AAD Token for managed identity.')
# check for Azure Metadata Service
# https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
try:
jsn = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
).json()
if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")

auth = _TokenAuth(self._fill_aad_tokens(headers))
self._check_azure_metadata_service()
auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE))
else:
self.log.info('Using basic auth.')
auth = (self.databricks_conn.login, self.databricks_conn.password)

url = f'https://{self._parse_host(host)}/{endpoint}'

if method == 'GET':
request_func = requests.get
elif method == 'POST':
Expand Down Expand Up @@ -356,31 +372,31 @@ def _do_api_call(self, endpoint_info, json):
def _log_request_error(self, attempt_num: int, error: str) -> None:
self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error)

def run_now(self, json: dict) -> str:
def run_now(self, json: dict) -> int:
Comment on lines -359 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it break existing code? for example if people are using this result to concatenate with log string without using str?

Copy link
Contributor Author

@eskarimov eskarimov Nov 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't break existing code, actually it's opposite - if someone assumes that output is str because of the function signature, then it'd break the code, because the actual returned type is int.

"""
Utility function to call the ``api/2.0/jobs/run-now`` endpoint.

:param json: The data used in the body of the request to the ``run-now`` endpoint.
:type json: dict
:return: the run_id as a string
:return: the run_id as an int
:rtype: str
"""
response = self._do_api_call(RUN_NOW_ENDPOINT, json)
return response['run_id']

def submit_run(self, json: dict) -> str:
def submit_run(self, json: dict) -> int:
"""
Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint.

:param json: The data used in the body of the request to the ``submit`` endpoint.
:type json: dict
:return: the run_id as a string
:return: the run_id as an int
:rtype: str
"""
response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
return response['run_id']

def get_run_page_url(self, run_id: str) -> str:
def get_run_page_url(self, run_id: int) -> str:
"""
Retrieves run_page_url.

Expand All @@ -391,19 +407,19 @@ def get_run_page_url(self, run_id: str) -> str:
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response['run_page_url']

def get_job_id(self, run_id: str) -> str:
def get_job_id(self, run_id: int) -> int:
"""
Retrieves job_id from run_id.

:param run_id: id of the run
:type run_id: str
:type run_id: int
:return: Job id for given Databricks run
"""
json = {'run_id': run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response['job_id']

def get_run_state(self, run_id: str) -> RunState:
def get_run_state(self, run_id: int) -> RunState:
"""
Retrieves run state of the run.

Expand All @@ -421,13 +437,9 @@ def get_run_state(self, run_id: str) -> RunState:
json = {'run_id': run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
state = response['state']
life_cycle_state = state['life_cycle_state']
# result_state may not be in the state if not terminal
result_state = state.get('result_state', None)
state_message = state['state_message']
return RunState(life_cycle_state, result_state, state_message)
return RunState(**state)

def get_run_state_str(self, run_id: str) -> str:
def get_run_state_str(self, run_id: int) -> str:
"""
Return the string representation of RunState.

Expand All @@ -440,7 +452,7 @@ def get_run_state_str(self, run_id: str) -> str:
)
return run_state_str

def get_run_state_lifecycle(self, run_id: str) -> str:
def get_run_state_lifecycle(self, run_id: int) -> str:
"""
Returns the lifecycle state of the run

Expand All @@ -449,7 +461,7 @@ def get_run_state_lifecycle(self, run_id: str) -> str:
"""
return self.get_run_state(run_id).life_cycle_state

def get_run_state_result(self, run_id: str) -> str:
def get_run_state_result(self, run_id: int) -> str:
"""
Returns the resulting state of the run

Expand All @@ -458,7 +470,7 @@ def get_run_state_result(self, run_id: str) -> str:
"""
return self.get_run_state(run_id).result_state

def get_run_state_message(self, run_id: str) -> str:
def get_run_state_message(self, run_id: int) -> str:
"""
Returns the state message for the run

Expand All @@ -467,7 +479,7 @@ def get_run_state_message(self, run_id: str) -> str:
"""
return self.get_run_state(run_id).state_message

def cancel_run(self, run_id: str) -> None:
def cancel_run(self, run_id: int) -> None:
"""
Cancels the run.

Expand Down Expand Up @@ -531,9 +543,6 @@ def _retryable_error(exception) -> bool:
)


RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']


class _TokenAuth(AuthBase):
"""
Helper class for requests Auth field. AuthBase requires you to implement the __call__
Expand Down
56 changes: 55 additions & 1 deletion tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import itertools
import json
import time
import unittest
from unittest import mock

Expand All @@ -31,9 +32,11 @@
from airflow.providers.databricks.hooks.databricks import (
AZURE_DEFAULT_AD_ENDPOINT,
AZURE_MANAGEMENT_ENDPOINT,
AZURE_METADATA_SERVICE_TOKEN_URL,
AZURE_TOKEN_SERVICE_URL,
DEFAULT_DATABRICKS_SCOPE,
SUBMIT_RUN_ENDPOINT,
TOKEN_REFRESH_LEAD_TIME,
DatabricksHook,
RunState,
)
Expand Down Expand Up @@ -63,7 +66,7 @@
}
NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}
JAR_PARAMS = ["param1", "param2"]
RESULT_STATE = None # type: None
RESULT_STATE = ''
LIBRARIES = [
{"jar": "dbfs:/mnt/libraries/library.jar"},
{"maven": {"coordinates": "org.jsoup:jsoup:1.7.2", "exclusions": ["slf4j:slf4j"]}},
Expand Down Expand Up @@ -520,6 +523,14 @@ def test_uninstall_libs_on_cluster(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

def test_is_aad_token_valid_returns_true(self):
aad_token = {'token': 'my_token', 'expires_on': int(time.time()) + TOKEN_REFRESH_LEAD_TIME + 10}
self.assertTrue(self.hook._is_aad_token_valid(aad_token))

def test_is_aad_token_valid_returns_false(self):
aad_token = {'token': 'my_token', 'expires_on': int(time.time())}
self.assertFalse(self.hook._is_aad_token_valid(aad_token))


class TestDatabricksHookToken(unittest.TestCase):
"""
Expand Down Expand Up @@ -762,3 +773,46 @@ def test_submit_run(self, mock_requests):
assert kwargs['auth'].token == TOKEN
assert kwargs['headers']['X-Databricks-Azure-Workspace-Resource-Id'] == '/Some/resource'
assert kwargs['headers']['X-Databricks-Azure-SP-Management-Token'] == TOKEN


class TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase):
"""
Tests for DatabricksHook when auth is done with AAD leveraging Managed Identity authentication
"""

@provide_session
def setUp(self, session=None):
conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first()
conn.host = HOST
conn.extra = json.dumps(
{
'use_azure_managed_identity': True,
}
)
session.commit()
self.hook = DatabricksHook()

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_submit_run(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.side_effect = [
create_successful_response_mock({'compute': {'azEnvironment': 'AZUREPUBLICCLOUD'}}),
create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)),
]
mock_requests.post.side_effect = [
create_successful_response_mock({'run_id': '1'}),
]
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
run_id = self.hook.submit_run(data)

ad_call_args = mock_requests.method_calls[0]
assert ad_call_args[1][0] == AZURE_METADATA_SERVICE_TOKEN_URL
assert ad_call_args[2]['params']['api-version'] > '2018-02-01'
assert ad_call_args[2]['headers']['Metadata'] == 'true'

assert run_id == '1'
args = mock_requests.post.call_args
kwargs = args[1]
assert kwargs['auth'].token == TOKEN