Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions providers/github/docs/connections/github.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@

GitHub Connection
====================
The GitHub connection type provides connection to a GitHub or GitHub Enterprise.
The GitHub connection provides two authentication mechanisms:
- Token-based authentication
- GitHub App authentication

For Token-based authentication, you must provide an access token.
For GitHub App authentication, you must configure the connection's Extras field with the required GitHub App parameters.

Configuring the Connection
--------------------------
Access Token (required)
Access Token (optional)
Personal Access token with required permissions.
- GitHub - Create token - https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token/
- GitHub Enterprise - Create token - https://docs.github.com/en/enterprise-cloud@latest/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token/
Expand All @@ -40,3 +45,27 @@ Host (optional)
.. code-block::

https://{hostname}/api/v3

GitHub App authentication
--------------------------------

You can authenticate using a GitHub App installation by setting the extra field of your connection, instead of using a token.

- ``key_path``: Path to the private key file used for GitHub App authentication.
- ``app_id``: The application ID.
- ``installation_id``: The ID of the app installation.
- ``token_permissions``: A dictionary of permissions. - Properties of permissions - https://docs.github.com/en/rest/apps/apps?apiVersion=2022-11-28#create-an-installation-access-token-for-an-app

Example "extras" field:

.. code-block:: json

{
"key_path": "FAKE_KEY.pem",
"app_id": "123456s",
"installation_id": 123456789,
"token_permissions": {
"issues":"write",
"contents":"read"
}
}
37 changes: 27 additions & 10 deletions providers/github/src/airflow/providers/github/hooks/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

from typing import TYPE_CHECKING

from github import Github as GithubClient
from github import Auth, Github as GithubClient

from airflow.exceptions import AirflowException
from airflow.providers.common.compat.sdk import BaseHook


Expand Down Expand Up @@ -55,17 +54,35 @@ def get_conn(self) -> GithubClient:
conn = self.get_connection(self.github_conn_id)
access_token = conn.password
host = conn.host

# Currently the only method of authenticating to GitHub in Airflow is via a token. This is not the
# only means available, but raising an exception to enforce this method for now.
# TODO: When/If other auth methods are implemented this exception should be removed/modified.
if not access_token:
raise AirflowException("An access token is required to authenticate to GitHub.")
extras = conn.extra_dejson or {}

if access_token:
auth: Auth.Auth = Auth.Token(access_token)
elif extras:
if key_path := extras.get("key_path"):
if not key_path.endswith(".pem"):
raise ValueError("Unrecognised key file: expected a .pem private key")
with open(key_path) as key_file:
private_key = key_file.read()
else:
raise ValueError("No key_path provided for GitHub App authentication.")

app_id = extras.get("app_id")
installation_id = extras.get("installation_id")
if not isinstance(installation_id, int):
raise ValueError("The provided installation_id should be integer.")
if not isinstance(app_id, (str | int)):
raise ValueError("The provided app_id should be integer or string.")
token_permissions = extras.get("token_permissions", None)

auth = Auth.AppAuth(app_id, private_key).get_installation_auth(installation_id, token_permissions)
else:
raise ValueError("No access token or authentication method provided.")

if not host:
self.client = GithubClient(login_or_token=access_token)
self.client = GithubClient(auth=auth)
else:
self.client = GithubClient(login_or_token=access_token, base_url=host)
self.client = GithubClient(auth=auth, base_url=host)

return self.client

Expand Down
99 changes: 94 additions & 5 deletions providers/github/tests/unit/github/hooks/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from unittest.mock import Mock, patch
from unittest.mock import Mock, mock_open, patch

import pytest
from github import BadCredentialsException, Github, NamedUser
Expand All @@ -26,6 +26,7 @@
from airflow.providers.github.hooks.github import GithubHook

github_client_mock = Mock(name="github_client_for_test")
github_app_client_mock = Mock(name="github_app_client_for_test")


