Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pyupgrade as pre-commit hook #1484

Merged
merged 4 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.7", "3.9"]
# Numba doesn't yet support 3.11, so we test primarly on 3.10, and
# below we add an include for 3.11 without Numba.
python-version: ["3.8", "3.10"]
fast-compile: [0]
float32: [0]
install-numba: [1]
Expand All @@ -78,12 +80,12 @@ jobs:
- "tests/tensor/test_elemwise.py tests/tensor/rewriting/test_basic.py tests/tensor/rewriting/test_math.py"
- "tests/tensor/nnet/test_conv.py"
include:
- python-version: "3.7"
- python-version: "3.8"
fast-compile: 1
float32: 1
install-numba: 1
part: "tests --ignore=tests/tensor/nnet --ignore=tests/tensor/signal"
- python-version: "3.7"
- python-version: "3.8"
fast-compile: 1
float32: 0
install-numba: 1
Expand Down
9 changes: 6 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ repos:
aesara/tensor/var\.py|
)$
- id: check-merge-conflict
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: ["--py38-plus"]
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
# NOTE: flake8 v6 requires python >=3.8.1 but
# Aesara still supports python 3.7.
rev: 5.0.4
rev: 6.0.0
hooks:
- id: flake8
- repo: https://github.com/pycqa/isort
Expand Down
3 changes: 1 addition & 2 deletions aesara/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import time
import warnings
from itertools import chain
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Tuple, Type, Union

import numpy as np
from typing_extensions import Literal

import aesara
import aesara.compile.profiling
Expand Down
4 changes: 1 addition & 3 deletions aesara/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

import logging
import warnings
from typing import Optional, Tuple, Union

from typing_extensions import Literal
from typing import Literal, Optional, Tuple, Union

from aesara.compile.function.types import Supervisor
from aesara.configdefaults import config
Expand Down
5 changes: 1 addition & 4 deletions aesara/configparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,7 @@ def get_config_hash(self):
)
return hash_from_code(
"\n".join(
[
"{} = {}".format(cv.name, cv.__get__(self, self.__class__))
for cv in all_opts
]
[f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts]
)
)

Expand Down
40 changes: 15 additions & 25 deletions aesara/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Callable,
Dict,
List,
Literal,
Mapping,
MutableSequence,
Optional,
Expand All @@ -18,7 +19,6 @@
)

import numpy as np
from typing_extensions import Literal

import aesara
from aesara.compile.ops import ViewOp
Expand Down Expand Up @@ -91,10 +91,8 @@ def grad_not_implemented(op, x_pos, x, comment=""):

return (
NullType(
(
"This variable is Null because the grad method for "
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
)
"This variable is Null because the grad method for "
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
)
)()

Expand All @@ -114,10 +112,8 @@ def grad_undefined(op, x_pos, x, comment=""):

return (
NullType(
(
"This variable is Null because the grad method for "
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
)
"This variable is Null because the grad method for "
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
)
)()

Expand Down Expand Up @@ -1270,14 +1266,12 @@ def try_to_copy_if_needed(var):
# We therefore don't allow it because its usage has become
# so muddied.
raise TypeError(
(
f"{node.op}.grad returned None for a gradient term, "
"this is prohibited. Instead of None,"
"return zeros_like(input), disconnected_type(),"
" or a NullType variable such as those made with "
"the grad_undefined or grad_unimplemented helper "
"functions."
)
f"{node.op}.grad returned None for a gradient term, "
"this is prohibited. Instead of None,"
"return zeros_like(input), disconnected_type(),"
" or a NullType variable such as those made with "
"the grad_undefined or grad_unimplemented helper "
"functions."
)

# Check that the gradient term for this input
Expand Down Expand Up @@ -1396,10 +1390,8 @@ def access_grad_cache(var):

if hasattr(var, "ndim") and term.ndim != var.ndim:
raise ValueError(
(
f"{node.op}.grad returned a term with"
f" {int(term.ndim)} dimensions, but {int(var.ndim)} are required."
)
f"{node.op}.grad returned a term with"
f" {int(term.ndim)} dimensions, but {int(var.ndim)} are required."
)

