Skip to content

Commit

Permalink
Validate test values using a tensor's type
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 17, 2020
1 parent 74ee82b commit 21daf4e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
14 changes: 8 additions & 6 deletions tests/gof/test_compute_test_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,16 @@ def test_constant(self):

@theano.change_flags(compute_test_value="raise")
def test_incorrect_type(self):
x = tt.fmatrix("x")
# Incorrect dtype (float64) for test_value
x.tag.test_value = np.random.rand(3, 4)
y = tt.dmatrix("y")
y.tag.test_value = np.random.rand(4, 5)

x = tt.vector("x")
with pytest.raises(TypeError):
tt.dot(x, y)
# Incorrect shape for test value
x.tag.test_value = np.empty((2, 2))

x = tt.fmatrix("x")
with pytest.raises(TypeError):
# Incorrect dtype (float64) for test value
x.tag.test_value = np.random.rand(3, 4)

@theano.change_flags(compute_test_value="raise")
def test_overided_function(self):
Expand Down
2 changes: 1 addition & 1 deletion theano/gof/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ class Variable(Node):
def __init__(self, type, owner=None, index=None, name=None):
super(Variable, self).__init__()

self.tag = utils.Scratchpad()
self.tag = utils.ValidatingScratchpad("test_value", type.filter)

self.type = type
if owner is not None and not isinstance(owner, Apply):
Expand Down
17 changes: 17 additions & 0 deletions theano/gof/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,23 @@ def info(self):
print(" %s: %s" % (k, v))


class ValidatingScratchpad(Scratchpad):
"""This `Scratchpad` validates attribute values."""

def __init__(self, attr, attr_filter):
super().__init__()

object.__setattr__(self, "attr", attr)
object.__setattr__(self, "attr_filter", attr_filter)

def __setattr__(self, attr, obj):

if getattr(self, "attr", None) == attr:
obj = self.attr_filter(obj)

return object.__setattr__(self, attr, obj)


class D:
def __init__(self, **d):
self.__dict__.update(d)
Expand Down
5 changes: 4 additions & 1 deletion theano/tensor/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7743,7 +7743,10 @@ def local_fuse(node):
if tv.size > 0:
tmp.tag.test_value = tv.flatten()[0]
else:
tmp.tag.test_value = tv
_logger.warning(
"Cannot construct a scalar test value"
" from a test value with no size: {}".format(ii)
)
except AttributeError:
pass

Expand Down

0 comments on commit 21daf4e

Please sign in to comment.