Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pre-commit for CI style checks. #3062

Merged
merged 11 commits into from
Dec 15, 2022
29 changes: 23 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,45 @@
# To run: `pre-commit run --all-files`
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: debug-statements
- id: mixed-line-ending
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 22.10.0
hooks:
- id: black
language_version: python3
exclude: versioneer.py
args: [--target-version=py38]
files: ^python/
- repo: https://github.com/PyCQA/flake8
rev: 3.8.4
rev: 6.0.0
hooks:
- id: flake8
args: [--config=python/.flake8]
files: ^python/
args: ["--config=setup.cfg"]
files: python/.*$
types: [file]
types_or: [python] # TODO: Enable [python, cython]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will enable flake8 on Cython files in a follow-up PR. Commit 2e4f34f is what I'll put in that PR.

additional_dependencies: ["flake8-force"]
- repo: https://github.com/asottile/yesqa
rev: v1.3.0
hooks:
- id: yesqa
additional_dependencies:
- flake8==3.8.4
- flake8==6.0.0
# - repo: https://github.com/pre-commit/mirrors-clang-format
# rev: v11.1.0
# hooks:
# - id: clang-format
# types_or: [c, c++, cuda]
# args: ["-fallback-style=none", "-style=file", "-i"]
- repo: local
hooks:
- id: copyright-check
name: copyright-check
entry: python ./ci/checks/copyright.py --git-modified-only --update-current-year
language: python
pass_filenames: false
additional_dependencies: [gitpython]
217 changes: 127 additions & 90 deletions ci/checks/copyright.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file was updated to better support pre-commit (I copied these changes from cudf). It also supports autofixing now, so if the style checks fail, it will tell you exactly which files need to change and what they should say.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,46 +13,43 @@
# limitations under the License.
#

import datetime
import re
import argparse
import io
import datetime
import os
import re
import sys

SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))

# Add the scripts dir for gitutils
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR,
"../../cpp/scripts")))

# Now import gitutils. Ignore flake8 error here since there is no other way to
# set up imports
import gitutils # noqa: E402
import git

FilesToCheck = [
re.compile(r"[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$"),
re.compile(r"CMakeLists[.]txt$"),
re.compile(r"CMakeLists_standalone[.]txt$"),
re.compile(r"setup[.]cfg$"),
re.compile(r"[.]flake8[.]cython$"),
re.compile(r"meta[.]yaml$")
re.compile(r"meta[.]yaml$"),
]
ExemptFiles = [
re.compile(r"cpp/include/cudf_test/cxxopts.hpp"),
re.compile(r"versioneer[.]py"),
]
ExemptFiles = ['versioneer.py']

# this will break starting at year 10000, which is probably OK :)
CheckSimple = re.compile(
r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)")
r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)"
)
CheckDouble = re.compile(
r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501
)


def checkThisFile(f):
# This check covers things like symlinks which point to files that DNE
if not (os.path.exists(f)):
return False
if gitutils and gitutils.isFileEmpty(f):
if isinstance(f, git.Diff):
if f.deleted_file or f.b_blob.size == 0:
return False
f = f.b_path
elif not os.path.exists(f) or os.stat(f).st_size == 0:
# This check covers things like symlinks which point to files that DNE
return False
for exempt in ExemptFiles:
if exempt.search(f):
Expand All @@ -63,36 +60,90 @@ def checkThisFile(f):
return False


def modifiedFiles():
"""Get a set of all modified files, as Diff objects.

The files returned have been modified in git since the merge base of HEAD
and the upstream of the target branch. We return the Diff objects so that
we can read only the staged changes.
"""
repo = git.Repo()
# Use the environment variable TARGET_BRANCH (defined in CI) if possible
target_branch = os.environ.get("TARGET_BRANCH")
if target_branch is None:
# Fall back to the closest branch if not on CI
target_branch = repo.git.describe(
all=True, tags=True, match="branch-*", abbrev=0
).lstrip("heads/")

