Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 167 additions & 126 deletions examples/streaming/04_task_based_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,151 +107,160 @@ def setup_routes(self, app):
import threading
from queue import Queue
from flask import request, Response, jsonify

# Register the tasks/stream endpoint
import time
import traceback
from contextlib import contextmanager

STREAM_TIMEOUT = 300
QUEUE_CHECK_INTERVAL = 0.01

@contextmanager
def managed_thread(target_func, daemon=True):
thread = threading.Thread(target=target_func)
thread.daemon = daemon
thread.start()
try:
yield thread
finally:
pass

def create_sse_event(event_type, data):
if event_type:
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
return f"data: {json.dumps(data)}\n\n"

def log_with_context(message, task_id=None, level="info"):
log_context = {"task_id": task_id} if task_id else {}
log_data = {"message": message, "level": level, "context": log_context}
print(f"{level.upper()}: {message}")
return log_data

async def process_task_stream(task, queue, done_event, error_event, error_message):
task_id = task.id if task else "unknown"
task_stream = None
try:
task_stream = self.tasks_send_subscribe(task)
index = 0
last_task_update = None

async for task_update in task_stream:
last_task_update = task_update
update_dict = task_update.to_dict()
queue.put({
"task": update_dict,
"index": index,
"append": True
})
index += 1

if last_task_update:
final_dict = last_task_update.to_dict()
if isinstance(final_dict.get("status"), dict):
final_dict["status"]["state"] = "completed"
queue.put({
"task": final_dict,
"index": index,
"append": True,
"lastChunk": True
})

except asyncio.CancelledError:
error_message[0] = "Task streaming cancelled"
error_event.set()
except Exception as e:
error_message[0] = str(e)
error_event.set()
finally:
done_event.set()
if hasattr(task_stream, 'aclose') and callable(task_stream.aclose):
try:
await task_stream.aclose()
except Exception as e:
log_with_context(f"Error closing task stream: {e}", task_id, "error")

def run_task_stream(task, queue, done_event, error_event, error_message):
task_id = task.id if task else "unknown"
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
main_task = loop.create_task(process_task_stream(task, queue, done_event, error_event, error_message))
try:
loop.run_until_complete(main_task)
except Exception as e:
error_message[0] = f"Event loop error: {str(e)}"
error_event.set()
finally:
pending = asyncio.all_tasks(loop)
for pending_task in pending:
pending_task.cancel()
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
except Exception as e:
error_message[0] = f"Thread setup error: {str(e)}"
error_event.set()
done_event.set()

@app.route("/a2a/tasks/stream", methods=["POST"])
def handle_task_streaming():
"""Handle task streaming requests."""
task = None
try:
data = request.json
print(f"Task streaming request received: {json.dumps(data)[:100]}...")

# Parse the task
from python_a2a.models.task import Task
if "task" in data:
task = Task.from_dict(data["task"])
else:
task = Task.from_dict(data)

# Check if tasks_send_subscribe is implemented

if not hasattr(self, 'tasks_send_subscribe'):
return jsonify({"error": "This agent does not support task streaming"}), 405

# Set up streaming response

def generate():
"""Generator for streaming server-sent events."""
# Create a thread and asyncio event loop
queue = Queue()
done_event = threading.Event()

def run_task_stream():
"""Run the task streaming in a dedicated thread."""
async def process_task_stream():
"""Process the task stream."""
error_event = threading.Event()
error_message = [None]
task_id = task.id if task else "unknown"

with managed_thread(lambda: run_task_stream(task, queue, done_event, error_event, error_message)):
yield create_sse_event(None, {"message": "Task streaming established"})

deadline = time.time() + STREAM_TIMEOUT
sent_last_chunk = False

while (not done_event.is_set() or not queue.empty()) and time.time() < deadline:
if error_event.is_set():
yield create_sse_event("error", {"error": error_message[0] or "Unknown error"})
break

try:
# Get the task stream generator
task_stream = self.tasks_send_subscribe(task)

# Process each task update
index = 0
async for task_update in task_stream:
# Convert task to dict
update_dict = task_update.to_dict()

# Add metadata for streaming
update_data = {
"task": update_dict,
"index": index,
"append": True
}

# Put in queue
queue.put(update_data)
print(f"Put task update {index} in queue")
index += 1

# Signal completion
queue.put({
"task": task_update.to_dict(),
"index": index,
"append": True,
"lastUpdate": True
})
print("Task streaming complete")

if not queue.empty():
update = queue.get(block=False)
yield create_sse_event(None, update)
if update.get("lastChunk", False):
sent_last_chunk = True
break
else:
time.sleep(QUEUE_CHECK_INTERVAL)
except Exception as e:
# Log the error
print(f"Error in task streaming: {str(e)}")
import traceback
traceback.print_exc()

# Put error in queue
queue.put({"error": str(e)})

finally:
# Signal we're done
done_event.set()

