Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

支持上下文存储 & 持续对话能力 #30

Merged
merged 10 commits into from
Feb 16, 2023
130 changes: 108 additions & 22 deletions event.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const aircode = require("aircode");
const lark = require("@larksuiteoapi/node-sdk");
var axios = require("axios");
const EventDB = aircode.db.table("event");
const MsgTable = aircode.db.table("msg"); // 用于保存历史会话的表

// 如果你不想配置环境变量,或环境变量不生效,则可以把结果填写在每一行最后的 "" 内部
const FEISHU_APP_ID = process.env.APPID || ""; // 飞书的应用 ID
Expand Down Expand Up @@ -42,32 +43,102 @@ async function reply(messageId, content) {
}
}

// 根据中英文设置不同的 prompt
function getPrompt(content) {
if (content.length === 0) {
return "";

// 根据sessionId构造用户会话
async function buildConversation(sessionId, question) {
// 根据中英文设置不同的 prompt
let prompt = "你是 ChatGPT, 一个由 OpenAI 训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。\n请回答我下面的问题\n";
if ((question[0] >= "a" && question[0] <= "z") || (question[0] >= "A" && question[0] <= "Z")) {
prompt = "You are ChatGPT, a LLM model trained by OpenAI. \nplease answer my following question\n";
}

// 从 MsgTable 表中取出历史记录构造 question
const historyMsgs = await MsgTable.where({ sessionId }).find();
for (const conversation of historyMsgs) {
prompt += "Q: " + conversation.question + "\nA: " + conversation.answer + "\n\n";
}

// 拼接最新 question
return prompt + "Q: " + question + "\nA: ";
}

// 保存用户会话
async function saveConversation(sessionId, question, answer) {
const msgSize = question.length + answer.length
const result = await MsgTable.save({
sessionId,
question,
answer,
msgSize,
});
if (result) {
// 有历史会话是否需要抛弃
await discardConversation(sessionId);
}
}

// 如果历史会话记录大于OPENAI_MAX_TOKEN,则从第一条开始抛弃超过限制的对话
async function discardConversation(sessionId) {
let totalSize = 0;
const countList = [];
const historyMsgs = await MsgTable.where({ sessionId }).sort({ createdAt: -1 }).find();
const historyMsgLen = historyMsgs.length;
for (let i = 0; i < historyMsgLen; i++) {
const msgId = historyMsgs[i]._id;
totalSize += historyMsgs[i].msgSize;
countList.push({
msgId,
totalSize,
});
}
if (
(content[0] >= "a" && content[0] <= "z") ||
(content[0] >= "A" && content[0] <= "Z")
) {
return (
"You are ChatGPT, a LLM model trained by OpenAI. \nplease answer my following question\nQ: " +
content +
"\nA: "
);
for (const c of countList) {
if (c.totalSize > OPENAI_MAX_TOKEN) {
await MsgTable.where({_id: c.msgId}).delete();
}
}
}

return (
"你是 ChatGPT, 一个由 OpenAI 训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。\n请回答我下面的问题\nQ: " +
content +
"\nA: "
);
// 清除历史会话
async function clearConversation(sessionId) {
return await MsgTable.where({ sessionId }).delete();
}

// 指令处理
async function cmdProcess(cmdParams) {
switch (cmdParams && cmdParams.action) {
case "/help":
await cmdHelp(cmdParams.messageId);
break;
case "/clear":
await cmdClear(cmdParams.sessionId, cmdParams.messageId);
break;
default:
await cmdHelp(cmdParams.messageId);
break;
}
return { code: 0 }
}

// 帮助指令
async function cmdHelp(messageId) {
helpText = `ChatGPT 指令使用指南

Usage:
/clear 清除上下文
/help 获取更多帮助
`
await reply(messageId, helpText);
}

// 清除记忆指令
async function cmdClear(sessionId, messageId) {
await clearConversation(sessionId)
await reply(messageId, "✅记忆已清除");
}

// 通过 OpenAI API 获取回复
async function getOpenAIReply(content) {
var prompt = getPrompt(content.trim());
async function getOpenAIReply(prompt) {
logger("send prompt: " + prompt);

var data = JSON.stringify({
model: OPENAI_MODEL,
Expand Down Expand Up @@ -219,6 +290,9 @@ module.exports = async function (params, context) {
if ((params.header.event_type === "im.message.receive_v1")) {
let eventId = params.header.event_id;
let messageId = params.event.message.message_id;
let chatId = params.event.message.chat_id;
let senderId = params.event.sender.sender_id.user_id;
let sessionId = chatId + senderId;

// 对于同一个事件,只处理一次
const count = await EventDB.where({ event_id: eventId }).count();
Expand All @@ -239,7 +313,13 @@ module.exports = async function (params, context) {
// 是文本消息,直接回复
const userInput = JSON.parse(params.event.message.content);
const question = userInput.text.replace("@_user_1", "");
const openaiResponse = await getOpenAIReply(question);
const action = question.trim();
if (action.startsWith("/")) {
return await cmdProcess({action, sessionId, messageId});
}
const prompt = await buildConversation(sessionId, question);
const openaiResponse = await getOpenAIReply(prompt);
await saveConversation(sessionId, question, openaiResponse)
await reply(messageId, openaiResponse);
return { code: 0 };
}
Expand All @@ -261,7 +341,13 @@ module.exports = async function (params, context) {
}
const userInput = JSON.parse(params.event.message.content);
const question = userInput.text.replace("@_user_1", "");
const openaiResponse = await getOpenAIReply(question);
const action = question.trim();
if (action.startsWith("/")) {
return await cmdProcess({action, sessionId, messageId});
}
const prompt = await buildConversation(sessionId, question);
const openaiResponse = await getOpenAIReply(prompt);
await saveConversation(sessionId, question, openaiResponse)
await reply(messageId, openaiResponse);
return { code: 0 };
}
Expand Down