class TestGithubHook:
Expand All @@ -40,6 +41,19 @@ def setup_connections(self, create_connection_without_db):
host="https://mygithub.com/api/v3",
)
)
create_connection_without_db(
Connection(
conn_id="github_app_conn",
conn_type="github",
host="https://mygithub.com/api/v3",
extra={
"app_id": "123456",
"installation_id": 654321,
"key_path": "FAKE_PRIVATE_KEY.pem",
"token_permissions": {"issues": "write", "pull_requests": "read"},
},
)
)

@patch(
"airflow.providers.github.hooks.github.GithubClient", autospec=True, return_value=github_client_mock
Expand All @@ -51,8 +65,14 @@ def test_github_client_connection(self, github_mock):
assert isinstance(github_hook.client, Mock)
assert github_hook.client.name == github_mock.return_value.name

def test_connection_success(self):
hook = GithubHook()
@pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"])
@patch(
"airflow.providers.github.hooks.github.open",
new_callable=mock_open,
read_data="FAKE_PRIVATE_KEY_CONTENT",
)
def test_connection_success(self, mock_file, conn_id):
hook = GithubHook(github_conn_id=conn_id)
hook.client = Mock(spec=Github)
hook.client.get_user.return_value = NamedUser.NamedUser

Expand All @@ -61,8 +81,14 @@ def test_connection_success(self):
assert status is True
assert msg == "Successfully connected to GitHub."

def test_connection_failure(self):
hook = GithubHook()
@pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"])
@patch(
"airflow.providers.github.hooks.github.open",
new_callable=mock_open,
read_data="FAKE_PRIVATE_KEY_CONTENT",
)
def test_connection_failure(self, mock_file, conn_id):
hook = GithubHook(github_conn_id=conn_id)
hook.client.get_user = Mock(
side_effect=BadCredentialsException(
status=401,
Expand All @@ -74,3 +100,66 @@ def test_connection_failure(self):

assert status is False
assert msg == '401 {"message": "Bad credentials"}'

@pytest.mark.parametrize(
(
"conn_id",
"extra",
"expected_error_message",
),
[
# Wrong key file extension
(
"invalid_key_path",
{"app_id": "1", "installation_id": 1, "key_path": "wrong_ext.txt"},
"Unrecognised key file: expected a .pem private key",
),
# Missing key_path
(
"missing_key_path",
{"app_id": "1", "installation_id": 1},
"No key_path provided for GitHub App authentication.",
),
# installation_id is not integer
(
"invalid_install_id",
{"app_id": "1", "installation_id": "654321_string", "key_path": "key.pem"},
"The provided installation_id should be integer.",
),
# app_id is not integer or string
(
"invalid_app_id",
{"app_id": ["123456_list"], "installation_id": 1, "key_path": "key.pem"},
"The provided app_id should be integer or string.",
),
# No access token or authentication method provided
(
"no_auth_conn",
{},
"No access token or authentication method provided.",
),
],
)
@patch("airflow.providers.github.hooks.github.GithubHook.get_connection")
@patch(
"airflow.providers.github.hooks.github.open",
new_callable=mock_open,
read_data="FAKE_PRIVATE_KEY_CONTENT",
)
def test_get_conn_value_error_cases(
self,
mock_file,
get_connection_mock,
conn_id,
extra,
expected_error_message,
):
mock_conn = Connection(
conn_id=conn_id,
conn_type="github",
extra=extra,
)
get_connection_mock.return_value = mock_conn

with pytest.raises(ValueError, match=expected_error_message):
GithubHook(github_conn_id=conn_id)
34 changes: 30 additions & 4 deletions providers/github/tests/unit/github/operators/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
# under the License.
from __future__ import annotations

from unittest.mock import Mock, patch
from unittest.mock import Mock, mock_open, patch

import pytest

from airflow.models import Connection
from airflow.models.dag import DAG
from airflow.providers.github.operators.github import GithubOperator
from airflow.utils import timezone

try:
from airflow.sdk import timezone
except ImportError:
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]

DEFAULT_DATE = timezone.datetime(2017, 1, 1)
github_client_mock = Mock(name="github_client_for_test")
Expand All @@ -42,26 +46,47 @@ def setup_connections(self, create_connection_without_db):
host="https://mygithub.com/api/v3",
)
)
create_connection_without_db(
Connection(
conn_id="github_app_conn",
conn_type="github",
host="https://mygithub.com/api/v3",
extra={
"app_id": "123456",
"installation_id": 654321,
"key_path": "FAKE_PRIVATE_KEY.pem",
"token_permissions": {"issues": "write", "pull_requests": "read"},
},
)
)

