Skip to content

[Redo of #23090] Clean up binary element-wise assertions #23109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test

Expand Down Expand Up @@ -461,13 +462,14 @@ def dynamic_run(fun, x_value, **kwargs):
def testNoBatchMultivariateRaisesWhenSingular(self):
with self.cached_session():
mu = [1., -1]
bijector = Affine(
shift=mu,
# Has zero on the diagonal.
scale_diag=[0., 1],
validate_args=True)
with self.assertRaisesOpError("diagonal part must be non-zero"):
bijector.forward([1., 1.]).eval()
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"diagonal part must be non-zero"):
_ = Affine(
shift=mu,
# Has zero on the diagonal.
scale_diag=[0., 1],
validate_args=True)
# Error detected statically; don't need to run the op.

def _makeScale(self,
x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
Expand Down Expand Up @@ -150,6 +151,22 @@ def _testInvalidDimensionsOpError(self, expected_error_message):
with self.assertRaisesError(expected_error_message):
sess.run(bijector.forward_event_shape_tensor(shape_in),
feed_dict=feed_dict)

def _testInvalidDimensionsStatic(self, expected_error_message):
"""Version of _testInvalidDimensionsOpError for errors detected statically
at graph construction time.

Args:
expected_error_message: String that should be present in the error
message that `Reshape` raises for invalid shapes.
"""
shape_in, shape_out, _ = self.build_shapes([2, 3], [1, 2, -2,])
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
expected_error_message):
_ = Reshape(
event_shape_out=shape_out,
event_shape_in=shape_in,
validate_args=True)
# pylint: enable=invalid-name

def testValidButNonMatchingInputOpError(self):
Expand Down Expand Up @@ -300,9 +317,9 @@ def testBijectiveAndFinite(self):
assert_bijective_and_finite(
bijector, x, y, event_ndims=2, rtol=1e-6, atol=0)

def testInvalidDimensionsOpError(self):
self._testInvalidDimensionsOpError(
"Invalid value in tensor used for shape: -2")
def testInvalidDimensionsStatic(self):
self._testInvalidDimensionsStatic(
"elements must be either positive integers or `-1`")

def testInputOutputMismatchOpError(self):
self._testInputOutputMismatchOpError("Cannot reshape a tensor with")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np

from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.framework import errors
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
Expand All @@ -43,9 +44,10 @@ def _softplus_ildj_before_reduction(self, y):

def testHingeSoftnessZeroRaises(self):
with self.cached_session():
bijector = Softplus(hinge_softness=0., validate_args=True)
with self.assertRaisesOpError("must be non-zero"):
bijector.forward([1., 1.]).eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"must be non-zero"):
_ = Softplus(hinge_softness=0., validate_args=True)
# Error detected statically; don't need to run op.

def testBijectorForwardInverseEventDimsZero(self):
with self.cached_session():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow.contrib.distributions.python.ops import cauchy as cauchy_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
Expand Down Expand Up @@ -400,9 +401,10 @@ def testCauchySampleMultiDimensional(self):

def testCauchyNegativeLocFails(self):
with self.cached_session():
cauchy = cauchy_lib.Cauchy(loc=[1.], scale=[-5.], validate_args=True)
with self.assertRaisesOpError("Condition x > 0 did not hold"):
cauchy.mode().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"Condition x > 0 did not hold"):
_ = cauchy_lib.Cauchy(loc=[1.], scale=[-5.], validate_args=True)
# Error detected statically; no need for _.mode().eval()

def testCauchyShape(self):
with self.cached_session():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
Expand All @@ -40,11 +41,11 @@ def testShape(self):

def testInvalidTolRaises(self):
loc = rng.rand(2, 3, 4).astype(np.float32)
deterministic = deterministic_lib.Deterministic(
loc, atol=-1, validate_args=True)
with self.cached_session():
with self.assertRaisesOpError("Condition x >= 0"):
deterministic.prob(0.).eval()
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Condition x >= 0"):
_ = deterministic_lib.Deterministic(
loc, atol=-1, validate_args=True)
# Error detected statically; no need for _.prob(0.).eval()

def testProbWithNoBatchDimsIntegerType(self):
deterministic = deterministic_lib.Deterministic(0)
Expand Down Expand Up @@ -195,16 +196,16 @@ def testShape(self):