terms.append(term)
Expand Down Expand Up @@ -1761,10 +1753,8 @@ def verify_grad(
for i, p in enumerate(pt):
if p.dtype not in ("float16", "float32", "float64"):
raise TypeError(
(
"verify_grad can work only with floating point "
f'inputs, but input {i} has dtype "{p.dtype}".'
)
"verify_grad can work only with floating point "
f'inputs, but input {i} has dtype "{p.dtype}".'
)

_type_tol = dict( # relative error tolerances for different types
Expand Down
8 changes: 3 additions & 5 deletions aesara/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,11 +1183,9 @@ def clone_replace(
items = []
else:
raise ValueError(
(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
Expand Down
3 changes: 1 addition & 2 deletions aesara/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Set,
Expand All @@ -16,8 +17,6 @@
cast,
)

from typing_extensions import Literal

import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, AtomicVariable, Variable, applys_between
Expand Down
10 changes: 4 additions & 6 deletions aesara/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@
Dict,
List,
Optional,
Protocol,
Sequence,
Text,
Tuple,
TypeVar,
Union,
cast,
)

from typing_extensions import Protocol

import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, NoParams, Variable
Expand Down Expand Up @@ -496,7 +494,7 @@ def prepare_node(
node: Apply,
storage_map: Optional[StorageMapType],
compute_map: Optional[ComputeMapType],
impl: Optional[Text],
impl: Optional[str],
) -> None:
"""Make any special modifications that the `Op` needs before doing :meth:`Op.make_thunk`.

Expand Down Expand Up @@ -573,7 +571,7 @@ def make_thunk(
storage_map: StorageMapType,
compute_map: ComputeMapType,
no_recycling: List[Variable],
impl: Optional[Text] = None,
impl: Optional[str] = None,
) -> ThunkType:
r"""Create a thunk.

Expand Down Expand Up @@ -676,7 +674,7 @@ def get_test_value(v: Any) -> Any:
return v.get_test_value()


def missing_test_message(msg: Text) -> None:
def missing_test_message(msg: str) -> None:
"""Display a message saying that some test_value is missing.

This uses the appropriate form based on ``config.compute_test_value``:
Expand Down
6 changes: 2 additions & 4 deletions aesara/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
from itertools import chain
from typing import TYPE_CHECKING, Callable, Dict
from typing import Iterable as IterableType
from typing import List, Optional, Sequence, Tuple, Union, cast

from typing_extensions import Literal
from typing import List, Literal, Optional, Sequence, Tuple, Union, cast

import aesara
from aesara.configdefaults import config
Expand Down Expand Up @@ -1189,7 +1187,7 @@ def _find_impl(self, cls) -> List[NodeRewriter]:
matches.extend(match)
return matches

@functools.lru_cache()
@functools.lru_cache
def get_trackers(self, op: Op) -> List[NodeRewriter]:
"""Get all the rewrites applicable to `op`."""
return (
Expand Down
2 changes: 1 addition & 1 deletion aesara/graph/rewriting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
# We also need to make sure we replace a Variable if it is present in
# `givens`.
vars_replaced = [givens.get(v, v) for v in fgraph.outputs]
o1, o2 = [v.owner for v in vars_replaced]
o1, o2 = (v.owner for v in vars_replaced)
if o1 is None and o2 is None:
# Comparing two single-Variable graphs: they are equal if they are
# the same Variable.
Expand Down
8 changes: 4 additions & 4 deletions aesara/graph/type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union
from typing import Any, Generic, Optional, Tuple, TypeVar, Union

from typing_extensions import TypeAlias

Expand Down Expand Up @@ -188,7 +188,7 @@ def is_valid_value(self, data: D, strict: bool = True) -> bool:
except (TypeError, ValueError):
return False

def make_variable(self, name: Optional[Text] = None) -> variable_type:
def make_variable(self, name: Optional[str] = None) -> variable_type:
"""Return a new `Variable` instance of this `Type`.

Parameters
Expand All @@ -199,7 +199,7 @@ def make_variable(self, name: Optional[Text] = None) -> variable_type:
"""
return self.variable_type(self, None, name=name)

def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type:
def make_constant(self, value: D, name: Optional[str] = None) -> constant_type:
"""Return a new `Constant` instance of this `Type`.

Parameters
Expand All @@ -216,7 +216,7 @@ def clone(self, *args, **kwargs) -> "Type":
"""Clone a copy of this type with the given arguments/keyword values, if any."""
return type(self)(*args, **kwargs)

def __call__(self, name: Optional[Text] = None) -> variable_type:
def __call__(self, name: Optional[str] = None) -> variable_type:
"""Return a new `Variable` instance of Type `self`.

Parameters
Expand Down
4 changes: 1 addition & 3 deletions aesara/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,7 @@ def __str__(self):
def __str__(self):
return "{}{{{}}}".format(
self.__class__.__name__,
", ".join(
"{}={!r}".format(p, getattr(self, p)) for p in props
),
", ".join(f"{p}={getattr(self, p)!r}" for p in props),
)

dct["__str__"] = __str__
Expand Down
6 changes: 2 additions & 4 deletions aesara/link/c/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,16 +1695,14 @@ def instantiate_code(self, n_args):
print(" return NULL;", file=code)
print(" }", file=code)
print(
"""\
f"""\
PyObject* thunk = PyCapsule_New((void*)(&{struct_name}_executor), NULL, {struct_name}_destructor);
if (thunk != NULL && PyCapsule_SetContext(thunk, struct_ptr) != 0) {{
PyErr_Clear();
Py_DECREF(thunk);
thunk = NULL;
}}
""".format(
**locals()
),
""",
file=code,
)
print(" return thunk; }", file=code)
Expand Down
Loading