Skip to content

Commit

Permalink
Update pre-commit config
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 21, 2023
1 parent 5268d20 commit 1da48b5
Show file tree
Hide file tree
Showing 168 changed files with 75 additions and 562 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
)$
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 22.12.0
rev: 23.1.0
hooks:
- id: black
language_version: python3
Expand Down Expand Up @@ -47,7 +47,7 @@ repos:
)$
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
rev: v1.1.1
hooks:
- id: mypy
additional_dependencies:
Expand Down
2 changes: 0 additions & 2 deletions aesara/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(self, name):
self.name = name

def make_node(self, condition, *monitored_vars):

# Ensure that condition is an Aesara tensor
if not isinstance(condition, Variable):
condition = as_tensor_variable(condition)
Expand Down Expand Up @@ -150,7 +149,6 @@ def infer_shape(self, fgraph, inputs, input_shapes):
return input_shapes[1:]

def connection_pattern(self, node):

nb_inp = len(node.inputs)
nb_out = nb_inp - 1

Expand Down
1 change: 0 additions & 1 deletion aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,6 @@ def connection_pattern(self, node):
return list(map(list, cpmat_self))

def infer_shape(self, fgraph, node, shapes):

# TODO: Use `fgraph.shape_feature` to do this instead.
out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes)

Expand Down
7 changes: 1 addition & 6 deletions aesara/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,6 @@ def _check_viewmap(fgraph, node, storage_map):
"""

for oi, onode in enumerate(node.outputs):

good_alias, bad_alias = {}, {}
outstorage = storage_map[onode][0]

Expand All @@ -590,13 +589,11 @@ def _check_viewmap(fgraph, node, storage_map):
if hasattr(inode.type, "may_share_memory") and inode.type.may_share_memory(
outstorage, in_storage
):

nodeid = id(inode)
bad_alias[nodeid] = ii

# check that the aliasing was declared in [view|destroy]_map
if [ii] == view_map.get(oi, None) or [ii] == destroy_map.get(oi, None):

good_alias[nodeid] = bad_alias.pop(nodeid)

# TODO: make sure this is correct
Expand Down Expand Up @@ -1010,7 +1007,7 @@ def _check_preallocated_output(
aliased_inputs.add(r)

_logger.debug("starting preallocated output checking")
for (name, out_map) in _get_preallocated_maps(
for name, out_map in _get_preallocated_maps(
node,
thunk,
prealloc_modes,
Expand Down Expand Up @@ -1180,7 +1177,6 @@ class _VariableEquivalenceTracker(Feature):
"""

def on_attach(self, fgraph):

if hasattr(fgraph, "_eq_tracker_equiv"):
raise AlreadyThere()

Expand Down Expand Up @@ -1675,7 +1671,6 @@ def f():
sys.stdout.flush()

if thunk_c:

clobber = True
if thunk_py:
dmap = node.op.destroy_map
Expand Down
3 changes: 1 addition & 2 deletions aesara/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def clone_inputs(i):
# Fill update_d and update_expr with provided updates
if updates is None:
updates = []
for (store_into, update_val) in iter_over_pairs(updates):
for store_into, update_val in iter_over_pairs(updates):
if not isinstance(store_into, SharedVariable):
raise TypeError("update target must be a SharedVariable", store_into)
if store_into in update_d:
Expand Down Expand Up @@ -471,7 +471,6 @@ def construct_pfunc_ins_and_outs(
)

if not fgraph:

