Skip to content

Commit 0c5889d

Browse files
committed
Add: linear chain conditional random field. I made it!
1 parent 057c285 commit 0c5889d

File tree

2 files changed

+339
-4
lines changed

2 files changed

+339
-4
lines changed
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
from math import log
2+
import os
3+
from matplotlib.tri.triinterpolate import LinearTriInterpolator
4+
import numpy as np
5+
from functools import partial
6+
import sys
7+
from pathlib import Path
8+
from rich.console import Console
9+
from rich.table import Table
10+
sys.path.append(str(Path(os.path.abspath(__file__)).parent.parent))
11+
from utils import *
12+
13+
class LinearChainConditionalRandomField:
14+
def __init__(self, feature_funcs, trans_feature_funcs, sequence_length, n_x, n_y, max_iteration=100, verbose=False):
15+
"""
16+
`feature_funcs` are a group of functions s(y_i, X, i) in a list
17+
`trans_feature_funcs` are a group of functions t(y_{i-1}, y_i, X, i) in a list
18+
`sequence_length` is the length of each input sequence
19+
`n_x` is the number of possible values of each item in a sequence x
20+
`n_y` is the number of possible values of each item in a sequence y
21+
"""
22+
self.feature_funcs = feature_funcs
23+
self.trans_feature_funcs = trans_feature_funcs
24+
self.n_x = n_x
25+
self.n_y = n_y
26+
self.sequence_length = sequence_length
27+
self.max_iteration = max_iteration
28+
self.verbose = verbose
29+
30+
def get_trans(self, x):
31+
"""get transition matrix given observed sequence x"""
32+
trans_feature = np.zeros([self.sequence_length, self.n_y, self.n_y])
33+
for i in range(self.sequence_length):
34+
for y_i_1 in range(self.n_y):
35+
for y_i in range(self.n_y):
36+
for j, func in enumerate(self.used_feature_funcs):
37+
trans_feature[i, y_i_1, y_i] += self.w_feature_funcs[j] * func(y_i, x, i)
38+
if i > 0:
39+
for y_i_1 in range(self.n_y):
40+
for y_i in range(self.n_y):
41+
for j, func in enumerate(self.used_trans_feature_funcs):
42+
trans_feature[i, y_i_1, y_i] += self.w_trans_feature_funcs[j] * func(y_i_1, y_i, x, i)
43+
return np.exp(trans_feature)
44+
45+
def fit(self, X, Y):
46+
"""
47+
X is a two dimensional matrix of observation sequence
48+
Y is a two dimensional matrix of hidden state sequence
49+
optimize weights by Improved Iterative Scaling
50+
"""
51+
E_feature = np.zeros(len(self.feature_funcs))
52+
E_trans_feature = np.zeros(len(self.trans_feature_funcs))
53+
54+
# Because each x is a sequence, it's vector space is too large to iterate.
55+
# We need to store all the possible sequence x during the training time
56+
# and only iterate over existing x.
57+
p_x = {tuple(x): 0. for x in X}
58+
59+
for x, y in zip(X, Y):
60+
x_key = tuple(x)
61+
p_x[x_key] += 1 / len(X)
62+
for i, yi in enumerate(y):
63+
for j, func in enumerate(self.feature_funcs):
64+
E_feature[j] += func(yi, x, i) / len(X)
65+
for i in range(1, self.sequence_length):
66+
yi_1, yi = y[i - 1], y[i]
67+
for j, func in enumerate(self.trans_feature_funcs):
68+
E_trans_feature[j] += func(yi_1, yi, x, i) / len(X)
69+
70+
# features that don't show in training data are useless, filter them
71+
self.used_feature_funcs = [func for E, func in zip(E_feature, self.feature_funcs) if E != 0]
72+
self.used_trans_feature_funcs = [func for E, func in zip(E_trans_feature, self.trans_feature_funcs) if E != 0]
73+
E_feature = E_feature[E_feature.nonzero()]
74+
E_trans_feature = E_trans_feature[E_trans_feature.nonzero()]
75+
self.w_feature_funcs = np.zeros(len(self.used_feature_funcs))
76+
self.w_trans_feature_funcs = np.zeros(len(self.used_trans_feature_funcs))
77+
78+
# pre-calculate all the possible values of feature functions
79+
feature = np.zeros([len(self.used_feature_funcs), len(p_x), self.sequence_length, self.n_y])
80+
trans_feature = np.zeros([len(self.used_trans_feature_funcs), len(p_x), self.sequence_length, self.n_y, self.n_y])
81+
for x_i, x_key in enumerate(p_x):
82+
x = np.array(x_key)
83+
for func_i, func in enumerate(self.used_trans_feature_funcs):
84+
for i in range(1, self.sequence_length):
85+
for y_i_1 in range(self.n_y):
86+
for y_i in range(self.n_y):
87+
trans_feature[func_i, x_i, i, y_i_1, y_i] = func(y_i_1, y_i, x, i)
88+
for func_i, func in enumerate(self.used_feature_funcs):
89+
for i in range(self.sequence_length):
90+
for y_i in range(self.n_y):
91+
feature[func_i, x_i, i, y_i] = func(y_i, x, i)
92+
93+
# pre-calculate the max number of features, given x
94+
max_feature = np.zeros(len(p_x), dtype=int)
95+
sum_trans_feature = trans_feature.sum(axis=0)
96+
sum_feature = feature.sum(axis=0)
97+
for x_i, x_key in enumerate(p_x):
98+
cur_max_feature = np.zeros(self.n_y)
99+
for i in range(self.sequence_length):
100+
cur_max_feature = (cur_max_feature[:, None] + sum_trans_feature[x_i, i]).max(axis=0) + sum_feature[x_i, i]
101+
max_feature[x_i] = cur_max_feature.max()
102+
n_coef = max(max_feature) + 1
103+
104+
# train
105+
for iteration in range(self.max_iteration):
106+
if self.verbose:
107+
print(f'Iteration {iteration} starts...')
108+
loss = 0.
109+
for funcs, w, E_experience in zip(
110+
[self.used_feature_funcs, self.used_trans_feature_funcs],
111+
[self.w_feature_funcs, self.w_trans_feature_funcs],
112+
[E_feature, E_trans_feature]):
113+
for func_i in range(len(funcs)):
114+
# if funcs is self.used_trans_feature_funcs:
115+
coef = np.zeros(n_coef)
116+
# only iterater over possible x
117+
for x_i, x_key in enumerate(p_x):
118+
cur_p_x = p_x[x_key]
119+
x = np.array(x_key)
120+
121+
trans = self.get_trans(x)
122+
# forward algorithm
123+
cur_prob = np.ones(self.n_y)
124+
forward_prob = np.zeros([self.sequence_length + 1, self.n_y])
125+
forward_prob[0] = cur_prob
126+
for i in range(self.sequence_length):
127+
cur_prob = cur_prob @ trans[i]
128+
forward_prob[i + 1] = cur_prob
129+
# backward algorithm
130+
cur_prob = np.ones(self.n_y)
131+
backward_prob = np.zeros([self.sequence_length + 1, self.n_y])
132+
backward_prob[-1] = cur_prob
133+
for i in range(self.sequence_length - 1, -1, -1):
134+
cur_prob = trans[i] @ cur_prob
135+
backward_prob[i] = cur_prob
136+
137+
if iteration < 10:
138+
np.testing.assert_almost_equal(
139+
forward_prob[-1].sum(),
140+
backward_prob[0].sum()
141+
)
142+
for i in range(1, self.sequence_length + 1):
143+
np.testing.assert_almost_equal(
144+
forward_prob[i] @ backward_prob[i],
145+
forward_prob[-1].sum()
146+
)
147+
for i in range(0, self.sequence_length):
148+
np.testing.assert_almost_equal(
149+
(np.outer(forward_prob[i], backward_prob[i + 1]) * trans[i]).sum(),
150+
forward_prob[-1].sum()
151+
)
152+
153+
# calculate expectation of each feature_function given x
154+
cur_E_feature = 0.
155+
if funcs is self.used_feature_funcs:
156+
for i in range(1, self.sequence_length + 1):
157+
cur_E_feature += (
158+
forward_prob[i] * backward_prob[i] * feature[func_i, x_i, i - 1]
159+
).sum()
160+
elif funcs is self.used_trans_feature_funcs:
161+
for i in range(0, self.sequence_length):
162+
cur_E_feature += (
163+
np.outer(forward_prob[i], backward_prob[i + 1]) * trans[i] * trans_feature[func_i, x_i, i]
164+
).sum()
165+
else:
166+
raise Exception("Unknown function set!")
167+
cur_E_feature /= forward_prob[-1].sum()
168+
169+
coef[max_feature[x_i]] += cur_p_x * cur_E_feature
170+
171+
# update w
172+
dw_i = log(newton(
173+
lambda x: sum(c * x ** i for i, c in enumerate(coef)) - E_experience[func_i],
174+
lambda x: sum(i * c * x ** (i - 1) for i, c in enumerate(coef) if i > 0),
175+
1
176+
))
177+
w[func_i] += dw_i
178+
loss += abs(E_experience[func_i] - coef.sum())
179+
loss /= len(self.feature_funcs) + len(self.trans_feature_funcs)
180+
if self.verbose:
181+
print(f'Iteration {iteration} ends, Loss: {loss}')
182+
183+
def predict(self, X):
184+
"""
185+
predict state sequence y using viterbi algorithm
186+
X is a group of sequence x in a two-dimensional array
187+
"""
188+
189+
ans = np.zeros([len(X), self.sequence_length])
190+
for x_i, x in enumerate(X):
191+
# pre-calculate all the possible values of feature functions
192+
feature = np.zeros([len(self.used_feature_funcs), self.sequence_length, self.n_y])
193+
trans_feature = np.zeros([len(self.used_trans_feature_funcs), self.sequence_length, self.n_y, self.n_y])
194+
for func_i, func in enumerate(self.used_trans_feature_funcs):
195+
for i in range(1, self.sequence_length):
196+
for y_i_1 in range(self.n_y):
197+
for y_i in range(self.n_y):
198+
trans_feature[func_i, i, y_i_1, y_i] = func(y_i_1, y_i, x, i)
199+
for func_i, func in enumerate(self.used_feature_funcs):
200+
for i in range(self.sequence_length):
201+
for y_i in range(self.n_y):
202+
feature[func_i, i, y_i] = func(y_i, x, i)
203+
feature = (self.w_feature_funcs[:, None, None] * feature).sum(axis=0)
204+
trans_feature = (self.w_trans_feature_funcs[:, None, None, None] * trans_feature).sum(axis=0)
205+
206+
# viterbi
207+
pre_state = np.zeros([self.sequence_length, self.n_y], dtype=int) - 1
208+
prob = np.zeros([self.sequence_length, self.n_y])
209+
cur_prob = np.ones(self.n_y)
210+
for i in range(self.sequence_length):
211+
trans_prob = cur_prob[:, None] + trans_feature[i]
212+
pre_state[i] = trans_prob.argmax(axis=0)
213+
cur_prob = trans_prob.max(axis=0) + feature[i]
214+
prob[i] = cur_prob
215+
216+
# back track the trace
217+
cur_state = prob[-1].argmax()
218+
for i in range(self.sequence_length - 1, -1, -1):
219+
ans[x_i, i] = cur_state
220+
cur_state = pre_state[i, cur_state]
221+
return ans
222+
223+
224+
if __name__ == '__main__':
225+
def demonstrate(X, Y, testX, n_y, desc):
226+
console = Console(markup=False)
227+
228+
vocab = set(X.flatten())
229+
vocab_size = len(vocab)
230+
word2num = {word: num for num, word in enumerate(vocab)}
231+
232+
f_word2num = np.vectorize(lambda word: word2num[word])
233+
234+
numX, num_testX = map(f_word2num, (X, testX))
235+
236+
sequence_length = numX.shape[-1]
237+
238+
class FeatureFunc:
239+
def __init__(self, x_i, y_i):
240+
self.x_i = x_i
241+
self.y_i = y_i
242+
243+
def __call__(self, y_i, x, i):
244+
return int(y_i == self.y_i and x[i] == self.x_i)
245+
246+
class TransFeatureFunc:
247+
def __init__(self, y_i_1, y_i):
248+
self.y_i = y_i
249+
self.y_i_1 = y_i_1
250+
251+
def __call__(self, y_i_1, y_i, x, i):
252+
return int(y_i_1 == self.y_i_1 and y_i == self.y_i)
253+
254+
feature_funcs = [FeatureFunc(x_i, y_i)
255+
for x_i in range(vocab_size)
256+
for y_i in range(n_y)]
257+
trans_feature_funcs = [TransFeatureFunc(y_i_1, y_i)
258+
for y_i_1 in range(n_y)
259+
for y_i in range(n_y)]
260+
261+
linear_chain_conditional_random_field = LinearChainConditionalRandomField(
262+
feature_funcs,
263+
trans_feature_funcs,
264+
sequence_length,
265+
vocab_size,
266+
n_y,
267+
verbose=True
268+
)
269+
linear_chain_conditional_random_field.fit(numX, Y)
270+
pred = linear_chain_conditional_random_field.predict(num_testX)
271+
272+
# show in table
273+
print(desc)
274+
table = Table()
275+
for x, p in zip(testX, pred):
276+
table.add_row(*map(str, x))
277+
table.add_row(*map(str, p))
278+
console.print(table)
279+
280+
281+
# ---------------------- Example 1 --------------------------------------------
282+
X = np.array([s.split() for s in
283+
['i am good .',
284+
'i am bad .',
285+
'you are good .',
286+
'you are bad .',
287+
'it is good .',
288+
'it is bad .',
289+
]
290+
])
291+
Y = np.array([
292+
[0, 1, 2, 3],
293+
[0, 1, 2, 3],
294+
[0, 1, 2, 3],
295+
[0, 1, 2, 3],
296+
[0, 1, 2, 3],
297+
])
298+
testX = np.array([s.split() for s in
299+
['you is good .',
300+
'i are bad .',
301+
'it are good .']
302+
])
303+
testX = np.concatenate([X, testX])
304+
demonstrate(X, Y, testX, 4, "Example 1")
305+
306+
# ---------------------- Example 1 --------------------------------------------
307+
X = np.array([s.split() for s in
308+
['i be good .',
309+
'you be good .',
310+
'be good . .',
311+
'i love you .',
312+
'he be . .',
313+
]
314+
])
315+
# pronoun: 0, verb: 1, adjective: 2, ".": 3
316+
Y = np.array([
317+
[0, 1, 2, 3],
318+
[0, 1, 2, 3],
319+
[1, 2, 3, 3],
320+
[0, 1, 0, 3],
321+
[0, 1, 3, 3],
322+
])
323+
testX = np.array([s.split() for s in
324+
['you be good .',
325+
'he love you .',
326+
'i love good .',
327+
'. be love .',
328+
'. love be .',
329+
'. . be good']
330+
])
331+
testX = np.concatenate([X, testX])
332+
demonstrate(X, Y, testX, 4, "Example 2")

utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,15 @@ def line_search(f, l, r, epsilon=1e-6):
8686
fll, frr = None, None
8787
return (l + r) / 2
8888

89-
def newton(f, x0, epsilon=1e-6):
90-
"""Find the fixed point wehre f(x) = x of function f"""
89+
def newton(f, g, x0, epsilon=1e-6):
90+
"""
91+
Find the zero point wehre f(x) = 0 of function f
92+
g(x) is the gradient function of f
93+
"""
9194
prex = x0
92-
x = f(x0)
95+
x = x0 - f(x0) / g(x0)
9396
while abs(x - prex) > epsilon:
94-
prex, x = x, f(x)
97+
prex, x = x, x - f(x) / g(x)
9598
return x
9699

97100
def one_hot(i, size):

0 commit comments

Comments
 (0)