Skip to content

Commit 4a687c0

Browse files
committed
Apply pyupgrade
1 parent 8bac20b commit 4a687c0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+359
-400
lines changed

aesara/compile/function/types.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import time
77
import warnings
88
from itertools import chain
9-
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union
9+
from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Tuple, Type, Union
1010

1111
import numpy as np
12-
from typing_extensions import Literal
1312

1413
import aesara
1514
import aesara.compile.profiling

aesara/compile/mode.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
import logging
77
import warnings
8-
from typing import Optional, Tuple, Union
9-
10-
from typing_extensions import Literal
8+
from typing import Literal, Optional, Tuple, Union
119

1210
from aesara.compile.function.types import Supervisor
1311
from aesara.configdefaults import config

aesara/configparser.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,7 @@ def get_config_hash(self):
125125
)
126126
return hash_from_code(
127127
"\n".join(
128-
[
129-
"{} = {}".format(cv.name, cv.__get__(self, self.__class__))
130-
for cv in all_opts
131-
]
128+
[f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts]
132129
)
133130
)
134131

aesara/gradient.py

+15-25
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Callable,
99
Dict,
1010
List,
11+
Literal,
1112
Mapping,
1213
MutableSequence,
1314
Optional,
@@ -18,7 +19,6 @@
1819
)
1920

2021
import numpy as np
21-
from typing_extensions import Literal
2222

2323
import aesara
2424
from aesara.compile.ops import ViewOp
@@ -91,10 +91,8 @@ def grad_not_implemented(op, x_pos, x, comment=""):
9191

9292
return (
9393
NullType(
94-
(
95-
"This variable is Null because the grad method for "
96-
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
97-
)
94+
"This variable is Null because the grad method for "
95+
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
9896
)
9997
)()
10098

@@ -114,10 +112,8 @@ def grad_undefined(op, x_pos, x, comment=""):
114112

115113
return (
116114
NullType(
117-
(
118-
"This variable is Null because the grad method for "
119-
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
120-
)
115+
"This variable is Null because the grad method for "
116+
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
121117
)
122118
)()
123119

@@ -1270,14 +1266,12 @@ def try_to_copy_if_needed(var):
12701266
# We therefore don't allow it because its usage has become
12711267
# so muddied.
12721268
raise TypeError(
1273-
(
1274-
f"{node.op}.grad returned None for a gradient term, "
1275-
"this is prohibited. Instead of None,"
1276-
"return zeros_like(input), disconnected_type(),"
1277-
" or a NullType variable such as those made with "
1278-
"the grad_undefined or grad_unimplemented helper "
1279-
"functions."
1280-
)
1269+
f"{node.op}.grad returned None for a gradient term, "
1270+
"this is prohibited. Instead of None,"
1271+
"return zeros_like(input), disconnected_type(),"
1272+
" or a NullType variable such as those made with "
1273+
"the grad_undefined or grad_unimplemented helper "
1274+
"functions."
12811275
)
12821276

12831277
# Check that the gradient term for this input
@@ -1396,10 +1390,8 @@ def access_grad_cache(var):
13961390

13971391
if hasattr(var, "ndim") and term.ndim != var.ndim:
13981392
raise ValueError(
1399-
(
1400-
f"{node.op}.grad returned a term with"
1401-
f" {int(term.ndim)} dimensions, but {int(var.ndim)} are required."
1402-
)
1393+
f"{node.op}.grad returned a term with"
1394+
f" {int(term.ndim)} dimensions, but {int(var.ndim)} are required."
14031395
)
14041396

