-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rewrite the text classification demo.
- Loading branch information
Showing
14 changed files
with
762 additions
and
694 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
data | ||
*.log | ||
*.pyc |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
import sys | ||
import os | ||
import gzip | ||
|
||
import paddle.v2 as paddle | ||
from paddle.v2.layer import parse_network | ||
|
||
import network_conf | ||
import reader | ||
from utils import * | ||
|
||
|
||
def infer(topology, data_dir, word_dict_path, model_path, batch_size=50): | ||
def _infer_a_batch(inferer, test_batch): | ||
probs = inferer.infer(input=test_batch, field=['value']) | ||
for i, prob in enumerate(probs): | ||
print(prob) | ||
|
||
print("Begin to predict...") | ||
use_default_data = (data_dir is None) | ||
|
||
if use_default_data: | ||
word_dict = paddle.dataset.imdb.word_dict() | ||
test_reader = paddle.dataset.imdb.test(word_dict) | ||
else: | ||
assert os.path.exists( | ||
word_dict_path), "word dictionary file does not exist" | ||
word_dict = load_dict(word_dict_path) | ||
test_reader = reader.test_reader(data_dir, word_dict)() | ||
|
||
dict_dim = len(word_dict) | ||
prob = topology(dict_dim, class_num=6, is_infer=True) | ||
|
||
# initialize PaddlePaddle | ||
paddle.init(use_gpu=False, trainer_count=1) | ||
|
||
# load the trained models | ||
parameters = paddle.parameters.Parameters.from_tar( | ||
gzip.open(model_path, "r")) | ||
inferer = paddle.inference.Inference( | ||
output_layer=prob, parameters=parameters) | ||
|
||
test_batch = [] | ||
for idx, item in enumerate(test_reader): | ||
test_batch.append([item[0]]) | ||
if idx and (not (idx + 1) % batch_size): | ||
_infer_a_batch(inferer, test_batch) | ||
test_batch = [] | ||
|
||
infer_a_batch(inferer, test_data) | ||
test_batch = [] | ||
|
||
|
||
if __name__ == "__main__": | ||
model_path = "dnn_params_pass_00000.tar.gz" | ||
test_dir = None | ||
word_dict = None | ||
nn_type = "dnn" | ||
class_num = 2 | ||
|
||
if nn_type == "dnn": | ||
topology = network_conf.fc_net | ||
elif nn_type == "cnn": | ||
topology = network_conf.convolution_net | ||
|
||
infer( | ||
topology=topology, | ||
data_dir=test_dir, | ||
word_dict_path=word_dict, | ||
model_path=model_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import sys | ||
import math | ||
import gzip | ||
|
||
from paddle.v2.layer import parse_network | ||
import paddle.v2 as paddle | ||
|
||
__all__ = ["fc_net", "convolution_net"] | ||
|
||
|
||
def fc_net(dict_dim, | ||
class_num, | ||
emb_dim=28, | ||
hidden_layer_sizes=[28, 8], | ||
is_infer=False): | ||
""" | ||
define the topology of the dnn network | ||
:param dict_dim: size of word dictionary | ||
:type input_dim: int | ||
:params class_num: number of instance class | ||
:type class_num: int | ||
:params emb_dim: embedding vector dimension | ||
:type emb_dim: int | ||
""" | ||
|
||
# define the input layers | ||
data = paddle.layer.data("word", | ||
paddle.data_type.integer_value_sequence(dict_dim)) | ||
if not is_infer: | ||
lbl = paddle.layer.data("label", | ||
paddle.data_type.integer_value(class_num)) | ||
|
||
# define the embedding layer | ||
emb = paddle.layer.embedding(input=data, size=emb_dim) | ||
# max pooling to reduce the input sequence into a vector (non-sequence) | ||
seq_pool = paddle.layer.pooling( | ||
input=emb, pooling_type=paddle.pooling.Max()) | ||
|
||
for idx, hidden_size in enumerate(hidden_layer_sizes): | ||
hidden_init_std = 1.0 / math.sqrt(hidden_size) | ||
hidden = paddle.layer.fc( | ||
input=hidden if idx else seq_pool, | ||
size=hidden_size, | ||
act=paddle.activation.Tanh(), | ||
param_attr=paddle.attr.Param(initial_std=hidden_init_std)) | ||
|
||
prob = paddle.layer.fc( | ||
input=hidden, | ||
size=class_num, | ||
act=paddle.activation.Softmax(), | ||
param_attr=paddle.attr.Param(initial_std=1.0 / math.sqrt(class_num))) | ||
|
||
if is_infer: | ||
return prob | ||
else: | ||
return paddle.layer.classification_cost( | ||
input=prob, label=lbl), prob, lbl | ||
|
||
|
||
def convolution_net(dict_dim, class_dim=2, emb_dim=28, hid_dim=128): | ||
""" | ||
cnn network definition | ||
:param dict_dim: size of word dictionary | ||
:type input_dim: int | ||
:params class_dim: number of instance class | ||
:type class_dim: int | ||
:params emb_dim: embedding vector dimension | ||
:type emb_dim: int | ||
:params hid_dim: number of same size convolution kernels | ||
:type hid_dim: int | ||
""" | ||
|
||
# input layers | ||
data = paddle.layer.data("word", | ||
paddle.data_type.integer_value_sequence(dict_dim)) | ||
lbl = paddle.layer.data("label", paddle.data_type.integer_value(2)) | ||
|
||
#embedding layer | ||
emb = paddle.layer.embedding(input=data, size=emb_dim) | ||
|
||
# convolution layers with max pooling | ||
conv_3 = paddle.networks.sequence_conv_pool( | ||
input=emb, context_len=3, hidden_size=hid_dim) | ||
conv_4 = paddle.networks.sequence_conv_pool( | ||
input=emb, context_len=4, hidden_size=hid_dim) | ||
|
||
# fc and output layer | ||
output = paddle.layer.fc( | ||
input=[conv_3, conv_4], size=class_dim, act=paddle.activation.Softmax()) | ||
|
||
cost = paddle.layer.classification_cost(input=output, label=lbl) | ||
|
||
return cost, output, lbl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
import os | ||
|
||
|
||
def train_reader(data_dir, word_dict, label_dict): | ||
""" | ||
Reader interface for training data | ||
:param data_dir: data directory | ||
:type data_dir: str | ||
:param word_dict: path of word dictionary, | ||
the dictionary must has a "UNK" in it. | ||
:type word_dict: Python dict | ||
:param label_dict: path of label dictionary | ||
:type label_dict: Python dict | ||
""" | ||
|
||
def reader(): | ||
UNK_ID = word_dict["<UNK>"] | ||
word_col = 1 | ||
lbl_col = 0 | ||
|
||
for file_name in os.listdir(data_dir): | ||
with open(os.path.join(data_dir, file_name), "r") as f: | ||
for line in f: | ||
line_split = line.strip().split("\t") | ||
word_ids = [ | ||
word_dict.get(w, UNK_ID) | ||
for w in line_split[word_col].split() | ||
] | ||
yield word_ids, label_dict[line_split[lbl_col]] | ||
|
||
return reader | ||
|
||
|
||
def test_reader(data_dir, word_dict): | ||
""" | ||
Reader interface for testing data | ||
:param data_dir: data directory. | ||
:type data_dir: str | ||
:param word_dict: path of word dictionary, | ||
the dictionary must has a "UNK" in it. | ||
:type word_dict: Python dict | ||
""" | ||
|
||
def reader(): | ||
UNK_ID = word_dict["<UNK>"] | ||
word_col = 1 | ||
|
||
for file_name in os.listdir(data_dir): | ||
with open(os.path.join(data_dir, file_name), "r") as f: | ||
for line in f: | ||
line_split = line.strip().split("\t") | ||
if len(line_split) < word_col: continue | ||
word_ids = [ | ||
word_dict.get(w, UNK_ID) | ||
for w in line_split[word_col].split() | ||
] | ||
yield word_ids, line_split[word_col] | ||
|
||
return reader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/sh | ||
|
||
python train.py \ | ||
--nn_type="dnn" \ | ||
--batch_size=64 \ | ||
--num_passes=10 \ | ||
2>&1 | tee train.log |
Oops, something went wrong.