upstream_target_branch = None
if target_branch in repo.heads:
# Use the tracking branch of the local reference if it exists. This
# returns None if no tracking branch is set.
upstream_target_branch = repo.heads[target_branch].tracking_branch()
if upstream_target_branch is None:
# Fall back to the remote with the newest target_branch. This code
# path is used on CI because the only local branch reference is
# current-pr-branch, and thus target_branch is not in repo.heads.
# This also happens if no tracking branch is defined for the local
# target_branch. We use the remote with the latest commit if
# multiple remotes are defined.
candidate_branches = [
remote.refs[target_branch] for remote in repo.remotes
if target_branch in remote.refs
]
if len(candidate_branches) > 0:
upstream_target_branch = sorted(
candidate_branches,
key=lambda branch: branch.commit.committed_datetime,
)[-1]
else:
# If no remotes are defined, try to use the local version of the
# target_branch. If this fails, the repo configuration must be very
# strange and we can fix this script on a case-by-case basis.
upstream_target_branch = repo.heads[target_branch]
merge_base = repo.merge_base("HEAD", upstream_target_branch.commit)[0]
diff = merge_base.diff()
changed_files = {f for f in diff if f.b_path is not None}
return changed_files


def getCopyrightYears(line):
res = CheckSimple.search(line)
if res:
return (int(res.group(1)), int(res.group(1)))
return int(res.group(1)), int(res.group(1))
res = CheckDouble.search(line)
if res:
return (int(res.group(1)), int(res.group(2)))
return (None, None)
return int(res.group(1)), int(res.group(2))
return None, None


def replaceCurrentYear(line, start, end):
# first turn a simple regex into double (if applicable). then update years
res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line)
res = CheckDouble.sub(
r"Copyright (c) {:04d}-{:04d}, NVIDIA CORPORATION".format(start, end),
res)
rf"Copyright (c) {start:04d}-{end:04d}, NVIDIA CORPORATION",
res,
)
return res


def checkCopyright(f, update_current_year):
"""
Checks for copyright headers and their years
"""
"""Checks for copyright headers and their years."""
errs = []
thisYear = datetime.datetime.now().year
lineNum = 0
crFound = False
yearMatched = False
with io.open(f, "r", encoding="utf-8") as fp:
lines = fp.readlines()

if isinstance(f, git.Diff):
path = f.b_path
lines = f.b_blob.data_stream.read().decode().splitlines(keepends=True)
else:
path = f
with open(f, encoding="utf-8") as fp:
lines = fp.readlines()

for line in lines:
lineNum += 1
start, end = getCopyrightYears(line)
Expand All @@ -101,20 +152,19 @@ def checkCopyright(f, update_current_year):
crFound = True
if start > end:
e = [
f,
path,
lineNum,
"First year after second year in the copyright "
"header (manual fix required)",
None
None,
]
errs.append(e)
if thisYear < start or thisYear > end:
elif thisYear < start or thisYear > end:
e = [
f,
path,
lineNum,
"Current year not included in the "
"copyright header",
None
"Current year not included in the copyright header",
None,
]
if thisYear < start:
e[-1] = replaceCurrentYear(line, thisYear, end)
Expand All @@ -123,15 +173,14 @@ def checkCopyright(f, update_current_year):
errs.append(e)
else:
yearMatched = True
fp.close()
# copyright header itself not found
if not crFound:
e = [
f,
path,
0,
"Copyright header missing or formatted incorrectly "
"(manual fix required)",
None
None,
]
errs.append(e)
# even if the year matches a copyright header, make the check pass
Expand All @@ -141,21 +190,19 @@ def checkCopyright(f, update_current_year):
if update_current_year:
errs_update = [x for x in errs if x[-1] is not None]
if len(errs_update) > 0:
print("File: {}. Changing line(s) {}".format(
f, ', '.join(str(x[1]) for x in errs if x[-1] is not None)))
lines_changed = ", ".join(str(x[1]) for x in errs_update)
print(f"File: {path}. Changing line(s) {lines_changed}")
for _, lineNum, __, replacement in errs_update:
lines[lineNum - 1] = replacement
with io.open(f, "w", encoding="utf-8") as out_file:
for new_line in lines:
out_file.write(new_line)
errs = [x for x in errs if x[-1] is None]
with open(path, "w", encoding="utf-8") as out_file:
out_file.writelines(lines)

