Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combining readout and teacher_force #61

Open
temporaer opened this issue May 15, 2017 · 1 comment
Open

combining readout and teacher_force #61

temporaer opened this issue May 15, 2017 · 1 comment
Labels

Comments

@temporaer
Copy link

temporaer commented May 15, 2017

i cannot find a way of combining the two examples from the documentation.
E.g., i want to decode a 13-dimensional state vector s0 into a sequence of length 11. Each element of the sequence is a softmax over a vocabulary of 3 words. I'm using batch size of 7. This is what I've got, but it fails in fit due to a shape problem:

import numpy as np
from keras.layers import Input, Dense
from keras.engine import Model
from recurrentshop import RecurrentModel, RecurrentSequential
from recurrentshop.cells import GRUCell

rnn = RecurrentSequential(decode=True, output_length=11, readout='readout_only',
          teacher_force=True, return_sequences=True)
rnn.add(GRUCell(13, input_dim=3))
rnn.add(Dense(3, activation='softmax', input_dim=13))

x = Input((3,), name='x')
y0 = Input((3,), name='y0')
s0 = Input((13,), name='s0')
yt = Input((11, 3), name='yt')
y = rnn(s0, ground_truth=yt, initial_readout=y0)

model = Model([x, y0, yt, s0], y)
model.compile('sgd', 'categorical_crossentropy')

npyt = np.ones((7, 11, 3))
npx = np.zeros((7, 3))
npy0 = np.ones((7, 3))
nps0 = np.ones((7, 13))
model.fit([npx, npy0, npyt, nps0], npyt)
y2 = model.predict([npx, npy0, npyt, nps0])

i also tried passing s0 in the initial_state parameter of rnn, but that breaks even earlier (pop from empty inputs_list in the rnn(...) call.

@TrentBrick
Copy link

I am having problems with this too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants