diff --git a/examples/streaming/04_task_based_streaming.py b/examples/streaming/04_task_based_streaming.py index 34b88d9..c5efb86 100644 --- a/examples/streaming/04_task_based_streaming.py +++ b/examples/streaming/04_task_based_streaming.py @@ -107,151 +107,160 @@ def setup_routes(self, app): import threading from queue import Queue from flask import request, Response, jsonify - - # Register the tasks/stream endpoint + import time + import traceback + from contextlib import contextmanager + + STREAM_TIMEOUT = 300 + QUEUE_CHECK_INTERVAL = 0.01 + + @contextmanager + def managed_thread(target_func, daemon=True): + thread = threading.Thread(target=target_func) + thread.daemon = daemon + thread.start() + try: + yield thread + finally: + pass + + def create_sse_event(event_type, data): + if event_type: + return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" + return f"data: {json.dumps(data)}\n\n" + + def log_with_context(message, task_id=None, level="info"): + log_context = {"task_id": task_id} if task_id else {} + log_data = {"message": message, "level": level, "context": log_context} + print(f"{level.upper()}: {message}") + return log_data + + async def process_task_stream(task, queue, done_event, error_event, error_message): + task_id = task.id if task else "unknown" + task_stream = None + try: + task_stream = self.tasks_send_subscribe(task) + index = 0 + last_task_update = None + + async for task_update in task_stream: + last_task_update = task_update + update_dict = task_update.to_dict() + queue.put({ + "task": update_dict, + "index": index, + "append": True + }) + index += 1 + + if last_task_update: + final_dict = last_task_update.to_dict() + if isinstance(final_dict.get("status"), dict): + final_dict["status"]["state"] = "completed" + queue.put({ + "task": final_dict, + "index": index, + "append": True, + "lastChunk": True + }) + + except asyncio.CancelledError: + error_message[0] = "Task streaming cancelled" + error_event.set() + except Exception as e: + error_message[0] = str(e) + error_event.set() + finally: + done_event.set() + if hasattr(task_stream, 'aclose') and callable(task_stream.aclose): + try: + await task_stream.aclose() + except Exception as e: + log_with_context(f"Error closing task stream: {e}", task_id, "error") + + def run_task_stream(task, queue, done_event, error_event, error_message): + task_id = task.id if task else "unknown" + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + main_task = loop.create_task(process_task_stream(task, queue, done_event, error_event, error_message)) + try: + loop.run_until_complete(main_task) + except Exception as e: + error_message[0] = f"Event loop error: {str(e)}" + error_event.set() + finally: + pending = asyncio.all_tasks(loop) + for pending_task in pending: + pending_task.cancel() + if pending: + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + loop.close() + except Exception as e: + error_message[0] = f"Thread setup error: {str(e)}" + error_event.set() + done_event.set() + @app.route("/a2a/tasks/stream", methods=["POST"]) def handle_task_streaming(): - """Handle task streaming requests.""" + task = None try: data = request.json - print(f"Task streaming request received: {json.dumps(data)[:100]}...") - - # Parse the task - from python_a2a.models.task import Task if "task" in data: task = Task.from_dict(data["task"]) else: task = Task.from_dict(data) - - # Check if tasks_send_subscribe is implemented + if not hasattr(self, 'tasks_send_subscribe'): return jsonify({"error": "This agent does not support task streaming"}), 405 - - # Set up streaming response + def generate(): - """Generator for streaming server-sent events.""" - # Create a thread and asyncio event loop queue = Queue() done_event = threading.Event() - - def run_task_stream(): - """Run the task streaming in a dedicated thread.""" - async def process_task_stream(): - """Process the task stream.""" + error_event = threading.Event() + error_message = [None] + task_id = task.id if task else "unknown" + + with managed_thread(lambda: run_task_stream(task, queue, done_event, error_event, error_message)): + yield create_sse_event(None, {"message": "Task streaming established"}) + + deadline = time.time() + STREAM_TIMEOUT + sent_last_chunk = False + + while (not done_event.is_set() or not queue.empty()) and time.time() < deadline: + if error_event.is_set(): + yield create_sse_event("error", {"error": error_message[0] or "Unknown error"}) + break + try: - # Get the task stream generator - task_stream = self.tasks_send_subscribe(task) - - # Process each task update - index = 0 - async for task_update in task_stream: - # Convert task to dict - update_dict = task_update.to_dict() - - # Add metadata for streaming - update_data = { - "task": update_dict, - "index": index, - "append": True - } - - # Put in queue - queue.put(update_data) - print(f"Put task update {index} in queue") - index += 1 - - # Signal completion - queue.put({ - "task": task_update.to_dict(), - "index": index, - "append": True, - "lastUpdate": True - }) - print("Task streaming complete") - + if not queue.empty(): + update = queue.get(block=False) + yield create_sse_event(None, update) + if update.get("lastChunk", False): + sent_last_chunk = True + break + else: + time.sleep(QUEUE_CHECK_INTERVAL) except Exception as e: - # Log the error - print(f"Error in task streaming: {str(e)}") - import traceback - traceback.print_exc() - - # Put error in queue - queue.put({"error": str(e)}) - - finally: - # Signal we're done - done_event.set() - - # Create a new event loop - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Run the streaming process - try: - loop.run_until_complete(process_task_stream()) - finally: - loop.close() - - # Start the streaming thread - thread = threading.Thread(target=run_task_stream) - thread.daemon = True - thread.start() - - # Yield initial SSE comment - yield f": Task streaming established\n\n" - - # Process queue items until done - import time - timeout = time.time() + 60 # 60-second timeout - - while not done_event.is_set() and time.time() < timeout: - try: - # Check for update in queue - if not queue.empty(): - update = queue.get(block=False) - - # Check if it's an error - if "error" in update: - error_event = f"event: error\ndata: {json.dumps(update)}\n\n" - yield error_event - break - - # Format as SSE event - data_event = f"data: {json.dumps(update)}\n\n" - yield data_event - - # Check if it's the last update - if update.get("lastUpdate", False): - break - else: - # No data yet, sleep briefly - time.sleep(0.01) - except Exception as e: - # Error - print(f"Error in queue processing: {e}") - error_event = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" - yield error_event - break - - # If timed out - if time.time() >= timeout and not done_event.is_set(): - error_event = f"event: error\ndata: {json.dumps({'error': 'Task streaming timed out'})}\n\n" - yield error_event - - # Create the SSE response + yield create_sse_event("error", {"error": str(e)}) + break + + if time.time() >= deadline and not done_event.is_set(): + yield create_sse_event("error", {"error": "Task streaming timed out"}) + response = Response(generate(), mimetype="text/event-stream") - response.headers["Cache-Control"] = "no-cache" - response.headers["Connection"] = "keep-alive" - response.headers["X-Accel-Buffering"] = "no" + response.headers.update({ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "Transfer-Encoding": "chunked" + }) return response - + except Exception as e: - # Log the exception - print(f"Exception in task streaming handler: {str(e)}") - import traceback + task_id = task.id if task else None + log_with_context(f"Exception in task streaming handler: {str(e)}", task_id, "error") traceback.print_exc() - - # Return error return jsonify({"error": str(e)}), 500 def handle_message(self, message: Message) -> Message: @@ -356,6 +365,21 @@ async def tasks_send_subscribe(self, task: Task) -> AsyncGenerator[Task, None]: print(f"[Server] Processing task {task_id}") print(f"[Server] Query: {query[:50]}...") + task.status = TaskStatus(state=TaskState.SUBMITTED) + + print(f"[Server] Task {task_id}: Yielding SUBMITTED state") + yield Task( + id=task.id, + status=TaskStatus( + state=task.status.state, + message=task.status.message.copy() if task.status.message else None, + timestamp=task.status.timestamp + ), + message=task.message, + session_id=task.session_id, + artifacts=task.artifacts.copy() if task.artifacts else [] + ) + # Update task status to waiting (analogous to in_progress) task.status = TaskStatus(state=TaskState.WAITING) @@ -652,9 +676,15 @@ async def run_task(self, query: str) -> Dict[str, Any]: try: # Stream task updates async for task_update in self.client.tasks_send_subscribe(task): + print(f"task update is :{task_update}") # Track updates self.updates_received += 1 + print(f"Raw update {self.updates_received} artifacts:") + for i, artifact in enumerate(task_update.artifacts or []): + print(f" Artifact {i} type: {artifact.get('type', 'MISSING')}") + print(f" Artifact {i} raw: {json.dumps(artifact)[:200]}...") + # Store latest update latest_update = task_update @@ -705,6 +735,7 @@ async def _process_task_update(self, task: Task) -> None: print(status_line) # Process artifacts + print(f"Task for processing artifact: {task}") artifacts = task.artifacts or [] for artifact in artifacts: await self._process_artifact(artifact) @@ -766,6 +797,16 @@ async def _process_artifact(self, artifact: Dict[str, Any]) -> None: print(f"\n{CYAN}Partial Result:{RESET}") print(f"{text}") + elif "parts" in artifact and isinstance(artifact["parts"], list): + # Handle artifacts with parts but no type + text = "" + for part in artifact["parts"]: + if isinstance(part, dict) and part.get("type") == "text": + text += part.get("text", "") + if text: + print(f"\n{CYAN}Partial Result:{RESET}") + print(f"{text}") + elif artifact_type == "text": # Simple text artifact if "parts" in artifact: diff --git a/python_a2a/client/streaming.py b/python_a2a/client/streaming.py index 866fd13..ff47d2e 100644 --- a/python_a2a/client/streaming.py +++ b/python_a2a/client/streaming.py @@ -739,7 +739,8 @@ async def tasks_send_subscribe(self, task: Task) -> AsyncGenerator[Task, None]: if event_type == "update" or event_type == "complete": if isinstance(data_obj, dict): # Parse as a Task - current_task = Task.from_dict(data_obj) + task_data = data_obj.get("task", data_obj) + current_task = Task.from_dict(task_data) yield current_task # If this is a complete event, we're done