-
Notifications
You must be signed in to change notification settings - Fork 439
New TimeseriesGenerator #7
Changes from all commits
8514b36
9a0a403
2308f06
aa573f9
5e4b35e
04b019d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,10 +8,12 @@ | |
import numpy as np | ||
import random | ||
from six.moves import range | ||
from math import ceil | ||
|
||
from . import get_keras_submodule | ||
|
||
keras_utils = get_keras_submodule('utils') | ||
from keras.utils.data_utils import Sequence | ||
|
||
|
||
def pad_sequences(sequences, maxlen=None, dtype='int32', | ||
|
@@ -213,7 +215,7 @@ def skipgrams(sequence, vocabulary_size, | |
random.shuffle(words) | ||
|
||
couples += [[words[i % len(words)], | ||
random.randint(1, vocabulary_size - 1)] | ||
random.randint(1, vocabulary_size - 1)] | ||
for i in range(num_negative_samples)] | ||
if categorical: | ||
labels += [[1, 0]] * num_negative_samples | ||
|
@@ -250,121 +252,245 @@ def _remove_long_seq(maxlen, seq, label): | |
return new_seq, new_label | ||
|
||
|
||
class TimeseriesGenerator(keras_utils.Sequence): | ||
class TimeseriesGenerator(Sequence): | ||
"""Utility class for generating batches of temporal data. | ||
|
||
This class takes in a sequence of data-points gathered at | ||
equal intervals, along with time series parameters such as | ||
stride, length of history, etc., to produce batches for | ||
training/validation. | ||
|
||
# Arguments | ||
data: Indexable generator (such as list or Numpy array) | ||
containing consecutive data points (timesteps). | ||
The data should be at 2D, and axis 0 is expected | ||
to be the time dimension. | ||
The data should be convertible into a 1D numpy array, | ||
if 2D or more, axis 0 is expected to be the time dimension. | ||
targets: Targets corresponding to timesteps in `data`. | ||
It should have same length as `data`. | ||
length: Length of the output sequences (in number of timesteps). | ||
It should have at least the same length as `data`. | ||
length: length of the output sub-sequence before sampling | ||
(depreciated, use hlength instead). | ||
sampling_rate: Period between successive individual timesteps | ||
within sequences. For rate `r`, timesteps | ||
`data[i]`, `data[i-r]`, ... `data[i - length]` | ||
are used for create a sample sequence. | ||
within sequences, `length` has to be a multiple of `sampling_rate`. | ||
stride: Period between successive output sequences. | ||
For stride `s`, consecutive output samples would | ||
be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. | ||
start_index: Data points earlier than `start_index` will not be used | ||
in the output sequences. This is useful to reserve part of the | ||
data for test or validation. | ||
end_index: Data points later than `end_index` will not be used | ||
in the output sequences. This is useful to reserve part of the | ||
data for test or validation. | ||
start_index, end_index: Data points earlier than `start_index` | ||
or later than `end_index` will not be used in the output sequences. | ||
This is useful to reserve part of the data for test or validation. | ||
shuffle: Whether to shuffle output samples, | ||
or instead draw them in chronological order. | ||
reverse: Boolean: if `true`, timesteps in each output sample will be | ||
reverse: Boolean: if `True`, timesteps in each output sample will be | ||
in reverse chronological order. | ||
batch_size: Number of timeseries samples in each batch | ||
(except maybe the last one). | ||
batch_size: Number of timeseries samples in each batch. | ||
hlength: Effective "history" length of the output sub-sequences after | ||
sampling (in number of timesteps). | ||
gap: prediction gap, i.e. numer of timesteps ahead (usually zero, or | ||
same as samplig_rate) | ||
`x=data[i - (hlength-1)*sampling_rate - gap:i-gap+1:sampling_rate]` | ||
and `y=targets[i]` | ||
are used respectively as sample sequence `x` and target value `y`. | ||
target_seq: Boolean: if 'True', produces full shifted sequence targets: | ||
If target_seq is set, for sampling rate `r`, timesteps | ||
`data[i - (hlength-1)*r - gap]`, ..., `data[i-r-gap]`, `data[i-gap]` | ||
and | ||
`targets[i - (hlength-1)*r]`, ..., `data[i-r]`, `data[i]` | ||
are used respectively as sample sequence `x` and target sequence `y`. | ||
dtype: force sample/target dtype (default is None) | ||
stateful: helper to check if parameters are valid for stateful learning | ||
(experimental). | ||
|
||
|
||
# Returns | ||
A [Sequence](/utils/#sequence) instance. | ||
A [Sequence](/utils/#sequence) instance of tuples (x,y) | ||
where x is a numpy array of shape (batch_size, hlength, ...) | ||
and y is a numpy array of shape (batch_size, ...) if target_seq is `False` | ||
or (batch_size, hlength, ...) if target_seq is `True`. | ||
If not specified, output dtype is infered from data dtype. | ||
|
||
# Examples | ||
|
||
```python | ||
from keras.preprocessing.sequence import TimeseriesGenerator | ||
import numpy as np | ||
|
||
data = np.array([[i] for i in range(50)]) | ||
targets = np.array([[i] for i in range(50)]) | ||
|
||
data_gen = TimeseriesGenerator(data, targets, | ||
length=10, sampling_rate=2, | ||
batch_size=2) | ||
assert len(data_gen) == 20 | ||
|
||
batch_0 = data_gen[0] | ||
x, y = batch_0 | ||
assert np.array_equal(x, | ||
np.array([[[0], [2], [4], [6], [8]], | ||
[[1], [3], [5], [7], [9]]])) | ||
assert np.array_equal(y, | ||
np.array([[10], [11]])) | ||
txt = bytearray("Keras is simple.", 'utf-8') | ||
data_gen = TimeseriesGenerator(txt, txt, hlength=10, batch_size=1, gap=1) | ||
|
||
for i in range(len(data_gen)): | ||
print(data_gen[i][0].tostring(), "->'%s'" % data_gen[i][1].tostring()) | ||
|
||
assert data_gen[-1][0].shape == (1, 10) and data_gen[-1][1].shape == (1,) | ||
assert data_gen[-1][0].tostring() == u" is simple" | ||
assert data_gen[-1][1].tostring() == u"." | ||
|
||
t = np.linspace(0,20*np.pi, num=1000) # time | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. space after "," |
||
x = np.sin(np.cos(3*t)) # input signa | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. signal* |
||
y = np.sin(np.cos(6*t+4)) # output signal | ||
|
||
# define recurrent model | ||
from keras.models import Model | ||
from keras.layers import Input, SimpleRNN, LSTM, GRU,Dense | ||
|
||
inputs = Input(batch_shape=(None, None, 1)) | ||
l = SimpleRNN(100, return_sequences=True)(inputs) | ||
l = Dense(100, activation='tanh')(l) | ||
preds = Dense(1, activation='linear')(l) | ||
model = Model(inputs=inputs, outputs=preds) | ||
model.compile(loss='mean_squared_error', optimizer='Nadam') | ||
|
||
# fit model to sequence | ||
xx = np.expand_dims(x, axis=-1) | ||
g = TimeseriesGenerator(xx, y, hlength=100, target_seq=True, shuffle=True) | ||
model.fit_generator(g, steps_per_epoch=len(g), epochs=20, shuffle=True) | ||
|
||
# plot prediction | ||
x2 = np.reshape(x,(1,x.shape[0],1)) | ||
z = model.predict(x2) | ||
|
||
import matplotlib.pyplot as plt | ||
plt.figure(figsize=(12,12)) | ||
plt.title('Phase representation') | ||
plt.plot(x,y.flatten(), color='black') | ||
plt.plot(x,z.flatten(), dashes=[8,1], label='prediction', color='orange') | ||
plt.xlabel('input') | ||
plt.ylabel('output') | ||
plt.grid() | ||
plt.show() | ||
|
||
``` | ||
""" | ||
|
||
def __init__(self, data, targets, length, | ||
def __init__(self, data, targets, length=None, | ||
sampling_rate=1, | ||
stride=1, | ||
start_index=0, | ||
end_index=None, | ||
start_index=0, end_index=None, | ||
shuffle=False, | ||
reverse=False, | ||
batch_size=128): | ||
self.data = data | ||
self.targets = targets | ||
self.length = length | ||
batch_size=128, | ||
hlength=None, | ||
target_seq=False, | ||
gap=0, | ||
dtype=None, | ||
stateful=False): | ||
|
||
# Sanity check | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove stray newline |
||
if sampling_rate <= 0: | ||
raise ValueError('`sampling_rate` must be strictly positive.') | ||
if stride <= 0: | ||
raise ValueError('`stride` must be strictly positive.') | ||
if batch_size <= 0: | ||
raise ValueError('`batch_size` must be strictly positive.') | ||
if len(data) > len(targets): | ||
raise ValueError('`targets` has to be at least as long as `data`.') | ||
|
||
if hlength is None: | ||
if length % sampling_rate != 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a DeprecationWarning |
||
raise ValueError( | ||
"`length` has to be a multiple of `sampling_rate`." | ||
" For instance, `length=%i` would do." % (2 * sampling_rate)) | ||
hlength = length // sampling_rate | ||
|
||
if gap % sampling_rate != 0: | ||
warnings.warn( | ||
"Unless you know what you do, `gap` should be zero or" | ||
" a multiple of `sampling_rate`.", UserWarning) | ||
|
||
self.hlength = hlength | ||
assert self.hlength > 0 | ||
|
||
self.data = np.asarray(data) | ||
self.targets = np.asarray(targets) | ||
|
||
# FIXME: targets must be 2D for sequences output | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a new limitation or was it always like that? |
||
if target_seq and len(self.targets.shape) < 2: | ||
self.targets = np.expand_dims(self.targets, axis=-1) | ||
|
||
if dtype is None: | ||
self.data_type = self.data.dtype | ||
self.targets_type = self.targets.dtype | ||
else: | ||
self.data_type = dtype | ||
self.targets_type = dtype | ||
|
||
# Check if parameters are stateful-compatible | ||
if stateful: | ||
if shuffle: | ||
raise ValueError('Do not shuffle for stateful learning.') | ||
if self.hlength % batch_size != 0: | ||
raise ValueError("For stateful learning, `hlength` has to be" | ||
"a multiple of `batch_size`." | ||
"For instance, `hlength=%i` would do." | ||
% (3 * batch_size)) | ||
if stride != (self.hlength // batch_size) * sampling_rate: | ||
raise ValueError( | ||
'`stride=%i`, for these parameters set `stride=%i`.' | ||
% (stride, (hlength // batch_size) * sampling_rate)) | ||
|
||
self.sampling_rate = sampling_rate | ||
self.batch_size = batch_size | ||
assert stride > 0 | ||
self.stride = stride | ||
self.start_index = start_index + length | ||
self.gap = gap | ||
|
||
sliding_win_size = (self.hlength - 1) * sampling_rate + gap | ||
self.start_index = start_index + sliding_win_size | ||
if end_index is None: | ||
end_index = len(data) - 1 | ||
end_index = len(data) | ||
assert end_index <= len(data) | ||
self.end_index = end_index | ||
self.shuffle = shuffle | ||
self.reverse = reverse | ||
self.batch_size = batch_size | ||
self.target_seq = target_seq | ||
|
||
self.len = int(ceil(float(self.end_index - self.start_index) / | ||
(self.batch_size * self.stride))) | ||
if self.len <= 0: | ||
err = "This configuration gives no output, try with a longer" | ||
" input sequence or different parameters." | ||
raise ValueError(err) | ||
|
||
assert self.len > 0 | ||
|
||
if self.start_index > self.end_index: | ||
raise ValueError('`start_index+length=%i > end_index=%i` ' | ||
'is disallowed, as no part of the sequence ' | ||
'would be left to be used as current step.' | ||
% (self.start_index, self.end_index)) | ||
self.perm = np.arange(self.start_index, self.end_index) | ||
if shuffle: | ||
np.random.shuffle(self.perm) | ||
|
||
def __len__(self): | ||
return (self.end_index - self.start_index + | ||
self.batch_size * self.stride) // (self.batch_size * self.stride) | ||
return self.len | ||
|
||
def _empty_batch(self, num_rows): | ||
samples_shape = [num_rows, self.length // self.sampling_rate] | ||
samples_shape = [num_rows, self.hlength] | ||
samples_shape.extend(self.data.shape[1:]) | ||
targets_shape = [num_rows] | ||
if self.target_seq: | ||
targets_shape = [num_rows, self.hlength] | ||
else: | ||
targets_shape = [num_rows] | ||
targets_shape.extend(self.targets.shape[1:]) | ||
return np.empty(samples_shape), np.empty(targets_shape) | ||
|
||
return np.empty(samples_shape, dtype=self.data_type), np.empty( | ||
targets_shape, dtype=self.targets_type) | ||
|
||
def __getitem__(self, index): | ||
if self.shuffle: | ||
rows = np.random.randint( | ||
self.start_index, self.end_index + 1, size=self.batch_size) | ||
else: | ||
i = self.start_index + self.batch_size * self.stride * index | ||
rows = np.arange(i, min(i + self.batch_size * | ||
self.stride, self.end_index + 1), self.stride) | ||
while index < 0: | ||
index += self.len | ||
assert index < self.len | ||
batch_start = self.batch_size * self.stride * index | ||
rows = np.arange(batch_start, min(batch_start + self.batch_size * | ||
self.stride, | ||
self.end_index - self.start_index), | ||
self.stride) | ||
rows = self.perm[rows] | ||
|
||
samples, targets = self._empty_batch(len(rows)) | ||
for j, row in enumerate(rows): | ||
indices = range(rows[j] - self.length, rows[j], self.sampling_rate) | ||
indices = range(rows[j] - self.gap - | ||
(self.hlength - 1) * self.sampling_rate, | ||
rows[j] - self.gap + 1, self.sampling_rate) | ||
samples[j] = self.data[indices] | ||
targets[j] = self.targets[rows[j]] | ||
if self.target_seq: | ||
shifted_indices = range(rows[j] - (self.hlength - 1) * | ||
self.sampling_rate, | ||
rows[j] + 1, self.sampling_rate) | ||
targets[j] = self.targets[shifted_indices] | ||
else: | ||
targets[j] = self.targets[rows[j]] | ||
if self.reverse: | ||
return samples[:, ::-1, ...], targets | ||
return samples, targets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Might make this a little easier to understand if we say "and* axis 0"