# Create a new event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# Run the streaming process
try:
loop.run_until_complete(process_task_stream())
finally:
loop.close()

# Start the streaming thread
thread = threading.Thread(target=run_task_stream)
thread.daemon = True
thread.start()

# Yield initial SSE comment
yield f": Task streaming established\n\n"

# Process queue items until done
import time
timeout = time.time() + 60 # 60-second timeout

while not done_event.is_set() and time.time() < timeout:
try:
# Check for update in queue
if not queue.empty():
update = queue.get(block=False)

# Check if it's an error
if "error" in update:
error_event = f"event: error\ndata: {json.dumps(update)}\n\n"
yield error_event
break

# Format as SSE event
data_event = f"data: {json.dumps(update)}\n\n"
yield data_event

# Check if it's the last update
if update.get("lastUpdate", False):
break
else:
# No data yet, sleep briefly
time.sleep(0.01)
except Exception as e:
# Error
print(f"Error in queue processing: {e}")
error_event = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
yield error_event
break

# If timed out
if time.time() >= timeout and not done_event.is_set():
error_event = f"event: error\ndata: {json.dumps({'error': 'Task streaming timed out'})}\n\n"
yield error_event

# Create the SSE response
yield create_sse_event("error", {"error": str(e)})
break

if time.time() >= deadline and not done_event.is_set():
yield create_sse_event("error", {"error": "Task streaming timed out"})

response = Response(generate(), mimetype="text/event-stream")
response.headers["Cache-Control"] = "no-cache"
response.headers["Connection"] = "keep-alive"
response.headers["X-Accel-Buffering"] = "no"
response.headers.update({
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Transfer-Encoding": "chunked"
})
return response

except Exception as e:
# Log the exception
print(f"Exception in task streaming handler: {str(e)}")
import traceback
task_id = task.id if task else None
log_with_context(f"Exception in task streaming handler: {str(e)}", task_id, "error")
traceback.print_exc()

# Return error
return jsonify({"error": str(e)}), 500

def handle_message(self, message: Message) -> Message:
Expand Down Expand Up @@ -356,6 +365,21 @@ async def tasks_send_subscribe(self, task: Task) -> AsyncGenerator[Task, None]:
print(f"[Server] Processing task {task_id}")
print(f"[Server] Query: {query[:50]}...")

task.status = TaskStatus(state=TaskState.SUBMITTED)

print(f"[Server] Task {task_id}: Yielding SUBMITTED state")
yield Task(
id=task.id,
status=TaskStatus(
state=task.status.state,
message=task.status.message.copy() if task.status.message else None,
timestamp=task.status.timestamp
),
message=task.message,
session_id=task.session_id,
artifacts=task.artifacts.copy() if task.artifacts else []
)

# Update task status to waiting (analogous to in_progress)
task.status = TaskStatus(state=TaskState.WAITING)

Expand Down Expand Up @@ -652,9 +676,15 @@ async def run_task(self, query: str) -> Dict[str, Any]:
try:
# Stream task updates
async for task_update in self.client.tasks_send_subscribe(task):
print(f"task update is :{task_update}")
# Track updates
self.updates_received += 1

print(f"Raw update {self.updates_received} artifacts:")
for i, artifact in enumerate(task_update.artifacts or []):
print(f" Artifact {i} type: {artifact.get('type', 'MISSING')}")
print(f" Artifact {i} raw: {json.dumps(artifact)[:200]}...")

# Store latest update
latest_update = task_update

Expand Down Expand Up @@ -705,6 +735,7 @@ async def _process_task_update(self, task: Task) -> None:
print(status_line)

# Process artifacts
print(f"Task for processing artifact: {task}")
artifacts = task.artifacts or []
for artifact in artifacts:
await self._process_artifact(artifact)
Expand Down Expand Up @@ -766,6 +797,16 @@ async def _process_artifact(self, artifact: Dict[str, Any]) -> None:
print(f"\n{CYAN}Partial Result:{RESET}")
print(f"{text}")

elif "parts" in artifact and isinstance(artifact["parts"], list):
# Handle artifacts with parts but no type
text = ""
for part in artifact["parts"]:
if isinstance(part, dict) and part.get("type") == "text":
text += part.get("text", "")
if text:
print(f"\n{CYAN}Partial Result:{RESET}")
print(f"{text}")

elif artifact_type == "text":
# Simple text artifact
if "parts" in artifact:
Expand Down
3 changes: 2 additions & 1 deletion python_a2a/client/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,8 @@ async def tasks_send_subscribe(self, task: Task) -> AsyncGenerator[Task, None]:
if event_type == "update" or event_type == "complete":
if isinstance(data_obj, dict):
# Parse as a Task
current_task = Task.from_dict(data_obj)
task_data = data_obj.get("task", data_obj)
current_task = Task.from_dict(task_data)
yield current_task

# If this is a complete event, we're done
Expand Down