def testInvalidTolRaises(self):
loc = rng.rand(2, 3, 4).astype(np.float32)
deterministic = deterministic_lib.VectorDeterministic(
loc, atol=-1, validate_args=True)
with self.cached_session():
with self.assertRaisesOpError("Condition x >= 0"):
deterministic.prob(loc).eval()
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Condition x >= 0"):
_ = deterministic_lib.VectorDeterministic(
loc, atol=-1, validate_args=True)
# Error detected statically; no need for _.prob(loc).eval()

def testInvalidXRaises(self):
loc = rng.rand(2, 3, 4).astype(np.float32)
deterministic = deterministic_lib.VectorDeterministic(
loc, atol=-1, validate_args=True)
loc, atol=None, validate_args=True)
with self.cached_session():
with self.assertRaisesRegexp(ValueError, "must have rank at least 1"):
deterministic.prob(0.).eval()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
Expand All @@ -41,6 +42,7 @@ def try_import(name): # pylint: disable=invalid-name
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module


stats = try_import("scipy.stats")


Expand Down Expand Up @@ -288,9 +290,10 @@ def testHalfNormalSampleMultiDimensional(self):

def testNegativeSigmaFails(self):
with self.cached_session():
halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G")
with self.assertRaisesOpError("Condition x > 0 did not hold"):
halfnorm.mean().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"Condition x > 0 did not hold"):
_ = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G")
# Error detected statically; no need for _.mean().eval()

def testHalfNormalShape(self):
with self.cached_session():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tensorflow.contrib.distributions.python.ops import inverse_gamma
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
Expand Down Expand Up @@ -249,7 +250,8 @@ def testInverseGammaSampleMultiDimensional(self):
fails += 0 if self._kstest(a, b, s) else 1
self.assertLess(fails, trials * 0.03)

def _kstest(self, alpha, beta, samples):
@staticmethod
def _kstest(alpha, beta, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
ks, _ = stats.kstest(samples, stats.invgamma(alpha, scale=beta).cdf)
# Return True when the test passes.
Expand Down Expand Up @@ -295,16 +297,18 @@ def testInverseGammaNonPositiveInitializationParamsRaises(self):
with self.cached_session():
alpha_v = constant_op.constant(0.0, name="alpha")
beta_v = constant_op.constant(1.0, name="beta")
inv_gamma = inverse_gamma.InverseGamma(
concentration=alpha_v, rate=beta_v, validate_args=True)
with self.assertRaisesOpError("alpha"):
inv_gamma.mean().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"alpha"):
_ = inverse_gamma.InverseGamma(
concentration=alpha_v, rate=beta_v, validate_args=True)
# Error detected statically; no need for _.mean().eval()
alpha_v = constant_op.constant(1.0, name="alpha")
beta_v = constant_op.constant(0.0, name="beta")
inv_gamma = inverse_gamma.InverseGamma(
concentration=alpha_v, rate=beta_v, validate_args=True)
with self.assertRaisesOpError("beta"):
inv_gamma.mean().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"beta"):
_ = inverse_gamma.InverseGamma(
concentration=alpha_v, rate=beta_v, validate_args=True)
# Error detected statically; no need for _.mean().eval()

def testInverseGammaWithSoftplusConcentrationRate(self):
with self.cached_session():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from scipy import stats
from tensorflow.contrib import distributions as distributions_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
Expand Down Expand Up @@ -361,15 +362,14 @@ def testProbAndGradGivesFiniteResultsForCommonEvents(self):

def testLowerCutoffMustBeBelowUpperCutoffOrWeRaise(self):
with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(loc=0., scale=1.),
low=1., # not strictly less than high.
high=1.,
validate_args=True)

self.assertTrue(qdist.validate_args) # Default is True.
with self.assertRaisesOpError("must be strictly less"):
qdist.sample().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"must be strictly less"):
_ = distributions.QuantizedDistribution(
distribution=distributions.Normal(loc=0., scale=1.),
low=1., # not strictly less than high.
high=1.,
validate_args=True)
# Error detected statically; no need for _.sample().eval()

def testCutoffsMustBeIntegerValuedIfValidateArgsTrue(self):
with self.cached_session():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,11 @@ def testZeroTemperature(self):
"""If validate_args, raises InvalidArgumentError when temperature is 0."""
temperature = constant_op.constant(0.0)
p = constant_op.constant([0.1, 0.4])
dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p,
validate_args=True)
with self.cached_session():
sample = dist.sample()
with self.assertRaises(errors_impl.InvalidArgumentError):
sample.eval()
with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError,
"x > 0 did not hold"):
_ = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p,
validate_args=True)
# Error detected statically; no need to run the op.

