Skip to content

Commit e2b36c5

Browse files
mrrytensorflower-gardener
authored andcommitted
[tf.data] Remove Dataset.make_one_shot_iterator() from the V2 API.
Add `tf.compat.v1.data.make_one_shot_iterator(dataset)` to enable the use of V2 `Dataset` objects in a legacy V1 pipeline. PiperOrigin-RevId: 223896743
1 parent 81eaf1d commit e2b36c5

File tree

88 files changed

+402
-411
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+402
-411
lines changed

tensorflow/contrib/data/python/ops/readers.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def read_batch_features(file_pattern,
355355
shuffle=randomize_input,
356356
num_epochs=num_epochs,
357357
shuffle_buffer_size=capacity)
358-
iterator = dataset.make_one_shot_iterator()
358+
iterator = dataset_ops.make_one_shot_iterator(dataset)
359359
outputs = iterator.get_next()
360360
return outputs
361361

@@ -379,15 +379,13 @@ def __init__(self, filenames):
379379
(key value) pairs sequentially.
380380
For example:
381381
```python
382+
tf.enable_eager_execution()
383+
382384
dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb")
383-
iterator = dataset.make_one_shot_iterator()
384-
next_element = iterator.get_next()
385+
385386
# Prints the (key, value) pairs inside a lmdb file.
386-
while True:
387-
try:
388-
print(sess.run(next_element))
389-
except tf.errors.OutOfRangeError:
390-
break
387+
for key, value in dataset:
388+
print(key, value)
391389
```
392390
Args:
393391
filenames: A `tf.string` tensor containing one or more filenames.

tensorflow/contrib/eager/python/evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from tensorflow.contrib.eager.python import datasets
2424
from tensorflow.contrib.eager.python import metrics
25+
from tensorflow.python.data.ops import dataset_ops
2526
from tensorflow.python.eager import context
2627
from tensorflow.python.eager import function
2728
from tensorflow.python.framework import errors_impl
@@ -164,8 +165,8 @@ def evaluate_on_dataset(self, dataset, *args, **kwargs):
164165
self.__call__(example, *args, **kwargs)
165166
return self.all_metric_results(summary_logdir)
166167
# Graph construction
167-
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args,
168-
**kwargs)
168+
call_op = self.__call__(
169+
dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs)
169170
init_op = self.init_variables()
170171
results_op = self.all_metric_results(summary_logdir)
171172
return (init_op, call_op, results_op)

tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def benchmark_graph_train(self):
119119
with tf.Graph().as_default():
120120
np_images, np_labels = random_batch(batch_size)
121121
dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat()
122-
(images, labels) = dataset.make_one_shot_iterator().get_next()
122+
(images, labels) = tf.compat.v1.data.make_one_shot_iterator(
123+
dataset).get_next()
123124

124125
model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks,
125126
self.output_classes,

tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def _create_graph(self, batch_size):
4242
# Generate some random data.
4343
images_data = np.random.randn(batch_size, 784).astype(np.float32)
4444
dataset = tf.data.Dataset.from_tensors(images_data)
45-
images = dataset.repeat().make_one_shot_iterator().get_next()
45+
images = tf.compat.v1.data.make_one_shot_iterator(
46+
dataset.repeat()).get_next()
4647

4748
# Create the models and optimizers
4849
generator = mnist.Generator(data_format())

tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@
470470
"\n",
471471
" if epoch % 1 == 0:\n",
472472
" loss = tfe.metrics.Mean()\n",
473-
" for test_x in test_dataset.make_one_shot_iterator():\n",
473+
" for test_x in test_dataset:\n",
474474
" loss(compute_loss(model, test_x))\n",
475475
" elbo = -loss.result()\n",
476476
" display.clear_output(wait=False)\n",

tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def benchmark_graph_train(self):
142142
with tf.Graph().as_default():
143143
np_images, np_labels = random_batch(batch_size)
144144
dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat()
145-
(images, labels) = dataset.make_one_shot_iterator().get_next()
145+
images, labels = tf.compat.v1.data.make_one_shot_iterator(
146+
dataset).get_next()
146147

147148
model = resnet50.ResNet50(data_format())
148149
logits = model(images, training=True)

tensorflow/contrib/eager/python/examples/revnet/main.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,11 @@ def main(_):
7272
train_one_iter(model, x, y, optimizer, global_step=global_step)
7373

7474
if global_step.numpy() % config.log_every == 0:
75-
it_test = ds_test.make_one_shot_iterator()
76-
acc_test, loss_test = evaluate(model, it_test)
75+
acc_test, loss_test = evaluate(model, ds_test)
7776

