Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions scripts/query_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _llama_index_query(args: argparse.Namespace) -> None:
"node": {
"id": node.node_id,
"text": node.text,
"metadata": node.metadata if hasattr(node, 'metadata') else {}
}
"metadata": node.metadata if hasattr(node, "metadata") else {},
},
}
if args.json:
print(json.dumps(result, indent=2))
Expand All @@ -60,7 +60,7 @@ def _llama_index_query(args: argparse.Namespace) -> None:
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
"nodes": [],
}
print(json.dumps(result, indent=2))
exit(1)
Expand All @@ -75,7 +75,7 @@ def _llama_index_query(args: argparse.Namespace) -> None:
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
"nodes": [],
}
print(json.dumps(result, indent=2))
exit(1)
Expand All @@ -85,14 +85,14 @@ def _llama_index_query(args: argparse.Namespace) -> None:
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
"nodes": [],
}
for node in nodes:
for node in nodes: # type: ignore
node_data = {
"id": node.node_id,
"score": node.score,
"text": node.text,
"metadata": node.metadata if hasattr(node, 'metadata') else {}
"metadata": node.metadata if hasattr(node, "metadata") else {},
}
result["nodes"].append(node_data)

Expand Down Expand Up @@ -169,7 +169,7 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
"nodes": [],
}
print(json.dumps(result, indent=2))
exit(1)
Expand All @@ -185,7 +185,7 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
"nodes": [],
}
print(json.dumps(result, indent=2))
exit(1)
Expand All @@ -195,15 +195,15 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
"query": args.query,
"top_k": args.top_k,
"threshold": args.threshold,
"nodes": []
"nodes": [],
}

for _id, chunk, score in zip(md["document_ids"], md["chunks"], md["scores"]):
node_data = {
"id": _id,
"score": score.score if hasattr(score, 'score') else score,
"score": score.score if hasattr(score, "score") else score,
"text": chunk,
"metadata": {}
"metadata": {},
}
result["nodes"].append(node_data)

Expand All @@ -221,6 +221,7 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
# else:
# print(content)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Utility script for querying RAG database"
Expand Down Expand Up @@ -263,15 +264,12 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
# In JSON mode, only show ERROR or higher to avoid polluting JSON output
logging.basicConfig(
level=logging.ERROR,
format='%(levelname)s: %(message)s',
stream=sys.stderr # Send logs to stderr to keep stdout clean for JSON
format="%(levelname)s: %(message)s",
stream=sys.stderr, # Send logs to stderr to keep stdout clean for JSON
)
else:
# In normal mode, show info and above
logging.basicConfig(
level=logging.INFO,
format='%(message)s'
)
logging.basicConfig(level=logging.INFO, format="%(message)s")

if not args.json:
logging.info("Command line used: " + " ".join(sys.argv))
Expand Down
Loading