14051397
terms.append(term)
@@ -1761,10 +1753,8 @@ def verify_grad(
17611753
for i, p in enumerate(pt):
17621754
if p.dtype not in ("float16", "float32", "float64"):
17631755
raise TypeError(
1764-
(
1765-
"verify_grad can work only with floating point "
1766-
f'inputs, but input {i} has dtype "{p.dtype}".'
1767-
)
1756+
"verify_grad can work only with floating point "
1757+
f'inputs, but input {i} has dtype "{p.dtype}".'
17681758
)
17691759

17701760
_type_tol = dict( # relative error tolerances for different types

aesara/graph/basic.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1183,11 +1183,9 @@ def clone_replace(
11831183
items = []
11841184
else:
11851185
raise ValueError(
1186-
(
1187-
"replace is neither a dictionary, list, "
1188-
f"tuple or None ! The value provided is {replace},"
1189-
f"of type {type(replace)}"
1190-
)
1186+
"replace is neither a dictionary, list, "
1187+
f"tuple or None ! The value provided is {replace},"
1188+
f"of type {type(replace)}"
11911189
)
11921190
tmp_replace = [(x, x.type()) for x, y in items]
11931191
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]

aesara/graph/fg.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Dict,
99
Iterable,
1010
List,
11+
Literal,
1112
Optional,
1213
Sequence,
1314
Set,
@@ -16,8 +17,6 @@
1617
cast,
1718
)
1819

19-
from typing_extensions import Literal
20-
2120
import aesara
2221
from aesara.configdefaults import config
2322
from aesara.graph.basic import Apply, AtomicVariable, Variable, applys_between

aesara/graph/op.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@
99
Dict,
1010
List,
1111
Optional,
12+
Protocol,
1213
Sequence,
13-
Text,
1414
Tuple,
1515
TypeVar,
1616
Union,
1717
cast,
1818
)
1919

20-
from typing_extensions import Protocol
21-
2220
import aesara
2321
from aesara.configdefaults import config
2422
from aesara.graph.basic import Apply, NoParams, Variable
@@ -496,7 +494,7 @@ def prepare_node(
496494
node: Apply,
497495
storage_map: Optional[StorageMapType],
498496
compute_map: Optional[ComputeMapType],
499-
impl: Optional[Text],
497+
impl: Optional[str],
500498
) -> None:
501499
"""Make any special modifications that the `Op` needs before doing :meth:`Op.make_thunk`.
502500
@@ -573,7 +571,7 @@ def make_thunk(
573571
storage_map: StorageMapType,
574572
compute_map: ComputeMapType,
575573
no_recycling: List[Variable],
576-
impl: Optional[Text] = None,
574+
impl: Optional[str] = None,
577575
) -> ThunkType:
578576
r"""Create a thunk.
579577
@@ -676,7 +674,7 @@ def get_test_value(v: Any) -> Any:
676674
return v.get_test_value()
677675

678676

679-
def missing_test_message(msg: Text) -> None:
677+
def missing_test_message(msg: str) -> None:
680678
"""Display a message saying that some test_value is missing.
681679
682680
This uses the appropriate form based on ``config.compute_test_value``:

aesara/graph/rewriting/basic.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
from itertools import chain
1616
from typing import TYPE_CHECKING, Callable, Dict
1717
from typing import Iterable as IterableType
18-
from typing import List, Optional, Sequence, Tuple, Union, cast
19-
20-
from typing_extensions import Literal
18+
from typing import List, Literal, Optional, Sequence, Tuple, Union, cast
2119

2220
import aesara
2321
from aesara.configdefaults import config
@@ -1189,7 +1187,7 @@ def _find_impl(self, cls) -> List[NodeRewriter]:
11891187
matches.extend(match)
11901188
return matches
11911189

1192-
@functools.lru_cache()
1190+
@functools.lru_cache
11931191
def get_trackers(self, op: Op) -> List[NodeRewriter]:
11941192
"""Get all the rewrites applicable to `op`."""
11951193
return (

aesara/graph/rewriting/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
114114
# We also need to make sure we replace a Variable if it is present in
115115
# `givens`.
116116
vars_replaced = [givens.get(v, v) for v in fgraph.outputs]
117-
o1, o2 = [v.owner for v in vars_replaced]
117+
o1, o2 = (v.owner for v in vars_replaced)
118118
if o1 is None and o2 is None:
119119
# Comparing two single-Variable graphs: they are equal if they are
120120
# the same Variable.

aesara/graph/type.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union
2+
from typing import Any, Generic, Optional, Tuple, TypeVar, Union
33

44
from typing_extensions import TypeAlias
55

@@ -188,7 +188,7 @@ def is_valid_value(self, data: D, strict: bool = True) -> bool:
188188
except (TypeError, ValueError):
189189
return False
190190

191-
def make_variable(self, name: Optional[Text] = None) -> variable_type:
191+
def make_variable(self, name: Optional[str] = None) -> variable_type:
192192
"""Return a new `Variable` instance of this `Type`.
193193
194194
Parameters
@@ -199,7 +199,7 @@ def make_variable(self, name: Optional[Text] = None) -> variable_type:
199199
"""
200200
return self.variable_type(self, None, name=name)
201201

202-
def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type:
202+
def make_constant(self, value: D, name: Optional[str] = None) -> constant_type:
203203
"""Return a new `Constant` instance of this `Type`.
204204
205205
Parameters
@@ -216,7 +216,7 @@ def clone(self, *args, **kwargs) -> "Type":
216216
"""Clone a copy of this type with the given arguments/keyword values, if any."""
217217
return type(self)(*args, **kwargs)
218218

