Skip to content

Commit

Permalink
Support iframe chatbot. (infiniflow#3961)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#3909

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh authored Dec 10, 2024
1 parent 601d741 commit e9b8c30
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 141 deletions.
5 changes: 4 additions & 1 deletion agent/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,7 @@ def set_global_param(self, **kwargs):
q["value"] = v

def get_preset_param(self):
return self.components["begin"]["obj"]._param.query
return self.components["begin"]["obj"]._param.query

def get_component_input_elements(self, cpnnm):
return self.components["begin"]["obj"].get_input_elements()
17 changes: 16 additions & 1 deletion agent/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def get_input(self):
self._param.inputs.append({"component_id": q["component_id"],
"content": "\n".join(
[str(d["content"]) for d in outs[-1].to_dict('records')])})
elif q["value"]:
elif q.get("value"):
self._param.inputs.append({"component_id": None, "content": q["value"]})
outs.append(pd.DataFrame([{"content": q["value"]}]))
if outs:
Expand Down Expand Up @@ -526,6 +526,21 @@ def get_input(self):

return df

def get_input_elements(self):
assert self._param.query, "Please identify input parameters firstly."
eles = []
for q in self._param.query:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
continue

eles.append({"key": q["key"], "component_id": q["component_id"]})
else:
eles.append({"key": q["key"]})
return eles

def get_stream_input(self):
reversed_cpnts = []
if len(self._canvas.path) > 1:
Expand Down
8 changes: 8 additions & 0 deletions agent/component/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from functools import partial
import pandas as pd
from api.db import LLMType
from api.db.services.conversation_service import structure_answer
from api.db.services.dialog_service import message_fit_in
from api.db.services.llm_service import LLMBundle
from api import settings
Expand Down Expand Up @@ -104,9 +105,16 @@ def set_cite(self, retrieval_res, answer):
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
res = {"content": answer, "reference": reference}
res = structure_answer(None, res, "", "")

return res

def get_input_elements(self):
if self._param.parameters:
return self._param.parameters

return [{"key": "input"}]

def _run(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
prompt = self._param.prompt
Expand Down
20 changes: 20 additions & 0 deletions api/apps/canvas_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,26 @@ def reset():
return server_error_response(e)


@manager.route('/input_elements', methods=['GET']) # noqa: F821
@validate_request("id", "component_id")
@login_required
def input_elements():
req = request.json
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)

canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
return get_json_result(data=canvas.get_component_input_elements(req["component_id"]))
except Exception as e:
return server_error_response(e)


@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
@validate_request("db_type", "database", "username", "host", "port", "password")
@login_required
Expand Down
52 changes: 36 additions & 16 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import traceback
from copy import deepcopy

from api.db.services.conversation_service import ConversationService
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.user_service import UserTenantService
from flask import request, Response
from flask_login import login_required, current_user
Expand Down Expand Up @@ -90,6 +90,21 @@ def get():
return get_json_result(
data=False, message='Only owner of conversation authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)

def get_value(d, k1, k2):
return d.get(k1, d.get(k2))

for ref in conv.reference:
ref["chunks"] = [{
"id": get_value(ck, "chunk_id", "id"),
"content": get_value(ck, "content", "content_with_weight"),
"document_id": get_value(ck, "doc_id", "document_id"),
"document_name": get_value(ck, "docnm_kwd", "document_name"),
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
"image_id": get_value(ck, "image_id", "img_id"),
"positions": get_value(ck, "positions", "position_int"),
} for ck in ref.get("chunks", [])]

conv = conv.to_dict()
return get_json_result(data=conv)
except Exception as e:
Expand Down Expand Up @@ -132,6 +147,7 @@ def list_convsersation():
dialog_id=dialog_id,
order_by=ConversationService.model.create_time,
reverse=True)

convs = [d.to_dict() for d in convs]
return get_json_result(data=convs)
except Exception as e:
Expand Down Expand Up @@ -164,24 +180,29 @@ def completion():

if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})

def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
"id": message_id, "prompt": ans.get("prompt", "")}
ans["id"] = message_id
else:
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))

for ref in conv.reference:
ref["chunks"] = [{
"id": get_value(ck, "chunk_id", "id"),
"content": get_value(ck, "content", "content_with_weight"),
"document_id": get_value(ck, "doc_id", "document_id"),
"document_name": get_value(ck, "docnm_kwd", "document_name"),
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
"image_id": get_value(ck, "image_id", "img_id"),
"positions": get_value(ck, "positions", "position_int"),
} for ck in ref.get("chunks", [])]

