-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1666b6b
commit 2e2f5b5
Showing
12 changed files
with
428 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
/venv | ||
/__pycache__ | ||
/.idea | ||
*/__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,151 @@ | ||
# PassageSummary | ||
PassageSummary是一个基于Gpt工作的Api,提供文章总结,话题提取等若干总结功能 | ||
PassageSummary是一个基于Gpt工作的Api,提供文章总结,话题提取等若干总结功能。 | ||
|
||
# 约束 | ||
|
||
ApiService提供基本的文本缓存服务,支持对于完全相同的文本的缓存。此外,在ApiService返回异常的时候,会返回`errno`和`message`两个参数,其分别表示全局唯一错误码和错误信息。 | ||
|
||
ApiService的返回内容均为Json格式,其中token参数表示OpenAI的ApiKey,使用`gpt-3.5-turbo`模型。 | ||
|
||
# Api | ||
|
||
得到ApiService的工作状态: | ||
|
||
``` | ||
GET / | ||
{ | ||
"message":"This server is working normally." | ||
} | ||
``` | ||
|
||
上传文本: | ||
|
||
``` | ||
POST /passage | ||
请求: | ||
{ | ||
"content":"This is a content", | ||
"token":"sk-ss" | ||
} | ||
返回: | ||
正常返回: | ||
{ | ||
"hash":"d622ac64268ce69eef0f3dc8277d06a9182f71c7" | ||
} | ||
错误返回: | ||
{ | ||
"errno":1, | ||
"message":"文本转换异常" | ||
} | ||
``` | ||
|
||
询问文本: | ||
|
||
``` | ||
POST /passage/{hash} | ||
请求: | ||
{ | ||
"action":"ask", | ||
"param":"这篇文章主要描述了什么?", | ||
"token":"sk-xxx" | ||
} | ||
返回: | ||
正常返回: | ||
{ | ||
"content":"这篇文章主要讲述了...." | ||
} | ||
错误返回: | ||
{ | ||
"errno":2, | ||
"message":"文章内容不存在..." | ||
} | ||
``` | ||
|
||
得到文章话题: | ||
|
||
``` | ||
POST /passage/{hash} | ||
请求: | ||
{ | ||
"action":"topic", | ||
"token":"sk-xxx" | ||
} | ||
返回: | ||
正常返回: | ||
{ | ||
"topics": | ||
[ | ||
{ | ||
"topic":"原神怎么你了?", | ||
"relative":"0.2" | ||
} | ||
] | ||
} | ||
//topic表示的是话题,relative表示话题相关度 | ||
错误返回: | ||
{ | ||
"errno":2, | ||
"message":"文章内容不存在..." | ||
} | ||
``` | ||
|
||
判断文章与话题的相关度: | ||
|
||
``` | ||
POST /passage/{hash} | ||
请求: | ||
{ | ||
"action":"getTopicRelative", | ||
"param":"原神,原批", | ||
"token":"sk-xxx" | ||
} | ||
返回: | ||
正常返回: | ||
{ | ||
"topics": | ||
[ | ||
{ | ||
"topic":"原神", | ||
"relative":"0.2" | ||
}, | ||
{ | ||
"topic":"原批", | ||
"relative":"0.9" | ||
} | ||
] | ||
} | ||
//topic表示的是话题,relative表示话题相关度 | ||
错误返回: | ||
{ | ||
"errno":2, | ||
"message":"文章内容不存在..." | ||
} | ||
``` | ||
|
||
总结文章: | ||
|
||
``` | ||
POST /passage/{hash} | ||
请求: | ||
{ | ||
"action":"summary", | ||
"token":"sk-xxx" | ||
} | ||
返回: | ||
正常返回: | ||
{ | ||
"content":"这篇文章讲述了一个原批转换为星批的故事。" | ||
} | ||
错误返回: | ||
{ | ||
"errno":2, | ||
"message":"文章内容不存在..." | ||
} | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import os | ||
from models import request | ||
from passages import savePassage | ||
from passages import passageAnalysis | ||
from fastapi import FastAPI | ||
|
||
# 初始化,创建缓存目录: | ||
if not os.path.exists('cache'): | ||
os.mkdir('cache') | ||
|
||
app = FastAPI() | ||
|
||
|
||
@app.get("/") | ||
async def root(): | ||
return {"message": "This server is working normally."} | ||
|
||
|
||
@app.post("/passage") | ||
async def say_hello(req: request.SavePassageRequest): | ||
return savePassage.save_passage(req.content, req.token) | ||
|
||
|
||
@app.post("/passage/{hash}") | ||
async def action(hash: str, req: request.PassageRequest): | ||
if not os.path.exists(os.path.join('cache', hash)): | ||
return { | ||
"errno": 10001, | ||
"message": "hash对应的文件不存在,或者是文件读取异常" | ||
} | ||
return passageAnalysis.dispatch_action(req, hash) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class SavePassageRequest(BaseModel): | ||
content: str | ||
token: str | ||
|
||
|
||
class PassageRequest(BaseModel): | ||
action: str | ||
param: object | None | ||
token: str |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import os | ||
import json | ||
from langchain import OpenAI | ||
from models import request | ||
from llama_index import ( | ||
GPTSimpleVectorIndex, | ||
PromptHelper, | ||
LLMPredictor, | ||
QuestionAnswerPrompt, | ||
ServiceContext | ||
) | ||
|
||
|
||
def dispatch_action(req: request.PassageRequest, hash: str): | ||
path = os.path.join('cache', hash) | ||
vector = os.path.join(path, 'index.json') | ||
match req.action: | ||
case "ask": | ||
return ask(vector, str(req.param), req.token) | ||
case "topic": | ||
return get_topics(vector, req.token) | ||
case "getTopicRelative": | ||
return get_topic_relative(vector, str(req.param), req.token) | ||
case "summary": | ||
return summary(vector, req.token) | ||
|
||
|
||
def ask(vector: str, ask_question: str, token: str): | ||
response = common_ask(vector, ask_question, token) | ||
if response.response is None: | ||
return { | ||
"errno": 10002, | ||
"message": "Gpt未返回信息,请检查Token是否有效!" | ||
} | ||
return { | ||
"content": response.response | ||
} | ||
|
||
|
||
def summary(vector: str, token: str): | ||
response = common_ask(vector, "Summary this passage in Chinese", token) | ||
if response.response is None: | ||
return { | ||
"errno": 10002, | ||
"message": "Gpt未返回信息,请检查Token是否有效!" | ||
} | ||
return { | ||
"content": response.response | ||
} | ||
|
||
|
||
def get_topics(vector: str, token: str): | ||
response = common_ask(vector, "Analysis this passage, getting the topic or key word of it. " | ||
"Returning {{xxx#relative}}. 'xxx' is the the topic or key word of the passage and " | ||
"relative is a num between 0 and 1 presenting the closeness of " | ||
"the topic or key word and the text. " | ||
"For example, returning '{{Minecraft#0.2}},{{Game#0.8}}'", token) | ||
return get_topic_with_relative(response) | ||
|
||
|
||
def get_topic_relative(vector: str, key_word: str, token: str): | ||
response = common_ask(vector, "Analysis this passage, getting the relative of the topic or key word " | ||
"with the passage. " | ||
"Returning {{xxx#relative}}. 'xxx' is the the topic or key word given and " | ||
"relative is a num between 0 and 1 presenting the closeness of " | ||
"the topic or key word and the text. " | ||
"For example, giving 'TopicA,TopicB' returning " | ||
"'{{TopicA#0.2}},{{TopicB#0.8}}'. Now the giving keyword is " + key_word, token) | ||
return get_topic_with_relative(response) | ||
|
||
|
||
def get_topic_with_relative(response): | ||
if response.response is None: | ||
return { | ||
"errno": 10002, | ||
"message": "Gpt未返回信息,请检查Token是否有效!" | ||
} | ||
topic = [] | ||
topics = response.response.split(",") | ||
for i in topics: | ||
cts = i.split("#") | ||
if len(cts) != 2: | ||
return { | ||
"errno": 10003, | ||
"message": "Gpt返回无效信息,请尝试重新请求或舍弃请求." | ||
} | ||
temp = { | ||
"topic": str(cts[0]).replace("{{", "").replace("\n", "").replace("{", ""), | ||
"relative": str(cts[1]).replace("}}", "").replace("}", "") | ||
} | ||
topic.append(temp) | ||
return { | ||
"content": topic | ||
} | ||
|
||
|
||
def common_ask(vector: str, ask_question: str, token: str, prompt: str = "Please answer the question with the context " | ||
"information"): | ||
llm_predictor, prompt_helper = prepare_llama_para(token) | ||
|
||
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper) | ||
|
||
qa_prompt_impl = ( | ||
"We have provided context information below. \n" | ||
"---------------------\n" | ||
"{context_str}" | ||
"\n---------------------\n" | ||
f"{prompt}: {{query_str}}\n" | ||
) | ||
qa_prompt = QuestionAnswerPrompt(qa_prompt_impl) | ||
index = GPTSimpleVectorIndex.load_from_disk(vector, service_context=service_context) | ||
response = index.query(ask_question, response_mode="compact", text_qa_template=qa_prompt) | ||
return response | ||
|
||
|
||
def prepare_llama_para(token): | ||
os.environ["OPENAI_API_KEY"] = token | ||
max_input_size = 4096 | ||
num_outputs = 1024 | ||
max_chunk_overlap = 20 | ||
chunk_size_limit = 1000 | ||
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="text-davinci-003", max_tokens=num_outputs)) | ||
prompt_helper = PromptHelper(max_input_size, num_outputs, max_chunk_overlap, chunk_size_limit=chunk_size_limit) | ||
return llm_predictor, prompt_helper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
from passages import passageAnalysis | ||
from llama_index import ( | ||
GPTSimpleVectorIndex, | ||
SimpleDirectoryReader, | ||
LLMPredictor, | ||
ServiceContext, | ||
PromptHelper | ||
) | ||
|
||
|
||
def save_passage(content: str, token: str): | ||
name = hash(str) | ||
res = { | ||
"hash": name | ||
} | ||
|
||
# permanently cache | ||
dir_path = os.path.join('cache', str(name)) | ||
if not os.path.exists(dir_path): | ||
os.mkdir(dir_path) | ||
else: | ||
return res | ||
file_name = os.path.join(dir_path, 'raw') | ||
index_name = os.path.join(dir_path, 'index.json') | ||
with open(file_name, "w") as file: | ||
file.write(content) | ||
llm_predictor, prompt_helper = passageAnalysis.prepare_llama_para(token) | ||
documents = SimpleDirectoryReader(dir_path).load_data() | ||
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper) | ||
index = GPTSimpleVectorIndex.from_documents( | ||
documents, service_context=service_context | ||
) | ||
index.save_to_disk(index_name) | ||
return res | ||
|
||
|
||
def get_passage_content(hash: str): | ||
with open(hash, "r") as file: | ||
return file.read() |
Oops, something went wrong.