-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_copy.lua
84 lines (60 loc) · 2.15 KB
/
load_copy.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
require 'nn'
require 'ntm'
require 'gnuplot'
local tasks = require 'tasks'
local ntm_params = {
input_size = 9,
output_size = 9,
mem_locations = 128,
mem_location_size = 20,
hidden_state_size = 100,
allowed_shifts = {-1,0,1}
}
local ntm = nn.NTM(ntm_params)
-- Parameters of a model trained with sequences of length 25.
-- Trained with zeros as targets for input phase :
-- input size : 9, output size : 9, n memeory slots : 128
local loaded_params = torch.load('parameters/copy/copy_force/24.01.2017_14:26:31_len=1-20_lr=0.0001/25000-0.00001.params')
-- Trained without specific output target for input phase :
-- input size : 9, output size : 9, n memeory slots : 128
-- local loaded_params = torch.load('parameters/copy/copy_no_force/24.01.2017_12:32:00_len=1-20_lr=0.0001/25000-0.00002.params')
-- Smaller model : input size : 7, output size : 7, n memeory slots : 10
-- local loaded_params = torch.load('parameters/copy/toy/copy_force=false_seed=1/29.01.2017_10:58:15_len=1-7_lr=0.0001/25000-0.00002.params')
local ntm_p, ntm_g = ntm:getParameters()
ntm_p:copy(loaded_params)
local min_seq_len = 37
local max_seq_len = 37
local crit = nn.BCECriterion()
for i = min_seq_len, max_seq_len do
local seq_len = i
local inputs, targets, exp_out = tasks.generate_copy_sequence(seq_len, ntm_params.input_size, false)
local out = torch.Tensor(targets:size())
local err = 0
local n_out = 0
for j=1,inputs:size(1) do
out[j] = ntm:forward(inputs[j])
if exp_out[j] then
err = err + crit:forward(out[j], targets[j])
n_out = n_out + 1
end
end
err = err / n_out
io.write('Input : \n')
print(inputs)
io.write('Output : \n')
print(out)
local diff = (targets - out):abs()
io.write('Error : \n')
print (diff)
gnuplot.figure(1)
gnuplot.imagesc(inputs:t(),'color')
gnuplot.figure(2)
gnuplot.imagesc(out:t(),'color')
gnuplot.figure(3)
gnuplot.imagesc(diff:t(),'color')
io.write('Sequence length\t\tError\n')
local str_format = '%d\t\t\t%f\n'
io.write(str_format:format(seq_len, err))
io.flush()
ntm:new_sequence()
end