在当前环境安装本项目使用的环境
pip install -r ./requirements.txt
我尝试将项目克隆下来之后安装了requirements里的包,发现始终缺少依赖.
所以我将重要的包版本罗列出来, 建议使用conda安装以下列表的依赖和对应的版本...
jieba==0.39
numpy==1.17.3
requests==2.22.0
keras==2.3.0
pandas==0.25.1
tqdm==4.31.0
tensorflow==1.14.0
下载wiki.zh.vec至项目文件夹下 ./data/ 下载地址
找到或者直接点击Chinese: bin+text, text下载
python train.py # 运行train.py文件进行训练demo数据
一列为class用于存储每个类别的标签, 一列为data用于存储每条文本数据
class | data |
---|---|
phone | 苹果 |
phone | 华为 |
phone | 小米 |
phone | 传音 |
bank | 中国建设 银行 |
bank | 中国 银行 |
bank | 中国工商银行 |
bank | 中国农业银行 |
country | 中国 |
country | 美国 |
country | 俄罗斯 |
country | 加拿大 |
- train_data_path 为自定义数据的文件路径,也可覆盖demo数据.默认为: "./data/train_data.csv"
- embedded_matrix_size 为嵌入矩阵大小, 根据词频保留的词数,用于构建嵌入矩阵.默认为: 10240
- validation_ratio 为划分测试数据集占总数据集比例. 默认为: 0.2
- epochs 为整个数据集迭代次数. 默认为: 512
- batch_size 为优化模型每个批次的数据条数. 默认为: 2 注意:当前2为特殊情况(因为测试数据集较小)一定记得修改
- learning_rate 为优化模型的学习速率. 默认为: 0.01
- learning_rate_decay 为学习速率每个epochs进行衰减的比率. 默认为: 0.95
- 运行过程中会在
./save_model/save/
下生成model.h5
模型文件,运行结束会生成final_model.h5
- 运行过程中会在
./save_model/logs/
下生成并不断更新一个日志文件,在项目根目录执行tensorboard --logdir=save_model/logs
即可监控模型训练过程 - 运行成功后会在
./save_model/deploy/
下生成可用于服务器部署的 pb 格式文件
:
.
└── 0
├── saved_model.pb
└── variables
├── variables.data-00000-of-00001
└── variables.index
记得修改class_dict = {0: "phone", 1: "bank", 2: "country"}
模型输出对应的值,即可得到对应的类别名称