def setup_class(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG("test_dag_id", schedule=None, default_args=args)
self.dag = dag

def test_operator_init_with_optional_args(self):
@pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"])
def test_operator_init_with_optional_args(self, conn_id):
github_operator = GithubOperator(
task_id="github_list_repos",
github_method="get_user",
github_conn_id=conn_id,
)

assert github_operator.github_method_args == {}
assert github_operator.result_processor is None

@pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"])
@pytest.mark.db_test
@patch(
"airflow.providers.github.hooks.github.GithubClient", autospec=True, return_value=github_client_mock
)
def test_find_repos(self, github_mock, dag_maker):
@patch(
"airflow.providers.github.hooks.github.open",
new_callable=mock_open,
read_data="FAKE_PRIVATE_KEY_CONTENT",
)
def test_find_repos(self, mock_file, github_mock, dag_maker, conn_id):
class MockRepository:
pass

Expand All @@ -74,6 +99,7 @@ class MockRepository:
task_id="github-test",
github_method="get_repo",
github_method_args={"full_name_or_id": "apache/airflow"},
github_conn_id=conn_id,
result_processor=lambda r: r.full_name,
)
dr = dag_maker.create_dagrun()
Expand Down
32 changes: 28 additions & 4 deletions providers/github/tests/unit/github/sensors/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
# under the License.
from __future__ import annotations

from unittest.mock import Mock, patch
from unittest.mock import Mock, mock_open, patch

import pytest

from airflow.models import Connection
from airflow.models.dag import DAG
from airflow.providers.github.sensors.github import GithubTagSensor
from airflow.utils import timezone

try:
from airflow.sdk import timezone
except ImportError:
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]

DEFAULT_DATE = timezone.datetime(2017, 1, 1)
github_client_mock = Mock(name="github_client_for_test")
Expand All @@ -42,18 +46,37 @@ def setup_connections(self, create_connection_without_db):
host="https://mygithub.com/api/v3",
)
)
create_connection_without_db(
Connection(
conn_id="github_app_conn",
conn_type="github",
host="https://mygithub.com/api/v3",
extra={
"app_id": "123456",
"installation_id": 654321,
"key_path": "FAKE_PRIVATE_KEY.pem",
"token_permissions": {"issues": "write", "pull_requests": "read"},
},
)
)

def setup_class(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG("test_dag_id", schedule=None, default_args=args)
self.dag = dag

@pytest.mark.parametrize("conn_id", ["github_default", "github_app_conn"])
@patch(
"airflow.providers.github.hooks.github.GithubClient",
autospec=True,
return_value=github_client_mock,
)
def test_github_tag_created(self, github_mock):
@patch(
"airflow.providers.github.hooks.github.open",
new_callable=mock_open,
read_data="FAKE_PRIVATE_KEY_CONTENT",
)
def test_github_tag_created(self, mock_file, github_mock, conn_id):
class MockTag:
pass

Expand All @@ -63,12 +86,13 @@ class MockTag:
github_mock.return_value.get_repo.return_value.get_tags.return_value = [tag]

github_tag_sensor = GithubTagSensor(
task_id="search-ticket-test",
task_id=f"search-ticket-test-{conn_id}",
tag_name="v1.0",
repository_name="pateash/jetbrains_settings",
timeout=60,
poke_interval=10,
dag=self.dag,
github_conn_id=conn_id,
)

github_tag_sensor.execute({})
Expand Down