Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions python/packages/core/agent_framework/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,13 +1405,17 @@ async def _stream_generator() -> Any:
call_middleware = kwargs.pop("middleware", None)
instance_middleware = getattr(self, "middleware", None)

# Merge middleware from both sources, filtering for chat middleware only
all_middleware: list[ChatMiddleware | ChatMiddlewareCallable] = _merge_and_filter_chat_middleware(
instance_middleware, call_middleware
)
# Merge all middleware and separate by type
middleware = categorize_middleware(instance_middleware, call_middleware)
chat_middleware_list = middleware["chat"]
function_middleware_list = middleware["function"]

# Pass function middleware to function invocation system if present
if function_middleware_list:
kwargs["_function_middleware_pipeline"] = FunctionMiddlewarePipeline(function_middleware_list)

# If no middleware, use original method
if not all_middleware:
# If no chat middleware, use original method
if not chat_middleware_list:
async for update in original_get_streaming_response(self, messages, **kwargs):
yield update
return
Expand All @@ -1422,7 +1426,7 @@ async def _stream_generator() -> Any:
# Extract chat_options or create default
chat_options = kwargs.pop("chat_options", ChatOptions())

pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type]
pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type]
context = ChatContext(
chat_client=self,
messages=prepare_messages(messages),
Expand Down Expand Up @@ -1536,27 +1540,40 @@ def _merge_and_filter_chat_middleware(
return middleware["chat"] # type: ignore[return-value]


def extract_and_merge_function_middleware(chat_client: Any, **kwargs: Any) -> None:
def extract_and_merge_function_middleware(
chat_client: Any, kwargs: dict[str, Any]
) -> "FunctionMiddlewarePipeline | None":
"""Extract function middleware from chat client and merge with existing pipeline in kwargs.

Args:
chat_client: The chat client instance to extract middleware from.
kwargs: Dictionary containing middleware and pipeline information.

Keyword Args:
**kwargs: Dictionary containing middleware and pipeline information.
Returns:
A FunctionMiddlewarePipeline if function middleware is found, None otherwise.
"""
# Check if a pipeline was already created by use_chat_middleware
existing_pipeline: FunctionMiddlewarePipeline | None = kwargs.get("_function_middleware_pipeline")

# Get middleware sources
client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None
run_level_middleware = kwargs.get("middleware")
existing_pipeline = kwargs.get("_function_middleware_pipeline")

# Extract existing pipeline middlewares if present
existing_middlewares = existing_pipeline._middlewares if existing_pipeline else None
# If we have an existing pipeline but no additional middleware sources, return it directly
if existing_pipeline and not client_middleware and not run_level_middleware:
return existing_pipeline

# If we have an existing pipeline with additional middleware, we need to merge
# Extract existing pipeline middlewares if present - cast to list[Middleware] for type compatibility
existing_middlewares: list[Middleware] | None = list(existing_pipeline._middlewares) if existing_pipeline else None

# Create combined pipeline from all sources using existing helper
combined_pipeline = create_function_middleware_pipeline(
client_middleware, run_level_middleware, existing_middlewares
)

if combined_pipeline:
kwargs["_function_middleware_pipeline"] = combined_pipeline
# If we have an existing pipeline but combined is None (no new middlewares), return existing
if existing_pipeline and combined_pipeline is None:
return existing_pipeline

return combined_pipeline
Loading
Loading