From 21daf4e0fb82b9024ac1a7b57a3a815ba9f1ba68 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 17 Oct 2020 01:12:31 -0500 Subject: [PATCH] Validate test values using a tensor's type --- tests/gof/test_compute_test_value.py | 14 ++++++++------ theano/gof/graph.py | 2 +- theano/gof/utils.py | 17 +++++++++++++++++ theano/tensor/opt.py | 5 ++++- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/tests/gof/test_compute_test_value.py b/tests/gof/test_compute_test_value.py index b25991f4a2..34c4c7808b 100644 --- a/tests/gof/test_compute_test_value.py +++ b/tests/gof/test_compute_test_value.py @@ -167,14 +167,16 @@ def test_constant(self): @theano.change_flags(compute_test_value="raise") def test_incorrect_type(self): - x = tt.fmatrix("x") - # Incorrect dtype (float64) for test_value - x.tag.test_value = np.random.rand(3, 4) - y = tt.dmatrix("y") - y.tag.test_value = np.random.rand(4, 5) + x = tt.vector("x") with pytest.raises(TypeError): - tt.dot(x, y) + # Incorrect shape for test value + x.tag.test_value = np.empty((2, 2)) + + x = tt.fmatrix("x") + with pytest.raises(TypeError): + # Incorrect dtype (float64) for test value + x.tag.test_value = np.random.rand(3, 4) @theano.change_flags(compute_test_value="raise") def test_overided_function(self): diff --git a/theano/gof/graph.py b/theano/gof/graph.py index 881a45a08c..f36dd122be 100644 --- a/theano/gof/graph.py +++ b/theano/gof/graph.py @@ -383,7 +383,7 @@ class Variable(Node): def __init__(self, type, owner=None, index=None, name=None): super(Variable, self).__init__() - self.tag = utils.Scratchpad() + self.tag = utils.ValidatingScratchpad("test_value", type.filter) self.type = type if owner is not None and not isinstance(owner, Apply): diff --git a/theano/gof/utils.py b/theano/gof/utils.py index 261702756f..bd5e333a8e 100644 --- a/theano/gof/utils.py +++ b/theano/gof/utils.py @@ -259,6 +259,23 @@ def info(self): print(" %s: %s" % (k, v)) +class ValidatingScratchpad(Scratchpad): + """This `Scratchpad` validates attribute values.""" + + def __init__(self, attr, attr_filter): + super().__init__() + + object.__setattr__(self, "attr", attr) + object.__setattr__(self, "attr_filter", attr_filter) + + def __setattr__(self, attr, obj): + + if getattr(self, "attr", None) == attr: + obj = self.attr_filter(obj) + + return object.__setattr__(self, attr, obj) + + class D: def __init__(self, **d): self.__dict__.update(d) diff --git a/theano/tensor/opt.py b/theano/tensor/opt.py index 97ddf04f69..00e61ea548 100644 --- a/theano/tensor/opt.py +++ b/theano/tensor/opt.py @@ -7743,7 +7743,10 @@ def local_fuse(node): if tv.size > 0: tmp.tag.test_value = tv.flatten()[0] else: - tmp.tag.test_value = tv + _logger.warning( + "Cannot construct a scalar test value" + " from a test value with no size: {}".format(ii) + ) except AttributeError: pass