diff --git a/providers/github/docs/connections/github.rst b/providers/github/docs/connections/github.rst index 8934207875b10..20a5cbe19d3ba 100644 --- a/providers/github/docs/connections/github.rst +++ b/providers/github/docs/connections/github.rst @@ -20,11 +20,16 @@ GitHub Connection ==================== -The GitHub connection type provides connection to a GitHub or GitHub Enterprise. +The GitHub connection provides two authentication mechanisms: + - Token-based authentication + - GitHub App authentication + +For Token-based authentication, you must provide an access token. +For GitHub App authentication, you must configure the connection's Extras field with the required GitHub App parameters. Configuring the Connection -------------------------- -Access Token (required) +Access Token (optional) Personal Access token with required permissions. - GitHub - Create token - https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token/ - GitHub Enterprise - Create token - https://docs.github.com/en/enterprise-cloud@latest/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token/ @@ -40,3 +45,27 @@ Host (optional) .. code-block:: https://{hostname}/api/v3 + +GitHub App authentication +-------------------------------- + +You can authenticate using a GitHub App installation by setting the extra field of your connection, instead of using a token. + +- ``key_path``: Path to the private key file used for GitHub App authentication. +- ``app_id``: The application ID. +- ``installation_id``: The ID of the app installation. +- ``token_permissions``: A dictionary of permissions. - Properties of permissions - https://docs.github.com/en/rest/apps/apps?apiVersion=2022-11-28#create-an-installation-access-token-for-an-app + +Example "extras" field: + +.. code-block:: json + + { + "key_path": "FAKE_KEY.pem", + "app_id": "123456s", + "installation_id": 123456789, + "token_permissions": { + "issues":"write", + "contents":"read" + } + } diff --git a/providers/github/src/airflow/providers/github/hooks/github.py b/providers/github/src/airflow/providers/github/hooks/github.py index d0b60d3908c1f..2e93e5e1c633e 100644 --- a/providers/github/src/airflow/providers/github/hooks/github.py +++ b/providers/github/src/airflow/providers/github/hooks/github.py @@ -21,9 +21,8 @@ from typing import TYPE_CHECKING -from github import Github as GithubClient +from github import Auth, Github as GithubClient -from airflow.exceptions import AirflowException from airflow.providers.common.compat.sdk import BaseHook @@ -55,17 +54,35 @@ def get_conn(self) -> GithubClient: conn = self.get_connection(self.github_conn_id) access_token = conn.password host = conn.host - - # Currently the only method of authenticating to GitHub in Airflow is via a token. This is not the - # only means available, but raising an exception to enforce this method for now. - # TODO: When/If other auth methods are implemented this exception should be removed/modified. - if not access_token: - raise AirflowException("An access token is required to authenticate to GitHub.") + extras = conn.extra_dejson or {} + + if access_token: + auth: Auth.Auth = Auth.Token(access_token) + elif extras: + if key_path := extras.get("key_path"): + if not key_path.endswith(".pem"): + raise ValueError("Unrecognised key file: expected a .pem private key") + with open(key_path) as key_file: + private_key = key_file.read() + else: + raise ValueError("No key_path provided for GitHub App authentication.") + + app_id = extras.get("app_id") + installation_id = extras.get("installation_id") + if not isinstance(installation_id, int): + raise ValueError("The provided installation_id should be integer.") + if not isinstance(app_id, (str | int)): + raise ValueError("The provided app_id should be integer or string.") + token_permissions = extras.get("token_permissions", None) + + auth = Auth.AppAuth(app_id, private_key).get_installation_auth(installation_id, token_permissions) + else: + raise ValueError("No access token or authentication method provided.") if not host: - self.client = GithubClient(login_or_token=access_token) + self.client = GithubClient(auth=auth) else: - self.client = GithubClient(login_or_token=access_token, base_url=host) + self.client = GithubClient(auth=auth, base_url=host) return self.client diff --git a/providers/github/tests/unit/github/hooks/test_github.py b/providers/github/tests/unit/github/hooks/test_github.py index f7e294f9c915e..6982637a3646d 100644 --- a/providers/github/tests/unit/github/hooks/test_github.py +++ b/providers/github/tests/unit/github/hooks/test_github.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from unittest.mock import Mock, patch +from unittest.mock import Mock, mock_open, patch import pytest from github import BadCredentialsException, Github, NamedUser @@ -26,6 +26,7 @@ from airflow.providers.github.hooks.github import GithubHook github_client_mock = Mock(name="github_client_for_test") +github_app_client_mock = Mock(name="github_app_client_for_test") class TestGithubHook: @@ -40,6 +41,19 @@ def setup_connections(self, create_connection_without_db): host="https://mygithub.com/api/v3", ) ) + create_connection_without_db( + Connection( + conn_id="github_app_conn", + conn_type="github", + host="https://mygithub.com/api/v3", + extra={ + "app_id": "123456", + "installation_id": 654321, + "key_path": "FAKE_PRIVATE_KEY.pem", + "token_permissions": {"issues": "write", "pull_requests": "read"}, + }, + ) + ) @patch( "airflow.providers.github.hooks.github.GithubClient", autospec=True, return_value=github_client_mock @@ -51,8 +65,14 @@ def test_github_client_connection(self, github_mock): assert isinstance(github_hook.client, Mock) assert github_hook.client.name == github_mock.return_value.name - def test_connection_success(self): - hook = GithubHook() + @pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"]) + @patch( + "airflow.providers.github.hooks.github.open", + new_callable=mock_open, + read_data="FAKE_PRIVATE_KEY_CONTENT", + ) + def test_connection_success(self, mock_file, conn_id): + hook = GithubHook(github_conn_id=conn_id) hook.client = Mock(spec=Github) hook.client.get_user.return_value = NamedUser.NamedUser @@ -61,8 +81,14 @@ def test_connection_success(self): assert status is True assert msg == "Successfully connected to GitHub." - def test_connection_failure(self): - hook = GithubHook() + @pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"]) + @patch( + "airflow.providers.github.hooks.github.open", + new_callable=mock_open, + read_data="FAKE_PRIVATE_KEY_CONTENT", + ) + def test_connection_failure(self, mock_file, conn_id): + hook = GithubHook(github_conn_id=conn_id) hook.client.get_user = Mock( side_effect=BadCredentialsException( status=401, @@ -74,3 +100,66 @@ def test_connection_failure(self): assert status is False assert msg == '401 {"message": "Bad credentials"}' + + @pytest.mark.parametrize( + ( + "conn_id", + "extra", + "expected_error_message", + ), + [ + # Wrong key file extension + ( + "invalid_key_path", + {"app_id": "1", "installation_id": 1, "key_path": "wrong_ext.txt"}, + "Unrecognised key file: expected a .pem private key", + ), + # Missing key_path + ( + "missing_key_path", + {"app_id": "1", "installation_id": 1}, + "No key_path provided for GitHub App authentication.", + ), + # installation_id is not integer + ( + "invalid_install_id", + {"app_id": "1", "installation_id": "654321_string", "key_path": "key.pem"}, + "The provided installation_id should be integer.", + ), + # app_id is not integer or string + ( + "invalid_app_id", + {"app_id": ["123456_list"], "installation_id": 1, "key_path": "key.pem"}, + "The provided app_id should be integer or string.", + ), + # No access token or authentication method provided + ( + "no_auth_conn", + {}, + "No access token or authentication method provided.", + ), + ], + ) + @patch("airflow.providers.github.hooks.github.GithubHook.get_connection") + @patch( + "airflow.providers.github.hooks.github.open", + new_callable=mock_open, + read_data="FAKE_PRIVATE_KEY_CONTENT", + ) + def test_get_conn_value_error_cases( + self, + mock_file, + get_connection_mock, + conn_id, + extra, + expected_error_message, + ): + mock_conn = Connection( + conn_id=conn_id, + conn_type="github", + extra=extra, + ) + get_connection_mock.return_value = mock_conn + + with pytest.raises(ValueError, match=expected_error_message): + GithubHook(github_conn_id=conn_id) diff --git a/providers/github/tests/unit/github/operators/test_github.py b/providers/github/tests/unit/github/operators/test_github.py index b6c6cf681fbb4..25176e222a76a 100644 --- a/providers/github/tests/unit/github/operators/test_github.py +++ b/providers/github/tests/unit/github/operators/test_github.py @@ -17,14 +17,18 @@ # under the License. from __future__ import annotations -from unittest.mock import Mock, patch +from unittest.mock import Mock, mock_open, patch import pytest from airflow.models import Connection from airflow.models.dag import DAG from airflow.providers.github.operators.github import GithubOperator -from airflow.utils import timezone + +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] DEFAULT_DATE = timezone.datetime(2017, 1, 1) github_client_mock = Mock(name="github_client_for_test") @@ -42,26 +46,47 @@ def setup_connections(self, create_connection_without_db): host="https://mygithub.com/api/v3", ) ) + create_connection_without_db( + Connection( + conn_id="github_app_conn", + conn_type="github", + host="https://mygithub.com/api/v3", + extra={ + "app_id": "123456", + "installation_id": 654321, + "key_path": "FAKE_PRIVATE_KEY.pem", + "token_permissions": {"issues": "write", "pull_requests": "read"}, + }, + ) + ) def setup_class(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG("test_dag_id", schedule=None, default_args=args) self.dag = dag - def test_operator_init_with_optional_args(self): + @pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"]) + def test_operator_init_with_optional_args(self, conn_id): github_operator = GithubOperator( task_id="github_list_repos", github_method="get_user", + github_conn_id=conn_id, ) assert github_operator.github_method_args == {} assert github_operator.result_processor is None + @pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"]) @pytest.mark.db_test @patch( "airflow.providers.github.hooks.github.GithubClient", autospec=True, return_value=github_client_mock ) - def test_find_repos(self, github_mock, dag_maker): + @patch( + "airflow.providers.github.hooks.github.open", + new_callable=mock_open, + read_data="FAKE_PRIVATE_KEY_CONTENT", + ) + def test_find_repos(self, mock_file, github_mock, dag_maker, conn_id): class MockRepository: pass @@ -74,6 +99,7 @@ class MockRepository: task_id="github-test", github_method="get_repo", github_method_args={"full_name_or_id": "apache/airflow"}, + github_conn_id=conn_id, result_processor=lambda r: r.full_name, ) dr = dag_maker.create_dagrun() diff --git a/providers/github/tests/unit/github/sensors/test_github.py b/providers/github/tests/unit/github/sensors/test_github.py index b885be43ddedd..a9a414bae377e 100644 --- a/providers/github/tests/unit/github/sensors/test_github.py +++ b/providers/github/tests/unit/github/sensors/test_github.py @@ -17,14 +17,18 @@ # under the License. from __future__ import annotations -from unittest.mock import Mock, patch +from unittest.mock import Mock, mock_open, patch import pytest from airflow.models import Connection from airflow.models.dag import DAG from airflow.providers.github.sensors.github import GithubTagSensor -from airflow.utils import timezone + +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] DEFAULT_DATE = timezone.datetime(2017, 1, 1) github_client_mock = Mock(name="github_client_for_test") @@ -42,18 +46,37 @@ def setup_connections(self, create_connection_without_db): host="https://mygithub.com/api/v3", ) ) + create_connection_without_db( + Connection( + conn_id="github_app_conn", + conn_type="github", + host="https://mygithub.com/api/v3", + extra={ + "app_id": "123456", + "installation_id": 654321, + "key_path": "FAKE_PRIVATE_KEY.pem", + "token_permissions": {"issues": "write", "pull_requests": "read"}, + }, + ) + ) def setup_class(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG("test_dag_id", schedule=None, default_args=args) self.dag = dag + @pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"]) @patch( "airflow.providers.github.hooks.github.GithubClient", autospec=True, return_value=github_client_mock, ) - def test_github_tag_created(self, github_mock): + @patch( + "airflow.providers.github.hooks.github.open", + new_callable=mock_open, + read_data="FAKE_PRIVATE_KEY_CONTENT", + ) + def test_github_tag_created(self, mock_file, github_mock, conn_id): class MockTag: pass @@ -63,12 +86,13 @@ class MockTag: github_mock.return_value.get_repo.return_value.get_tags.return_value = [tag] github_tag_sensor = GithubTagSensor( - task_id="search-ticket-test", + task_id=f"search-ticket-test-{conn_id}", tag_name="v1.0", repository_name="pateash/jetbrains_settings", timeout=60, poke_interval=10, dag=self.dag, + github_conn_id=conn_id, ) github_tag_sensor.execute({})