diff --git a/airflow/providers/databricks/example_dags/example_databricks_repos.py b/airflow/providers/databricks/example_dags/example_databricks_repos.py index 458f7cb8ce72b..e33d32044f5df 100644 --- a/airflow/providers/databricks/example_dags/example_databricks_repos.py +++ b/airflow/providers/databricks/example_dags/example_databricks_repos.py @@ -19,20 +19,32 @@ from airflow import DAG from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator -from airflow.providers.databricks.operators.databricks_repos import DatabricksReposUpdateOperator +from airflow.providers.databricks.operators.databricks_repos import ( + DatabricksReposCreateOperator, + DatabricksReposDeleteOperator, + DatabricksReposUpdateOperator, +) default_args = { 'owner': 'airflow', - 'databricks_conn_id': 'my-shard-pat', + 'databricks_conn_id': 'databricks', } with DAG( - dag_id='example_databricks_operator', + dag_id='example_databricks_repos_operator', schedule_interval='@daily', start_date=datetime(2021, 1, 1), + default_args=default_args, tags=['example'], catchup=False, ) as dag: + # [START howto_operator_databricks_repo_create] + # Example of creating a Databricks Repo + repo_path = "/Repos/user@domain.com/demo-repo" + git_url = "https://github.com/test/test" + create_repo = DatabricksReposCreateOperator(task_id='create_repo', repo_path=repo_path, git_url=git_url) + # [END howto_operator_databricks_repo_create] + # [START howto_operator_databricks_repo_update] # Example of updating a Databricks Repo to the latest code repo_path = "/Repos/user@domain.com/demo-repo" @@ -53,4 +65,10 @@ notebook_task = DatabricksSubmitRunOperator(task_id='notebook_task', json=notebook_task_params) - (update_repo >> notebook_task) + # [START howto_operator_databricks_repo_delete] + # Example of deleting a Databricks Repo + repo_path = "/Repos/user@domain.com/demo-repo" + delete_repo = DatabricksReposDeleteOperator(task_id='delete_repo', repo_path=repo_path) + # [END howto_operator_databricks_repo_delete] + + (create_repo >> update_repo >> notebook_task >> delete_repo) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 977800edb77bc..ffa77570d5ab9 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -333,14 +333,35 @@ def uninstall(self, json: dict) -> None: def update_repo(self, repo_id: str, json: Dict[str, Any]) -> dict: """ + Updates given Databricks Repos - :param repo_id: - :param json: - :return: + :param repo_id: ID of Databricks Repos + :param json: payload + :return: metadata from update """ repos_endpoint = ('PATCH', f'api/2.0/repos/{repo_id}') return self._do_api_call(repos_endpoint, json) + def delete_repo(self, repo_id: str): + """ + Deletes given Databricks Repos + + :param repo_id: ID of Databricks Repos + :return: + """ + repos_endpoint = ('DELETE', f'api/2.0/repos/{repo_id}') + self._do_api_call(repos_endpoint) + + def create_repo(self, json: Dict[str, Any]) -> dict: + """ + Creates a Databricks Repos + + :param json: payload + :return: + """ + repos_endpoint = ('POST', 'api/2.0/repos') + return self._do_api_call(repos_endpoint, json) + def get_repo_by_path(self, path: str) -> Optional[str]: """ diff --git a/airflow/providers/databricks/hooks/databricks_base.py b/airflow/providers/databricks/hooks/databricks_base.py index ec856a053d005..1a418fd04e625 100644 --- a/airflow/providers/databricks/hooks/databricks_base.py +++ b/airflow/providers/databricks/hooks/databricks_base.py @@ -31,6 +31,7 @@ import requests from requests import PreparedRequest, exceptions as requests_exceptions from requests.auth import AuthBase, HTTPBasicAuth +from requests.exceptions import JSONDecodeError from tenacity import RetryError, Retrying, retry_if_exception, stop_after_attempt, wait_exponential from airflow import __version__ @@ -340,6 +341,8 @@ def _do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str, request_func = requests.post elif method == 'PATCH': request_func = requests.patch + elif method == 'DELETE': + request_func = requests.delete else: raise AirflowException('Unexpected HTTP Method: ' + method) @@ -361,13 +364,31 @@ def _do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str, except requests_exceptions.HTTPError as e: raise AirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}') + @staticmethod + def _get_error_code(exception: BaseException) -> str: + if isinstance(exception, requests_exceptions.HTTPError): + try: + jsn = exception.response.json() + return jsn.get('error_code', '') + except JSONDecodeError: + pass + + return "" + @staticmethod def _retryable_error(exception: BaseException) -> bool: if not isinstance(exception, requests_exceptions.RequestException): return False return isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or ( exception.response is not None - and (exception.response.status_code >= 500 or exception.response.status_code == 429) + and ( + exception.response.status_code >= 500 + or exception.response.status_code == 429 + or ( + exception.response.status_code == 400 + and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK' + ) + ) ) diff --git a/airflow/providers/databricks/operators/databricks_repos.py b/airflow/providers/databricks/operators/databricks_repos.py index fc50730d03d06..15543cc509385 100644 --- a/airflow/providers/databricks/operators/databricks_repos.py +++ b/airflow/providers/databricks/operators/databricks_repos.py @@ -17,8 +17,9 @@ # under the License. # """This module contains Databricks operators.""" - +import re from typing import TYPE_CHECKING, Optional, Sequence +from urllib.parse import urlparse from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -28,12 +29,142 @@ from airflow.utils.context import Context +class DatabricksReposCreateOperator(BaseOperator): + """ + Creates a Databricks Repo + using + `POST api/2.0/repos `_ + API endpoint and optionally checking it out to a specific branch or tag. + + :param git_url: Required HTTPS URL of a Git repository + :param git_provider: Optional name of Git provider. Must be provided if we can't guess its name from URL. + :param repo_path: optional path for a repository. Must be in the format ``/Repos/{folder}/{repo-name}``. + If not specified, it will be created in the user's directory. + :param branch: optional name of branch to check out. + :param tag: optional name of tag to checkout. + :param ignore_existing_repo: don't throw exception if repository with given path already exists. + :param databricks_conn_id: Reference to the :ref:`Databricks connection `. + By default and in the common case this will be ``databricks_default``. To use + token based authentication, provide the key ``token`` in the extra field for the + connection and create the key ``host`` and leave the ``host`` field empty. + :param databricks_retry_limit: Amount of times retry if the Databricks backend is + unreachable. Its value must be greater than or equal to 1. + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + """ + + # Used in airflow.models.BaseOperator + template_fields: Sequence[str] = ('repo_path', 'tag', 'branch') + + __git_providers__ = { + "github.com": "gitHub", + "dev.azure.com": "azureDevOpsServices", + "gitlab.com": "gitLab", + "bitbucket.org": "bitbucketCloud", + } + __aws_code_commit_regexp__ = re.compile(r"^git-codecommit\.[^.]+\.amazonaws.com$") + __repos_path_regexp__ = re.compile(r"/Repos/[^/]+/[^/]+/?$") + + def __init__( + self, + *, + git_url: str, + git_provider: Optional[str] = None, + branch: Optional[str] = None, + tag: Optional[str] = None, + repo_path: Optional[str] = None, + ignore_existing_repo: bool = False, + databricks_conn_id: str = 'databricks_default', + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + **kwargs, + ) -> None: + """Creates a new ``DatabricksReposCreateOperator``.""" + super().__init__(**kwargs) + self.databricks_conn_id = databricks_conn_id + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + self.git_url = git_url + self.ignore_existing_repo = ignore_existing_repo + if git_provider is None: + self.git_provider = self.__detect_repo_provider__(git_url) + if self.git_provider is None: + raise AirflowException( + "git_provider isn't specified and couldn't be guessed" f" for URL {git_url}" + ) + else: + self.git_provider = git_provider + self.repo_path = repo_path + if branch is not None and tag is not None: + raise AirflowException("Only one of branch or tag should be provided, but not both") + self.branch = branch + self.tag = tag + + @staticmethod + def __detect_repo_provider__(url): + provider = None + try: + netloc = urlparse(url).netloc + idx = netloc.rfind("@") + if idx != -1: + netloc = netloc[(idx + 1) :] + netloc = netloc.lower() + provider = DatabricksReposCreateOperator.__git_providers__.get(netloc) + if provider is None and DatabricksReposCreateOperator.__aws_code_commit_regexp__.match(netloc): + provider = "awsCodeCommit" + except ValueError: + pass + return provider + + def _get_hook(self) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + ) + + def execute(self, context: 'Context'): + """ + Creates a Databricks Repo + + :param context: context + :return: Repo ID + """ + payload = { + "url": self.git_url, + "provider": self.git_provider, + } + if self.repo_path is not None: + if not self.__repos_path_regexp__.match(self.repo_path): + raise AirflowException( + f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'" + ) + payload["path"] = self.repo_path + hook = self._get_hook() + existing_repo_id = None + if self.repo_path is not None: + existing_repo_id = hook.get_repo_by_path(self.repo_path) + if existing_repo_id is not None and not self.ignore_existing_repo: + raise AirflowException(f"Repo with path '{self.repo_path}' already exists") + if existing_repo_id is None: + result = hook.create_repo(payload) + repo_id = result["id"] + else: + repo_id = existing_repo_id + # update repo if necessary + if self.branch is not None: + hook.update_repo(str(repo_id), {'branch': str(self.branch)}) + elif self.tag is not None: + hook.update_repo(str(repo_id), {'tag': str(self.tag)}) + + return repo_id + + class DatabricksReposUpdateOperator(BaseOperator): """ - Updates specified repository to a given branch or tag using - `api/2.0/repos/ - `_ - API endpoint. + Updates specified repository to a given branch or tag + using `PATCH api/2.0/repos + `_ API endpoint. :param branch: optional name of branch to update to. Should be specified if ``tag`` is omitted :param tag: optional name of tag to update to. Should be specified if ``branch`` is omitted @@ -64,7 +195,7 @@ def __init__( databricks_retry_delay: int = 1, **kwargs, ) -> None: - """Creates a new ``DatabricksSubmitRunOperator``.""" + """Creates a new ``DatabricksReposUpdateOperator``.""" super().__init__(**kwargs) self.databricks_conn_id = databricks_conn_id self.databricks_retry_limit = databricks_retry_limit @@ -76,7 +207,7 @@ def __init__( if repo_id is not None and repo_path is not None: raise AirflowException("Only one of repo_id or repo_path should be provided, but not both") if repo_id is None and repo_path is None: - raise AirflowException("One of repo_id repo_path tag should be provided") + raise AirflowException("One of repo_id or repo_path should be provided") self.repo_path = repo_path self.repo_id = repo_id self.branch = branch @@ -102,3 +233,63 @@ def execute(self, context: 'Context'): result = hook.update_repo(str(self.repo_id), payload) return result['head_commit_id'] + + +class DatabricksReposDeleteOperator(BaseOperator): + """ + Deletes specified repository + using `DELETE api/2.0/repos + `_ API endpoint. + + :param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted + :param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted + :param databricks_conn_id: Reference to the :ref:`Databricks connection `. + By default and in the common case this will be ``databricks_default``. To use + token based authentication, provide the key ``token`` in the extra field for the + connection and create the key ``host`` and leave the ``host`` field empty. + :param databricks_retry_limit: Amount of times retry if the Databricks backend is + unreachable. Its value must be greater than or equal to 1. + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + """ + + # Used in airflow.models.BaseOperator + template_fields: Sequence[str] = ('repo_path',) + + def __init__( + self, + *, + repo_id: Optional[str] = None, + repo_path: Optional[str] = None, + databricks_conn_id: str = 'databricks_default', + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + **kwargs, + ) -> None: + """Creates a new ``DatabricksReposDeleteOperator``.""" + super().__init__(**kwargs) + self.databricks_conn_id = databricks_conn_id + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + if repo_id is not None and repo_path is not None: + raise AirflowException("Only one of repo_id or repo_path should be provided, but not both") + if repo_id is None and repo_path is None: + raise AirflowException("One of repo_id repo_path tag should be provided") + self.repo_path = repo_path + self.repo_id = repo_id + + def _get_hook(self) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + ) + + def execute(self, context: 'Context'): + hook = self._get_hook() + if self.repo_path is not None: + self.repo_id = hook.get_repo_by_path(self.repo_path) + if self.repo_id is None: + raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'") + + hook.delete_repo(str(self.repo_id)) diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index 7003fe086ceda..ba9b3f0af8bcc 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -57,7 +57,9 @@ integrations: - integration-name: Databricks Repos external-doc-url: https://docs.databricks.com/repos/index.html how-to-guide: + - /docs/apache-airflow-providers-databricks/operators/repos_create.rst - /docs/apache-airflow-providers-databricks/operators/repos_update.rst + - /docs/apache-airflow-providers-databricks/operators/repos_delete.rst logo: /integration-logos/databricks/Databricks.png tags: [service] diff --git a/docs/apache-airflow-providers-databricks/connections/databricks.rst b/docs/apache-airflow-providers-databricks/connections/databricks.rst index 5a8753991c3e8..cb62ada1130c8 100644 --- a/docs/apache-airflow-providers-databricks/connections/databricks.rst +++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst @@ -64,9 +64,9 @@ Password (optional) Extra (optional) Specify the extra parameter (as json dictionary) that can be used in the Databricks connection. - Following parameter should be used if using the *PAT* authentication method: + Following parameter could be used if using the *PAT* authentication method: - * ``token``: Specify PAT to use. Note, the PAT must appear in both the Password field as the token value in Extra. + * ``token``: Specify PAT to use. Consider to switch to specification of PAT in the Password field as it's more secure. Following parameters are necessary if using authentication with AAD token: diff --git a/docs/apache-airflow-providers-databricks/operators/repos_create.rst b/docs/apache-airflow-providers-databricks/operators/repos_create.rst new file mode 100644 index 0000000000000..fc04340796d49 --- /dev/null +++ b/docs/apache-airflow-providers-databricks/operators/repos_create.rst @@ -0,0 +1,69 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +DatabricksReposCreateOperator +============================= + +Use the :class:`~airflow.providers.databricks.operators.DatabricksReposCreateOperator` to create (and optionally checkout) a +`Databricks Repos `_ +via `api/2.0/repos `_ API endpoint. + + +Using the Operator +^^^^^^^^^^^^^^^^^^ + +To use this operator you need to provide at least ``git_url`` parameter. + +.. list-table:: + :widths: 15 25 + :header-rows: 1 + + * - Parameter + - Input + * - git_url: str + - Required HTTPS URL of a Git repository + * - git_provider: str + - Optional name of Git provider. Must be provided if we can't guess its name from URL. See API documentation for actual list of supported Git providers. + * - branch: str + - Optional name of the existing Git branch to checkout. + * - tag: str + - Optional name of the existing Git tag to checkout. + * - repo_path: str + - Optional path to a Databricks Repos, like, ``/Repos//repo_name``. If not specified, it will be created in the user's directory. + * - ignore_existing_repo: bool + - Don't throw exception if repository with given path already exists. + * - databricks_conn_id: string + - the name of the Airflow connection to use. + * - databricks_retry_limit: integer + - amount of times retry if the Databricks backend is unreachable. + * - databricks_retry_delay: decimal + - number of seconds to wait between retries. + +Examples +-------- + +Create a Databricks Repo +^^^^^^^^^^^^^^^^^^^^^^^^ + +An example usage of the DatabricksReposCreateOperator is as follows: + +.. exampleinclude:: /../../airflow/providers/databricks/example_dags/example_databricks_repos.py + :language: python + :start-after: [START howto_operator_databricks_repo_create] + :end-before: [END howto_operator_databricks_repo_create] diff --git a/docs/apache-airflow-providers-databricks/operators/repos_delete.rst b/docs/apache-airflow-providers-databricks/operators/repos_delete.rst new file mode 100644 index 0000000000000..e359deb7c9170 --- /dev/null +++ b/docs/apache-airflow-providers-databricks/operators/repos_delete.rst @@ -0,0 +1,61 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +DatabricksReposDeleteOperator +============================= + +Use the :class:`~airflow.providers.databricks.operators.DatabricksReposDeleteOperator` to delete an existing +`Databricks Repo `_ +via `api/2.0/repos/ `_ API endpoint. + + +Using the Operator +^^^^^^^^^^^^^^^^^^ + +To use this operator you need to provide either ``repo_path`` or ``repo_id``. + +.. list-table:: + :widths: 15 25 + :header-rows: 1 + + * - Parameter + - Input + * - repo_path: str + - Path to existing Databricks Repos, like, ``/Repos//repo_name`` (required if ``repo_id`` isn't provided). + * - repo_id: str + - ID of existing Databricks Repos (required if ``repo_path`` isn't provided). + * - databricks_conn_id: string + - the name of the Airflow connection to use. + * - databricks_retry_limit: integer + - amount of times retry if the Databricks backend is unreachable. + * - databricks_retry_delay: decimal + - number of seconds to wait between retries. + +Examples +-------- + +Deleting Databricks Repo by specifying path +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An example usage of the DatabricksReposDeleteOperator is as follows: + +.. exampleinclude:: /../../airflow/providers/databricks/example_dags/example_databricks_repos.py + :language: python + :start-after: [START howto_operator_databricks_repo_delete] + :end-before: [END howto_operator_databricks_repo_delete] diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 0d1bd09cdaa20..c467234e09bc2 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -571,7 +571,7 @@ def test_no_wait_for_termination(self, db_mock_class): db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_not_called() - def test_init_exeption_with_job_name_and_job_id(self): + def test_init_exception_with_job_name_and_job_id(self): exception_message = "Argument 'job_name' is not allowed with argument 'job_id'" with pytest.raises(AirflowException, match=exception_message): diff --git a/tests/providers/databricks/operators/test_databricks_repos.py b/tests/providers/databricks/operators/test_databricks_repos.py index ad8ccdc82ef8b..aaf03261ba7fa 100644 --- a/tests/providers/databricks/operators/test_databricks_repos.py +++ b/tests/providers/databricks/operators/test_databricks_repos.py @@ -19,7 +19,14 @@ import unittest from unittest import mock -from airflow.providers.databricks.operators.databricks_repos import DatabricksReposUpdateOperator +import pytest + +from airflow import AirflowException +from airflow.providers.databricks.operators.databricks_repos import ( + DatabricksReposCreateOperator, + DatabricksReposDeleteOperator, + DatabricksReposUpdateOperator, +) TASK_ID = 'databricks-operator' DEFAULT_CONN_ID = 'databricks_default' @@ -29,7 +36,7 @@ class TestDatabricksReposUpdateOperator(unittest.TestCase): @mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook') def test_update_with_id(self, db_mock_class): """ - Test the execute function in case where the run is successful. + Test the execute function using Repo ID. """ op = DatabricksReposUpdateOperator(task_id=TASK_ID, branch="releases", repo_id="123") db_mock = db_mock_class.return_value @@ -46,7 +53,7 @@ def test_update_with_id(self, db_mock_class): @mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook') def test_update_with_path(self, db_mock_class): """ - Test the execute function in case where the run is successful. + Test the execute function using Repo path. """ op = DatabricksReposUpdateOperator( task_id=TASK_ID, tag="v1.0.0", repo_path="/Repos/user@domain.com/test-repo" @@ -62,3 +69,144 @@ def test_update_with_path(self, db_mock_class): ) db_mock.update_repo.assert_called_once_with('123', {'tag': 'v1.0.0'}) + + def test_init_exception(self): + """ + Tests handling of incorrect parameters passed to ``__init__`` + """ + with pytest.raises( + AirflowException, match="Only one of repo_id or repo_path should be provided, but not both" + ): + DatabricksReposUpdateOperator(task_id=TASK_ID, repo_id="abc", repo_path="path", branch="abc") + + with pytest.raises(AirflowException, match="One of repo_id or repo_path should be provided"): + DatabricksReposUpdateOperator(task_id=TASK_ID, branch="abc") + + with pytest.raises( + AirflowException, match="Only one of branch or tag should be provided, but not both" + ): + DatabricksReposUpdateOperator(task_id=TASK_ID, repo_id="123", branch="123", tag="123") + + with pytest.raises(AirflowException, match="One of branch or tag should be provided"): + DatabricksReposUpdateOperator(task_id=TASK_ID, repo_id="123") + + +class TestDatabricksReposDeleteOperator(unittest.TestCase): + @mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook') + def test_delete_with_id(self, db_mock_class): + """ + Test the execute function using Repo ID. + """ + op = DatabricksReposDeleteOperator(task_id=TASK_ID, repo_id="123") + db_mock = db_mock_class.return_value + db_mock.delete_repo.return_value = None + + op.execute(None) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + ) + + db_mock.delete_repo.assert_called_once_with('123') + + @mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook') + def test_delete_with_path(self, db_mock_class): + """ + Test the execute function using Repo path. + """ + op = DatabricksReposDeleteOperator(task_id=TASK_ID, repo_path="/Repos/user@domain.com/test-repo") + db_mock = db_mock_class.return_value + db_mock.get_repo_by_path.return_value = '123' + db_mock.delete_repo.return_value = None + + op.execute(None) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + ) + + db_mock.delete_repo.assert_called_once_with('123') + + def test_init_exception(self): + """ + Tests handling of incorrect parameters passed to ``__init__`` + """ + with pytest.raises( + AirflowException, match="Only one of repo_id or repo_path should be provided, but not both" + ): + DatabricksReposDeleteOperator(task_id=TASK_ID, repo_id="abc", repo_path="path") + + with pytest.raises(AirflowException, match="One of repo_id repo_path tag should be provided"): + DatabricksReposDeleteOperator(task_id=TASK_ID) + + +class TestDatabricksReposCreateOperator(unittest.TestCase): + @mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook') + def test_create_plus_checkout(self, db_mock_class): + """ + Test the execute function creating new Repo. + """ + git_url = "https://github.com/test/test" + repo_path = '/Repos/Project1/test-repo' + op = DatabricksReposCreateOperator( + task_id=TASK_ID, git_url=git_url, repo_path=repo_path, branch="releases" + ) + db_mock = db_mock_class.return_value + db_mock.update_repo.return_value = {'head_commit_id': '123456'} + db_mock.create_repo.return_value = {'id': '123', 'branch': 'main'} + db_mock.get_repo_by_path.return_value = None + + op.execute(None) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + ) + + db_mock.create_repo.assert_called_once_with({'url': git_url, 'provider': 'gitHub', 'path': repo_path}) + db_mock.update_repo.assert_called_once_with('123', {'branch': 'releases'}) + + @mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook') + def test_create_ignore_existing_plus_checkout(self, db_mock_class): + """ + Test the execute function creating new Repo. + """ + git_url = "https://github.com/test/test" + repo_path = '/Repos/Project1/test-repo' + op = DatabricksReposCreateOperator( + task_id=TASK_ID, + git_url=git_url, + repo_path=repo_path, + branch="releases", + ignore_existing_repo=True, + ) + db_mock = db_mock_class.return_value + db_mock.update_repo.return_value = {'head_commit_id': '123456'} + db_mock.get_repo_by_path.return_value = '123' + + op.execute(None) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + ) + + db_mock.get_repo_by_path.assert_called_once_with(repo_path) + db_mock.update_repo.assert_called_once_with('123', {'branch': 'releases'}) + + def test_init_exception(self): + """ + Tests handling of incorrect parameters passed to ``__init__`` + """ + git_url = "https://github.com/test/test" + repo_path = '/Repos/test-repo' + exception_message = ( + f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{repo_path}'" + ) + + with pytest.raises(AirflowException, match=exception_message): + op = DatabricksReposCreateOperator(task_id=TASK_ID, git_url=git_url, repo_path=repo_path) + op.execute(None) + + with pytest.raises( + AirflowException, match="Only one of branch or tag should be provided, but not both" + ): + DatabricksReposCreateOperator(task_id=TASK_ID, git_url=git_url, branch="123", tag="123")