-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
142 lines (120 loc) · 5.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import logging
import wandb
import torch
import os
import dill
from tqdm import tqdm
from torch.utils.data import DataLoader
from TorchsRNN import SrnnNet, RNNdataset, collate_fun2, DRNN, load_config
from evalTools import acc_metrics, recall_metrics, f1_metrics
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
wandb.login(host="http://47.108.152.202:8080",
key="local-86eb7fd9098b0b6aa0e6ddd886a989e62b6075f0")
wandb.init(project="DRNN-Bert-embw")
wandb.config = {
"learning_rate": 1e-3,
"epochs": 1,
"batch_size": 64
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64 # 基本参数
input_size = 300
hiden_size1 = 300
hiden_size2 = 300
loss1 = 0
loss2 = 0
loss = 0
epochs = 1
evaluation_epochs = 1
lr = 1e-3
embedding_model = open("embedding_origin.pkl", "rb")
matrix = dill.load(embedding_model)
embedding_model.close()
matrix = torch.tensor(matrix).to(device)
model = DRNN(inputsize=input_size,
inputsize1=900,
hiddensize1=hiden_size1,
hiddensize2=hiden_size2,
inchanle=hiden_size2,
outchanle1=2,
outchanle2=5,
batchsize=batch_size,
embw=matrix).to(device)
load_config(
model,
target_path="/RNN_original/",
para_name="epoch_1.pth",
if_load_or_not=False
)
dataset_file = open("data_set.pkl", 'rb')
train, test, dict = dill.load(dataset_file)
dataset = RNNdataset(train)
train_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
drop_last=True,
collate_fn=collate_fun2
)
evaluation_dataset = RNNdataset(test)
evaluation_loader = DataLoader(dataset=evaluation_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
drop_last=True,
collate_fn=collate_fun2)
lossfunction = torch.nn.CrossEntropyLoss() # 优化器、损失函数选择
optimizer = torch.optim.Adam(model.parameters(), lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
logging.info("Start Iteration")
for epoch in range(epochs): # the length of padding is 128
iteration = tqdm(train_loader, desc=f"TRAIN on epoch {epoch}")
model.train()
for step, inputs in enumerate(iteration):
output1, output2 = model(
(inputs[0], torch.randn([1, batch_size, hiden_size1]))) # 模型计算
sentence_preds = output1.argmax(axis=2)
sequence_preds = output2.argmax(axis=2)
sen_acc = acc_metrics(sentence_preds, inputs[1][0]) # 指标计算
seq_acc = acc_metrics(sequence_preds, inputs[1][1])
sen_recall = recall_metrics(sentence_preds, inputs[1][0])
seq_recall = recall_metrics(sentence_preds, inputs[1][0])
sen_f1 = f1_metrics(sen_acc, sen_recall)
seq_f1 = f1_metrics(seq_acc, seq_recall)
wandb.log({"Train Sentence Precision": sen_acc}) # 指标可视化
wandb.log({"Train Sequence Precision": seq_acc})
wandb.log({"Train Sentence Recall": sen_recall})
wandb.log({"Train Sequence Recall": seq_recall})
wandb.log({"Train Sentence F1 Score": sen_f1})
wandb.log({"Train Sequence F1 Score": seq_f1})
loss1 = lossfunction(output1.permute(0, 2, 1), inputs[1][0]) # loss计算,按照NER标准
loss2 = lossfunction(output2.permute(0, 2, 1), inputs[1][1])
loss = loss2 * 0.7 + loss1 * 0.3
iteration.set_postfix(loss1='{:.4f}'.format(loss1), loss2='{:.4f}'.format(loss2))
wandb.log({"train loss1": loss1})
wandb.log({"train loss2": loss2})
wandb.log({"train Totalloss": loss})
wandb.log({"lr:": optimizer.state_dict()['param_groups'][0]['lr']})
optimizer.zero_grad()
loss.backward()
optimizer.step()
"""
for name, parms in model.named_parameters(): #debug时使用,可视化每一个层的grad与weight
wandb.log({f"{name} Weight:" : torch.mean(parms.data)})
if parms.grad is not None:
wandb.log({f"{name} Grad_Value:" : torch.mean(parms.grad)})
"""
torch.save(model.state_dict(), "./check_points/RNN_original/epoch_1.pth")
for epoch in range(evaluation_epochs):
evaluation_iteration = tqdm(evaluation_loader, desc=f"EVALUATION on epoch {epoch + 1}")
model.eval()
for step, evaluation_input in enumerate(evaluation_iteration):
with torch.no_grad():
output1, output2 = model((evaluation_input[
0])) # 模型计算
sentence_preds = output1.argmax(axis=2)
sequence_preds = output2.argmax(axis=2)
sen_acc = acc_metrics(sentence_preds, evaluation_input[1][0]) # 参数计算
seq_acc = acc_metrics(sequence_preds, evaluation_input[1][1])
wandb.log({"Sentence Precision": sen_acc})
wandb.log({"Sequence Precision": seq_acc})