Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Howie/fully customizable event handler #14

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ To achieve this, when users submit a message to the web server, the web server w

## Local Development

1. Run `pip install -r requirements.txt`.
1. Run `pip install -r ./src/requirements.txt`.

2. Make sure that the `.env` file exists.

Expand Down
Binary file not shown.
118 changes: 69 additions & 49 deletions src/quartapp/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.

from typing import Any
from typing import AsyncGenerator, Dict, Optional, Tuple
from quart import Blueprint, jsonify, request, Response, render_template, current_app

import asyncio
Expand All @@ -12,16 +12,54 @@
from azure.identity import DefaultAzureCredential

from azure.ai.projects.models import (
MessageDeltaTextContent,
MessageDeltaChunk,
ThreadMessage,
FileSearchTool,
AsyncToolSet,
FilePurpose,
AgentStreamEvent
ThreadMessage,
StreamEventData,
AsyncAgentEventHandler,
Agent,
VectorStore
)

bp = Blueprint("chat", __name__, template_folder="templates", static_folder="static")
class ChatBlueprint(Blueprint):
ai_client: AIProjectClient
agent: Agent
files: Dict[str, str]
vector_store: VectorStore

bp = ChatBlueprint("chat", __name__, template_folder="templates", static_folder="static")

class MyEventHandler(AsyncAgentEventHandler[str]):

async def on_message_delta(
self, delta: "MessageDeltaChunk"
) -> Optional[str]:
stream_data = json.dumps({'content': delta.text, 'type': "message"})
return f"data: {stream_data}\n\n"

async def on_thread_message(
self, message: "ThreadMessage"
) -> Optional[str]:
if message.status == "completed":
annotations = [annotation.as_dict() for annotation in message.file_citation_annotations]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are annotations always citations? Do they have any other additional information?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are annotations always citations?
No, a while ago there was issue that the delta text has annotation, but we got no citation. Backend claimed they have fixed it. To do our best, we collect all citation and convert the anntations into clickable links only if the corresponded citation exist.

Do they have any other additional information?
Yes, but for citation, I only care about file_citation_annotations.

stream_data = json.dumps({'content': message.text_messages[0].text.value, 'annotations': annotations, 'type': "completed_message"})
return f"data: {stream_data}\n\n"
return None

async def on_error(self, data: str) -> Optional[str]:
print(f"An error occurred. Data: {data}")
stream_data = json.dumps({'type': "stream_end"})
return f"data: {stream_data}\n\n"

async def on_done(
self,
) -> Optional[str]:
stream_data = json.dumps({'type': "stream_end"})
return f"data: {stream_data}\n\n"



@bp.before_app_serving
Expand All @@ -33,15 +71,15 @@ async def start_server():
)

# TODO: add more files are not supported for citation at the moment
files = ["product_info_1.md"]
file_ids = []
for file in files:
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', file))
file_names = ["product_info_1.md", "product_info_2.md"]
files: Dict[str, str] = {}
for file_name in file_names:
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', file_name))
print(f"Uploading file {file_path}")
file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS)
file_ids.append(file.id)
files.update({file.id: file_path})

vector_store = await ai_client.agents.create_vector_store(file_ids=file_ids, name="sample_store")
vector_store = await ai_client.agents.create_vector_store_and_poll(file_ids=list(files.keys()), name="sample_store")

file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])

Expand All @@ -59,12 +97,12 @@ async def start_server():
bp.ai_client = ai_client
bp.agent = agent
bp.vector_store = vector_store
bp.file_ids = file_ids
bp.files = files


@bp.after_app_serving
async def stop_server():
for file_id in bp.file_ids:
for file_id in bp.files.keys():
await bp.ai_client.agents.delete_file(file_id)
print(f"Deleted file {file_id}")

Expand All @@ -78,47 +116,32 @@ async def stop_server():
await bp.ai_client.close()
print("Closed AIProjectClient")




@bp.get("/")
async def index():
return await render_template("index.html")

async def create_stream(thread_id: str, agent_id: str):


async def get_result(thread_id: str, agent_id: str) -> AsyncGenerator[str, None]:
async with await bp.ai_client.agents.create_stream(
thread_id=thread_id, assistant_id=agent_id
thread_id=thread_id, assistant_id=agent_id,
event_handler=MyEventHandler()
) as stream:
accumulated_text = ""

async for event_type, event_data in stream:

stream_data = None
if isinstance(event_data, MessageDeltaChunk):
for content_part in event_data.delta.content:
if isinstance(content_part, MessageDeltaTextContent):
text_value = content_part.text.value if content_part.text else "No text"
accumulated_text += text_value
print(f"Text delta received: {text_value}")
stream_data = json.dumps({'content': text_value, 'type': "message"})

elif isinstance(event_data, ThreadMessage):
print(f"ThreadMessage created. ID: {event_data.id}, Status: {event_data.status}")
if (event_data.status == "completed"):
stream_data = json.dumps({'content': accumulated_text, 'type': "completed_message"})

elif event_type == AgentStreamEvent.DONE:
print("Stream completed.")
stream_data = json.dumps({'type': "stream_end"})

if stream_data:
yield f"data: {stream_data}\n\n"
# Iterate over the steam to trigger event functions
async for _, _, event_func_return_val in stream:
if event_func_return_val:
yield event_func_return_val


@bp.route('/chat', methods=['POST'])
async def chat():
thread_id = request.cookies.get('thread_id')
agent_id = request.cookies.get('agent_id')
thread = None

if thread_id or agent_id != bp.agent.id:
if thread_id and agent_id == bp.agent.id:
# Check if the thread is still active
try:
thread = await bp.ai_client.agents.get_thread(thread_id)
Expand Down Expand Up @@ -147,24 +170,21 @@ async def chat():
'Content-Type': 'text/event-stream'
}

