Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python updates for Elemwise broadcasting #336

Merged
9 changes: 4 additions & 5 deletions aesara/compile/compilelock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""
import os
import threading
import typing
from contextlib import contextmanager
from typing import Optional

import filelock

Expand Down Expand Up @@ -45,7 +45,7 @@ def force_unlock(lock_dir: os.PathLike):


@contextmanager
def lock_ctx(lock_dir: os.PathLike = None, *, timeout: typing.Optional[float] = -1):
def lock_ctx(lock_dir: os.PathLike = None, *, timeout: Optional[float] = None):
"""Context manager that wraps around FileLock and SoftFileLock from filelock package.

Parameters
Expand All @@ -59,10 +59,9 @@ def lock_ctx(lock_dir: os.PathLike = None, *, timeout: typing.Optional[float] =
"""
if lock_dir is None:
lock_dir = config.compiledir
if timeout == -1:

if timeout is None:
timeout = config.compile__timeout
elif not (timeout is None or timeout > 0):
raise ValueError(f"Timeout parameter must be None or positive. Got {timeout}.")

# locks are kept in a dictionary to account for changing compiledirs
dir_key = f"{lock_dir}-{os.getpid()}"
Expand Down
3 changes: 3 additions & 0 deletions aesara/link/c/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,9 @@ def __call__(self):
raise
raise exc_value.with_traceback(exc_trace)

def __str__(self):
return f"{type(self).__name__}({self.module})"


class OpWiseCLinker(LocalLinker):
"""
Expand Down
8 changes: 4 additions & 4 deletions aesara/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# we will abuse the lockfile mechanism when reading and writing the registry
from aesara.compile.compilelock import lock_ctx
from aesara.configdefaults import config, gcc_version_str
from aesara.link.c.exceptions import MissingGXX
from aesara.link.c.exceptions import CompileError, MissingGXX
from aesara.utils import (
LOCAL_BITWIDTH,
flatten,
Expand Down Expand Up @@ -2543,9 +2543,9 @@ def print_command_line_error():
# We replace '\n' by '. ' in the error message because when Python
# prints the exception, having '\n' in the text makes it more
# difficult to read.
compile_stderr = compile_stderr.replace("\n", ". ")
raise Exception(
f"Compilation failed (return status={status}): {compile_stderr}"
# compile_stderr = compile_stderr.replace("\n", ". ")
raise CompileError(
f"Compilation failed (return status={status}):\n{' '.join(cmd)}\n{compile_stderr}"
)
elif config.cmodule__compilation_warning and compile_stderr:
# Print errors just below the command line.
Expand Down
12 changes: 12 additions & 0 deletions aesara/link/c/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from distutils.errors import CompileError as BaseCompileError


class MissingGXX(Exception):
"""
This error is raised when we try to generate c code,
but g++ is not available.

"""


class CompileError(BaseCompileError):
"""This custom `Exception` prints compilation errors with their original
formatting.
"""

def __str__(self):
return self.args[0]
31 changes: 17 additions & 14 deletions aesara/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ def raise_with_op(
else:
scalar_values.append("not shown")
else:
shapes = "The thunk don't have an inputs attributes."
strides = "So we can't access the strides of inputs values"
scalar_values = "And can't print its inputs scalar value"
shapes = "The thunk doesn't have an `inputs` attributes."
strides = "So we can't access the strides of the input values"
scalar_values = "and we can't print its scalar input values"
clients = [[c[0] for c in fgraph.clients[var]] for var in node.outputs]
detailed_err_msg += (
f"Inputs shapes: {shapes}"
Expand All @@ -349,14 +349,15 @@ def raise_with_op(
detailed_err_msg += f"\nOutputs clients: {clients}\n"
else:
hints.append(
"HINT: Use another linker then the c linker to"
" have the inputs shapes and strides printed."
"HINT: Use a linker other than the C linker to"
" print the inputs' shapes and strides."
)

# Print node backtraces
tr = getattr(node.outputs[0].tag, "trace", [])
if isinstance(tr, list) and len(tr) > 0:
detailed_err_msg += "\nBacktrace when the node is created(use Aesara flag traceback__limit=N to make it longer):\n"
detailed_err_msg += "\nBacktrace when the node is created "
detailed_err_msg += "(use Aesara flag traceback__limit=N to make it longer):\n"

# Print separate message for each element in the list of batcktraces
sio = io.StringIO()
Expand All @@ -365,9 +366,9 @@ def raise_with_op(
detailed_err_msg += str(sio.getvalue())
else:
hints.append(
"HINT: Re-running with most Aesara optimization disabled could"
" give you a back-trace of when this node was created. This can"
" be done with by setting the Aesara flag"
"HINT: Re-running with most Aesara optimizations disabled could"
" provide a back-trace showing when this node was created. This can"
" be done by setting the Aesara flag"
" 'optimizer=fast_compile'. If that does not work,"
" Aesara optimizations can be disabled with 'optimizer=None'."
)
Expand All @@ -378,7 +379,7 @@ def raise_with_op(

f = io.StringIO()
aesara.printing.debugprint(node, file=f, stop_on_name=True, print_type=True)
detailed_err_msg += "\nDebugprint of the apply node: \n"
detailed_err_msg += "\nDebug print of the apply node: \n"
detailed_err_msg += f.getvalue()

# Prints output_map
Expand Down Expand Up @@ -497,16 +498,18 @@ def raise_with_op(

else:
hints.append(
"HINT: Use the Aesara flag 'exception_verbosity=high'"
" for a debugprint and storage map footprint of this apply node."
"HINT: Use the Aesara flag `exception_verbosity=high`"
" for a debug print-out and storage map footprint of this Apply node."
)

try:
exc_value = exc_type(
str(exc_value) + detailed_err_msg + "\n" + "\n".join(hints)
)
except TypeError:
warnings.warn(f"{exc_type} error does not allow us to add extra error message")
warnings.warn(
f"{exc_type} error does not allow us to add an extra error message"
)
# Some exception need extra parameter in inputs. So forget the
# extra long error message in that case.
raise exc_value.with_traceback(exc_trace)
Expand Down Expand Up @@ -541,7 +544,7 @@ def write(msg):
write(line)
write(
"For the full definition stack trace set"
" the Aesara flags traceback__limit to -1"
" the Aesara flags `traceback__limit` to -1"
)


Expand Down
33 changes: 6 additions & 27 deletions aesara/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def __setstate__(self, d):

def get_output_info(self, dim_shuffle, *inputs):
"""Return the outputs dtype and broadcastable pattern and the
dimshuffled niputs.
dimshuffled inputs.