def testDtype(self):
temperature = constant_op.constant(1.0, dtype=dtypes.float32)
Expand Down
7 changes: 4 additions & 3 deletions tensorflow/contrib/metrics/python/ops/metric_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,9 +1734,10 @@ def testPredictionsOutOfRange(self):
predictions = constant_op.constant(
[1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
_, update_op = metrics.streaming_auc(predictions, labels)
sess.run(variables.local_variables_initializer())
self.assertRaises(errors_impl.InvalidArgumentError, update_op.eval)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
r"predictions must be in \[0, 1\]"):
_, _ = metrics.streaming_auc(predictions, labels)
# Error detected statically; no need to run the op.

def testAllCorrect(self):
self.allCorrectAsExpected('ROC')
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2147,7 +2147,7 @@ def expand_dims(input_tensor):
if rank is not None:
if rank == 0:
raise ValueError(
'Feature (key: {}) cannot have rank 0. Give: {}'.format(
'Feature (key: {}) cannot have rank 0. Given: {}'.format(
key, feature_tensor))
return feature_tensor if rank != 1 else expand_dims(feature_tensor)

Expand Down Expand Up @@ -2833,9 +2833,13 @@ def _transform_feature(self, inputs):
# Fail if values are out-of-range.
assert_less = check_ops.assert_less(
values, num_buckets, data=(values, num_buckets),
message='Bucket index for categorical column '
'"{}" exceeds number of buckets'.format(self.name),
name='assert_less_than_num_buckets')
assert_greater = check_ops.assert_greater_equal(
values, zero, data=(values,),
message='Negative bucket index for categorical column "{}"'.format(
self.name),
name='assert_greater_or_equal_0')
with ops.control_dependencies((assert_less, assert_greater)):
values = array_ops.identity(values)
Expand Down
26 changes: 16 additions & 10 deletions tensorflow/python/feature_column/feature_column_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4387,30 +4387,36 @@ def test_get_sparse_tensors_dense_input(self):
@test_util.run_deprecated_v1
def test_get_sparse_tensors_with_inputs_too_small(self):
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
inputs_value = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(1, -1, 0),
dense_shape=(2, 2))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
inputs_placeholder = array_ops.sparse_placeholder(dtypes.int32)
id_weight_pair = column._get_sparse_tensors(
_LazyBuilder({'aaa': inputs_placeholder}))
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
with _initialized_session() as sess:
with self.assertRaisesRegexp(
errors.OpError, 'assert_greater_or_equal_0'):
id_weight_pair.id_tensor.eval()
errors.OpError, 'Negative bucket index'):
sess.run(id_weight_pair.id_tensor,
feed_dict={inputs_placeholder: inputs_value})

@test_util.run_deprecated_v1
def test_get_sparse_tensors_with_inputs_too_big(self):
column = fc._categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
inputs_value = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(1, 99, 0),
dense_shape=(2, 2))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
inputs_placeholder = array_ops.sparse_placeholder(dtypes.int32)
id_weight_pair = column._get_sparse_tensors(
_LazyBuilder({'aaa': inputs_placeholder}))
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
with _initialized_session() as sess:
with self.assertRaisesRegexp(
errors.OpError, 'assert_less_than_num_buckets'):
id_weight_pair.id_tensor.eval()
errors.OpError, 'exceeds number of buckets'):
sess.run(id_weight_pair.id_tensor,
feed_dict={inputs_placeholder: inputs_value})

@test_util.run_deprecated_v1
def test_get_sparse_tensors_with_default_value(self):
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/python/feature_column/feature_column_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2599,7 +2599,7 @@ def expand_dims(input_tensor):
if rank is not None:
if rank == 0:
raise ValueError(
'Feature (key: {}) cannot have rank 0. Give: {}'.format(
'Feature (key: {}) cannot have rank 0. Given: {}'.format(
key, feature_tensor))
return feature_tensor if rank != 1 else expand_dims(feature_tensor)

Expand Down Expand Up @@ -3780,9 +3780,13 @@ def _transform_input_tensor(self, input_tensor):
# Fail if values are out-of-range.
assert_less = check_ops.assert_less(
values, num_buckets, data=(values, num_buckets),
message='Bucket index for categorical column '
'"{}" exceeds number of buckets'.format(self.name),
name='assert_less_than_num_buckets')
assert_greater = check_ops.assert_greater_equal(
values, zero, data=(values,),
message='Negative bucket index for categorical column "{}"'.format(
self.name),
name='assert_greater_or_equal_0')
with ops.control_dependencies((assert_less, assert_greater)):
values = array_ops.identity(values)
Expand Down
Loading