-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_rep_copy.lua
76 lines (55 loc) · 1.63 KB
/
load_rep_copy.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
require 'nn'
require 'ntm'
require 'gnuplot'
local tasks = require 'tasks'
local ntm_params = {
input_size = 7,
output_size = 7,
mem_locations = 128,
mem_location_size = 20,
hidden_state_size = 100,
allowed_shifts = {-1,0,1}
}
local ntm = nn.NTM(ntm_params)
-- Model trained with zeros as target for the input phase.
-- input size : 7, output size : 7, n memeory slots : 128
local loaded_params = torch.load('parameters/rep_cop/25000-0.00004.params')
local ntm_p, ntm_g = ntm:getParameters()
ntm_p:copy(loaded_params)
local min_seq_len = 5
local max_seq_len = 5
local crit = nn.BCECriterion()
for i = min_seq_len, max_seq_len do
local seq_len = i
local n_repeat = 5
local inputs, targets, exp_out = tasks.generate_repeat_copy_sequence(seq_len, ntm_params.input_size, n_repeat, false)
local out = torch.Tensor(targets:size())
local err = 0
local n_out = 0
for j=1,inputs:size(1) do
out[j] = ntm:forward(inputs[j])
if exp_out[j] then
err = err + crit:forward(out[j], targets[j])
n_out = n_out + 1
end
end
err = err / n_out
io.write('Input : \n')
print(inputs)
io.write('Output : \n')
print(out)
local diff = (targets - out):abs()
io.write('Error : \n')
print (diff)
gnuplot.figure(1)
gnuplot.imagesc(inputs:t(),'color')
gnuplot.figure(2)
gnuplot.imagesc(out:t(),'color')
gnuplot.figure(3)
gnuplot.imagesc(diff:t(),'color')
io.write('Sequence length\t\tError\n')
local str_format = '%d\t\t\t%f\n'
io.write(str_format:format(seq_len, err))
io.flush()
ntm:new_sequence()
end