Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Y. Zhou committed Jul 15, 2016
1 parent 05e06b3 commit 01a7e10
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
4 changes: 4 additions & 0 deletions main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,14 @@ function run(tr, n_epoches, dset_train, dset_dev, dset_test)
header('Evaluating on test set')
printf('-- using model with dev score = %.4f\n', best_score)
local test_preds = best_trainer:eval(dset_test)
local flag = false
if tr.task == 'SICK' then
local pearson_score = stats.pearson(test_preds, dset_test.labels)
local spearman_score = stats.spearmanr(test_preds, dset_test.labels)
local mse_score = stats.mse(test_preds, dset_test.labels)
printf('-- Test pearson = %.4f, spearmanr = %.4f, mse = %.4f \n',
pearson_score, spearman_score, mse_score)
if pearson_score > 0.87 then flag = true end
elseif tr.task == 'MSRP' then
local accuracy = stats.accuracy(test_preds, dset_test.labels)
local f1 = stats.f1(test_preds, dset_test.labels)
Expand All @@ -116,9 +118,11 @@ function run(tr, n_epoches, dset_train, dset_dev, dset_test)
local accuracy = stats.accuracy(test_preds, dset_test.labels)
printf('-- Test accuracy = %.4f \n', accuracy)
end
if flag then
print('save parameters')
local path = 'data/params/params-' .. tr.task .. '-' .. tr.structure .. '.t7'
best_trainer:save(path)
end
end

function test(ts, dset_test)
Expand Down
6 changes: 1 addition & 5 deletions models/AttTreeGRU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ function AttTreeGRU:__init(config)
self.output_module = self:new_output_module()
self.output_modules = {}

-- task
self.task = config.task or true
end

function AttTreeGRU:new_composer()
Expand All @@ -34,9 +32,7 @@ function AttTreeGRU:new_composer()
nn.Linear(self.mem_dim, self.mem_dim, false)(prev_res)
})
local temp = nn.Linear(self.mem_dim, 1)(M)
local attention_weights = (task == true)
and nn.Transpose({1,2})(nn.SoftMax()(nn.Transpose({1,2})(temp)))
or nn.Transpose({1,2})(nn.Sigmoid()(nn.Transpose({1,2})(temp)))
local attention_weights = nn.Transpose({1,2})(nn.SoftMax()(nn.Transpose({1,2})(temp)))
local child_h_att = nn.MM(true, false)({ child_h, attention_weights })
local child_h_sum = nn.Reshape(self.mem_dim)(child_h_att)

Expand Down
6 changes: 1 addition & 5 deletions models/AttTreeLSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ function AttTreeLSTM:__init(config)
self.output_module = self:new_output_module()
self.output_modules = {}

-- task
self.task = config.task or true
end

function AttTreeLSTM:new_composer()
Expand All @@ -37,9 +35,7 @@ function AttTreeLSTM:new_composer()
nn.Linear(self.mem_dim, self.mem_dim, false)(atte_h)
})
local temp = nn.Linear(self.mem_dim, 1)(M)
local attention_weights = (task == true)
and nn.Transpose({1,2})(nn.SoftMax()(nn.Transpose({1,2})(temp)))
or nn.Transpose({1,2})(nn.Sigmoid()(nn.Transpose({1,2})(temp)))
local attention_weights = nn.Transpose({1,2})(nn.SoftMax()(nn.Transpose({1,2})(temp)))
local child_h_att = nn.MM(true, false)({ child_h, attention_weights })
local child_h_sum = nn.Reshape(self.mem_dim)(child_h_att)

Expand Down

0 comments on commit 01a7e10

Please sign in to comment.