Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More operators for Databricks Repos #22422

Merged
merged 5 commits into from
Mar 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,32 @@

from airflow import DAG
from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator
from airflow.providers.databricks.operators.databricks_repos import DatabricksReposUpdateOperator
from airflow.providers.databricks.operators.databricks_repos import (
DatabricksReposCreateOperator,
DatabricksReposDeleteOperator,
DatabricksReposUpdateOperator,
)

default_args = {
'owner': 'airflow',
'databricks_conn_id': 'my-shard-pat',
'databricks_conn_id': 'databricks',
}

with DAG(
dag_id='example_databricks_operator',
dag_id='example_databricks_repos_operator',
schedule_interval='@daily',
start_date=datetime(2021, 1, 1),
default_args=default_args,
tags=['example'],
catchup=False,
) as dag:
# [START howto_operator_databricks_repo_create]
# Example of creating a Databricks Repo
repo_path = "/Repos/user@domain.com/demo-repo"
git_url = "https://github.com/test/test"
create_repo = DatabricksReposCreateOperator(task_id='create_repo', repo_path=repo_path, git_url=git_url)
# [END howto_operator_databricks_repo_create]

# [START howto_operator_databricks_repo_update]
# Example of updating a Databricks Repo to the latest code
repo_path = "/Repos/user@domain.com/demo-repo"
Expand All @@ -53,4 +65,10 @@

notebook_task = DatabricksSubmitRunOperator(task_id='notebook_task', json=notebook_task_params)

(update_repo >> notebook_task)
# [START howto_operator_databricks_repo_delete]
# Example of deleting a Databricks Repo
repo_path = "/Repos/user@domain.com/demo-repo"
delete_repo = DatabricksReposDeleteOperator(task_id='delete_repo', repo_path=repo_path)
# [END howto_operator_databricks_repo_delete]

(create_repo >> update_repo >> notebook_task >> delete_repo)
27 changes: 24 additions & 3 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,35 @@ def uninstall(self, json: dict) -> None:

def update_repo(self, repo_id: str, json: Dict[str, Any]) -> dict:
"""
Updates given Databricks Repos

:param repo_id:
:param json:
:return:
:param repo_id: ID of Databricks Repos
:param json: payload
:return: metadata from update
"""
repos_endpoint = ('PATCH', f'api/2.0/repos/{repo_id}')
return self._do_api_call(repos_endpoint, json)

def delete_repo(self, repo_id: str):
"""
Deletes given Databricks Repos

:param repo_id: ID of Databricks Repos
:return:
"""
repos_endpoint = ('DELETE', f'api/2.0/repos/{repo_id}')
self._do_api_call(repos_endpoint)

def create_repo(self, json: Dict[str, Any]) -> dict:
"""
Creates a Databricks Repos

:param json: payload
:return:
"""
repos_endpoint = ('POST', 'api/2.0/repos')
return self._do_api_call(repos_endpoint, json)

def get_repo_by_path(self, path: str) -> Optional[str]:
"""

Expand Down
23 changes: 22 additions & 1 deletion airflow/providers/databricks/hooks/databricks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import requests
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
from requests.exceptions import JSONDecodeError
from tenacity import RetryError, Retrying, retry_if_exception, stop_after_attempt, wait_exponential

from airflow import __version__
Expand Down Expand Up @@ -340,6 +341,8 @@ def _do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str,
request_func = requests.post
elif method == 'PATCH':
request_func = requests.patch
elif method == 'DELETE':
request_func = requests.delete
else:
raise AirflowException('Unexpected HTTP Method: ' + method)

Expand All @@ -361,13 +364,31 @@ def _do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str,
except requests_exceptions.HTTPError as e:
raise AirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}')

@staticmethod
def _get_error_code(exception: BaseException) -> str:
if isinstance(exception, requests_exceptions.HTTPError):
try:
jsn = exception.response.json()
return jsn.get('error_code', '')
except JSONDecodeError:
pass

return ""

@staticmethod
def _retryable_error(exception: BaseException) -> bool:
if not isinstance(exception, requests_exceptions.RequestException):
return False
return isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or (
exception.response is not None
and (exception.response.status_code >= 500 or exception.response.status_code == 429)
and (
exception.response.status_code >= 500
or exception.response.status_code == 429
or (
exception.response.status_code == 400
and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK'
)
)
)


Expand Down
205 changes: 198 additions & 7 deletions airflow/providers/databricks/operators/databricks_repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
# under the License.
#
"""This module contains Databricks operators."""

import re
from typing import TYPE_CHECKING, Optional, Sequence
from urllib.parse import urlparse

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand All @@ -28,12 +29,142 @@
from airflow.utils.context import Context


class DatabricksReposCreateOperator(BaseOperator):
"""
Creates a Databricks Repo
using
`POST api/2.0/repos <https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/create-repo>`_
API endpoint and optionally checking it out to a specific branch or tag.