return errs


def getAllFilesUnderDir(root, pathFilter=None):
retList = []
for (dirpath, dirnames, filenames) in os.walk(root):
for dirpath, dirnames, filenames in os.walk(root):
for fn in filenames:
filePath = os.path.join(dirpath, fn)
if pathFilter(filePath):
Expand All @@ -170,47 +217,37 @@ def checkCopyright_main():
it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch"
"""
retVal = 0
global ExemptFiles

argparser = argparse.ArgumentParser(
"Checks for a consistent copyright header in git's modified files")
argparser.add_argument("--update-current-year",
dest='update_current_year',
action="store_true",
required=False,
help="If set, "
"update the current year if a header "
"is already present and well formatted.")
argparser.add_argument("--git-modified-only",
dest='git_modified_only',
action="store_true",
required=False,
help="If set, "
"only files seen as modified by git will be "
"processed.")
argparser.add_argument("--exclude",
dest='exclude',
action="append",
required=False,
default=["python/cuml/_thirdparty/"],
help=("Exclude the paths specified (regexp). "
"Can be specified multiple times."))

(args, dirs) = argparser.parse_known_args()
try:
ExemptFiles = ExemptFiles + [pathName for pathName in args.exclude]
ExemptFiles = [re.compile(file) for file in ExemptFiles]
except re.error as reException:
print("Regular expression error:")
print(reException)
return 1
"Checks for a consistent copyright header in git's modified files"
)
argparser.add_argument(
"--update-current-year",
dest="update_current_year",
action="store_true",
required=False,
help="If set, "
"update the current year if a header is already "
"present and well formatted.",
)
argparser.add_argument(
"--git-modified-only",
dest="git_modified_only",
action="store_true",
required=False,
help="If set, "
"only files seen as modified by git will be "
"processed.",
)

args, dirs = argparser.parse_known_args()

if args.git_modified_only:
files = gitutils.modifiedFiles(pathFilter=checkThisFile)
files = [f for f in modifiedFiles() if checkThisFile(f)]
else:
files = []
for d in [os.path.abspath(d) for d in dirs]:
if not (os.path.isdir(d)):
if not os.path.isdir(d):
raise ValueError(f"{d} is not a directory.")
files += getAllFilesUnderDir(d, pathFilter=checkThisFile)

Expand All @@ -219,24 +256,24 @@ def checkCopyright_main():
errors += checkCopyright(f, args.update_current_year)

if len(errors) > 0:
print("Copyright headers incomplete in some of the files!")
if any(e[-1] is None for e in errors):
print("Copyright headers incomplete in some of the files!")
for e in errors:
print(" %s:%d Issue: %s" % (e[0], e[1], e[2]))
print("")
n_fixable = sum(1 for e in errors if e[-1] is not None)
path_parts = os.path.abspath(__file__).split(os.sep)
file_from_repo = os.sep.join(path_parts[path_parts.index("ci"):])
if n_fixable > 0:
print(("You can run `python {} --git-modified-only "
"--update-current-year` to fix {} of these "
"errors.\n").format(file_from_repo, n_fixable))
file_from_repo = os.sep.join(path_parts[path_parts.index("ci") :])
if n_fixable > 0 and not args.update_current_year:
print(
f"You can run `python {file_from_repo} --git-modified-only "
"--update-current-year` and stage the results in git to "
f"fix {n_fixable} of these errors.\n"
)
retVal = 1
else:
print("Copyright check passed")

return retVal


if __name__ == "__main__":
import sys
sys.exit(checkCopyright_main())
Loading