Skip to content

Commit

Permalink
Make input mismatch TypeError in make_node more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Nov 12, 2021
1 parent 7278da4 commit f9dbc8d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
10 changes: 9 additions & 1 deletion aesara/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,15 @@ def make_node(self, *inputs: Variable) -> Apply:
)
if not all(inp.type == it for inp, it in zip(inputs, self.itypes)):
raise TypeError(
f"We expected inputs of types '{str(self.itypes)}' but got types '{str([inp.type for inp in inputs])}'"
f"Invalid input types for Op {self}:\n"
+ "\n".join(
f"Input {i}/{len(inputs)}: Expected {inp}, got {out}"
for i, (inp, out) in enumerate(
zip(self.itypes, (inp.type for inp in inputs)),
start=1,
)
if inp != out
)
)
return Apply(self, inputs, [o() for o in self.otypes])

Expand Down
15 changes: 14 additions & 1 deletion tests/graph/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aesara.graph.type import Generic, Type
from aesara.graph.utils import MethodNotDefined, TestValueError
from aesara.tensor.math import log
from aesara.tensor.type import dmatrix, vector
from aesara.tensor.type import dmatrix, dvector, scalar, vector


def as_variable(x):
Expand Down Expand Up @@ -340,3 +340,16 @@ def test_get_test_values_exc():
with pytest.raises(TestValueError):
x = vector()
assert op.get_test_values(x) == []


def test_op_invalid_input_types():
class TestOp(aesara.graph.op.Op):
itypes = [dvector, dvector, dvector]
otypes = [dvector]

def perform(self, node, inputs, outputs):
pass

msg = r"^Invalid input types for Op TestOp:\nInput 2/3: Expected TensorType\(float64, vector\)"
with pytest.raises(TypeError, match=msg):
TestOp()(vector(), scalar(), vector())

0 comments on commit f9dbc8d

Please sign in to comment.