7877
if FLAGS.validate:
79-
it_train = ds_train_one_shot.make_one_shot_iterator()
80-
it_validation = ds_validation.make_one_shot_iterator()
81-
acc_train, loss_train = evaluate(model, it_train)
82-
acc_validation, loss_validation = evaluate(model, it_validation)
78+
acc_train, loss_train = evaluate(model, ds_train_one_shot)
79+
acc_validation, loss_validation = evaluate(model, ds_validation)
8380
print("Iter {}, "
8481
"training set accuracy {:.4f}, loss {:.4f}; "
8582
"validation set accuracy {:.4f}, loss {:.4f}; "
@@ -218,11 +215,11 @@ def train_one_iter(model, inputs, labels, optimizer, global_step=None):
218215
return logits, loss
219216

220217

221-
def evaluate(model, iterator):
218+
def evaluate(model, dataset):
222219
"""Compute accuracy with the given dataset iterator."""
223220
mean_loss = tfe.metrics.Mean()
224221
accuracy = tfe.metrics.Accuracy()
225-
for x, y in iterator:
222+
for x, y in dataset:
226223
logits, _ = model(x, training=False)
227224
loss = model.compute_loss(logits=logits, labels=y)
228225
accuracy(

tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _benchmark_apply(self, label, model):
8282
tf.ones(
8383
[PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE],
8484
dtype=tf.int64)).repeat(num_iters + num_warmup)
85-
inputs = dataset.make_one_shot_iterator().get_next()
85+
inputs = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
8686

8787
with tf.device(tf.test.gpu_device_name()):
8888
outputs = model(inputs, training=True)
@@ -124,7 +124,8 @@ def _benchmark_train(self, label, model):
124124
dtype=tf.int64)).repeat(num_iters + num_warmup)
125125
# inputs and labels have the same shape
126126
dataset = tf.data.Dataset.zip((dataset, dataset))
127-
(inputs, labels) = dataset.make_one_shot_iterator().get_next()
127+
(inputs, labels) = tf.compat.v1.data.make_one_shot_iterator(
128+
dataset).get_next()
128129

129130
with tf.device(tf.test.gpu_device_name()):
130131
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)

tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,12 @@ def __init__(self, filenames):
4040
For example:
4141
4242
```python
43+
tf.enable_eager_execution()
44+
4345
dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq")
44-
iterator = dataset.make_one_shot_iterator()
45-
next_element = iterator.get_next()
4646
# Prints the (key, value) pairs inside a hadoop sequence file.
47-
while True:
48-
try:
49-
print(sess.run(next_element))
50-
except tf.errors.OutOfRangeError:
51-
break
47+
for key, value in dataset:
48+
print(key, value)
5249
```
5350
5451
Args:

tensorflow/contrib/ignite/README.md

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,12 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL
5454
```python
5555
>>> import tensorflow as tf
5656
>>> from tensorflow.contrib.ignite import IgniteDataset
57-
>>>
57+
>>> tf.enable_eager_execution()
58+
>>>
5859
>>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE")
59-
>>> iterator = dataset.make_one_shot_iterator()
60-
>>> next_obj = iterator.get_next()
6160
>>>
62-
>>> with tf.Session() as sess:
63-
>>> for _ in range(3):
64-
>>> print(sess.run(next_obj))
61+
>>> for element in dataset:
62+
>>> print(element)
6563

6664
{'key': 1, 'val': {'NAME': b'WARM KITTY'}}
6765
{'key': 2, 'val': {'NAME': b'SOFT KITTY'}}
@@ -74,23 +72,22 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL
7472
```python
7573
>>> import tensorflow as tf
7674
>>> from tensorflow.contrib.ignite import IgniteDataset
77-
>>>
75+
>>> tf.enable_eager_execution()
76+
>>>
7877
>>> dataset = IgniteDataset(cache_name="IMAGES")
79-
>>> iterator = dataset.make_one_shot_iterator()
80-
>>> next_obj = iterator.get_next()
8178
>>>
82-
>>> with tf.Session() as sess:
83-
>>> print(sess.run(next_obj))
79+
>>> for element in dataset.take(1):
80+
>>> print(element)
8481

8582
{
86-
'key': 'kitten.png',
83+
'key': 'kitten.png',
8784
'val': {
8885
'metadata': {
8986
'file_name': b'kitten.png',
9087
'label': b'little ball of fur',
91-
width: 800,
88+
width: 800,
9289
height: 600
93-
},
90+
},
9491
'pixels': [0, 0, 0, 0, ..., 0]
9592
}
9693
}
@@ -100,13 +97,11 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL
10097
```python
10198
>>> import tensorflow as tf
10299
>>> from tensorflow.contrib.ignite import IgniteDataset
103-
>>>
100+
>>>
104101
>>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels'])
105-
>>> iterator = dataset.make_one_shot_iterator()
106-
>>> next_obj = iterator.get_next()
107102
>>>
108-
>>> with tf.Session() as sess:
109-
>>> print(sess.run(next_obj))
103+
>>> for element in dataset:
104+
>>> print(element)
110105

111106
[0, 0, 0, 0, ..., 0]
112107
```
@@ -126,26 +121,26 @@ Ignite Dataset allows using these two aspects of distributed neural network trai
126121
```python
127122
>>> import tensorflow as tf
128123
>>> from tensorflow.contrib.ignite import IgniteDataset
129-
>>>
124+
>>>
130125
>>> dataset = IgniteDataset("IMAGES")
131126
>>>
132127
>>> # Compute gradients locally on every worker node.
133-
>>> gradients = []
128+
>>> gradients = []
134129
>>> for i in range(5):
135130
>>> with tf.device("/job:WORKER/task:%d" % i):
136-
>>> device_iterator = dataset.make_one_shot_iterator()
131+
>>> device_iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
137132
>>> device_next_obj = device_iterator.get_next()
138133
>>> gradient = compute_gradient(device_next_obj)
139-
>>> gradients.append(gradient)
140-
>>>
134+
>>> gradients.append(gradient)
135+
>>>
141136
>>> # Aggregate them on master node.
142137
>>> result_gradient = tf.reduce_sum(gradients)
143138
>>>
144139
>>> with tf.Session("grpc://localhost:10000") as sess:
145140
>>> print(sess.run(result_gradient))
146141
```
147142

148-
High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well.
143+
High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well.
149144

150145
### Distributed File System
151146

0 commit comments

Comments
 (0)