From 327bd1eaaf4a1300268a41d888981dfe646413d8 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 4 Nov 2021 15:58:32 +0100 Subject: [PATCH] Make input mismatch TypeError in make_node more readable --- aesara/graph/op.py | 10 +++++++++- tests/graph/test_op.py | 15 ++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/aesara/graph/op.py b/aesara/graph/op.py index 7c287ff1e5..3d69ed0bbb 100644 --- a/aesara/graph/op.py +++ b/aesara/graph/op.py @@ -225,7 +225,15 @@ def make_node(self, *inputs: Variable) -> Apply: ) if not all(inp.type == it for inp, it in zip(inputs, self.itypes)): raise TypeError( - f"We expected inputs of types '{str(self.itypes)}' but got types '{str([inp.type for inp in inputs])}'" + f"Invalid input types for Op {self}:\n" + + "\n".join( + f"Input {i}/{len(inputs)}: Expected {inp}, got {out}" + for i, (inp, out) in enumerate( + zip(self.itypes, (inp.type for inp in inputs)), + start=1, + ) + if inp != out + ) ) return Apply(self, inputs, [o() for o in self.otypes]) diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index 68a74290e0..0e1d923d6c 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -12,7 +12,7 @@ from aesara.graph.type import Generic, Type from aesara.graph.utils import MethodNotDefined, TestValueError from aesara.tensor.math import log -from aesara.tensor.type import dmatrix, vector +from aesara.tensor.type import dmatrix, dscalar, dvector, vector def as_variable(x): @@ -340,3 +340,16 @@ def test_get_test_values_exc(): with pytest.raises(TestValueError): x = vector() assert op.get_test_values(x) == [] + + +def test_op_invalid_input_types(): + class TestOp(aesara.graph.op.Op): + itypes = [dvector, dvector, dvector] + otypes = [dvector] + + def perform(self, node, inputs, outputs): + pass + + msg = r"^Invalid input types for Op TestOp:\nInput 2/3: Expected TensorType\(float64, vector\)" + with pytest.raises(TypeError, match=msg): + TestOp()(dvector(), dscalar(), dvector())