"""
shadow = self.scalar_op.make_node(
Expand Down Expand Up @@ -736,30 +736,9 @@ def perform(self, node, inputs, output_storage):
# should be disabled.
super().perform(node, inputs, output_storage)

for dims in zip(
*[
list(zip(input.shape, sinput.type.broadcastable))
for input, sinput in zip(inputs, node.inputs)
]
):
if max(d for d, b in dims) != 1 and (1, False) in dims:
# yes there may be more compact ways to write this code,
# but please maintain python 2.4 compatibility
# (no "x if c else y")
msg = []
assert len(inputs) == len(node.inputs)
for input, sinput in zip(inputs, node.inputs):
assert len(input.shape) == len(sinput.type.broadcastable)
msg2 = []
for d, b in zip(input.shape, sinput.type.broadcastable):
if b:
msg2 += ["*"]
else:
msg2 += [str(d)]
msg.append(f"({', '.join(msg2)})")

base_exc_str = f"Dimension mismatch; shapes are {', '.join(msg)}"
raise ValueError(base_exc_str)
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")

# Determine the shape of outputs
out_shape = []
Expand Down Expand Up @@ -878,9 +857,9 @@ def infer_shape(self, fgraph, node, i_shapes):
return rval

def _c_all(self, node, nodename, inames, onames, sub):
# Some ops call directly the Elemwise._c_all or Elemwise.c_code
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
# To not request all of them to call prepare_node(), do it here.
# There is no harm if it get called multile time.
# There is no harm if it get called multiple times.
if not hasattr(node.tag, "fake_node"):
self.prepare_node(node, None, None, "c")
_inames = inames
Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):

multinomial = MultinomialRV()

vsearchsorted = np.vectorize(np.searchsorted, otypes=[np.int], signature="(n),()->()")
vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")


class CategoricalRV(RandomVariable):
Expand Down
41 changes: 26 additions & 15 deletions tests/compile/test_compilelock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,6 @@
from aesara.compile.compilelock import force_unlock, local_mem, lock_ctx


def test_compilelock_errors():
with tempfile.TemporaryDirectory() as dir:
with pytest.raises(ValueError):
with lock_ctx(dir, timeout=0):
pass
with pytest.raises(ValueError):
with lock_ctx(dir, timeout=-2):
pass


def test_compilelock_force_unlock():
with tempfile.TemporaryDirectory() as dir_name:
with lock_ctx(dir_name):
Expand Down Expand Up @@ -81,13 +71,22 @@ def run_locking_test(ctx):


def test_locking_thread():
import traceback

with tempfile.TemporaryDirectory() as dir_name:

def test_fn_1():
with lock_ctx(dir_name):
# Sleep "indefinitely"
time.sleep(100)
def test_fn_1(arg):
try:
with lock_ctx(dir_name):
# Notify the outside that we've obtained the lock
arg.append(False)
while True not in arg:
time.sleep(0.5)
except Exception:
# Notify the outside that we done
arg.append(False)
# If something unexpected happened, we want to know what it was
traceback.print_exc()

def test_fn_2(arg):
try:
Expand All @@ -98,18 +97,30 @@ def test_fn_2(arg):
# It timed out, which means that the lock was still held by the
# first thread
arg.append(True)
except Exception:
# If something unexpected happened, we want to know what it was
traceback.print_exc()

thread_1 = threading.Thread(target=test_fn_1)
res = []
thread_1 = threading.Thread(target=test_fn_1, args=(res,))
thread_2 = threading.Thread(target=test_fn_2, args=(res,))

thread_1.start()

# Make sure the first thread has obtained the lock
while False not in res:
time.sleep(0.5)

thread_2.start()

# The second thread should raise `filelock.Timeout`
thread_2.join()
assert True in res

thread_1.join()
assert not thread_1.is_alive()
assert not thread_2.is_alive()


@pytest.mark.skipif(sys.platform != "linux", reason="Fork is only available on linux")
def test_locking_multiprocess_fork():
Expand Down
Loading