-
Notifications
You must be signed in to change notification settings - Fork 4
/
config.py
121 lines (104 loc) · 6.24 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class Config:
#################### For BERT fine-tuning ####################
# control
datatype = "mixatis"
data_mode = "multi" #"single" # single or multi intent in data
retrain = False # Reuse trained model weights
test_mode = "data" # "embedding", "validation", "data"
# For zero-shot training/testing
# data: (real_num/ratio)
# semantic: (19/13), (17/9), (15/5), (14/3)
# mixatis: (16/15), (14/12), (13/10), (12/8)
# mixsnips: (6/6), (5/4), (4/2), (3/0)
is_zero_shot = False
real_num = 19
ratio = '13'
# For few-shot training/testing
is_few_shot = False
few_shot_ratio = 0.1
sentence_mode = "one" #"two" # one or two sentence in data
dialog_data_mode = False # for dialogue-wise data (A+B)
#################################
if datatype == "atis":
# atis dataset
train_path = "data/atis/raw_data.pkl"
test_path = "data/atis/raw_data_test.pkl"
dic_path_with_tokens = "data/atis/intent2id_with_tokens.pkl"
embedding_path = "finetune_results/atis_embeddings_with_hidden.pth"
elif datatype == "snips":
# snips dataset
train_path = "data/snips/raw_data_train.pkl"
test_path = "data/snips/raw_data_test.pkl"
dic_path_with_tokens = "data/snips/intent2id_with_tokens.pkl"
elif datatype == "semantic":
# semantic parsing dataset
if not is_zero_shot:
# normal
train_path = "data/semantic/raw_data_se.pkl" if data_mode == "single" else "data/semantic/raw_data_multi_se.pkl"
test_path = "data/semantic/raw_data_multi_se_test.pkl"
dic_path = "data/semantic/intent2id_se.pkl" if data_mode == "single" else "data/semantic/intent2id_multi_se.pkl"
dic_path_with_tokens = "data/semantic/intent2id_multi_se_with_tokens.pkl"
embedding_path = "embeddings/se_embeddings_with_hidden.pth"
else:
# zero-shot/few-shot
train_path = "data/semantic/raw_data_se.pkl" if data_mode == "single" else "data/semantic/zeroshot/raw_data_multi_se_zst_train{}.pkl".format(ratio)
test_path = "data/semantic/zeroshot/raw_data_multi_se_zst_test{}.pkl".format(ratio)
dic_path = "data/semantic/intent2id_se.pkl" if data_mode == "single" else "data/semantic/intent2id_multi_se.pkl"
dic_path_with_tokens = "data/semantic/zeroshot/intent2id_multi_se_with_tokens_zst_train{}.pkl".format(ratio)
dic_path_with_tokens_test = "data/semantic/zeroshot/intent2id_multi_se_with_tokens_zst_test{}.pkl".format(ratio)
elif datatype == "mixatis":
# mix atis dataset
if not is_zero_shot:
train_path = "data/MixATIS_clean/raw_data_multi_ma_train.pkl"
dev_path = "data/MixATIS_clean/raw_data_multi_ma_dev.pkl"
test_path = "data/MixATIS_clean/raw_data_multi_ma_test.pkl"
dic_path_with_tokens = "data/MixATIS_clean/intent2id_multi_ma_with_tokens.pkl"
else:
train_path = "data/MixATIS_clean/zeroshot/raw_data_multi_ma_train{}.pkl".format(ratio)
test_path = "data/MixATIS_clean/zeroshot/raw_data_multi_ma_test{}.pkl".format(ratio)
dic_path_with_tokens = "data/MixATIS_clean/zeroshot/intent2id_multi_ma_with_tokens_train{}.pkl".format(ratio)
dic_path_with_tokens_test = "data/MixATIS_clean/zeroshot/intent2id_multi_ma_with_tokens_test{}.pkl".format(ratio)
embedding_path = "embeddings/mixatis_embeddings_with_hidden.pth"
elif datatype == "mixsnips":
# mix snips dataset
if not is_zero_shot:
train_path = "data/MixSNIPS_clean/raw_data_multi_sn_train.pkl"
dev_path = "data/MixSNIPS_clean/raw_data_multi_sn_dev.pkl"
test_path = "data/MixSNIPS_clean/raw_data_multi_sn_test.pkl"
dic_path_with_tokens = "data/MixSNIPS_clean/intent2id_multi_sn_with_tokens.pkl"
else:
train_path = "data/MixSNIPS_clean/zeroshot/raw_data_multi_sn_train{}.pkl".format(ratio)
test_path = "data/MixSNIPS_clean/zeroshot/raw_data_multi_sn_test{}.pkl".format(ratio)
dic_path_with_tokens = "data/MixSNIPS_clean/zeroshot/intent2id_multi_sn_with_tokens_train{}.pkl".format(ratio)
dic_path_with_tokens_test = "data/MixSNIPS_clean/zeroshot/intent2id_multi_sn_with_tokens_test{}.pkl".format(ratio)
embedding_path = "embeddings/mixsnips_embeddings_with_hidden.pth"
elif datatype == "e2e":
# Microsoft e2e dialogue dataset
train_path = "data/e2e_dialogue/dialogue_data.pkl" if data_mode == "single" else "data/e2e_dialogue/dialogue_data_multi.pkl"
test_path = "data/e2e_dialogue/dialogue_data_multi.pkl"
dic_path = "data/e2e_dialogue/intent2id.pkl" if data_mode == "single" else "data/e2e_dialogue/intent2id_multi.pkl"
dic_path_with_tokens = "data/e2e_dialogue/intent2id_multi_with_tokens.pkl"
embedding_path = "embeddings/e2e_embeddings_with_hidden.pth"
pretrain_path = "data/e2e_dialogue/dialogue_data_pretrain.pkl"
elif datatype == "sgd":
# dstc8-sgd dialogue dataset
train_path = "data/sgd_dialogue/dialogue_data.pkl" if data_mode == "single" else "data/sgd_dialogue/dialogue_data_multi.pkl"
test_path = "data/sgd_dialogue/dialogue_data_multi.pkl"
dic_path = "data/sgd_dialogue/intent2id.pkl" if data_mode == "single" else "data/sgd_dialogue/intent2id_multi.pkl"
dic_path_with_tokens = "data/sgd_dialogue/intent2id_multi_with_tokens.pkl"
embedding_path = "embeddings/sgd_embeddings_with_hidden.pth"
pretrain_path = "data/sgd_dialogue/dialogue_data_pretrain.pkl"
if not is_zero_shot:
model_path = None if not retrain else "checkpoints/best_{}_{}.pth".format(datatype, data_mode)
else:
model_path = None if not retrain else "checkpoints/BEST/best_{}_{}_{}.pth".format(datatype, data_mode, ratio)
# model_path = "checkpoints/best_mixatis_multi.pth"
maxlen = 50 #20
batch_size = 32 #16, 128
epochs = 15 #30, 5
learning_rate_bert = 2e-5 #1e-3
learning_rate_classifier = 1e-3
max_dialog_size = 25 if datatype == "e2e" else 50
dialog_batch_size = 100
rnn_hidden = 256
opt = Config()