Skip to content
This repository has been archived by the owner on Jan 23, 2025. It is now read-only.

Commit

Permalink
Refactor tests in preparation for more asserts triggering at op creat…
Browse files Browse the repository at this point in the history
…ion time

This is purely refactoring that is compatible with current asserts firing
during session.run and newer asserts that sometime fire at op creation time.
It is needed to submit tensorflow/tensorflow#23109.

PiperOrigin-RevId: 261207674
  • Loading branch information
iganichev authored and mihaimaruseac committed Sep 18, 2019
1 parent e62e112 commit f010305
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -858,18 +858,18 @@ def _train_op_fn(loss):
del loss
return control_flow_ops.no_op()

spec = head.create_estimator_spec(
features={},
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'labels must be an integer indicator Tensor with values in '
r'\[0, 1\]'):
spec = head.create_estimator_spec(
features={},
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
_initialize_variables(self, spec.scaffold)
sess.run(spec.loss)

def test_train_invalid_sparse_labels(self):
Expand All @@ -884,17 +884,17 @@ def _train_op_fn(loss):
del loss
return control_flow_ops.no_op()

spec = head.create_estimator_spec(
features={},
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'labels must be an integer SparseTensor with values in \[0, 2\)'):
spec = head.create_estimator_spec(
features={},
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
_initialize_variables(self, spec.scaffold)
sess.run(spec.loss)

def _test_train(self, head, logits, labels, expected_loss):
Expand Down
36 changes: 18 additions & 18 deletions tensorflow_estimator/python/estimator/head/multi_label_head_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,20 +817,20 @@ def test_train_invalid_indicator_labels(self):
def _train_op_fn(loss):
del loss
return control_flow_ops.no_op()
spec = head.create_estimator_spec(
features={},
mode=ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn,
trainable_variables=[
variables.Variable([1.0, 2.0], dtype=dtypes.float32)])
with self.cached_session() as sess:
test_lib._initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'labels must be an integer indicator Tensor with values in '
r'\[0, 1\]'):
spec = head.create_estimator_spec(
features={},
mode=ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn,
trainable_variables=[
variables.Variable([1.0, 2.0], dtype=dtypes.float32)])
test_lib._initialize_variables(self, spec.scaffold)
sess.run(spec.loss)

def test_train_invalid_sparse_labels(self):
Expand All @@ -855,19 +855,19 @@ def test_train_invalid_sparse_labels(self):
def _train_op_fn(loss):
del loss
return control_flow_ops.no_op()
spec = head.create_estimator_spec(
features={},
mode=ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn,
trainable_variables=[
variables.Variable([1.0, 2.0], dtype=dtypes.float32)])
with self.cached_session() as sess:
test_lib._initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'labels must be an integer SparseTensor with values in \[0, 2\)'):
spec = head.create_estimator_spec(
features={},
mode=ModeKeys.TRAIN,
logits=logits,
labels=labels,
train_op_fn=_train_op_fn,
trainable_variables=[
variables.Variable([1.0, 2.0], dtype=dtypes.float32)])
test_lib._initialize_variables(self, spec.scaffold)
sess.run(spec.loss)

def _test_train(self, head, logits, labels, expected_loss):
Expand Down

0 comments on commit f010305

Please sign in to comment.