8
8
Callable ,
9
9
Dict ,
10
10
List ,
11
+ Literal ,
11
12
Mapping ,
12
13
MutableSequence ,
13
14
Optional ,
18
19
)
19
20
20
21
import numpy as np
21
- from typing_extensions import Literal
22
22
23
23
import aesara
24
24
from aesara .compile .ops import ViewOp
@@ -91,10 +91,8 @@ def grad_not_implemented(op, x_pos, x, comment=""):
91
91
92
92
return (
93
93
NullType (
94
- (
95
- "This variable is Null because the grad method for "
96
- f"input { x_pos } ({ x } ) of the { op } op is not implemented. { comment } "
97
- )
94
+ "This variable is Null because the grad method for "
95
+ f"input { x_pos } ({ x } ) of the { op } op is not implemented. { comment } "
98
96
)
99
97
)()
100
98
@@ -114,10 +112,8 @@ def grad_undefined(op, x_pos, x, comment=""):
114
112
115
113
return (
116
114
NullType (
117
- (
118
- "This variable is Null because the grad method for "
119
- f"input { x_pos } ({ x } ) of the { op } op is not implemented. { comment } "
120
- )
115
+ "This variable is Null because the grad method for "
116
+ f"input { x_pos } ({ x } ) of the { op } op is not implemented. { comment } "
121
117
)
122
118
)()
123
119
@@ -1270,14 +1266,12 @@ def try_to_copy_if_needed(var):
1270
1266
# We therefore don't allow it because its usage has become
1271
1267
# so muddied.
1272
1268
raise TypeError (
1273
- (
1274
- f"{ node .op } .grad returned None for a gradient term, "
1275
- "this is prohibited. Instead of None,"
1276
- "return zeros_like(input), disconnected_type(),"
1277
- " or a NullType variable such as those made with "
1278
- "the grad_undefined or grad_unimplemented helper "
1279
- "functions."
1280
- )
1269
+ f"{ node .op } .grad returned None for a gradient term, "
1270
+ "this is prohibited. Instead of None,"
1271
+ "return zeros_like(input), disconnected_type(),"
1272
+ " or a NullType variable such as those made with "
1273
+ "the grad_undefined or grad_unimplemented helper "
1274
+ "functions."
1281
1275
)
1282
1276
1283
1277
# Check that the gradient term for this input
@@ -1396,10 +1390,8 @@ def access_grad_cache(var):
1396
1390
1397
1391
if hasattr (var , "ndim" ) and term .ndim != var .ndim :
1398
1392
raise ValueError (
1399
- (
1400
- f"{ node .op } .grad returned a term with"
1401
- f" { int (term .ndim )} dimensions, but { int (var .ndim )} are required."
1402
- )
1393
+ f"{ node .op } .grad returned a term with"
1394
+ f" { int (term .ndim )} dimensions, but { int (var .ndim )} are required."
1403
1395
)
1404
1396
1405
1397
terms .append (term )
@@ -1761,10 +1753,8 @@ def verify_grad(
1761
1753
for i , p in enumerate (pt ):
1762
1754
if p .dtype not in ("float16" , "float32" , "float64" ):
1763
1755
raise TypeError (
1764
- (
1765
- "verify_grad can work only with floating point "
1766
- f'inputs, but input { i } has dtype "{ p .dtype } ".'
1767
- )
1756
+ "verify_grad can work only with floating point "
1757
+ f'inputs, but input { i } has dtype "{ p .dtype } ".'
1768
1758
)
1769
1759
1770
1760
_type_tol = dict ( # relative error tolerances for different types
0 commit comments