-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_ccl_rcnn_crf.sh
96 lines (83 loc) · 3.73 KB
/
train_ccl_rcnn_crf.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env bash
traindb(){
local fold=$1
python3 -m deepbond train \
--seed 42 \
--gpu-id 0 \
--output-dir "runs/test-cinderela-togo/" \
--save "saved-models/test-cinderela-togo/" \
--print-parameters-per-layer \
--final-report \
\
--train-path "data/folds/CCL-A/${fold}/train/" \
--dev-path "data/folds/CCL-A/${fold}/test/" \
--punctuations ".?!"\
--max-length 9999999 \
--min-length 0 \
\
--vocab-size 9999999 \
--vocab-min-frequency 1 \
--keep-rare-with-vectors \
--add-embeddings-vocab \
\
--emb-dropout 0.0 \
--embeddings-format "text" \
--embeddings-path "data/embs/word2vec/pt_word2vec_sg_600.small.pickle.emb" \
--embeddings-binary \
--freeze-embeddings \
\
--model rcnn_crf \
\
--conv-size 100 \
--kernel-size 7 \
--pool-length 3 \
--cnn-dropout 0.0 \
\
--rnn-type rnn \
--hidden-size 100 \
--bidirectional \
--sum-bidir \
--rnn-dropout 0.5 \
\
--loss-weights "balanced" \
--train-batch-size 1 \
--dev-batch-size 1 \
--epochs 40 \
--optimizer "adamw" \
--learning-rate 0.001 \
--weight-decay 0.01 \
--save-best-only \
--early-stopping-patience 10 \
--restore-best-model
}
predictdb(){
local fold=$1
python3 -m deepbond predict \
--gpu-id 0 \
--prediction-type classes \
--load "saved-models/test-cinderela-togo/" \
--test-path "data/folds/CCL-A/${fold}/test/" \
--output-dir "data/folds/CCL-A/${fold}/pred/" \
--train-batch-size 1 \
--dev-batch-size 1
}
###################################
# Train and predict for each fold #
###################################
traindb 0
predictdb 0
python3 scripts/join_original_text_with_predicted_labels.py data/folds/CCL-A/0/test/ data/folds/CCL-A/0/pred/predictions
traindb 1
predictdb 1
python3 scripts/join_original_text_with_predicted_labels.py data/folds/CCL-A/1/test/ data/folds/CCL-A/1/pred/predictions
traindb 2
predictdb 2
python3 scripts/join_original_text_with_predicted_labels.py data/folds/CCL-A/2/test/ data/folds/CCL-A/2/pred/predictions
traindb 3
predictdb 3
python3 scripts/join_original_text_with_predicted_labels.py data/folds/CCL-A/3/test/ data/folds/CCL-A/3/pred/predictions
traindb 4
predictdb 4
python3 scripts/join_original_text_with_predicted_labels.py data/folds/CCL-A/4/test/ data/folds/CCL-A/4/pred/predictions
# Error analysis
python3 scripts/error_analysis.py "data/folds/CCL-A/*/test/*" "data/folds/CCL-A/*/pred/predictions/*"