Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion tests/unit/oidc/models/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def test_github_publisher_missing_claims(self, monkeypatch, missing):
]
assert scope.fingerprint == [missing]

def test_github_publisher_missing_optional_claims(self, monkeypatch):
def test_github_publisher_missing_optional_claims(self, metrics, monkeypatch):
publisher = github.GitHubPublisher(
repository_name="fakerepo",
repository_owner="fakeowner",
Expand All @@ -352,6 +352,7 @@ def test_github_publisher_missing_optional_claims(self, monkeypatch):

service_ = pretend.stub(
jwt_identifier_exists=pretend.call_recorder(lambda s: False),
metrics=metrics,
)

signed_claims = {
Expand Down Expand Up @@ -433,6 +434,23 @@ def test_check_repository(self, truth, claim, valid):
check = github.GitHubPublisher.__required_verifiable_claims__["repository"]
assert check(truth, claim, pretend.stub()) == valid

def test_check_event_name_emits_metrics(self, metrics):
check = github.GitHubPublisher.__required_verifiable_claims__["event_name"]
publisher_service = pretend.stub(metrics=metrics)

assert check(
"throwaway",
"pull_request_target",
pretend.stub(),
publisher_service=publisher_service,
)
assert metrics.increment.calls == [
pretend.call(
"warehouse.oidc.claim",
tags=["publisher:GitHub", "event_name:pull_request_target"],
),
]

@pytest.mark.parametrize(
("claim", "ref", "sha", "valid", "expected"),
[
Expand Down
21 changes: 20 additions & 1 deletion warehouse/oidc/models/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
if typing.TYPE_CHECKING:
from sqlalchemy.orm import Session

from warehouse.oidc.services import OIDCPublisherService

GITHUB_OIDC_ISSUER_URL = "https://token.actions.githubusercontent.com"

# This expression matches the workflow filename component of a GitHub
Expand Down Expand Up @@ -152,6 +154,18 @@ def _check_sub(
return f"{org}:{repo}".lower() == ground_truth.lower()


def _check_event_name(
ground_truth: str, signed_claim: str, _all_signed_claims: SignedClaims, **kwargs
) -> bool:
# Log the event name
publisher_service: OIDCPublisherService = kwargs["publisher_service"]
publisher_service.metrics.increment(
"warehouse.oidc.claim", tags=["publisher:GitHub", f"event_name:{signed_claim}"]
)
# Always permit all event names for now
return True


class GitHubPublisherMixin:
"""
Common functionality for both pending and concrete GitHub OIDC publishers.
Expand All @@ -170,6 +184,7 @@ class GitHubPublisherMixin:
"repository_owner_id": check_claim_binary(str.__eq__),
"job_workflow_ref": _check_job_workflow_ref,
"jti": check_existing_jti,
"event_name": _check_event_name,
}

__required_unverifiable_claims__: set[str] = {"ref", "sha"}
Expand All @@ -186,7 +201,6 @@ class GitHubPublisherMixin:
"run_attempt",
"head_ref",
"base_ref",
"event_name",
"ref_type",
"repository_id",
"workflow",
Expand Down Expand Up @@ -275,6 +289,11 @@ def jti(self) -> str:
"""Placeholder value for JTI."""
return "placeholder"

@property
def event_name(self) -> str:
"""Placeholder value for event_name (not used)"""
return "placeholder"

def publisher_url(self, claims: SignedClaims | None = None) -> str:
base = self.publisher_base_url
sha = claims.get("sha") if claims else None
Expand Down