if not conv.reference:
conv.reference = []
conv.reference.append({"chunks": [], "doc_aggs": []})
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
ans = structure_answer(conv, ans, message_id, conv.id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
Expand All @@ -202,8 +223,7 @@ def stream():
else:
answer = None
for ans in chat(dia, msg, **req):
answer = ans
fillin_conv(ans)
answer = structure_answer(conv, ans, message_id, req["conversation_id"])
ConversationService.update_by_id(conv.id, conv.to_dict())
break
return get_json_result(data=answer)
Expand Down
10 changes: 10 additions & 0 deletions api/apps/sdk/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def update(tenant_id, chat_id, session_id):
@token_required
def chat_completion(tenant_id, chat_id):
req = request.json
if not DialogService.query(tenant_id=tenant_id,id=chat_id,status=StatusEnum.VALID.value):
return get_error_data_result(f"You don't own the chat {chat_id}")
if req.get("session_id"):
if not ConversationService.query(id=req["session_id"],dialog_id=chat_id):
return get_error_data_result(f"You don't own the session {req['session_id']}")
if req.get("stream", True):
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
Expand All @@ -133,6 +138,11 @@ def chat_completion(tenant_id, chat_id):
@token_required
def agent_completions(tenant_id, agent_id):
req = request.json
if not UserCanvasService.query(user_id=tenant_id,id=agent_id):
return get_error_data_result(f"You don't own the agent {agent_id}")
if req.get("session_id"):
if not API4ConversationService.query(id=req["session_id"],dialog_id=agent_id):
return get_error_data_result(f"You don't own the session {req['session_id']}")
if req.get("stream", True):
resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
Expand Down
44 changes: 16 additions & 28 deletions api/db/services/canvas_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#
import json
import traceback
from uuid import uuid4
from agent.canvas import Canvas
from api.db.db_models import DB, CanvasTemplate, UserCanvas, API4Conversation
Expand Down Expand Up @@ -58,6 +59,8 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
canvas = Canvas(cvs.dsl, tenant_id)
canvas.reset()
message_id = str(uuid4())

if not session_id:
session_id = get_uuid()
Expand All @@ -84,40 +87,24 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
return
conv = API4Conversation(**conv)
else:
session_id = session_id
e, conv = API4ConversationService.get_by_id(session_id)
assert e, "Session not found!"
canvas = Canvas(json.dumps(conv.dsl), tenant_id)

if not conv.message:
conv.message = []
messages = conv.message
question = {
"role": "user",
"content": question,
"id": str(uuid4())
}
messages.append(question)
msg = []
for m in messages:
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]

if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
canvas.messages.append({"role": "user", "content": question, "id": message_id})
canvas.add_user_input(question)
if not conv.message:
conv.message = []
conv.message.append({
"role": "user",
"content": question,
"id": message_id
})
if not conv.reference:
conv.reference = []
conv.reference.append({"chunks": [], "doc_aggs": []})

final_ans = {"reference": [], "content": ""}

canvas.add_user_input(msg[-1]["content"])

if stream:
try:
for ans in canvas.run(stream=stream):
Expand All @@ -141,6 +128,7 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
conv.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
traceback.print_exc()
conv.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
yield "data:" + json.dumps({"code": 500, "message": str(e),
Expand Down
39 changes: 21 additions & 18 deletions api/db/services/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from api.db.services.dialog_service import DialogService, chat
from api.utils import get_uuid
import json
from copy import deepcopy


class ConversationService(CommonService):
Expand Down Expand Up @@ -49,30 +48,35 @@ def structure_answer(conv, ans, message_id, session_id):
reference = ans["reference"]
if not isinstance(reference, dict):
reference = {}
temp_reference = deepcopy(ans["reference"])
if not conv.reference:
conv.reference.append(temp_reference)
else:
conv.reference[-1] = temp_reference
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
ans["reference"] = {}

def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
chunk_list = [{
"id": chunk["chunk_id"],
"content": chunk.get("content") if chunk.get("content") else chunk.get("content_with_content"),
"document_id": chunk["doc_id"],
"document_name": chunk["docnm_kwd"],
"dataset_id": chunk["kb_id"],
"image_id": chunk["image_id"],
"similarity": chunk["similarity"],
"vector_similarity": chunk["vector_similarity"],
"term_similarity": chunk["term_similarity"],
"positions": chunk["positions"],
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
} for chunk in reference.get("chunks", [])]

reference["chunks"] = chunk_list
ans["id"] = message_id
ans["session_id"] = session_id

if not conv:
return ans

if not conv.message:
conv.message = []
if not conv.message or conv.message[-1].get("role", "") != "assistant":
conv.message.append({"role": "assistant", "content": ans["answer"], "id": message_id})
else:
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
if conv.reference:
conv.reference[-1] = reference
return ans


Expand Down Expand Up @@ -199,7 +203,6 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg

if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})

if stream:
Expand Down
Loading

0 comments on commit e9b8c30

Please sign in to comment.