Skip to content

Commit

Permalink
Support debug components. (infiniflow#3994)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#3993

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
KevinHuSh authored and isthaison committed Dec 13, 2024
1 parent f63f2f8 commit 8b1088b
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 15 deletions.
3 changes: 2 additions & 1 deletion agent/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def prepare2run(cpns):
except Exception as e:
logging.exception(f"Canvas.run got exception: {e}")
self.path[-1].append(c)
ran += 1
raise e
self.path[-1].append(c)
ran += 1
Expand Down Expand Up @@ -330,4 +331,4 @@ def get_preset_param(self):
return self.components["begin"]["obj"]._param.query

def get_component_input_elements(self, cpnnm):
return self.components["begin"]["obj"].get_input_elements()
return self.components[cpnnm]["obj"].get_input_elements()
19 changes: 14 additions & 5 deletions agent/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self):
self.message_history_window_size = 22
self.query = []
self.inputs = []
self.debug_inputs = []

def set_name(self, name: str):
self._name = name
Expand Down Expand Up @@ -410,6 +411,7 @@ def get_dependent_components(self):
def run(self, history, **kwargs):
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
json.dumps(kwargs, ensure_ascii=False)))
self._param.debug_inputs = []
try:
res = self._run(history, **kwargs)
self.set_output(res)
Expand Down Expand Up @@ -446,10 +448,13 @@ def reset(self):
setattr(self._param, self._param.output_var_name, None)
self._param.inputs = []

def set_output(self, v: partial | pd.DataFrame):
def set_output(self, v):
setattr(self._param, self._param.output_var_name, v)

def get_input(self):
if self._param.debug_inputs:
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs])

reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
Expand Down Expand Up @@ -531,14 +536,15 @@ def get_input_elements(self):
eles = []
for q in self._param.query:
if q.get("component_id"):
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
cpn_id = q["component_id"]
if cpn_id.split("@")[0].lower().find("begin") >= 0:
cpn_id, key = cpn_id.split("@")
eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
continue

eles.append({"key": q["key"], "component_id": q["component_id"]})
eles.append({"name": self._canvas.get_compnent_name(cpn_id), "key": cpn_id})
else:
eles.append({"key": q["key"]})
eles.append({"key": q["value"], "name": q["value"], "value": q["value"]})
return eles

def get_stream_input(self):
Expand All @@ -558,3 +564,6 @@ def be_output(v):

def get_component_name(self, cpn_id):
return self._canvas.get_component(cpn_id)["obj"].component_name.lower()

def debug(self, **kwargs):
return self._run([], **kwargs)
2 changes: 1 addition & 1 deletion agent/component/begin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _run(self, history, **kwargs):
def stream_output(self):
res = {"content": self._param.prologue}
yield res
self.set_output(res)
self.set_output(self.be_output(res))



18 changes: 15 additions & 3 deletions agent/component/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def set_cite(self, retrieval_res, answer):

def get_input_elements(self):
if self._param.parameters:
return self._param.parameters
return [{"key": "user"}, *self._param.parameters]

return [{"key": "input"}]
return [{"key": "user"}]

def _run(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
Expand Down Expand Up @@ -218,4 +218,16 @@ def stream_output(self, chat_mdl, prompt, retrieval_res):
res = self.set_cite(retrieval_res, answer)
yield res

self.set_output(res)
self.set_output(Generate.be_output(res))

def debug(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
prompt = self._param.prompt

for para in self._param.debug_inputs:
kwargs[para["key"]] = para["value"]

for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)

return chat_mdl.chat(prompt, [{"role": "user", "content": kwargs.get("user", "")}], self._param.gen_conf())
28 changes: 26 additions & 2 deletions api/apps/canvas_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,32 @@ def reset():


@manager.route('/input_elements', methods=['GET']) # noqa: F821
@validate_request("id", "component_id")
@login_required
def input_elements():
cvs_id = request.args.get("id")
cpn_id = request.args.get("component_id")
try:
e, user_canvas = UserCanvasService.get_by_id(cvs_id)
if not e:
return get_data_error_result(message="canvas not found.")
if not UserCanvasService.query(user_id=current_user.id, id=cvs_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)

canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
return get_json_result(data=canvas.get_component_input_elements(cpn_id))
except Exception as e:
return server_error_response(e)


@manager.route('/debug', methods=['POST']) # noqa: F821
@validate_request("id", "component_id", "params")
@login_required
def debug():
req = request.json
for p in req["params"]:
assert p.get("key")
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
Expand All @@ -202,7 +224,9 @@ def input_elements():
code=RetCode.OPERATING_ERROR)

canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
return get_json_result(data=canvas.get_component_input_elements(req["component_id"]))
canvas.get_component(req["component_id"])["obj"]._param.debug_inputs = req["params"]
df = canvas.get_component(req["component_id"])["obj"].debug()
return get_json_result(data=df.to_dict(orient="records"))
except Exception as e:
return server_error_response(e)

Expand Down
2 changes: 2 additions & 0 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def get_value(d, k1, k2):
return d.get(k1, d.get(k2))

for ref in conv.reference:
if isinstance(ref, list):
continue
ref["chunks"] = [{
"id": get_value(ck, "chunk_id", "id"),
"content": get_value(ck, "content", "content_with_weight"),
Expand Down
2 changes: 1 addition & 1 deletion api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def parse():
})
driver = Chrome(options=options)
driver.get(url)
res_headers = [r.response.headers for r in driver.requests]
res_headers = [r.response.headers for r in driver.requests if r and r.response]
if len(res_headers) > 1:
sections = RAGFlowHtmlParser().parser_txt(driver.page_source)
driver.quit()
Expand Down
4 changes: 2 additions & 2 deletions rag/svr/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
from rag.utils import rmSpace, num_tokens_from_string
from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL

Expand Down Expand Up @@ -269,7 +269,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
batch_size = 16
tts, cnts = [], []
for d in docs:
tts.append(rmSpace(d.get("docnm_kwd", "Title")))
tts.append(d.get("docnm_kwd", "Title"))
c = "\n".join(d.get("question_kwd", []))
if not c:
c = d["content_with_weight"]
Expand Down

0 comments on commit 8b1088b

Please sign in to comment.