response = Response(create_stream(thread_id, agent_id), headers=headers)
response = Response(get_result(thread_id, agent_id), headers=headers)
response.set_cookie('thread_id', thread_id)
response.set_cookie('agent_id', agent_id)
return response

@bp.route('/fetch-document', methods=['GET'])
async def fetch_document():
filename = "product_info_1.md"

# Get the file path from the mapping
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', filename))

if not os.path.exists(file_path):
return jsonify({"error": f"File not found: {filename}"}), 404
file_id = request.args.get('file_id')
current_app.logger.info(f"Fetching document: {file_id}")
if not file_id:
return jsonify({"error": "file_id is required"}), 400

try:
# Read the file content asynchronously using asyncio.to_thread
data = await asyncio.to_thread(read_file, file_path)
data = await asyncio.to_thread(read_file, bp.files[file_id])
return Response(data, content_type='text/plain')

except Exception as e:
Expand Down
4 changes: 3 additions & 1 deletion src/quartapp/static/ChatClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ChatClient {
let accumulatedContent = '';
let isStreaming = true;
let buffer = '';
let annotations = [];

const reader = stream.getReader();
const decoder = new TextDecoder();
Expand Down Expand Up @@ -73,12 +74,13 @@ class ChatClient {
if (data.type === "completed_message") {
this.ui.clearAssistantMessage(messageDiv);
accumulatedContent = data.content;
annotations = data.annotations;
isStreaming = false;
} else {
accumulatedContent += data.content;
}

this.ui.appendAssistantMessage(messageDiv, accumulatedContent, isStreaming);
this.ui.appendAssistantMessage(messageDiv, accumulatedContent, isStreaming, annotations);
}
}

Expand Down
33 changes: 16 additions & 17 deletions src/quartapp/static/ChatUI.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,31 @@ class ChatUI {
this.attachCloseButtonListener();
}

preprocessContent(content) {
// Regular expression to find citations like 【n:m†filename.md】
const citationRegex = /\u3010(\d+):(\d+)\u2020([^\s]+)\u3011/g;
return content.replace(citationRegex, (match, _, __, filename) => {
return `<a href="#" class="file-citation" data-file-name="${filename}">${match}</a>`;
});
preprocessContent(content, annotations) {
if (annotations) {
annotations.slice().reverse().forEach(annotation => {
// the start and end index are the label of annotation. Replace them with a link
content = content.slice(0, annotation.start_index) +
`<a href="#" class="file-citation" data-file-id="${annotation.file_citation.file_id}">${annotation.text}</a>` +
content.slice(annotation.end_index);
});
}
return content;
}

addCitationClickListener() {
document.addEventListener('click', (event) => {
if (event.target.classList.contains('file-citation')) {
event.preventDefault();
const filename = event.target.getAttribute('data-file-name');
this.loadDocument(filename);
const file_id = event.target.getAttribute('data-file-id');
this.loadDocument(file_id);
}
});
}

async loadDocument(filename) {
async loadDocument(file_id) {
try {
const response = await fetch(`/fetch-document?filename=${filename}`);
const response = await fetch(`/fetch-document?file_id=${file_id}`);
if (!response.ok) {
throw new Error('Network response was not ok');
}
Expand All @@ -53,7 +57,6 @@ class ChatUI {
}

showDocument(content) {
console.log("showDocument:", content);
const docViewerSection = document.getElementById("document-viewer-section");
const chatColumn = document.getElementById("chat-container");

Expand Down Expand Up @@ -109,8 +112,7 @@ class ChatUI {
this.scrollToBottom();
}

appendAssistantMessage(messageDiv, accumulatedContent, isStreaming) {
//console.log("Accumulated Content before conversion:", accumulatedContent);
appendAssistantMessage(messageDiv, accumulatedContent, isStreaming, annotations) {
const md = window.markdownit({
html: true,
linkify: true,
Expand All @@ -120,7 +122,7 @@ class ChatUI {

try {
// Preprocess content to convert citations to links
const preprocessedContent = this.preprocessContent(accumulatedContent);
const preprocessedContent = this.preprocessContent(accumulatedContent, annotations);
// Convert the accumulated content to HTML using markdown-it
let htmlContent = md.render(preprocessedContent);
const messageTextDiv = messageDiv.querySelector(".message-text");
Expand All @@ -130,13 +132,10 @@ class ChatUI {

// Set the innerHTML of the message text div to the HTML content
messageTextDiv.innerHTML = htmlContent;
console.log("HTML set to messageTextDiv:", messageTextDiv.innerHTML);

// Use requestAnimationFrame to ensure the DOM has updated before scrolling
// Only scroll if not streaming
if (!isStreaming) {
console.log("Accumulated content:", accumulatedContent);
console.log("HTML set to messageTextDiv:", messageTextDiv.innerHTML);
requestAnimationFrame(() => {
this.scrollToBottom();
});
Expand Down
6 changes: 3 additions & 3 deletions src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ anyio==4.3.0
# watchfiles
attrs==23.2.0
# via aiohttp
azure-core==1.30.1
azure-core==1.31.0
# via azure-identity
azure-identity==1.15.0
# via quartapp (pyproject.toml)
Expand Down Expand Up @@ -146,7 +146,7 @@ sniffio==1.3.1
# openai
tqdm==4.66.2
# via openai
typing-extensions==4.11.0
typing-extensions==4.12.2
# via
# azure-core
# openai
Expand All @@ -171,5 +171,5 @@ wsproto==1.2.0
# via hypercorn
yarl==1.9.4
# via aiohttp
./packages/azure_ai_projects-1.0.0b1-py3-none-any.whl
azure-ai-projects==1.0.0b5
# via quartapp (pyproject.toml)