This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
532 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# CTC with Mxnet | ||
this is mx.contrib.sym.ctc_loss example. It was modified from example [warpctc](https://github.com/dmlc/mxnet/tree/master/example/warpctc) | ||
|
||
# Core code | ||
this is core change in lstm.py | ||
```Cython | ||
def lstm_unroll(num_lstm_layer, seq_len, | ||
num_hidden, num_label): | ||
param_cells = [] | ||
last_states = [] | ||
for i in range(num_lstm_layer): | ||
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), | ||
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), | ||
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), | ||
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) | ||
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), | ||
h=mx.sym.Variable("l%d_init_h" % i)) | ||
last_states.append(state) | ||
assert (len(last_states) == num_lstm_layer) | ||
# embeding layer | ||
data = mx.sym.Variable('data') | ||
label = mx.sym.Variable('label') | ||
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) | ||
hidden_all = [] | ||
for seqidx in range(seq_len): | ||
hidden = wordvec[seqidx] | ||
for i in range(num_lstm_layer): | ||
next_state = lstm(num_hidden, indata=hidden, | ||
prev_state=last_states[i], | ||
param=param_cells[i], | ||
seqidx=seqidx, layeridx=i) | ||
hidden = next_state.h | ||
last_states[i] = next_state | ||
hidden_all.append(hidden) | ||
hidden_concat = mx.sym.Concat(*hidden_all, dim=0) | ||
pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) | ||
pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0)) | ||
loss = mx.contrib.sym.ctc_loss(data=pred_ctc, label=label) | ||
ctc_loss = mx.sym.MakeLoss(loss) | ||
softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc) | ||
softmax_loss = mx.sym.MakeLoss(softmax_class) | ||
softmax_loss = mx.sym.BlockGrad(softmax_loss) | ||
return mx.sym.Group([softmax_loss, ctc_loss]) | ||
``` | ||
# Some Result | ||
If there were more training, the result would be better | ||
|
||
``` | ||
2017-07-08 13:22:01,155 Epoch[94] Batch [50] Speed: 4273.43 samples/sec Accuracy=0.808747 | ||
2017-07-08 13:22:13,141 Epoch[94] Batch [100] Speed: 4271.84 samples/sec Accuracy=0.786855 | ||
2017-07-08 13:22:25,179 Epoch[94] Batch [150] Speed: 4253.81 samples/sec Accuracy=0.810625 | ||
2017-07-08 13:22:37,198 Epoch[94] Batch [200] Speed: 4259.96 samples/sec Accuracy=0.808809 | ||
2017-07-08 13:22:49,233 Epoch[94] Batch [250] Speed: 4254.13 samples/sec Accuracy=0.806426 | ||
2017-07-08 13:23:01,308 Epoch[94] Batch [300] Speed: 4239.98 samples/sec Accuracy=0.817305 | ||
2017-07-08 13:23:02,030 Epoch[94] Train-Accuracy=0.819336 | ||
2017-07-08 13:23:02,030 Epoch[94] Time cost=73.092 | ||
2017-07-08 13:23:02,101 Saved checkpoint to "ocr-0095.params" | ||
2017-07-08 13:23:07,192 Epoch[94] Validation-Accuracy=0.819417 | ||
2017-07-08 13:23:20,579 Epoch[95] Batch [50] Speed: 4288.76 samples/sec Accuracy=0.817459 | ||
2017-07-08 13:23:32,573 Epoch[95] Batch [100] Speed: 4268.75 samples/sec Accuracy=0.815215 | ||
2017-07-08 13:23:44,635 Epoch[95] Batch [150] Speed: 4244.85 samples/sec Accuracy=0.820215 | ||
2017-07-08 13:23:56,670 Epoch[95] Batch [200] Speed: 4254.38 samples/sec Accuracy=0.823613 | ||
2017-07-08 13:24:08,650 Epoch[95] Batch [250] Speed: 4273.83 samples/sec Accuracy=0.827109 | ||
2017-07-08 13:24:20,680 Epoch[95] Batch [300] Speed: 4256.49 samples/sec Accuracy=0.824961 | ||
2017-07-08 13:24:21,401 Epoch[95] Train-Accuracy=0.840495 | ||
2017-07-08 13:24:21,401 Epoch[95] Time cost=73.008 | ||
2017-07-08 13:24:21,441 Saved checkpoint to "ocr-0096.params" | ||
2017-07-08 13:24:26,508 Epoch[95] Validation-Accuracy=0.834798 | ||
2017-07-08 13:24:39,938 Epoch[96] Batch [50] Speed: 4259.32 samples/sec Accuracy=0.825578 | ||
2017-07-08 13:24:51,987 Epoch[96] Batch [100] Speed: 4249.67 samples/sec Accuracy=0.826562 | ||
2017-07-08 13:25:04,041 Epoch[96] Batch [150] Speed: 4247.44 samples/sec Accuracy=0.831855 | ||
2017-07-08 13:25:16,058 Epoch[96] Batch [200] Speed: 4260.77 samples/sec Accuracy=0.830840 | ||
2017-07-08 13:25:28,109 Epoch[96] Batch [250] Speed: 4248.44 samples/sec Accuracy=0.827168 | ||
2017-07-08 13:25:40,057 Epoch[96] Batch [300] Speed: 4285.23 samples/sec Accuracy=0.832715 | ||
2017-07-08 13:25:40,782 Epoch[96] Train-Accuracy=0.830729 | ||
2017-07-08 13:25:40,782 Epoch[96] Time cost=73.098 | ||
2017-07-08 13:25:40,821 Saved checkpoint to "ocr-0097.params" | ||
2017-07-08 13:25:45,886 Epoch[96] Validation-Accuracy=0.840820 | ||
2017-07-08 13:25:59,283 Epoch[97] Batch [50] Speed: 4271.85 samples/sec Accuracy=0.831648 | ||
2017-07-08 13:26:11,243 Epoch[97] Batch [100] Speed: 4280.89 samples/sec Accuracy=0.835371 | ||
2017-07-08 13:26:23,263 Epoch[97] Batch [150] Speed: 4259.89 samples/sec Accuracy=0.831094 | ||
2017-07-08 13:26:35,230 Epoch[97] Batch [200] Speed: 4278.40 samples/sec Accuracy=0.827129 | ||
2017-07-08 13:26:47,199 Epoch[97] Batch [250] Speed: 4277.77 samples/sec Accuracy=0.834258 | ||
2017-07-08 13:26:59,257 Epoch[97] Batch [300] Speed: 4245.93 samples/sec Accuracy=0.833770 | ||
2017-07-08 13:26:59,971 Epoch[97] Train-Accuracy=0.844727 | ||
2017-07-08 13:26:59,971 Epoch[97] Time cost=72.908 | ||
2017-07-08 13:27:00,020 Saved checkpoint to "ocr-0098.params" | ||
2017-07-08 13:27:05,130 Epoch[97] Validation-Accuracy=0.827962 | ||
2017-07-08 13:27:18,521 Epoch[98] Batch [50] Speed: 4281.06 samples/sec Accuracy=0.834118 | ||
2017-07-08 13:27:30,537 Epoch[98] Batch [100] Speed: 4261.20 samples/sec Accuracy=0.835352 | ||
2017-07-08 13:27:42,542 Epoch[98] Batch [150] Speed: 4264.88 samples/sec Accuracy=0.839395 | ||
2017-07-08 13:27:54,544 Epoch[98] Batch [200] Speed: 4266.31 samples/sec Accuracy=0.836328 | ||
2017-07-08 13:28:06,550 Epoch[98] Batch [250] Speed: 4264.50 samples/sec Accuracy=0.841465 | ||
2017-07-08 13:28:18,622 Epoch[98] Batch [300] Speed: 4241.11 samples/sec Accuracy=0.831680 | ||
2017-07-08 13:28:19,349 Epoch[98] Train-Accuracy=0.833984 | ||
2017-07-08 13:28:19,349 Epoch[98] Time cost=73.018 | ||
2017-07-08 13:28:19,393 Saved checkpoint to "ocr-0099.params" | ||
2017-07-08 13:28:24,472 Epoch[98] Validation-Accuracy=0.818034 | ||
2017-07-08 13:28:37,961 Epoch[99] Batch [50] Speed: 4242.14 samples/sec Accuracy=0.835861 | ||
2017-07-08 13:28:50,031 Epoch[99] Batch [100] Speed: 4241.94 samples/sec Accuracy=0.846543 | ||
2017-07-08 13:29:02,108 Epoch[99] Batch [150] Speed: 4239.22 samples/sec Accuracy=0.850645 | ||
2017-07-08 13:29:14,160 Epoch[99] Batch [200] Speed: 4248.34 samples/sec Accuracy=0.844141 | ||
2017-07-08 13:29:26,225 Epoch[99] Batch [250] Speed: 4243.71 samples/sec Accuracy=0.842129 | ||
2017-07-08 13:29:38,277 Epoch[99] Batch [300] Speed: 4248.07 samples/sec Accuracy=0.851250 | ||
2017-07-08 13:29:38,975 Epoch[99] Train-Accuracy=0.854492 | ||
2017-07-08 13:29:38,976 Epoch[99] Time cost=73.315 | ||
2017-07-08 13:29:39,023 Saved checkpoint to "ocr-0100.params" | ||
2017-07-08 13:29:44,110 Epoch[99] Validation-Accuracy=0.851969 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# pylint:skip-file | ||
import sys | ||
|
||
from mxnet.symbol_doc import SymbolDoc | ||
|
||
sys.path.insert(0, "../../python") | ||
import mxnet as mx | ||
import numpy as np | ||
from collections import namedtuple | ||
import time | ||
import math | ||
|
||
LSTMState = namedtuple("LSTMState", ["c", "h"]) | ||
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", | ||
"h2h_weight", "h2h_bias"]) | ||
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", | ||
"init_states", "last_states", | ||
"seq_data", "seq_labels", "seq_outputs", | ||
"param_blocks"]) | ||
|
||
|
||
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx): | ||
"""LSTM Cell symbol""" | ||
i2h = mx.sym.FullyConnected(data=indata, | ||
weight=param.i2h_weight, | ||
bias=param.i2h_bias, | ||
num_hidden=num_hidden * 4, | ||
name="t%d_l%d_i2h" % (seqidx, layeridx)) | ||
h2h = mx.sym.FullyConnected(data=prev_state.h, | ||
weight=param.h2h_weight, | ||
bias=param.h2h_bias, | ||
num_hidden=num_hidden * 4, | ||
name="t%d_l%d_h2h" % (seqidx, layeridx)) | ||
gates = i2h + h2h | ||
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, | ||
name="t%d_l%d_slice" % (seqidx, layeridx)) | ||
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") | ||
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") | ||
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") | ||
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") | ||
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) | ||
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") | ||
return LSTMState(c=next_c, h=next_h) | ||
|
||
|
||
def lstm_unroll(num_lstm_layer, seq_len, | ||
num_hidden, num_label): | ||
param_cells = [] | ||
last_states = [] | ||
for i in range(num_lstm_layer): | ||
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), | ||
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), | ||
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), | ||
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) | ||
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), | ||
h=mx.sym.Variable("l%d_init_h" % i)) | ||
last_states.append(state) | ||
assert (len(last_states) == num_lstm_layer) | ||
|
||
# embeding layer | ||
data = mx.sym.Variable('data') | ||
label = mx.sym.Variable('label') | ||
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) | ||
|
||
hidden_all = [] | ||
for seqidx in range(seq_len): | ||
hidden = wordvec[seqidx] | ||
for i in range(num_lstm_layer): | ||
next_state = lstm(num_hidden, indata=hidden, | ||
prev_state=last_states[i], | ||
param=param_cells[i], | ||
seqidx=seqidx, layeridx=i) | ||
hidden = next_state.h | ||
last_states[i] = next_state | ||
hidden_all.append(hidden) | ||
|
||
hidden_concat = mx.sym.Concat(*hidden_all, dim=0) | ||
|
||
pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) | ||
pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0)) | ||
|
||
loss = mx.contrib.sym.ctc_loss(data=pred_ctc, label=label) | ||
ctc_loss = mx.sym.MakeLoss(loss) | ||
|
||
softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc) | ||
softmax_loss = mx.sym.MakeLoss(softmax_class) | ||
softmax_loss = mx.sym.BlockGrad(softmax_loss) | ||
|
||
return mx.sym.Group([softmax_loss, ctc_loss]) |
Oops, something went wrong.