:param git_url: Required HTTPS URL of a Git repository
:param git_provider: Optional name of Git provider. Must be provided if we can't guess its name from URL.
:param repo_path: optional path for a repository. Must be in the format ``/Repos/{folder}/{repo-name}``.
If not specified, it will be created in the user's directory.
:param branch: optional name of branch to check out.
:param tag: optional name of tag to checkout.
:param ignore_existing_repo: don't throw exception if repository with given path already exists.
:param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
By default and in the common case this will be ``databricks_default``. To use
token based authentication, provide the key ``token`` in the extra field for the
connection and create the key ``host`` and leave the ``host`` field empty.
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
unreachable. Its value must be greater than or equal to 1.
:param databricks_retry_delay: Number of seconds to wait between retries (it
might be a floating point number).
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ('repo_path', 'tag', 'branch')

__git_providers__ = {
"github.com": "gitHub",
"dev.azure.com": "azureDevOpsServices",
"gitlab.com": "gitLab",
"bitbucket.org": "bitbucketCloud",
}
__aws_code_commit_regexp__ = re.compile(r"^git-codecommit\.[^.]+\.amazonaws.com$")
__repos_path_regexp__ = re.compile(r"/Repos/[^/]+/[^/]+/?$")

def __init__(
self,
*,
git_url: str,
git_provider: Optional[str] = None,
branch: Optional[str] = None,
tag: Optional[str] = None,
repo_path: Optional[str] = None,
ignore_existing_repo: bool = False,
databricks_conn_id: str = 'databricks_default',
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
**kwargs,
) -> None:
"""Creates a new ``DatabricksReposCreateOperator``."""
super().__init__(**kwargs)
self.databricks_conn_id = databricks_conn_id
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.git_url = git_url
self.ignore_existing_repo = ignore_existing_repo
if git_provider is None:
self.git_provider = self.__detect_repo_provider__(git_url)
alexott marked this conversation as resolved.
Show resolved Hide resolved
if self.git_provider is None:
raise AirflowException(
"git_provider isn't specified and couldn't be guessed" f" for URL {git_url}"
)
else:
self.git_provider = git_provider
self.repo_path = repo_path
if branch is not None and tag is not None:
raise AirflowException("Only one of branch or tag should be provided, but not both")
self.branch = branch
self.tag = tag

@staticmethod
def __detect_repo_provider__(url):
provider = None
try:
netloc = urlparse(url).netloc
idx = netloc.rfind("@")
if idx != -1:
netloc = netloc[(idx + 1) :]
netloc = netloc.lower()
provider = DatabricksReposCreateOperator.__git_providers__.get(netloc)
if provider is None and DatabricksReposCreateOperator.__aws_code_commit_regexp__.match(netloc):
provider = "awsCodeCommit"
except ValueError:
pass
return provider

def _get_hook(self) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
)

def execute(self, context: 'Context'):
"""
Creates a Databricks Repo

:param context: context
:return: Repo ID
"""
payload = {
"url": self.git_url,
"provider": self.git_provider,
}
if self.repo_path is not None:
if not self.__repos_path_regexp__.match(self.repo_path):
raise AirflowException(
f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'"
)
payload["path"] = self.repo_path
hook = self._get_hook()
existing_repo_id = None
if self.repo_path is not None:
existing_repo_id = hook.get_repo_by_path(self.repo_path)
if existing_repo_id is not None and not self.ignore_existing_repo:
raise AirflowException(f"Repo with path '{self.repo_path}' already exists")
if existing_repo_id is None:
result = hook.create_repo(payload)
repo_id = result["id"]
else:
repo_id = existing_repo_id
# update repo if necessary
if self.branch is not None:
hook.update_repo(str(repo_id), {'branch': str(self.branch)})
elif self.tag is not None:
hook.update_repo(str(repo_id), {'tag': str(self.tag)})

