Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional maven repo url to databricks submission param #1001

Merged
merged 1 commit into from
Jan 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 54 additions & 31 deletions feathr_project/feathr/spark_provider/_databricks_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __init__(
self.auth_headers["Accept"] = "application/json"
self.auth_headers["Authorization"] = f"Bearer {token_value}"
self.databricks_work_dir = databricks_work_dir
self.api_client = ApiClient(host=self.workspace_instance_url, token=token_value)
self.api_client = ApiClient(
host=self.workspace_instance_url, token=token_value)

def upload_or_get_cloud_path(self, local_path_or_cloud_src_path: str, tar_dir_path: Optional[str] = None):
"""
Expand All @@ -71,18 +72,21 @@ def upload_or_get_cloud_path(self, local_path_or_cloud_src_path: str, tar_dir_pa
if local_path_or_cloud_src_path.startswith('dbfs') and tar_dir_path is not None:
if not tar_dir_path.startswith('dbfs'):
raise RuntimeError(
f"Failed to copy files from dbfs directory: {local_path_or_cloud_src_path}. {tar_dir_path} is not a valid target directory path"
)
f"Failed to copy files from dbfs directory: {local_path_or_cloud_src_path}. {tar_dir_path} is not a valid target directory path"
)
if not self.cloud_dir_exists(local_path_or_cloud_src_path):
raise RuntimeError(f"Source folder:{local_path_or_cloud_src_path} doesn't exist. Please make sure it's a valid path")
raise RuntimeError(
f"Source folder:{local_path_or_cloud_src_path} doesn't exist. Please make sure it's a valid path")
if self.cloud_dir_exists(tar_dir_path):
logger.warning('Target cloud directory {} already exists. Please use another one.', tar_dir_path)
logger.warning(
'Target cloud directory {} already exists. Please use another one.', tar_dir_path)
return tar_dir_path
DbfsApi(self.api_client).cp(recursive=True, overwrite=False, src=local_path_or_cloud_src_path, dst=tar_dir_path)
DbfsApi(self.api_client).cp(recursive=True, overwrite=False,
src=local_path_or_cloud_src_path, dst=tar_dir_path)
logger.info('{} is copied to location: {}',
local_path_or_cloud_src_path, tar_dir_path)
local_path_or_cloud_src_path, tar_dir_path)
return tar_dir_path

src_parse_result = urlparse(local_path_or_cloud_src_path)
file_name = os.path.basename(local_path_or_cloud_src_path)
# returned paths for the uploaded file. Note that we cannot use os.path.join here, since in Windows system it will yield paths like this:
Expand All @@ -98,13 +102,13 @@ def upload_or_get_cloud_path(self, local_path_or_cloud_src_path: str, tar_dir_pa
r = requests.post(url=self.workspace_instance_url+'/api/2.0/dbfs/put',
headers=self.auth_headers, files=files, data={'overwrite': 'true', 'path': cloud_dest_path})
logger.info('{} is downloaded and then uploaded to location: {}',
local_path_or_cloud_src_path, cloud_dest_path)
local_path_or_cloud_src_path, cloud_dest_path)
elif src_parse_result.scheme.startswith('dbfs'):
# passed a cloud path
logger.info(
'Skip uploading file {} as the file starts with dbfs:/', local_path_or_cloud_src_path)
cloud_dest_path = local_path_or_cloud_src_path
elif src_parse_result.scheme.startswith(('wasb','s3','gs')):
elif src_parse_result.scheme.startswith(('wasb', 's3', 'gs')):
# if the path starts with a location that's not a local path
logger.error(
"File {} cannot be downloaded. Please upload the file to dbfs manually.", local_path_or_cloud_src_path
Expand All @@ -115,14 +119,17 @@ def upload_or_get_cloud_path(self, local_path_or_cloud_src_path: str, tar_dir_pa
else:
# else it should be a local file path or dir
if os.path.isdir(local_path_or_cloud_src_path):
logger.info("Uploading folder {}", local_path_or_cloud_src_path)
logger.info("Uploading folder {}",
local_path_or_cloud_src_path)
dest_paths = []
for item in Path(local_path_or_cloud_src_path).glob('**/*.conf'):
cloud_dest_path = self._upload_local_file_to_workspace(item.resolve())
cloud_dest_path = self._upload_local_file_to_workspace(
item.resolve())
dest_paths.extend([cloud_dest_path])
cloud_dest_path = ','.join(dest_paths)
else:
cloud_dest_path = self._upload_local_file_to_workspace(local_path_or_cloud_src_path)
cloud_dest_path = self._upload_local_file_to_workspace(
local_path_or_cloud_src_path)
return cloud_dest_path

def _upload_local_file_to_workspace(self, local_path: str) -> str:
Expand All @@ -137,9 +144,11 @@ def _upload_local_file_to_workspace(self, local_path: str) -> str:
# `local_path_or_http_path` will be either string or PathLib object, so normalize it to string
local_path = str(local_path)
try:
DbfsApi(self.api_client).cp(recursive=True, overwrite=True, src=local_path, dst=cloud_dest_path)
DbfsApi(self.api_client).cp(recursive=True, overwrite=True,
src=local_path, dst=cloud_dest_path)
except RuntimeError as e:
raise RuntimeError(f"The source path: {local_path}, or the destination path: {cloud_dest_path}, is/are not valid.") from e
raise RuntimeError(
f"The source path: {local_path}, or the destination path: {cloud_dest_path}, is/are not valid.") from e
return cloud_dest_path

def submit_feathr_job(
Expand Down Expand Up @@ -182,7 +191,8 @@ def submit_feathr_job(
submission_params["run_name"] = job_name
cfg = configuration.copy()
if "existing_cluster_id" in submission_params:
logger.info("Using an existing general purpose cluster to run the feathr job...")
logger.info(
"Using an existing general purpose cluster to run the feathr job...")
if cfg:
logger.warning(
"Spark execution configuration will be ignored. To use job-specific spark configs, please use a new job cluster or set the configs via Databricks UI."
Expand All @@ -201,7 +211,8 @@ def submit_feathr_job(
submission_params["new_cluster"]["spark_conf"] = cfg

if job_tags:
custom_tags = submission_params["new_cluster"].get("custom_tags", {})
custom_tags = submission_params["new_cluster"].get(
"custom_tags", {})
for tag, value in job_tags.items():
custom_tags[tag] = value

Expand All @@ -214,10 +225,15 @@ def submit_feathr_job(

# the feathr main jar file is anyway needed regardless it's pyspark or scala spark
if not main_jar_path:
logger.info(f"Main JAR file is not set, using default package '{get_maven_artifact_fullname()}' from Maven")
submission_params['libraries'][0]['maven'] = { "coordinates": get_maven_artifact_fullname() }
logger.info(
f"Main JAR file is not set, using default package '{get_maven_artifact_fullname()}' from Maven")
submission_params['libraries'][0]['maven'] = {
"coordinates": get_maven_artifact_fullname()}
# add additional maven repos
submission_params["libraries"][0]["maven"]["repo"] = "https://repository.mulesoft.org/nexus/content/repositories/public/,https://linkedin.jfrog.io/artifactory/open-source/"
else:
submission_params["libraries"][0]["jar"] = self.upload_or_get_cloud_path(main_jar_path)
submission_params["libraries"][0]["jar"] = self.upload_or_get_cloud_path(
main_jar_path)
# see here for the submission parameter definition https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/2.0/jobs#--request-structure-6
if python_files:
# this is a pyspark job. definition here: https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/2.0/jobs#--sparkpythontask
Expand All @@ -228,7 +244,8 @@ def submit_feathr_job(
}
# indicates this is a pyspark job
# `setdefault` method will get the value of the "spark_python_task" item, if the "spark_python_task" item does not exist, insert "spark_python_task" with the value "param_and_file_dict":
submission_params.setdefault("spark_python_task", param_and_file_dict)
submission_params.setdefault(
"spark_python_task", param_and_file_dict)
else:
# this is a scala spark job
submission_params["spark_jar_task"]["parameters"] = arguments
Expand All @@ -247,7 +264,8 @@ def submit_feathr_job(

result = RunsApi(self.api_client).get_run(self.res_job_id)
self.job_url = result["run_page_url"]
logger.info("Feathr job Submitted Successfully. View more details here: {}", self.job_url)
logger.info(
"Feathr job Submitted Successfully. View more details here: {}", self.job_url)

# return ID as the submission result
return self.res_job_id
Expand All @@ -264,10 +282,12 @@ def wait_for_completion(self, timeout_seconds: Optional[int] = 600) -> bool:
if status in {"SUCCESS"}:
return True
elif status in {"INTERNAL_ERROR", "FAILED", "TIMEDOUT", "CANCELED"}:
result = RunsApi(self.api_client).get_run_output(self.res_job_id)
result = RunsApi(self.api_client).get_run_output(
self.res_job_id)
# See here for the returned fields: https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/2.0/jobs#--response-structure-8
# print out logs and stack trace if the job has failed
logger.error("Feathr job has failed. Please visit this page to view error message: {}", self.job_url)
logger.error(
"Feathr job has failed. Please visit this page to view error message: {}", self.job_url)
if "error" in result:
logger.error("Error Code: {}", result["error"])
if "error_trace" in result:
Expand All @@ -283,7 +303,8 @@ def get_status(self) -> str:
result = RunsApi(self.api_client).get_run(self.res_job_id)
# first try to get result state. it might not be available, and if that's the case, try to get life_cycle_state
# see result structure: https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/2.0/jobs#--response-structure-6
res_state = result["state"].get("result_state") or result["state"]["life_cycle_state"]
res_state = result["state"].get(
"result_state") or result["state"]["life_cycle_state"]
assert res_state is not None
return res_state

Expand All @@ -308,7 +329,8 @@ def get_job_tags(self) -> Dict[str, str]:
result = RunsApi(self.api_client).get_run(self.res_job_id)

if "new_cluster" in result["cluster_spec"]:
custom_tags = result["cluster_spec"]["new_cluster"].get("custom_tags")
custom_tags = result["cluster_spec"]["new_cluster"].get(
"custom_tags")
return custom_tags
else:
# this is not a new cluster; it's an existing cluster.
Expand All @@ -326,18 +348,19 @@ def download_result(self, result_path: str, local_folder: str):
'Currently only paths starting with dbfs is supported for downloading results from a databricks cluster. The path should start with "dbfs:" .'
)

DbfsApi(self.api_client).cp(recursive=True, overwrite=True, src=result_path, dst=local_folder)

DbfsApi(self.api_client).cp(recursive=True,
overwrite=True, src=result_path, dst=local_folder)

def cloud_dir_exists(self, dir_path: str):
"""
Check if a directory of hdfs already exists
"""
if not dir_path.startswith('dbfs'):
raise RuntimeError('Currently only paths starting with dbfs is supported. The paths should start with \"dbfs:\" .')

raise RuntimeError(
'Currently only paths starting with dbfs is supported. The paths should start with \"dbfs:\" .')

try:
DbfsApi(self.api_client).list_files(DbfsPath(dir_path))
return True
except:
return False