From f010305a7fc5b7d20354f01f7f46465c2246f1f0 Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Thu, 1 Aug 2019 15:15:44 -0700 Subject: [PATCH] Refactor tests in preparation for more asserts triggering at op creation time This is purely refactoring that is compatible with current asserts firing during session.run and newer asserts that sometime fire at op creation time. It is needed to submit https://github.com/tensorflow/tensorflow/pull/23109. PiperOrigin-RevId: 261207674 --- .../estimator/python/estimator/head_test.py | 28 +++++++-------- .../estimator/head/multi_label_head_test.py | 36 +++++++++---------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tensorflow_estimator/contrib/estimator/python/estimator/head_test.py b/tensorflow_estimator/contrib/estimator/python/estimator/head_test.py index e08c917b..e20265fc 100644 --- a/tensorflow_estimator/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow_estimator/contrib/estimator/python/estimator/head_test.py @@ -858,18 +858,18 @@ def _train_op_fn(loss): del loss return control_flow_ops.no_op() - spec = head.create_estimator_spec( - features={}, - mode=model_fn.ModeKeys.TRAIN, - logits=logits, - labels=labels, - train_op_fn=_train_op_fn) with self.cached_session() as sess: - _initialize_variables(self, spec.scaffold) with self.assertRaisesRegexp( errors.InvalidArgumentError, r'labels must be an integer indicator Tensor with values in ' r'\[0, 1\]'): + spec = head.create_estimator_spec( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + _initialize_variables(self, spec.scaffold) sess.run(spec.loss) def test_train_invalid_sparse_labels(self): @@ -884,17 +884,17 @@ def _train_op_fn(loss): del loss return control_flow_ops.no_op() - spec = head.create_estimator_spec( - features={}, - mode=model_fn.ModeKeys.TRAIN, - logits=logits, - labels=labels, - train_op_fn=_train_op_fn) with self.cached_session() as sess: - _initialize_variables(self, spec.scaffold) with self.assertRaisesRegexp( errors.InvalidArgumentError, r'labels must be an integer SparseTensor with values in \[0, 2\)'): + spec = head.create_estimator_spec( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + _initialize_variables(self, spec.scaffold) sess.run(spec.loss) def _test_train(self, head, logits, labels, expected_loss): diff --git a/tensorflow_estimator/python/estimator/head/multi_label_head_test.py b/tensorflow_estimator/python/estimator/head/multi_label_head_test.py index c42bd05a..8d77d6e4 100644 --- a/tensorflow_estimator/python/estimator/head/multi_label_head_test.py +++ b/tensorflow_estimator/python/estimator/head/multi_label_head_test.py @@ -817,20 +817,20 @@ def test_train_invalid_indicator_labels(self): def _train_op_fn(loss): del loss return control_flow_ops.no_op() - spec = head.create_estimator_spec( - features={}, - mode=ModeKeys.TRAIN, - logits=logits, - labels=labels, - train_op_fn=_train_op_fn, - trainable_variables=[ - variables.Variable([1.0, 2.0], dtype=dtypes.float32)]) with self.cached_session() as sess: - test_lib._initialize_variables(self, spec.scaffold) with self.assertRaisesRegexp( errors.InvalidArgumentError, r'labels must be an integer indicator Tensor with values in ' r'\[0, 1\]'): + spec = head.create_estimator_spec( + features={}, + mode=ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn, + trainable_variables=[ + variables.Variable([1.0, 2.0], dtype=dtypes.float32)]) + test_lib._initialize_variables(self, spec.scaffold) sess.run(spec.loss) def test_train_invalid_sparse_labels(self): @@ -855,19 +855,19 @@ def test_train_invalid_sparse_labels(self): def _train_op_fn(loss): del loss return control_flow_ops.no_op() - spec = head.create_estimator_spec( - features={}, - mode=ModeKeys.TRAIN, - logits=logits, - labels=labels, - train_op_fn=_train_op_fn, - trainable_variables=[ - variables.Variable([1.0, 2.0], dtype=dtypes.float32)]) with self.cached_session() as sess: - test_lib._initialize_variables(self, spec.scaffold) with self.assertRaisesRegexp( errors.InvalidArgumentError, r'labels must be an integer SparseTensor with values in \[0, 2\)'): + spec = head.create_estimator_spec( + features={}, + mode=ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn, + trainable_variables=[ + variables.Variable([1.0, 2.0], dtype=dtypes.float32)]) + test_lib._initialize_variables(self, spec.scaffold) sess.run(spec.loss) def _test_train(self, head, logits, labels, expected_loss):