-
Notifications
You must be signed in to change notification settings - Fork 14.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor DatabricksHook #19835
Refactor DatabricksHook #19835
Changes from all commits
f770b76
98d91c0
6234779
9d08a12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,8 @@ | |
|
||
USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'} | ||
|
||
RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] | ||
|
||
# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token | ||
# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints | ||
AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com" | ||
|
@@ -64,7 +66,9 @@ | |
class RunState: | ||
"""Utility class for the run state concept of Databricks runs.""" | ||
|
||
def __init__(self, life_cycle_state: str, result_state: str, state_message: str) -> None: | ||
def __init__( | ||
self, life_cycle_state: str, result_state: str = '', state_message: str = '', *args, **kwargs | ||
) -> None: | ||
self.life_cycle_state = life_cycle_state | ||
self.result_state = result_state | ||
self.state_message = state_message | ||
|
@@ -131,7 +135,11 @@ def __init__( | |
) -> None: | ||
super().__init__() | ||
self.databricks_conn_id = databricks_conn_id | ||
self.databricks_conn = None | ||
self.databricks_conn = self.get_connection(databricks_conn_id) | ||
if 'host' in self.databricks_conn.extra_dejson: | ||
self.host = self._parse_host(self.databricks_conn.extra_dejson['host']) | ||
else: | ||
self.host = self._parse_host(self.databricks_conn.host) | ||
self.timeout_seconds = timeout_seconds | ||
if retry_limit < 1: | ||
raise ValueError('Retry limit must be greater than equal to 1') | ||
|
@@ -173,13 +181,11 @@ def _get_aad_token(self, resource: str) -> str: | |
:param resource: resource to issue token to | ||
:return: AAD token, or raise an exception | ||
""" | ||
if resource in self.aad_tokens: | ||
d = self.aad_tokens[resource] | ||
now = int(time.time()) | ||
if d['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): # it expires in more than 2 minutes | ||
return d['token'] | ||
self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...") | ||
aad_token = self.aad_tokens.get(resource) | ||
if aad_token and self._is_aad_token_valid(aad_token): | ||
return aad_token['token'] | ||
|
||
self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...') | ||
attempt_num = 1 | ||
while True: | ||
try: | ||
|
@@ -235,21 +241,53 @@ def _get_aad_token(self, resource: str) -> str: | |
attempt_num += 1 | ||
sleep(self.retry_delay) | ||
|
||
def _fill_aad_tokens(self, headers: dict) -> str: | ||
def _get_aad_headers(self) -> dict: | ||
""" | ||
Fills headers if necessary (SPN is outside of the workspace) and generates AAD token | ||
:param headers: dictionary with headers to fill-in | ||
:return: AAD token | ||
Fills AAD headers if necessary (SPN is outside of the workspace) | ||
:return: dictionary with filled AAD headers | ||
""" | ||
# SP is outside of the workspace | ||
headers = {} | ||
if 'azure_resource_id' in self.databricks_conn.extra_dejson: | ||
mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT) | ||
headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[ | ||
'azure_resource_id' | ||
] | ||
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token | ||
return headers | ||
|
||
return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE) | ||
@staticmethod | ||
def _is_aad_token_valid(aad_token: dict) -> bool: | ||
""" | ||
Utility function to check AAD token hasn't expired yet | ||
:param aad_token: dict with properties of AAD token | ||
:type aad_token: dict | ||
:return: true if token is valid, false otherwise | ||
:rtype: bool | ||
""" | ||
now = int(time.time()) | ||
if aad_token['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): | ||
return True | ||
return False | ||
|
||
@staticmethod | ||
def _check_azure_metadata_service() -> None: | ||
""" | ||
Check for Azure Metadata Service | ||
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service | ||
""" | ||
try: | ||
jsn = requests.get( | ||
AZURE_METADATA_SERVICE_TOKEN_URL, | ||
params={"api-version": "2021-02-01"}, | ||
headers={"Metadata": "true"}, | ||
timeout=2, | ||
).json() | ||
if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']: | ||
raise AirflowException( | ||
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" | ||
) | ||
except (requests_exceptions.RequestException, ValueError) as e: | ||
raise AirflowException(f"Can't reach Azure Metadata Service: {e}") | ||
|
||
def _do_api_call(self, endpoint_info, json): | ||
""" | ||
|
@@ -265,14 +303,10 @@ def _do_api_call(self, endpoint_info, json): | |
:rtype: dict | ||
""" | ||
method, endpoint = endpoint_info | ||
url = f'https://{self.host}/{endpoint}' | ||
|
||
self.databricks_conn = self.get_connection(self.databricks_conn_id) | ||
|
||
headers = USER_AGENT_HEADER.copy() | ||
if 'host' in self.databricks_conn.extra_dejson: | ||
host = self._parse_host(self.databricks_conn.extra_dejson['host']) | ||
else: | ||
host = self.databricks_conn.host | ||
aad_headers = self._get_aad_headers() | ||
headers = {**USER_AGENT_HEADER.copy(), **aad_headers} | ||
|
||
if 'token' in self.databricks_conn.extra_dejson: | ||
self.log.info( | ||
|
@@ -285,34 +319,16 @@ def _do_api_call(self, endpoint_info, json): | |
elif 'azure_tenant_id' in self.databricks_conn.extra_dejson: | ||
if self.databricks_conn.login == "" or self.databricks_conn.password == "": | ||
raise AirflowException("Azure SPN credentials aren't provided") | ||
|
||
self.log.info('Using AAD Token for SPN. ') | ||
auth = _TokenAuth(self._fill_aad_tokens(headers)) | ||
self.log.info('Using AAD Token for SPN.') | ||
auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)) | ||
elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): | ||
self.log.info('Using AAD Token for managed identity.') | ||
# check for Azure Metadata Service | ||
# https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service | ||
try: | ||
jsn = requests.get( | ||
AZURE_METADATA_SERVICE_TOKEN_URL, | ||
params={"api-version": "2021-02-01"}, | ||
headers={"Metadata": "true"}, | ||
timeout=2, | ||
).json() | ||
if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']: | ||
raise AirflowException( | ||
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" | ||
) | ||
except (requests_exceptions.RequestException, ValueError) as e: | ||
raise AirflowException(f"Can't reach Azure Metadata Service: {e}") | ||
|
||
auth = _TokenAuth(self._fill_aad_tokens(headers)) | ||
self._check_azure_metadata_service() | ||
auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)) | ||
else: | ||
self.log.info('Using basic auth.') | ||
auth = (self.databricks_conn.login, self.databricks_conn.password) | ||
|
||
url = f'https://{self._parse_host(host)}/{endpoint}' | ||
|
||
if method == 'GET': | ||
request_func = requests.get | ||
elif method == 'POST': | ||
|
@@ -356,31 +372,31 @@ def _do_api_call(self, endpoint_info, json): | |
def _log_request_error(self, attempt_num: int, error: str) -> None: | ||
self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error) | ||
|
||
def run_now(self, json: dict) -> str: | ||
def run_now(self, json: dict) -> int: | ||
Comment on lines
-359
to
+375
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can it break existing code? for example if people are using this result to concatenate with log string without using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It won't break existing code, actually it's opposite - if someone assumes that output is |
||
""" | ||
Utility function to call the ``api/2.0/jobs/run-now`` endpoint. | ||
|
||
:param json: The data used in the body of the request to the ``run-now`` endpoint. | ||
:type json: dict | ||
:return: the run_id as a string | ||
:return: the run_id as an int | ||
:rtype: str | ||
""" | ||
response = self._do_api_call(RUN_NOW_ENDPOINT, json) | ||
return response['run_id'] | ||
|
||
def submit_run(self, json: dict) -> str: | ||
def submit_run(self, json: dict) -> int: | ||
""" | ||
Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint. | ||
|
||
:param json: The data used in the body of the request to the ``submit`` endpoint. | ||
:type json: dict | ||
:return: the run_id as a string | ||
:return: the run_id as an int | ||
:rtype: str | ||
""" | ||
response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json) | ||
return response['run_id'] | ||
|
||
def get_run_page_url(self, run_id: str) -> str: | ||
def get_run_page_url(self, run_id: int) -> str: | ||
""" | ||
Retrieves run_page_url. | ||
|
||
|
@@ -391,19 +407,19 @@ def get_run_page_url(self, run_id: str) -> str: | |
response = self._do_api_call(GET_RUN_ENDPOINT, json) | ||
return response['run_page_url'] | ||
|
||
def get_job_id(self, run_id: str) -> str: | ||
def get_job_id(self, run_id: int) -> int: | ||
""" | ||
Retrieves job_id from run_id. | ||
|
||
:param run_id: id of the run | ||
:type run_id: str | ||
:type run_id: int | ||
:return: Job id for given Databricks run | ||
""" | ||
json = {'run_id': run_id} | ||
response = self._do_api_call(GET_RUN_ENDPOINT, json) | ||
return response['job_id'] | ||
|
||
def get_run_state(self, run_id: str) -> RunState: | ||
def get_run_state(self, run_id: int) -> RunState: | ||
""" | ||
Retrieves run state of the run. | ||
|
||
|
@@ -421,13 +437,9 @@ def get_run_state(self, run_id: str) -> RunState: | |
json = {'run_id': run_id} | ||
response = self._do_api_call(GET_RUN_ENDPOINT, json) | ||
state = response['state'] | ||
life_cycle_state = state['life_cycle_state'] | ||
# result_state may not be in the state if not terminal | ||
result_state = state.get('result_state', None) | ||
state_message = state['state_message'] | ||
return RunState(life_cycle_state, result_state, state_message) | ||
return RunState(**state) | ||
|
||
def get_run_state_str(self, run_id: str) -> str: | ||
def get_run_state_str(self, run_id: int) -> str: | ||
""" | ||
Return the string representation of RunState. | ||
|
||
|
@@ -440,7 +452,7 @@ def get_run_state_str(self, run_id: str) -> str: | |
) | ||
return run_state_str | ||
|
||
def get_run_state_lifecycle(self, run_id: str) -> str: | ||
def get_run_state_lifecycle(self, run_id: int) -> str: | ||
""" | ||
Returns the lifecycle state of the run | ||
|
||
|
@@ -449,7 +461,7 @@ def get_run_state_lifecycle(self, run_id: str) -> str: | |
""" | ||
return self.get_run_state(run_id).life_cycle_state | ||
|
||
def get_run_state_result(self, run_id: str) -> str: | ||
def get_run_state_result(self, run_id: int) -> str: | ||
""" | ||
Returns the resulting state of the run | ||
|
||
|
@@ -458,7 +470,7 @@ def get_run_state_result(self, run_id: str) -> str: | |
""" | ||
return self.get_run_state(run_id).result_state | ||
|
||
def get_run_state_message(self, run_id: str) -> str: | ||
def get_run_state_message(self, run_id: int) -> str: | ||
""" | ||
Returns the state message for the run | ||
|
||
|
@@ -467,7 +479,7 @@ def get_run_state_message(self, run_id: str) -> str: | |
""" | ||
return self.get_run_state(run_id).state_message | ||
|
||
def cancel_run(self, run_id: str) -> None: | ||
def cancel_run(self, run_id: int) -> None: | ||
""" | ||
Cancels the run. | ||
|
||
|
@@ -531,9 +543,6 @@ def _retryable_error(exception) -> bool: | |
) | ||
|
||
|
||
RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] | ||
|
||
|
||
class _TokenAuth(AuthBase): | ||
""" | ||
Helper class for requests Auth field. AuthBase requires you to implement the __call__ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would keep this check inside the function, because it could be called by accident (in the future). maybe call it
_fill_aad_headers_if_needed
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think if we call it
_get_aad_headers()
, which would return either empty dict or a filled dict? Also we won't need input argheaders
in this case.Then we could construct headers like:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I thought something like this. it's easier to use because the logic of adding headers is incorporating inside function...