return repo_id


class DatabricksReposUpdateOperator(BaseOperator):
"""
Updates specified repository to a given branch or tag using
`api/2.0/repos/
<https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/update-repo>`_
API endpoint.
Updates specified repository to a given branch or tag
using `PATCH api/2.0/repos
<https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/update-repo>`_ API endpoint.

:param branch: optional name of branch to update to. Should be specified if ``tag`` is omitted
:param tag: optional name of tag to update to. Should be specified if ``branch`` is omitted
Expand Down Expand Up @@ -64,7 +195,7 @@ def __init__(
databricks_retry_delay: int = 1,
**kwargs,
) -> None:
"""Creates a new ``DatabricksSubmitRunOperator``."""
"""Creates a new ``DatabricksReposUpdateOperator``."""
super().__init__(**kwargs)
self.databricks_conn_id = databricks_conn_id
self.databricks_retry_limit = databricks_retry_limit
Expand All @@ -76,7 +207,7 @@ def __init__(
if repo_id is not None and repo_path is not None:
raise AirflowException("Only one of repo_id or repo_path should be provided, but not both")
if repo_id is None and repo_path is None:
raise AirflowException("One of repo_id repo_path tag should be provided")
raise AirflowException("One of repo_id or repo_path should be provided")
self.repo_path = repo_path
self.repo_id = repo_id
self.branch = branch
Expand All @@ -102,3 +233,63 @@ def execute(self, context: 'Context'):

result = hook.update_repo(str(self.repo_id), payload)
return result['head_commit_id']


class DatabricksReposDeleteOperator(BaseOperator):
"""
Deletes specified repository
using `DELETE api/2.0/repos
<https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/delete-repo>`_ API endpoint.

:param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted
:param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted
:param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
By default and in the common case this will be ``databricks_default``. To use
token based authentication, provide the key ``token`` in the extra field for the
connection and create the key ``host`` and leave the ``host`` field empty.
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
unreachable. Its value must be greater than or equal to 1.
:param databricks_retry_delay: Number of seconds to wait between retries (it
might be a floating point number).
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ('repo_path',)

def __init__(
self,
*,
repo_id: Optional[str] = None,
repo_path: Optional[str] = None,
databricks_conn_id: str = 'databricks_default',
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
**kwargs,
) -> None:
"""Creates a new ``DatabricksReposDeleteOperator``."""
super().__init__(**kwargs)
self.databricks_conn_id = databricks_conn_id
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
if repo_id is not None and repo_path is not None:
raise AirflowException("Only one of repo_id or repo_path should be provided, but not both")
if repo_id is None and repo_path is None:
raise AirflowException("One of repo_id repo_path tag should be provided")
self.repo_path = repo_path
self.repo_id = repo_id

def _get_hook(self) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
)

def execute(self, context: 'Context'):
hook = self._get_hook()
if self.repo_path is not None:
self.repo_id = hook.get_repo_by_path(self.repo_path)
if self.repo_id is None:
raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'")

hook.delete_repo(str(self.repo_id))
2 changes: 2 additions & 0 deletions airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ integrations:
- integration-name: Databricks Repos
external-doc-url: https://docs.databricks.com/repos/index.html
how-to-guide:
- /docs/apache-airflow-providers-databricks/operators/repos_create.rst
- /docs/apache-airflow-providers-databricks/operators/repos_update.rst
- /docs/apache-airflow-providers-databricks/operators/repos_delete.rst
logo: /integration-logos/databricks/Databricks.png
tags: [service]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ Password (optional)
Extra (optional)
Specify the extra parameter (as json dictionary) that can be used in the Databricks connection.

Following parameter should be used if using the *PAT* authentication method:
Following parameter could be used if using the *PAT* authentication method:
alexott marked this conversation as resolved.
Show resolved Hide resolved

* ``token``: Specify PAT to use. Note, the PAT must appear in both the Password field as the token value in Extra.
* ``token``: Specify PAT to use. Consider to switch to specification of PAT in the Password field as it's more secure.

Following parameters are necessary if using authentication with AAD token:

Expand Down
Loading