|
1 | | -## Text-CNN |
2 | | - |
3 | | -使用 PyTorch 实现 [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf) 中提出的文本分类方法。 |
4 | | - |
5 | | -## 数据集 |
6 | | - |
7 | | -此处使用的数据集来自 [text-classification-cnn-rnn](https://github.com/gaussic/text-classification-cnn-rnn) 作者整理的数据集。下载链接:https://pan.baidu.com/s/1hugrfRu 密码: qfud |
8 | | - |
9 | | -该数据集共包含 10 个类别,每个类别有 6500 条数据。类别如下: |
10 | | - |
11 | | -``` |
12 | | -'体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐' |
13 | | -``` |
14 | | - |
15 | | -数据集划分如下: |
16 | | - |
17 | | -- 训练集: 5000 * 10 |
18 | | -- 验证集: 500 * 10 |
19 | | -- 测试集: 1000 * 10 |
20 | | - |
21 | | -## 运行方法 |
22 | | - |
23 | | -下载数据集,并解压至 `datasets` 目录下,在 `main.py` 中做适当调整,然后运行: |
24 | | - |
25 | | -``` |
26 | | -$ python main.py |
27 | | -``` |
28 | | - |
29 | | -运行结果: |
30 | | - |
31 | | -``` |
32 | | -2019-05-24 20:45:03,204 - using device: cuda:7 |
33 | | -2019-05-24 20:45:03,205 - load and preprocess data... |
34 | | -2019-05-24 20:45:15,800 - training... |
35 | | -2019-05-24 20:45:30,872 - epoch: 1 - loss: 0.06 acc: 0.65 - val_loss: 0.03 val_acc: 0.75 |
36 | | -2019-05-24 20:45:41,568 - epoch: 2 - loss: 0.05 acc: 0.80 - val_loss: 0.03 val_acc: 0.77 |
37 | | -2019-05-24 20:45:52,137 - epoch: 3 - loss: 0.05 acc: 0.82 - val_loss: 0.03 val_acc: 0.82 |
38 | | -2019-05-24 20:46:02,975 - epoch: 4 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.78 |
39 | | -2019-05-24 20:46:13,769 - epoch: 5 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.82 |
40 | | -2019-05-24 20:46:24,514 - epoch: 6 - loss: 0.05 acc: 0.87 - val_loss: 0.02 val_acc: 0.90 |
41 | | -2019-05-24 20:46:35,237 - epoch: 7 - loss: 0.05 acc: 0.92 - val_loss: 0.02 val_acc: 0.90 |
42 | | -2019-05-24 20:46:45,801 - epoch: 8 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.91 |
43 | | -2019-05-24 20:46:56,050 - epoch: 9 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.93 |
44 | | -2019-05-24 20:47:06,771 - epoch: 10 - loss: 0.05 acc: 0.94 - val_loss: 0.02 val_acc: 0.94 |
45 | | -2019-05-24 20:47:07,000 - predicting... |
46 | | -2019-05-24 20:47:07,435 - test - acc: 0.9326 |
47 | | -``` |
48 | | - |
49 | | -这里并没有对文本进行过多的预处理,比如去除特殊符号,停用词等。另外直接采用了字作为特征,对于中文文本分类,感觉分词已经没有必要了。 |
50 | | - |
51 | | -我使用 [FastText](https://fasttext.cc/) 对该数据集进行了分类,发现分类准确度能轻松达到 99% 以上。这也表明,对于长文本分类问题,词袋模型就足够了。深度模型,可能更适合于一些复杂的场景,比如词与词之间关系较大时。 |
52 | | - |
53 | | -``` |
54 | | -F1-Score : 0.999400 Precision : 0.999800 Recall : 0.999000 __label__0 |
55 | | -F1-Score : 0.995690 Precision : 0.997991 Recall : 0.993400 __label__5 |
56 | | -F1-Score : 0.996396 Precision : 0.997395 Recall : 0.995400 __label__1 |
57 | | -F1-Score : 0.998701 Precision : 0.998003 Recall : 0.999400 __label__2 |
58 | | -F1-Score : 0.999000 Precision : 0.999400 Recall : 0.998600 __label__3 |
59 | | -F1-Score : 0.983119 Precision : 0.987884 Recall : 0.978400 __label__8 |
60 | | -F1-Score : 0.997598 Precision : 0.998397 Recall : 0.996800 __label__9 |
61 | | -F1-Score : 0.985344 Precision : 0.975873 Recall : 0.995000 __label__4 |
62 | | -F1-Score : 0.996898 Precision : 0.997597 Recall : 0.996200 __label__6 |
63 | | -F1-Score : 0.998700 Precision : 0.998800 Recall : 0.998600 __label__7 |
64 | | -N 50000 |
65 | | -P@1 0.995 |
66 | | -R@1 0.995 |
67 | | -``` |
68 | | - |
69 | | -## 配置 |
70 | | - |
71 | | -```python |
72 | | -class_num=10 # 类别数 |
73 | | -embed_num=5000 # 字典大小 |
74 | | -embed_dim=64 # 字向量维度 |
75 | | -kernel_num=128 # 卷积核数量 |
76 | | -kernel_size_list=[3,4,5] # 卷积核尺寸 |
77 | | -dropout=0.5 # 置 0 的概率 |
78 | | -``` |
79 | | - |
80 | | -## Text CNN 模型 |
81 | | - |
82 | | - |
83 | | - |
84 | | -该模型的基本思想是对输入序列先做 Embedding,而后使用不同窗口大小的 1D Conv 提取特征,经过 MaxPooing1D 后 一个卷积核得到一个标量,最后全部拼接起来,得到一个向量,然后使用全连接层加 softmax 进行分类。 |
| 1 | +## Text-Classification |
| 2 | + |
| 3 | +使用 PyTorch 实现了以下几种文本分类模型: |
| 4 | + |
| 5 | +#### Text-CNN |
| 6 | + |
| 7 | +- 目录:[cnn](./cnn) |
| 8 | +- 论文:[Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf) |
| 9 | + |
| 10 | +#### Text-RCNN |
| 11 | + |
| 12 | +- 目录:[rcnn](./rcnn) |
| 13 | +- 论文: [Recurrent Convolutional Neural Networks for Text Classification](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9745/9552) |
| 14 | + |
| 15 | +#### RNN-Attention |
| 16 | + |
| 17 | +- 目录:[rnn-attention](./rnn-attention) |
| 18 | +- 论文: [Hierarchical Attention Networks for Document Classification](https://www.aclweb.org/anthology/N16-1174) - 简化版实现。 |
| 19 | + |
| 20 | +## 数据集 |
| 21 | + |
| 22 | +此处使用的数据集来自 [text-classification-cnn-rnn](https://github.com/gaussic/text-classification-cnn-rnn) 作者整理的数据集。下载链接:https://pan.baidu.com/s/1hugrfRu 密码: qfud |
| 23 | + |
| 24 | +该数据集共包含 10 个类别,每个类别有 6500 条数据。类别如下: |
| 25 | + |
| 26 | +``` |
| 27 | +'体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐' |
| 28 | +``` |
| 29 | + |
| 30 | +数据集划分如下: |
| 31 | + |
| 32 | +- 训练集: 5000 * 10 |
| 33 | +- 验证集: 500 * 10 |
| 34 | +- 测试集: 1000 * 10 |
| 35 | + |
| 36 | +## 运行方法 |
| 37 | + |
| 38 | +**1. 下载数据集** |
| 39 | + |
| 40 | +下载数据集并解压至 `datasets` 目录下。 |
| 41 | + |
| 42 | +**2. 配置参数** |
| 43 | + |
| 44 | +在 `mian.py` 中做适当调整,然后运行: |
| 45 | + |
| 46 | +``` |
| 47 | +$ python main.py |
| 48 | +``` |
| 49 | + |
| 50 | +## 运行结果: |
| 51 | + |
| 52 | +这里并没有对文本进行过多的预处理,比如去除特殊符号,停用词等。另外直接采用了字作为特征,对于中文文本分类,感觉分词已经没有必要了。 |
| 53 | + |
| 54 | +以下都是用默认参数跑出来的结果,实验使用的 GPU 为 Tesla V100,如果要用 CPU 跑建议减少数据量,并限制文本长度。 |
| 55 | + |
| 56 | +### Text-CNN |
| 57 | + |
| 58 | +``` |
| 59 | +2019-05-24 20:45:30,872 - epoch: 1 - loss: 0.06 acc: 0.65 - val_loss: 0.03 val_acc: 0.75 |
| 60 | +2019-05-24 20:45:41,568 - epoch: 2 - loss: 0.05 acc: 0.80 - val_loss: 0.03 val_acc: 0.77 |
| 61 | +2019-05-24 20:45:52,137 - epoch: 3 - loss: 0.05 acc: 0.82 - val_loss: 0.03 val_acc: 0.82 |
| 62 | +2019-05-24 20:46:02,975 - epoch: 4 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.78 |
| 63 | +2019-05-24 20:46:13,769 - epoch: 5 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.82 |
| 64 | +2019-05-24 20:46:24,514 - epoch: 6 - loss: 0.05 acc: 0.87 - val_loss: 0.02 val_acc: 0.90 |
| 65 | +2019-05-24 20:46:35,237 - epoch: 7 - loss: 0.05 acc: 0.92 - val_loss: 0.02 val_acc: 0.90 |
| 66 | +2019-05-24 20:46:45,801 - epoch: 8 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.91 |
| 67 | +2019-05-24 20:46:56,050 - epoch: 9 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.93 |
| 68 | +2019-05-24 20:47:06,771 - epoch: 10 - loss: 0.05 acc: 0.94 - val_loss: 0.02 val_acc: 0.94 |
| 69 | +
|
| 70 | +2019-05-24 20:47:07,435 - test - acc: 0.9326 |
| 71 | +``` |
| 72 | + |
| 73 | +### Text-RCNN |
| 74 | + |
| 75 | +``` |
| 76 | +2019-05-26 12:40:35,331 - epoch 1 - loss: 0.02 acc: 0.81 - val_loss: 0.00 val_acc: 0.90 |
| 77 | +2019-05-26 12:42:10,316 - epoch 2 - loss: 0.01 acc: 0.94 - val_loss: 0.01 val_acc: 0.90 |
| 78 | +2019-05-26 12:43:42,279 - epoch 3 - loss: 0.01 acc: 0.95 - val_loss: 0.00 val_acc: 0.93 |
| 79 | +2019-05-26 12:45:14,370 - epoch 4 - loss: 0.00 acc: 0.96 - val_loss: 0.00 val_acc: 0.91 |
| 80 | +2019-05-26 12:46:46,713 - epoch 5 - loss: 0.00 acc: 0.96 - val_loss: 0.00 val_acc: 0.94 |
| 81 | +
|
| 82 | +2019-05-26 12:46:51,099 - test - acc: 0.95 |
| 83 | +``` |
| 84 | + |
| 85 | +相对 CNN 而言,RCNN 训练花费时间更多,RCNN 训练一个 epoch 可以让 CNN 训练 10 个 epoch。另外 RCNN 需要的 epoch 数相对较少,这里第一个 epoch 结束后,验证集上就达到了 90% 的准确度。 |
| 86 | + |
| 87 | +### RNN-Attention |
| 88 | + |
| 89 | +``` |
| 90 | +2019-05-26 12:55:42,786 - epoch 1 - loss: 0.03 acc: 0.66 - val_loss: 0.01 val_acc: 0.80 |
| 91 | +2019-05-26 12:57:04,999 - epoch 2 - loss: 0.01 acc: 0.87 - val_loss: 0.01 val_acc: 0.84 |
| 92 | +2019-05-26 12:58:36,714 - epoch 3 - loss: 0.01 acc: 0.91 - val_loss: 0.01 val_acc: 0.88 |
| 93 | +2019-05-26 13:00:08,892 - epoch 4 - loss: 0.01 acc: 0.93 - val_loss: 0.01 val_acc: 0.89 |
| 94 | +2019-05-26 13:01:41,746 - epoch 5 - loss: 0.01 acc: 0.94 - val_loss: 0.00 val_acc: 0.92 |
| 95 | +
|
| 96 | +2019-05-26 13:01:47,011 - test - acc: 0.9212 |
| 97 | +``` |
| 98 | + |
| 99 | +### FastText |
| 100 | + |
| 101 | +另外,我使用 [FastText](https://fasttext.cc/) 对该数据集进行了分类,发现分类准确度能轻松达到 99% 以上。这也表明,对于长文本分类问题,词袋模型就足够了。深度模型,在此简单任务上并没有优势。 |
| 102 | + |
| 103 | +``` |
| 104 | +F1-Score : 0.999400 Precision : 0.999800 Recall : 0.999000 __label__0 |
| 105 | +F1-Score : 0.995690 Precision : 0.997991 Recall : 0.993400 __label__5 |
| 106 | +F1-Score : 0.996396 Precision : 0.997395 Recall : 0.995400 __label__1 |
| 107 | +F1-Score : 0.998701 Precision : 0.998003 Recall : 0.999400 __label__2 |
| 108 | +F1-Score : 0.999000 Precision : 0.999400 Recall : 0.998600 __label__3 |
| 109 | +F1-Score : 0.983119 Precision : 0.987884 Recall : 0.978400 __label__8 |
| 110 | +F1-Score : 0.997598 Precision : 0.998397 Recall : 0.996800 __label__9 |
| 111 | +F1-Score : 0.985344 Precision : 0.975873 Recall : 0.995000 __label__4 |
| 112 | +F1-Score : 0.996898 Precision : 0.997597 Recall : 0.996200 __label__6 |
| 113 | +F1-Score : 0.998700 Precision : 0.998800 Recall : 0.998600 __label__7 |
| 114 | +N 50000 |
| 115 | +P@1 0.995 |
| 116 | +R@1 0.995 |
| 117 | +``` |
0 commit comments