# Extend the outputs with the updates on input variables so they are
# also cloned
additional_outputs = [i.update for i in inputs if i.update]
Expand Down
9 changes: 1 addition & 8 deletions aesara/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def copy(
aesara.Function
Copied aesara.Function
"""

# helper function
def checkSV(sv_ori, sv_rpl):
"""
Expand Down Expand Up @@ -761,7 +762,6 @@ def checkSV(sv_ori, sv_rpl):
for in_ori, in_cpy, ori, cpy in zip(
maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage
):

# Share immutable ShareVariable and constant input's storage
swapped = swap is not None and in_ori.variable in swap

Expand Down Expand Up @@ -911,7 +911,6 @@ def restore_defaults():
if hasattr(i_var.type, "may_share_memory"):
is_aliased = False
for j in range(len(args_share_memory)):

group_j = zip(
[
self.maker.inputs[k].variable
Expand All @@ -929,7 +928,6 @@ def restore_defaults():
)
for (var, val) in group_j
):

is_aliased = True
args_share_memory[j].append(i)
break
Expand Down Expand Up @@ -1057,9 +1055,7 @@ def restore_defaults():
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
return outputs[0]
else:

if self.output_keys is not None:

assert len(self.output_keys) == len(outputs)

if output_subset is None:
Expand Down Expand Up @@ -1452,7 +1448,6 @@ def prepare_fgraph(
update = fgraph_outputs[out_idx]

if update.owner and update.owner.op == update_placeholder:

# TODO: Consider removing the corresponding
# `FunctionGraph` input when it has no other
# references?
Expand All @@ -1479,7 +1474,6 @@ def prepare_fgraph(
# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:

# If the rewriter got interrupted
if rewrite_time is None:
end_rewriter = time.perf_counter()
Expand Down Expand Up @@ -1658,7 +1652,6 @@ def create(self, input_storage=None, trustme=False, storage_map=None):
for i, ((input, indices, subinputs), input_storage_i) in enumerate(
zip(self.indices, input_storage)
):

# Replace any default value given as a variable by its
# container. Note that this makes sense only in the
# context of shared variables, but for now we avoid
Expand Down
17 changes: 8 additions & 9 deletions aesara/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def _atexit_print_fn():
destination_file = config.profiling__destination

with extended_open(destination_file, mode="w"):

# Reverse sort in the order of compile+exec time
for ps in sorted(
_atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time
Expand Down Expand Up @@ -358,7 +357,7 @@ def class_impl(self):
"""
# timing is stored by node, we compute timing by class on demand
rval = {}
for (fgraph, node) in self.apply_callcount:
for fgraph, node in self.apply_callcount:
typ = type(node.op)
if self.apply_cimpl[node]:
impl = "C "
Expand Down Expand Up @@ -401,7 +400,7 @@ def compute_total_times(self):
"""
rval = {}
for (fgraph, node) in self.apply_time:
for fgraph, node in self.apply_time:
if node not in rval:
self.fill_node_total_time(fgraph, node, rval)
return rval
Expand Down Expand Up @@ -437,7 +436,7 @@ def op_impl(self):
"""
# timing is stored by node, we compute timing by Op on demand
rval = {}
for (fgraph, node) in self.apply_callcount:
for fgraph, node in self.apply_callcount:
if self.apply_cimpl[node]:
rval[node.op] = "C "
else:
Expand Down Expand Up @@ -711,7 +710,7 @@ def summary_nodes(self, file=sys.stderr, N=None):

atimes.sort(reverse=True, key=lambda t: (t[1], t[3]))
tot = 0
for (f, t, a, nd_id, nb_call) in atimes[:N]:
for f, t, a, nd_id, nb_call in atimes[:N]:
tot += t
ftot = tot * 100 / local_time
if nb_call == 0:
Expand Down Expand Up @@ -840,7 +839,7 @@ def summary_memory(self, file, N=None):
var_mem = {} # variable->size in bytes; don't include input variables
node_mem = {} # (fgraph, node)->total outputs size (only dense outputs)

for (fgraph, node) in self.apply_callcount:
for fgraph, node in self.apply_callcount:
fct_memory.setdefault(fgraph, {})
fct_memory[fgraph].setdefault(node, [])
fct_shapes.setdefault(fgraph, {})
Expand Down Expand Up @@ -1611,7 +1610,7 @@ def exp_float32_op(op):
printed_tip = True

# tip 4
for (fgraph, a) in self.apply_time:
for fgraph, a in self.apply_time:
node = a
if isinstance(node.op, Dot) and all(
len(i.type.broadcastable) == 2 for i in node.inputs
Expand All @@ -1628,7 +1627,7 @@ def exp_float32_op(op):
printed_tip = True

# tip 5
for (fgraph, a) in self.apply_time:
for fgraph, a in self.apply_time:
node = a
if isinstance(node.op, RandomVariable):
printed_tip = True
Expand All @@ -1642,7 +1641,7 @@ def exp_float32_op(op):
break

# tip 6
for (fgraph, a) in self.apply_time:
for fgraph, a in self.apply_time:
node = a
if isinstance(node.op, Dot) and len({i.dtype for i in node.inputs}) != 1:
print(
Expand Down
5 changes: 0 additions & 5 deletions aesara/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ def short_platform(r=None, p=None):


def add_basic_configvars():

config.add(
"floatX",
"Default floating-point precision for python casts.\n"
Expand Down Expand Up @@ -388,7 +387,6 @@ def _is_greater_or_equal_0(x):


def add_compile_configvars():

config.add(
"mode",
"Default compilation mode",
Expand Down Expand Up @@ -631,7 +629,6 @@ def _is_valid_cmp_sloppy(v):


def add_tensor_configvars():

# This flag is used when we import Aesara to initialize global variables.
# So changing it after import will not modify these global variables.
# This could be done differently... but for now we simply prevent it from being
Expand Down Expand Up @@ -717,7 +714,6 @@ def add_experimental_configvars():


def add_error_and_warning_configvars():

###
# To disable some warning about old bug that are fixed now.
###
Expand Down Expand Up @@ -1196,7 +1192,6 @@ def add_vm_configvars():


def add_deprecated_configvars():

# TODO: remove this?
config.add(
"unittests__rseed",
Expand Down
8 changes: 0 additions & 8 deletions aesara/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def Rop(
# Check that each element of wrt corresponds to an element
# of eval_points with the same dimensionality.
for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points)):

try:
if wrt_elem.type.ndim != eval_point.type.ndim:
raise ValueError(
Expand Down Expand Up @@ -266,7 +265,6 @@ def _traverse(node):
# arguments, like for example random states
local_eval_points.append(None)
elif inp.owner in seen_nodes:

local_eval_points.append(
seen_nodes[inp.owner][inp.owner.outputs.index(inp)]
)
Expand Down Expand Up @@ -941,7 +939,6 @@ def account_for(var):
var_idx = app.outputs.index(var)

for i, ipt in enumerate(app.inputs):

# don't process ipt if it is not a true
# parent of var
if not connection_pattern[i][var_idx]:
Expand Down Expand Up @@ -1052,7 +1049,6 @@ def access_term_cache(node):
"""Populates term_dict[node] and returns it"""

if node not in term_dict:

inputs = node.inputs

output_grads = [access_grad_cache(var) for var in node.outputs]
Expand Down Expand Up @@ -1267,7 +1263,6 @@ def try_to_copy_if_needed(var):
]

for i, term in enumerate(input_grads):

# Disallow Nones
if term is None:
# We don't know what None means. in the past it has been
Expand Down Expand Up @@ -1383,7 +1378,6 @@ def access_grad_cache(var):
node_to_idx = var_to_app_to_idx[var]
for node in node_to_idx:
for idx in node_to_idx[node]:

term = access_term_cache(node)[idx]

if not isinstance(term, Variable):
Expand Down Expand Up @@ -1868,7 +1862,6 @@ def random_projection():
)

if max_abs_err > abs_tol and max_rel_err > rel_tol:

raise GradientError(
max_arg,
max_err_pos,
Expand Down Expand Up @@ -2052,7 +2045,6 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):

hessians = []
for input in wrt:

if not isinstance(input, Variable):
raise TypeError("hessian expects a (list of) Variable as `wrt`")

Expand Down
Loading

1 comment on commit 1da48b5

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'Python Benchmark with pytest-benchmark'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 1da48b5 Previous: 00c8b41 Ratio
tests/link/jax/test_elemwise.py::test_logsumexp_benchmark[0-size0] 17447.471931758955 iter/sec (stddev: 0.00012246005550494655) 60027.85746898697 iter/sec (stddev: 0.000001067323697542201) 3.44
tests/link/jax/test_elemwise.py::test_logsumexp_benchmark[0-size1] 52.85492011542199 iter/sec (stddev: 0.0025642582995337587) 122.97074395654728 iter/sec (stddev: 0.00010956053715406643) 2.33
tests/link/jax/test_elemwise.py::test_logsumexp_benchmark[0-size2] 0.6008892633276169 iter/sec (stddev: 0.039700634900855764) 1.2984929924318256 iter/sec (stddev: 0.005443965239389432) 2.16
tests/link/jax/test_elemwise.py::test_logsumexp_benchmark[1-size0] 23144.251134269387 iter/sec (stddev: 0.0000338022112841827) 61013.391178171216 iter/sec (stddev: 7.600864749608043e-7) 2.64
tests/link/jax/test_elemwise.py::test_logsumexp_benchmark[1-size1] 62.1358443160501 iter/sec (stddev: 0.0017988142366627596) 125.04561262015282 iter/sec (stddev: 0.00010720437524860068) 2.01
tests/scan/test_basic.py::TestExamples::test_reordering 639.5283432674082 iter/sec (stddev: 0.0005788173565035848) 1347.5555152950258 iter/sec (stddev: 0.00003738088077344204) 2.11

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.