-
Notifications
You must be signed in to change notification settings - Fork 6
/
dataset.py
29 lines (26 loc) · 998 Bytes
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf
def create(filepath, batch_size=1, repeat=False, buffsize=1000):
def _parse(record):
keys_to_features = {
'uid': tf.FixedLenFeature([], tf.string),
'audio/data': tf.VarLenFeature(tf.float32),
'audio/shape': tf.VarLenFeature(tf.int64),
'text': tf.VarLenFeature(tf.int64)
}
features = tf.parse_single_example(
record,
features=keys_to_features
)
audio = features['audio/data'].values
shape = features['audio/shape'].values
audio = tf.reshape(audio, shape)
audio = tf.contrib.layers.dense_to_sparse(audio)
text = features['text']
return audio, text, shape[0], features['uid']
dataset = tf.data.TFRecordDataset(filepath).map(_parse).batch(batch_size=batch_size)
if buffsize > 0:
dataset = dataset.shuffle(buffer_size=buffsize)
if repeat:
dataset = dataset.repeat()
iterator = dataset.make_initializable_iterator()
return tuple(list(iterator.get_next()) + [iterator.initializer])