这是一个使用 Llama2 权重设计的循环神经网络(RNN)模型,旨在无限期运行(终身)。
- llama2: 可以使用 llama2 各种版本模型的权重
- rnn: 每个token的 attention sequence 长度固定,计算和内存开销不会增加,理论上支持无限长序列,可以从硬盘读取和保存记忆
- .c: 可以在本地设备上运行,甚至是移动平台
这里的主要思路是将 max_seq_len 拆分为长度更小为 mem_seq_len 的chunks,不同chunks之间通过RNN的形式连接中间状态。这样做的主要优势在于推理的时间和空间复杂度会更少,并支持无限长序列。
method\seq_len | 256 | 512 | 1024 | 4096 | 32768 |
---|---|---|---|---|---|
Attention Interpolation | 1.0583 | 1.3335 | 2.2598 | 4.1215 | 4.7887 |
Memory Attention | 1.0751 | 1.0611 | 1.0562 | 1.0321 | 0.9400 |
以上模型是在训练长度为256的tinistory的文本生成任务上训练的,性能用的是token预测的交叉熵损失。以下长度外推方案都是在没有微调模型的结果。其中 memory attention 的attention seq len为32。从结果可以看出:
- attention外推的各种改进方案只能缓解泛化问题,但仍然不会有序列长度收益存在。也就是说,随序列长度的增加性能会变好
- 而memory attention可以实现外推长度的性能收益,而且明确有长度越长性能越好。
more
示例
# mode = llama2Rnn_toy20M_q80.bin, train_seq_len = 256, attention_seq_len = 32
Memory at (null) is not exists.
Initialize memory.
(2023-12-19 14:16:05) User: 你能为我做什么?
Assistant: 作为一个AI助手,我没有个人情感和立场,但我可以提供一些有关人类发展的实用建议,如下所示:
1. 坚持学习和成长:不断学习和发展自己可以提高自己的技能和知识,帮助你在不同的领域中更好地发挥自己的潜力。
2. 建立良好的人际关系:与他人建立良好的人际关系可以让你获得支持和帮助,并且可以在社交场合中获得更多的机会和信息。
3. 坚持健康生活方式:保持健康的生活方式可以提高身体和心理健康,减轻压力和焦虑,从而提高生活质量。
4. 坚持适度的运动:适度的运动可以提高身体素质和心理健康,同时也可以减轻压力和焦虑,从而提高生活质量。
5. 保持积极的心态:保持积极的心态可以帮助你更好地应对生活中的挑战和困难,从而提高生活质量。
(2023-12-19 14:16:20) User:
能够看到虽然模型的训练长度只有256,而注意力长度只有32,却能生成更长的连贯回复。
要编译llama2Rnn.c
代码,有以下两种选择:
要快速编译,不使用 OpenMP 支持,请使用以下命令:
make runfast
要编译并支持 OpenMP,请使用以下命令:
make runomp
要无限期运行 Llama2RNN 模型,请使用以下命令:
./runqm llama2Rnn_toy20M_q80.bin -z llama2_tokenizer.bin -o mem20M.bin -m chat
见siyuanseever/llama2Rnn: How to train Llama2Rnn in torch (github.com)
- 202312.28
- 添加训练代码
- 2023.12.19
- 增加中文模型
- 2023.11.13
- 优化memory save,包括kv cache和token position
- 2023.11.06
- update 20M(22M) chat model: memory length 从32增加到128(val loss 2.1 -> 1.6)
- 增加记忆管理功能
- 2023.11.03
- 量化代码
- release 20M chat model
model | settings |
---|---|
20M | 英文模型,数据为wikipedia |
178M | 英文模型,数据为wikipedia |
178M_zh | 中文模型,数据包括moss |
- 调查并合并
run.cu
(CUDA) - 添加更多模型,如 1B 和 7B
- (LoRA)Llama2 模型微调
- 添加训练代码
- 支持 .txt 文档输入
- 感知物理时间
- molloc prompt 时可能有溢出问题?
- chat encode 有内存访问问题?
当前仓库基于llama2.c构建。
MIT