requirements: PyTorch, Jieba
使用QQ的聊天记录作为语料库,需要先把QQ聊天记录提取为txt格式才能使用
每个模块在干什么,在代码注释里写的还是比较清楚的
!!!由于QQ聊天过于碎片化,并且每句都很短,聊天记录最好选取关于短时间内同一个话题的内容(其实手工筛最好了)!!!
程序运行顺序
格式化聊天记录-->预处理-->生成语料库-->训练模型-->聊天测试;
运行程序
chatlog-->preprocess-->train;
会根据QQ聊天记录生成一个csv文件,作为之后的材料
- 所需要的字典index2voc和voc2index
- 根据voc2index字典生成的index句子对,句子对文件分为2行,左边是input,右边是target
根据输入的句子对生成一个batch,这个batch会作为一次训练的素材
训练用的神经网络,模型为RNN+Attention,其中RNN作为双向编码还有解码,Attention作为解码的一部分
训练和测试,使用trainBegin函数进行训练,使用chatBegin开始试着聊天 在实际进行聊天时使用了贪婪算法,在GreedySearchDecoder当中可以调参数 默认为每1000次迭代保存一次模型参数,初始时不加载模型参数(从0开始),字典默认是加 载的以便加快速度(重新生成字典很慢的)