Skip to content

Commit

Permalink
Display notice when copyrighted file is copied or renamed (#52)
Browse files Browse the repository at this point in the history
Also proposes requiring Python 3.10, since
rapidsai/build-planning#88 is being rolled
out.

Fixes #48
  • Loading branch information
KyleFromNVIDIA authored Aug 27, 2024
1 parent cca7f87 commit 80e3e1b
Show file tree
Hide file tree
Showing 5 changed files with 593 additions and 168 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ jobs:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'
- name: Build & Test
run: ./ci/build-test.sh
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
]
requires-python = ">=3.9"
requires-python = ">=3.10"
dependencies = [
"PyYAML",
"bashlex",
Expand Down
94 changes: 71 additions & 23 deletions src/rapids_pre_commit_hooks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import git

from .lint import Linter, LintMain
from .lint import Linter, LintMain, LintWarning

COPYRIGHT_RE: re.Pattern = re.compile(
r"Copyright *(?:\(c\))? *(?P<years>(?P<first_year>\d{4})(-(?P<last_year>\d{4}))?),?"
Expand Down Expand Up @@ -59,21 +59,59 @@ def append_stripped(start: int, item: re.Match):
return lines


def add_copy_rename_note(
linter: Linter,
warning: LintWarning,
change_type: str,
old_filename: Optional[Union[str, os.PathLike[str]]],
):
CHANGE_VERBS = {
"C": "copied",
"R": "renamed",
}
try:
change_verb = CHANGE_VERBS[change_type]
except KeyError:
pass
else:
warning.add_note(
(0, len(linter.content)),
f"file was {change_verb} from '{old_filename}' and is assumed to share "
"history with it",
)
warning.add_note(
(0, len(linter.content)),
"change file contents if you want its copyright dates to only be "
"determined by its own edit history",
)


def apply_copyright_revert(
linter: Linter, old_match: re.Match, new_match: re.Match
linter: Linter,
change_type: str,
old_filename: Optional[Union[str, os.PathLike[str]]],
old_match: re.Match,
new_match: re.Match,
) -> None:
if old_match.group("years") == new_match.group("years"):
warning_pos = new_match.span()
else:
warning_pos = new_match.span("years")
linter.add_warning(
w = linter.add_warning(
warning_pos,
"copyright is not out of date and should not be updated",
).add_replacement(new_match.span(), old_match.group())
)
w.add_replacement(new_match.span(), old_match.group())
add_copy_rename_note(linter, w, change_type, old_filename)


def apply_copyright_update(linter: Linter, match: re.Match, year: int) -> None:
linter.add_warning(match.span("years"), "copyright is out of date").add_replacement(
def apply_copyright_update(
linter: Linter,
match: re.Match,
year: int,
) -> None:
w = linter.add_warning(match.span("years"), "copyright is out of date")
w.add_replacement(
match.span(),
COPYRIGHT_REPLACEMENT.format(
first_year=match.group("first_year"),
Expand All @@ -82,7 +120,12 @@ def apply_copyright_update(linter: Linter, match: re.Match, year: int) -> None:
)


def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None:
def apply_copyright_check(
linter: Linter,
change_type: str,
old_filename: Optional[Union[str, os.PathLike[str]]],
old_content: Optional[str],
) -> None:
if linter.content != old_content:
current_year = datetime.datetime.now().year
new_copyright_matches = match_copyright(linter.content)
Expand All @@ -97,7 +140,9 @@ def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None:
old_copyright_matches, new_copyright_matches
):
if old_match.group() != new_match.group():
apply_copyright_revert(linter, old_match, new_match)
apply_copyright_revert(
linter, change_type, old_filename, old_match, new_match
)
elif new_copyright_matches:
for match in new_copyright_matches:
if (
Expand Down Expand Up @@ -233,22 +278,24 @@ def try_get_ref(remote: "git.Remote") -> Optional["git.Reference"]:

def get_changed_files(
args: argparse.Namespace,
) -> dict[Union[str, os.PathLike[str]], Optional["git.Blob"]]:
) -> dict[Union[str, os.PathLike[str]], tuple[str, Optional["git.Blob"]]]:
try:
repo = git.Repo()
except git.InvalidGitRepositoryError:
return {
os.path.relpath(os.path.join(dirpath, filename), "."): None
os.path.relpath(os.path.join(dirpath, filename), "."): ("A", None)
for dirpath, dirnames, filenames in os.walk(".")
for filename in filenames
}

changed_files: dict[Union[str, os.PathLike[str]], Optional["git.Blob"]] = {
f: None for f in repo.untracked_files
}
changed_files: dict[
Union[str, os.PathLike[str]], tuple[str, Optional["git.Blob"]]
] = {f: ("A", None) for f in repo.untracked_files}
target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args)
if target_branch_upstream_commit is None:
changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()})
changed_files.update(
{blob.path: ("A", None) for _, blob in repo.index.iter_blobs()}
)
return changed_files

for merge_base in repo.merge_base(
Expand All @@ -262,9 +309,9 @@ def get_changed_files(
)
for diff in diffs:
if diff.change_type == "A":
changed_files[diff.b_path] = None
changed_files[diff.b_path] = (diff.change_type, None)
elif diff.change_type != "D":
changed_files[diff.b_path] = diff.a_blob
changed_files[diff.b_path] = (diff.change_type, diff.a_blob)

return changed_files

Expand Down Expand Up @@ -313,16 +360,17 @@ def the_check(linter: Linter, args: argparse.Namespace):
return

try:
changed_file = changed_files[git_filename]
change_type, changed_file = changed_files[git_filename]
except KeyError:
return

old_content = (
changed_file.data_stream.read().decode()
if changed_file is not None
else None
)
apply_copyright_check(linter, old_content)
if changed_file is None:
old_filename = None
old_content = None
else:
old_filename = changed_file.path
old_content = changed_file.data_stream.read().decode()
apply_copyright_check(linter, change_type, old_filename, old_content)

return the_check

Expand Down
22 changes: 5 additions & 17 deletions src/rapids_pre_commit_hooks/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,13 @@
import functools
import re
import warnings
from collections.abc import Callable, Generator, Iterable
from collections.abc import Callable
from itertools import pairwise
from typing import Optional

from rich.console import Console
from rich.markup import escape


# Taken from Python docs
# (https://docs.python.org/3.12/library/itertools.html#itertools.pairwise)
# Replace with itertools.pairwise after dropping Python 3.9 support
def _pairwise(iterable: Iterable) -> Generator:
# pairwise('ABCDEFG') → AB BC CD DE EF FG
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b


_PosType = tuple[int, int]


Expand Down Expand Up @@ -66,9 +54,9 @@ class LintWarning:
pos: _PosType
msg: str
replacements: list[Replacement] = dataclasses.field(
default_factory=list, init=False
default_factory=list, kw_only=True
)
notes: list[Note] = dataclasses.field(default_factory=list, init=False)
notes: list[Note] = dataclasses.field(default_factory=list, kw_only=True)

def add_replacement(self, pos: _PosType, newtext: str) -> None:
self.replacements.append(Replacement(pos, newtext))
Expand Down Expand Up @@ -102,7 +90,7 @@ def fix(self) -> str:
key=lambda replacement: replacement.pos,
)

for r1, r2 in _pairwise(sorted_replacements):
for r1, r2 in pairwise(sorted_replacements):
if r1.pos[1] > r2.pos[0]:
raise OverlappingReplacementsError(f"{r1} overlaps with {r2}")

Expand Down
Loading

0 comments on commit 80e3e1b

Please sign in to comment.