From 79df90a4ae196b6240c68f17dc1435813d012b67 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Mon, 29 Aug 2022 16:30:08 -0700 Subject: [PATCH 1/4] refactor submission method and add command API as defualt --- dbt/adapters/spark/impl.py | 108 +-------- dbt/adapters/spark/python_submissions.py | 281 +++++++++++++++++++++++ 2 files changed, 290 insertions(+), 99 deletions(-) create mode 100644 dbt/adapters/spark/python_submissions.py diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 12c42ab98..e78d71bfb 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -1,7 +1,4 @@ import re -import requests -import time -import base64 from concurrent.futures import Future from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Union @@ -20,6 +17,7 @@ from dbt.adapters.spark import SparkConnectionManager from dbt.adapters.spark import SparkRelation from dbt.adapters.spark import SparkColumn +from dbt.adapters.spark.python_submissions import python_submission_helpers from dbt.adapters.base import BaseRelation from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger @@ -394,105 +392,17 @@ def submit_python_job(self, parsed_model: dict, compiled_code: str, timeout=None # of `None` which evaluates to True! # TODO limit this function to run only when doing the materialization of python nodes - # assuming that for python job running over 1 day user would mannually overwrite this - schema = getattr(parsed_model, "schema", self.config.credentials.schema) - identifier = parsed_model["alias"] - if not timeout: - timeout = 60 * 60 * 24 - if timeout <= 0: - raise ValueError("Timeout must larger than 0") - - auth_header = {"Authorization": f"Bearer {self.connections.profile.credentials.token}"} - - # create new dir - if not self.connections.profile.credentials.user: - raise ValueError("Need to supply user in profile to submit python job") - # it is safe to call mkdirs even if dir already exists and have content inside - work_dir = f"/Users/{self.connections.profile.credentials.user}/{schema}" - response = requests.post( - f"https://{self.connections.profile.credentials.host}/api/2.0/workspace/mkdirs", - headers=auth_header, - json={ - "path": work_dir, - }, - ) - if response.status_code != 200: - raise dbt.exceptions.RuntimeException( - f"Error creating work_dir for python notebooks\n {response.content!r}" + submission_method = parsed_model["config"].get("submission_method", "commands") + if submission_method not in python_submission_helpers: + raise NotImplementedError( + "Submission method {} is not supported".format(submission_method) ) - - # add notebook - b64_encoded_content = base64.b64encode(compiled_code.encode()).decode() - response = requests.post( - f"https://{self.connections.profile.credentials.host}/api/2.0/workspace/import", - headers=auth_header, - json={ - "path": f"{work_dir}/{identifier}", - "content": b64_encoded_content, - "language": "PYTHON", - "overwrite": True, - "format": "SOURCE", - }, + job_helper = python_submission_helpers[submission_method]( + parsed_model, self.connections.profile.credentials ) - if response.status_code != 200: - raise dbt.exceptions.RuntimeException( - f"Error creating python notebook.\n {response.content!r}" - ) - - # submit job - submit_response = requests.post( - f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/submit", - headers=auth_header, - json={ - "run_name": "debug task", - "existing_cluster_id": self.connections.profile.credentials.cluster, - "notebook_task": { - "notebook_path": f"{work_dir}/{identifier}", - }, - }, - ) - if submit_response.status_code != 200: - raise dbt.exceptions.RuntimeException( - f"Error creating python run.\n {response.content!r}" - ) - - # poll until job finish - state = None - start = time.time() - run_id = submit_response.json()["run_id"] - terminal_states = ["TERMINATED", "SKIPPED", "INTERNAL_ERROR"] - while state not in terminal_states and time.time() - start < timeout: - time.sleep(1) - resp = requests.get( - f"https://{self.connections.profile.credentials.host}" - f"/api/2.1/jobs/runs/get?run_id={run_id}", - headers=auth_header, - ) - json_resp = resp.json() - state = json_resp["state"]["life_cycle_state"] - # logger.debug(f"Polling.... in state: {state}") - if state != "TERMINATED": - raise dbt.exceptions.RuntimeException( - "python model run ended in state" - f"{state} with state_message\n{json_resp['state']['state_message']}" - ) - - # get end state to return to user - run_output = requests.get( - f"https://{self.connections.profile.credentials.host}" - f"/api/2.1/jobs/runs/get-output?run_id={run_id}", - headers=auth_header, - ) - json_run_output = run_output.json() - result_state = json_run_output["metadata"]["state"]["result_state"] - if result_state != "SUCCESS": - raise dbt.exceptions.RuntimeException( - "Python model failed with traceback as:\n" - "(Note that the line number here does not " - "match the line number in your code due to dbt templating)\n" - f"{json_run_output['error_trace']}" - ) + job_helper.submit(compiled_code) + # we don't really get any useful information back from the job submission other than success return self.connections.get_response(None) def standardize_grants_dict(self, grants_table: agate.Table) -> dict: diff --git a/dbt/adapters/spark/python_submissions.py b/dbt/adapters/spark/python_submissions.py new file mode 100644 index 000000000..7698f6825 --- /dev/null +++ b/dbt/adapters/spark/python_submissions.py @@ -0,0 +1,281 @@ +import base64 +import time +import requests +from typing import Any, Dict + +import dbt.exceptions + +DEFAULT_POLLING_INTERVAL = 3 + + +class BasePythonJobHelper: + def __init__(self, parsed_model, credentials): + self.check_credentials(credentials) + self.credentials = credentials + self.identifier = parsed_model["alias"] + self.schema = getattr(parsed_model, "schema", self.credentials.schema) + self.parsed_model = parsed_model + self.timeout = self.get_timeout() + self.polling_interval = DEFAULT_POLLING_INTERVAL + + def get_timeout(self): + timeout = self.parsed_model["config"].get("timeout", 60 * 60 * 24) + if timeout <= 0: + raise ValueError("Timeout must be a positive integer") + return timeout + + def check_credentials(self, credentials): + raise NotImplementedError( + "Overwrite this method to check specific requirement for current submission method" + ) + + def submit(self, compiled_code): + raise NotImplementedError( + "BasePythonJobHelper is an abstract class and you should implement submit method." + ) + + def polling( + self, + status_func, + status_func_kwargs, + get_state_func, + terminal_states, + expected_end_state, + get_state_msg_func, + ): + state = None + start = time.time() + exceeded_timeout = False + response = {} + while state not in terminal_states: + if time.time() - start > self.timeout: + exceeded_timeout = True + break + # TODO should we do exponential backoff? + time.sleep(self.polling_interval) + response = status_func(**status_func_kwargs) + state = get_state_func(response) + if exceeded_timeout: + raise dbt.exceptions.RuntimeException("python model run timed out") + if state != expected_end_state: + raise dbt.exceptions.RuntimeException( + "python model run ended in state" + f"{state} with state_message\n{get_state_msg_func(response)}" + ) + return response + + +class DBNotebookPythonJobHelper(BasePythonJobHelper): + def __init__(self, parsed_model, credentials): + super().__init__(parsed_model, credentials) + self.auth_header = {"Authorization": f"Bearer {self.credentials.token}"} + + def check_credentials(self, credentials): + if not credentials.user: + raise ValueError("Databricks user is required for notebook submission method.") + + def _create_work_dir(self, path): + response = requests.post( + f"https://{self.credentials.host}/api/2.0/workspace/mkdirs", + headers=self.auth_header, + json={ + "path": path, + }, + ) + if response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error creating work_dir for python notebooks\n {response.content!r}" + ) + + def _upload_notebook(self, path, compiled_code): + b64_encoded_content = base64.b64encode(compiled_code.encode()).decode() + response = requests.post( + f"https://{self.credentials.host}/api/2.0/workspace/import", + headers=self.auth_header, + json={ + "path": path, + "content": b64_encoded_content, + "language": "PYTHON", + "overwrite": True, + "format": "SOURCE", + }, + ) + if response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error creating python notebook.\n {response.content!r}" + ) + + def _submit_notebook(self, path): + submit_response = requests.post( + f"https://{self.credentials.host}/api/2.1/jobs/runs/submit", + headers=self.auth_header, + json={ + "run_name": self.identifier, # should there be an UUID, or also add schema + "existing_cluster_id": self.credentials.cluster, + "notebook_task": { + "notebook_path": path, + }, + }, + ) + if submit_response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error creating python run.\n {submit_response.content!r}" + ) + return submit_response.json()["run_id"] + + def submit(self, compiled_code): + # it is safe to call mkdirs even if dir already exists and have content inside + work_dir = f"/Users/{self.credentials.user}/{self.schema}/" + self._create_work_dir(work_dir) + + # add notebook + whole_file_path = f"{work_dir}{self.identifier}" + self._upload_notebook(whole_file_path, compiled_code) + + # submit job + run_id = self._submit_notebook(whole_file_path) + + self.polling( + status_func=requests.get, + status_func_kwargs={ + "url": f"https://{self.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}", + "headers": self.auth_header, + }, + get_state_func=lambda response: response.json()["state"]["life_cycle_state"], + terminal_states=("TERMINATED", "SKIPPED", "INTERNAL_ERROR"), + expected_end_state="TERMINATED", + get_state_msg_func=lambda response: response.json()["state"]["state_message"], + ) + + # get end state to return to user + run_output = requests.get( + f"https://{self.credentials.host}" f"/api/2.1/jobs/runs/get-output?run_id={run_id}", + headers=self.auth_header, + ) + json_run_output = run_output.json() + result_state = json_run_output["metadata"]["state"]["result_state"] + if result_state != "SUCCESS": + raise dbt.exceptions.RuntimeException( + "Python model failed with traceback as:\n" + "(Note that the line number here does not " + "match the line number in your code due to dbt templating)\n" + f"{json_run_output['error_trace']}" + ) + + +class DBContext: + def __init__(self, credentials): + self.auth_header = {"Authorization": f"Bearer {credentials.token}"} + self.cluster = credentials.cluster + self.host = credentials.host + + def create(self) -> str: + # https://docs.databricks.com/dev-tools/api/1.2/index.html#create-an-execution-context + response = requests.post( + f"https://{self.host}/api/1.2/contexts/create", + headers=self.auth_header, + json={ + "clusterId": self.cluster, + "language": "python", + }, + ) + if response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error creating an execution context.\n {response.content!r}" + ) + return response.json()["id"] + + def destroy(self, context_id: str) -> str: + # https://docs.databricks.com/dev-tools/api/1.2/index.html#delete-an-execution-context + response = requests.post( + f"https://{self.host}/api/1.2/contexts/destroy", + headers=self.auth_header, + json={ + "clusterId": self.cluster, + "contextId": context_id, + }, + ) + if response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error deleting an execution context.\n {response.content!r}" + ) + return response.json()["id"] + + +class DBCommand: + def __init__(self, credentials): + self.auth_header = {"Authorization": f"Bearer {credentials.token}"} + self.cluster = credentials.cluster + self.host = credentials.host + + def execute(self, context_id: str, command: str) -> str: + # https://docs.databricks.com/dev-tools/api/1.2/index.html#run-a-command + response = requests.post( + f"https://{self.host}/api/1.2/commands/execute", + headers=self.auth_header, + json={ + "clusterId": self.cluster, + "contextId": context_id, + "language": "python", + "command": command, + }, + ) + if response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error creating a command.\n {response.content!r}" + ) + return response.json()["id"] + + def status(self, context_id: str, command_id: str) -> Dict[str, Any]: + # https://docs.databricks.com/dev-tools/api/1.2/index.html#get-information-about-a-command + response = requests.get( + f"https://{self.host}/api/1.2/commands/status", + headers=self.auth_header, + params={ + "clusterId": self.cluster, + "contextId": context_id, + "commandId": command_id, + }, + ) + if response.status_code != 200: + raise dbt.exceptions.RuntimeException( + f"Error getting status of command.\n {response.content!r}" + ) + return response.json() + + +class DBCommandsApiPythonJobHelper(BasePythonJobHelper): + def check_credentials(self, credentials): + if not credentials.cluster: + raise ValueError("Databricks cluster is required for commands submission method.") + + def submit(self, compiled_code): + context = DBContext(self.credentials) + command = DBCommand(self.credentials) + context_id = context.create() + try: + command_id = command.execute(context_id, compiled_code) + # poll until job finish + response = self.polling( + status_func=command.status, + status_func_kwargs={ + "context_id": context_id, + "command_id": command_id, + }, + get_state_func=lambda response: response["status"], + terminal_states=("Cancelled", "Error", "Finished"), + expected_end_state="Finished", + get_state_msg_func=lambda response: response.json()["results"]["data"], + ) + if response["results"]["resultType"] == "error": + raise dbt.exceptions.RuntimeException( + f"Python model failed with traceback as:\n" f"{response['results']['cause']}" + ) + finally: + context.destroy(context_id) + + +python_submission_helpers = { + "notebook": DBNotebookPythonJobHelper, + "commands": DBCommandsApiPythonJobHelper, +} From 2db018cb7ff87cfa148de154e8575305d1e2a4d0 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Mon, 29 Aug 2022 16:46:43 -0700 Subject: [PATCH 2/4] update run_name and add changelog --- .../Under the Hood-20220829-164426.yaml | 7 ++++++ dbt/adapters/spark/python_submissions.py | 25 +++++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20220829-164426.yaml diff --git a/.changes/unreleased/Under the Hood-20220829-164426.yaml b/.changes/unreleased/Under the Hood-20220829-164426.yaml new file mode 100644 index 000000000..bf58971f2 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20220829-164426.yaml @@ -0,0 +1,7 @@ +kind: Under the Hood +body: Submit python model with Command API by default. Adjusted run name +time: 2022-08-29T16:44:26.509138-07:00 +custom: + Author: ChenyuLInx + Issue: "424" + PR: "442" diff --git a/dbt/adapters/spark/python_submissions.py b/dbt/adapters/spark/python_submissions.py index 7698f6825..167338919 100644 --- a/dbt/adapters/spark/python_submissions.py +++ b/dbt/adapters/spark/python_submissions.py @@ -2,6 +2,7 @@ import time import requests from typing import Any, Dict +import uuid import dbt.exceptions @@ -72,7 +73,9 @@ def __init__(self, parsed_model, credentials): def check_credentials(self, credentials): if not credentials.user: - raise ValueError("Databricks user is required for notebook submission method.") + raise ValueError( + "Databricks user is required for notebook submission method." + ) def _create_work_dir(self, path): response = requests.post( @@ -110,7 +113,7 @@ def _submit_notebook(self, path): f"https://{self.credentials.host}/api/2.1/jobs/runs/submit", headers=self.auth_header, json={ - "run_name": self.identifier, # should there be an UUID, or also add schema + "run_name": f"{self.schema}-{self.identifier}-{uuid.uuid4()}", "existing_cluster_id": self.credentials.cluster, "notebook_task": { "notebook_path": path, @@ -141,15 +144,20 @@ def submit(self, compiled_code): "url": f"https://{self.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}", "headers": self.auth_header, }, - get_state_func=lambda response: response.json()["state"]["life_cycle_state"], + get_state_func=lambda response: response.json()["state"][ + "life_cycle_state" + ], terminal_states=("TERMINATED", "SKIPPED", "INTERNAL_ERROR"), expected_end_state="TERMINATED", - get_state_msg_func=lambda response: response.json()["state"]["state_message"], + get_state_msg_func=lambda response: response.json()["state"][ + "state_message" + ], ) # get end state to return to user run_output = requests.get( - f"https://{self.credentials.host}" f"/api/2.1/jobs/runs/get-output?run_id={run_id}", + f"https://{self.credentials.host}" + f"/api/2.1/jobs/runs/get-output?run_id={run_id}", headers=self.auth_header, ) json_run_output = run_output.json() @@ -247,7 +255,9 @@ def status(self, context_id: str, command_id: str) -> Dict[str, Any]: class DBCommandsApiPythonJobHelper(BasePythonJobHelper): def check_credentials(self, credentials): if not credentials.cluster: - raise ValueError("Databricks cluster is required for commands submission method.") + raise ValueError( + "Databricks cluster is required for commands submission method." + ) def submit(self, compiled_code): context = DBContext(self.credentials) @@ -269,7 +279,8 @@ def submit(self, compiled_code): ) if response["results"]["resultType"] == "error": raise dbt.exceptions.RuntimeException( - f"Python model failed with traceback as:\n" f"{response['results']['cause']}" + f"Python model failed with traceback as:\n" + f"{response['results']['cause']}" ) finally: context.destroy(context_id) From 965fdbfe6003278f03a4bf4b3ed2839439dab7c4 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Mon, 29 Aug 2022 16:51:36 -0700 Subject: [PATCH 3/4] fix format --- dbt/adapters/spark/python_submissions.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/dbt/adapters/spark/python_submissions.py b/dbt/adapters/spark/python_submissions.py index 167338919..66c89fcd7 100644 --- a/dbt/adapters/spark/python_submissions.py +++ b/dbt/adapters/spark/python_submissions.py @@ -73,9 +73,7 @@ def __init__(self, parsed_model, credentials): def check_credentials(self, credentials): if not credentials.user: - raise ValueError( - "Databricks user is required for notebook submission method." - ) + raise ValueError("Databricks user is required for notebook submission method.") def _create_work_dir(self, path): response = requests.post( @@ -144,20 +142,15 @@ def submit(self, compiled_code): "url": f"https://{self.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}", "headers": self.auth_header, }, - get_state_func=lambda response: response.json()["state"][ - "life_cycle_state" - ], + get_state_func=lambda response: response.json()["state"]["life_cycle_state"], terminal_states=("TERMINATED", "SKIPPED", "INTERNAL_ERROR"), expected_end_state="TERMINATED", - get_state_msg_func=lambda response: response.json()["state"][ - "state_message" - ], + get_state_msg_func=lambda response: response.json()["state"]["state_message"], ) # get end state to return to user run_output = requests.get( - f"https://{self.credentials.host}" - f"/api/2.1/jobs/runs/get-output?run_id={run_id}", + f"https://{self.credentials.host}" f"/api/2.1/jobs/runs/get-output?run_id={run_id}", headers=self.auth_header, ) json_run_output = run_output.json() @@ -255,9 +248,7 @@ def status(self, context_id: str, command_id: str) -> Dict[str, Any]: class DBCommandsApiPythonJobHelper(BasePythonJobHelper): def check_credentials(self, credentials): if not credentials.cluster: - raise ValueError( - "Databricks cluster is required for commands submission method." - ) + raise ValueError("Databricks cluster is required for commands submission method.") def submit(self, compiled_code): context = DBContext(self.credentials) @@ -279,8 +270,7 @@ def submit(self, compiled_code): ) if response["results"]["resultType"] == "error": raise dbt.exceptions.RuntimeException( - f"Python model failed with traceback as:\n" - f"{response['results']['cause']}" + f"Python model failed with traceback as:\n" f"{response['results']['cause']}" ) finally: context.destroy(context_id) From b4e177ee14e22abd477777eae7322c777febbde5 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Tue, 30 Aug 2022 12:58:11 -0700 Subject: [PATCH 4/4] pr feedback --- dbt/adapters/spark/impl.py | 6 +++--- dbt/adapters/spark/python_submissions.py | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index e78d71bfb..6e97ce1f5 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -17,7 +17,7 @@ from dbt.adapters.spark import SparkConnectionManager from dbt.adapters.spark import SparkRelation from dbt.adapters.spark import SparkColumn -from dbt.adapters.spark.python_submissions import python_submission_helpers +from dbt.adapters.spark.python_submissions import PYTHON_SUBMISSION_HELPERS from dbt.adapters.base import BaseRelation from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger @@ -394,11 +394,11 @@ def submit_python_job(self, parsed_model: dict, compiled_code: str, timeout=None # TODO limit this function to run only when doing the materialization of python nodes # assuming that for python job running over 1 day user would mannually overwrite this submission_method = parsed_model["config"].get("submission_method", "commands") - if submission_method not in python_submission_helpers: + if submission_method not in PYTHON_SUBMISSION_HELPERS: raise NotImplementedError( "Submission method {} is not supported".format(submission_method) ) - job_helper = python_submission_helpers[submission_method]( + job_helper = PYTHON_SUBMISSION_HELPERS[submission_method]( parsed_model, self.connections.profile.credentials ) job_helper.submit(compiled_code) diff --git a/dbt/adapters/spark/python_submissions.py b/dbt/adapters/spark/python_submissions.py index 66c89fcd7..ea172ef03 100644 --- a/dbt/adapters/spark/python_submissions.py +++ b/dbt/adapters/spark/python_submissions.py @@ -7,6 +7,8 @@ import dbt.exceptions DEFAULT_POLLING_INTERVAL = 3 +SUBMISSION_LANGUAGE = "python" +DEFAULT_TIMEOUT = 60 * 60 * 24 class BasePythonJobHelper: @@ -20,7 +22,7 @@ def __init__(self, parsed_model, credentials): self.polling_interval = DEFAULT_POLLING_INTERVAL def get_timeout(self): - timeout = self.parsed_model["config"].get("timeout", 60 * 60 * 24) + timeout = self.parsed_model["config"].get("timeout", DEFAULT_TIMEOUT) if timeout <= 0: raise ValueError("Timeout must be a positive integer") return timeout @@ -177,7 +179,7 @@ def create(self) -> str: headers=self.auth_header, json={ "clusterId": self.cluster, - "language": "python", + "language": SUBMISSION_LANGUAGE, }, ) if response.status_code != 200: @@ -217,7 +219,7 @@ def execute(self, context_id: str, command: str) -> str: json={ "clusterId": self.cluster, "contextId": context_id, - "language": "python", + "language": SUBMISSION_LANGUAGE, "command": command, }, ) @@ -276,7 +278,7 @@ def submit(self, compiled_code): context.destroy(context_id) -python_submission_helpers = { +PYTHON_SUBMISSION_HELPERS = { "notebook": DBNotebookPythonJobHelper, "commands": DBCommandsApiPythonJobHelper, }