219-
def __call__(self, name: Optional[Text] = None) -> variable_type:
219+
def __call__(self, name: Optional[str] = None) -> variable_type:
220220
"""Return a new `Variable` instance of Type `self`.
221221
222222
Parameters

aesara/graph/utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,7 @@ def __str__(self):
245245
def __str__(self):
246246
return "{}{{{}}}".format(
247247
self.__class__.__name__,
248-
", ".join(
249-
"{}={!r}".format(p, getattr(self, p)) for p in props
250-
),
248+
", ".join(f"{p}={getattr(self, p)!r}" for p in props),
251249
)
252250

253251
dct["__str__"] = __str__

aesara/link/c/basic.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1695,16 +1695,14 @@ def instantiate_code(self, n_args):
16951695
print(" return NULL;", file=code)
16961696
print(" }", file=code)
16971697
print(
1698-
"""\
1698+
f"""\
16991699
PyObject* thunk = PyCapsule_New((void*)(&{struct_name}_executor), NULL, {struct_name}_destructor);
17001700
if (thunk != NULL && PyCapsule_SetContext(thunk, struct_ptr) != 0) {{
17011701
PyErr_Clear();
17021702
Py_DECREF(thunk);
17031703
thunk = NULL;
17041704
}}
1705-
""".format(
1706-
**locals()
1707-
),
1705+
""",
17081706
file=code,
17091707
)
17101708
print(" return thunk; }", file=code)

aesara/link/c/cmodule.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,17 @@
1919
import time
2020
import warnings
2121
from io import BytesIO, StringIO
22-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, cast
22+
from typing import (
23+
TYPE_CHECKING,
24+
Callable,
25+
Dict,
26+
List,
27+
Optional,
28+
Protocol,
29+
Set,
30+
Tuple,
31+
cast,
32+
)
2333

2434
import numpy as np
2535
from setuptools._distutils.sysconfig import (
@@ -28,7 +38,6 @@
2838
get_python_inc,
2939
get_python_lib,
3040
)
31-
from typing_extensions import Protocol
3241

3342
# we will abuse the lockfile mechanism when reading and writing the registry
3443
from aesara.compile.compilelock import lock_ctx
@@ -2157,15 +2166,13 @@ def get_lines(cmd, parse=True):
21572166
if len(default_lines) < 1:
21582167
reported_lines = get_lines(f"{config.cxx} -E -v -", parse=False)
21592168
warnings.warn(
2160-
(
2161-
"Aesara was not able to find the "
2162-
"default g++ parameters. This is needed to tune "
2163-
"the compilation to your specific "
2164-
"CPU. This can slow down the execution of Aesara "
2165-
"functions. Please submit the following lines to "
2166-
"Aesara's mailing list so that we can fix this "
2167-
f"problem:\n {reported_lines}"
2168-
)
2169+
"Aesara was not able to find the "
2170+
"default g++ parameters. This is needed to tune "
2171+
"the compilation to your specific "
2172+
"CPU. This can slow down the execution of Aesara "
2173+
"functions. Please submit the following lines to "
2174+
"Aesara's mailing list so that we can fix this "
2175+
f"problem:\n {reported_lines}"
21692176
)
21702177
else:
21712178
# Some options are actually given as "-option value",
@@ -2248,7 +2255,7 @@ def join_options(init_part):
22482255
if len(version) != 3:
22492256
# Unexpected, but should not be a problem
22502257
continue
2251-
mj, mn, patch = [int(vp) for vp in version]
2258+
mj, mn, patch = (int(vp) for vp in version)
22522259
if (
22532260
((mj, mn) == (4, 6) and patch < 4)
22542261
or ((mj, mn) == (4, 7) and patch <= 3)

0 commit comments

Comments
 (0)