Skip to content

Commit

Permalink
refactor: register repo checks with a decorator
Browse files Browse the repository at this point in the history
Removes a bit of redundancy / potential typos inherent in the
CHECK list.

This will also make it easier to break repo_checks.py into multiple modules,
if we ever decide to do that.
  • Loading branch information
kdmccormick committed Aug 7, 2024
1 parent 7e4a5c8 commit 165e7d6
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions edx_repo_tools/repo_checks/repo_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import importlib.resources
import re
import textwrap
import typing as t
from functools import cache
from itertools import chain
from pprint import pformat
Expand Down Expand Up @@ -104,11 +105,25 @@ class Check:
(is_relevant, check, fix, and dry_run).
"""

_registered = {}

def __init__(self, api: GhApi, org: str, repo: str):
self.api = api
self.org_name = org
self.repo_name = repo

@staticmethod
def register(subclass: type[t.Self]) -> type[t.Self]:
"""
Decorate a Check subclass so that it will be available in main()
"""
Check._registered[subclass.__name__] = subclass
return subclass

@staticmethod
def get_registered_checks() -> dict[str, type[t.Self]]:
return Check._registered.copy()

def is_relevant(self) -> bool:
"""
Checks to see if the given check is relevant to run on the
Expand Down Expand Up @@ -152,6 +167,7 @@ def dry_run(self):
raise NotImplementedError


@Check.register
class EnsureRepoSettings(Check):
"""
There are certain settings that we agree we want to be set a specific way on all repos. This check
Expand Down Expand Up @@ -240,6 +256,7 @@ def fix(self, dry_run=False):
return steps


@Check.register
class EnsureNoAdminOrMaintainTeams(Check):
"""
Teams should not be granted `admin` or `maintain` access to a repository unless the access
Expand Down Expand Up @@ -309,6 +326,7 @@ def fix(self, dry_run=False):
return steps


@Check.register
class EnsureWorkflowTemplates(Check):
"""
There are certain github action workflows that we to exist on all
Expand Down Expand Up @@ -594,6 +612,7 @@ def fix(self, dry_run=False):
return steps


@Check.register
class EnsureLabels(Check):
"""
All repos in the org should have certain labels.
Expand Down Expand Up @@ -782,6 +801,7 @@ def fix(self, dry_run=False):
raise


@Check.register
class RequireTriageTeamAccess(RequireTeamPermission):
"""
Ensure that the openedx-triage team grants Triage access to every public repo in the org.
Expand All @@ -797,6 +817,7 @@ def is_relevant(self):
return is_public(self.api, self.org_name, self.repo_name)


@Check.register
class RequiredCLACheck(Check):
"""
This class validates the following:
Expand Down Expand Up @@ -1057,6 +1078,7 @@ def _get_update_params_from_get_branch_protection(self):
return params


@Check.register
class EnsureNoDirectRepoAccessToUsers(Check):
"""
Users should not have direct repo access
Expand Down Expand Up @@ -1114,19 +1136,6 @@ def fix(self, dry_run=False):
return steps


CHECKS = [
RequiredCLACheck,
RequireTriageTeamAccess,
EnsureLabels,
EnsureWorkflowTemplates,
EnsureNoAdminOrMaintainTeams,
EnsureRepoSettings,
EnsureNoDirectRepoAccessToUsers,
]
CHECKS_BY_NAME = {check_cls.__name__: check_cls for check_cls in CHECKS}
CHECKS_BY_NAME_LOWER = {check_cls.__name__.lower(): check_cls for check_cls in CHECKS}


@click.command()
@click.option(
"--github-token",
Expand Down Expand Up @@ -1154,7 +1163,7 @@ def fix(self, dry_run=False):
"check_names",
default=None,
multiple=True,
type=click.Choice(CHECKS_BY_NAME.keys(), case_sensitive=False),
type=click.Choice(Check.get_registered_checks().keys(), case_sensitive=False),
help=f"Limit to specific check(s), case-insensitive.",
)
@click.option(
Expand Down Expand Up @@ -1193,9 +1202,9 @@ def main(org, dry_run, _github_token, check_names, repos, start_at):
click.secho("No Actual Changes Being Made", fg="yellow")

if check_names:
active_checks = [CHECKS_BY_NAME[check_name] for check_name in check_names]
active_checks = [Check.get_registered_checks()[check_name] for check_name in check_names]
else:
active_checks = CHECKS
active_checks = list(Check.get_registered_checks().values())
click.secho(f"The following checks will be run:", fg="magenta", bold=True)
active_checks_string = "\n".join(
"\t" + check_cls.__name__ for check_cls in active_checks
Expand Down

0 comments on commit 165e7d6

Please sign in to comment.