-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathget_feature.py
96 lines (83 loc) · 3.35 KB
/
get_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, print_function, division
from io import open
import torch
import json
import numpy as np
from word_emb import emb_size, word2id, id2word, emb, word2count, vocab_size, SOS_token, EOS_token, PAD_token, UNK_token
with open('resource/word_dict.json', 'r', encoding='utf-8') as f1:
word_dict = json.load(f1)
forbidden_words = ['一', '二', '三', '四', '五', '六', '七', '八', '九', '十', '千', '百', '万',
'艇', '些', '的',
'START', 'END', '/', 'START1', 'END1', '-', 'UNK'] # END1
# forbidden_id = [word2id[word] for word in forbidden_words]
yun_sen = [2,4]
use_mode = False
poem_type = 'poem7' # 需要修改
hard_lv = True # 强lv
if poem_type == 'poem7':
sen_len = 7
if hard_lv:
lv_list = [['p', 'p', 'z', 'z', 'p', 'p', 'z'], ['z', 'z', 'p', 'p', 'z', 'z', 'p'],
['z', 'z', 'p', 'p', 'p', 'z', 'z'], ['p', 'p', 'z', 'z', 'z', 'p', 'p']]
else:
lv_list = [['0', 'p', '0', 'z', '0', 'p', '0'], ['0', 'z', '0', 'p', '0', 'z', 'p'],
['0', 'z', '0', 'p', '0', 'z', '0'], ['0', 'p', '0', 'z', '0', 'p', 'p']]
else:
sen_len = 5
lv_list = [['z', 'z', 'p', 'p', 'z'], ['p', 'p', 'z', 'z', 'p'],
['p', 'p', 'p', 'z', 'z'], ['z', 'z', 'z', 'p', 'p']]
all_lv = []
for i in range(vocab_size):
word = id2word[str(i)]
if word in word_dict.keys():
word_lv = word_dict[word]['pz']
else:
word_lv = ''
all_lv.append(word_lv)
all_lv = np.array(all_lv)
all_yun = []
for i in range(vocab_size):
word = id2word[str(i)]
if word in word_dict.keys():
word_yun = word_dict[word]['yun']
else:
word_yun = ''
all_yun.append(word_yun)
all_yun = np.array(all_yun)
oov = 0
def get_feature(decoded_words, sen_num, w_num, batch_size, target):
feature1 = []
feature2 = []
for i in range(batch_size):
# feature1: lv
if use_mode:
target_lv = lv_list[sen_num][w_num]
else:
target_word = id2word[str(target[i].item())]
if target_word in word_dict.keys():
target_lv = word_dict[target_word]['pz']
else:
target_lv = '0' # 如何让它
# oov += 1
target_lv = [target_lv] * vocab_size
target_lv = np.array(target_lv)
feature1_batch = np.where(all_lv == target_lv, 1.0, 0.0)
feature1.append(feature1_batch)
# feature2: yun
decoded_words_batch = decoded_words[i]
if sen_num+1 in yun_sen[1:] and w_num+1 == sen_len:
yun_word = decoded_words_batch[yun_sen[0] * (sen_len+1) - 2] # 第一个押韵句子的最后一个字
if yun_word in word_dict.keys(): # 生成标志词需要特殊处理,待考虑
target_yun = word_dict[yun_word]['yun']
target_yun = [target_yun] * vocab_size
target_yun = np.array(target_yun)
feature2_batch = np.where(all_yun == target_yun, 1.0, 0.0)
else:
feature2_batch = np.zeros(vocab_size)
else:
feature2_batch = np.zeros(vocab_size)
feature2.append(feature2_batch)
feature1 = np.array(feature1).astype(np.float32)
feature2 = np.array(feature2).astype(np.float32)
return feature1, feature2