diff --git a/examples/mcp_agent.py b/examples/mcp_agent.py new file mode 100644 index 00000000..8dfd2f69 --- /dev/null +++ b/examples/mcp_agent.py @@ -0,0 +1,136 @@ +import json +import logging +from urllib.parse import urlparse + +import fire +import httpx +from llama_stack_client import Agent, AgentEventLogger, LlamaStackClient +from llama_stack_client.lib import get_oauth_token_for_mcp_server +from rich import print as rprint + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +import tempfile +from pathlib import Path + +TMP_DIR = Path(tempfile.gettempdir()) / "llama-stack" +TMP_DIR.mkdir(parents=True, exist_ok=True) + +CACHE_FILE = TMP_DIR / "mcp_tokens.json" + + +def main(model_id: str, mcp_servers: str = "https://mcp.asana.com/sse", llama_stack_url: str = "http://localhost:8321"): + """Run an MCP agent with the specified model and servers. + + Args: + model_id: The model to use for the agent. + mcp_servers: Comma-separated list of MCP servers to use for the agent. + llama_stack_url: The URL of the Llama Stack server to use. + + Examples: + python mcp_agent.py "meta-llama/Llama-4-Scout-17B-16E-Instruct" \ + -m "https://mcp.asana.com/sse" \ + -l "http://localhost:8321" + """ + client = LlamaStackClient(base_url=llama_stack_url) + if not check_model_exists(client, model_id): + return + + servers = [s.strip() for s in mcp_servers.split(",")] + mcp_headers = get_and_cache_mcp_headers(servers) + + toolgroup_ids = [] + for server in servers: + # we cannot use "/" in the toolgroup_id because we have some tech debt from earlier which uses + # "/" as a separator for toolgroup_id and tool_name. We should fix this in the future. + group_id = urlparse(server).netloc + toolgroup_ids.append(group_id) + client.toolgroups.register( + toolgroup_id=group_id, mcp_endpoint=dict(uri=server), provider_id="model-context-protocol" + ) + + agent = Agent( + client=client, + model=model_id, + instructions="You are a helpful assistant who can use tools when necessary to answer questions.", + tools=toolgroup_ids, + extra_headers={ + "X-LlamaStack-Provider-Data": json.dumps( + { + "mcp_headers": mcp_headers, + } + ), + }, + ) + + session_id = agent.create_session("test-session") + + while True: + user_input = input("Enter a question: ") + if user_input.lower() in ("q", "quit", "exit", "bye", ""): + print("Exiting...") + break + response = agent.create_turn( + session_id=session_id, + messages=[{"role": "user", "content": user_input}], + stream=True, + ) + for log in AgentEventLogger().log(response): + log.print() + + +def check_model_exists(client: LlamaStackClient, model_id: str) -> bool: + models = [m for m in client.models.list() if m.model_type == "llm"] + if model_id not in [m.identifier for m in models]: + rprint(f"[red]Model {model_id} not found[/red]") + rprint("[yellow]Available models:[/yellow]") + for model in models: + rprint(f" - {model.identifier}") + return False + return True + + +def get_and_cache_mcp_headers(servers: list[str]) -> dict[str, dict[str, str]]: + mcp_headers = {} + + logger.info(f"Using cache file: {CACHE_FILE} for MCP tokens") + tokens = {} + if CACHE_FILE.exists(): + with open(CACHE_FILE, "r") as f: + tokens = json.load(f) + for server, token in tokens.items(): + mcp_headers[server] = { + "Authorization": f"Bearer {token}", + } + + for server in servers: + with httpx.Client() as http_client: + headers = mcp_headers.get(server, {}) + try: + response = http_client.get(server, headers=headers, timeout=1.0) + except httpx.TimeoutException: + # timeout means success since we did not get an immediate 40X + continue + + if response.status_code in (401, 403): + logger.info(f"Server {server} requires authentication, getting token") + token = get_oauth_token_for_mcp_server(server) + if not token: + logger.error(f"No token obtained for {server}") + return + + tokens[server] = token + mcp_headers[server] = { + "Authorization": f"Bearer {token}", + } + + with open(CACHE_FILE, "w") as f: + json.dump(tokens, f, indent=2) + + return mcp_headers + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/pyproject.toml b/pyproject.toml index 2677cfeb..7e5c55e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,9 +4,7 @@ version = "0.2.7" description = "The official Python library for the llama-stack-client API" dynamic = ["readme"] license = "Apache-2.0" -authors = [ -{ name = "Llama Stack Client", email = "dev-feedback@llama-stack-client.com" }, -] +authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }] dependencies = [ "httpx>=0.23.0, <1", "pydantic>=1.9.0, <3", @@ -48,52 +46,6 @@ Repository = "https://github.com/meta-llama/llama-stack-client-python" -[tool.rye] -managed = true -# version pins are in requirements-dev.lock -dev-dependencies = [ - "pyright>=1.1.359", - "mypy", - "respx", - "pytest", - "pytest-asyncio", - "ruff", - "time-machine", - "nox", - "dirty-equals>=0.6.0", - "importlib-metadata>=6.7.0", - "rich>=13.7.1", -] - -[tool.rye.scripts] -format = { chain = [ - "format:ruff", - "format:docs", - "fix:ruff", -]} -"format:black" = "black ." -"format:docs" = "python scripts/utils/ruffen-docs.py README.md api.md" -"format:ruff" = "ruff format" -"format:isort" = "isort ." - -"lint" = { chain = [ - "check:ruff", - "typecheck", - "check:importable", -]} -"check:ruff" = "ruff check ." -"fix:ruff" = "ruff check --fix ." - -"check:importable" = "python -c 'import llama_stack_client'" - -typecheck = { chain = [ - "typecheck:pyright", - "typecheck:mypy" -]} -"typecheck:pyright" = "pyright" -"typecheck:verify-types" = "pyright --verifytypes llama_stack_client --ignoreexternal" -"typecheck:mypy" = "mypy ." - [build-system] requires = ["hatchling", "hatch-fancy-pypi-readme"] build-backend = "hatchling.build" @@ -132,37 +84,6 @@ path = "README.md" pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)' replacement = '[\1](https://github.com/meta-llama/llama-stack-client-python/tree/main/\g<2>)' -[tool.black] -line-length = 120 - -[tool.pytest.ini_options] -testpaths = ["tests"] -addopts = "--tb=short" -xfail_strict = true -asyncio_mode = "auto" -filterwarnings = [ - "error" -] - -[tool.pyright] -# this enables practically every flag given by pyright. -# there are a couple of flags that are still disabled by -# default in strict mode as they are experimental and niche. -typeCheckingMode = "strict" -pythonVersion = "3.7" - -exclude = [ - "_dev", - ".venv", - ".nox", -] - -reportImplicitOverride = true - -reportImportCycles = false -reportPrivateUsage = false - - [tool.ruff] line-length = 120 output-format = "grouped" diff --git a/src/llama_stack_client/lib/__init__.py b/src/llama_stack_client/lib/__init__.py index 756f351d..6bc5d151 100644 --- a/src/llama_stack_client/lib/__init__.py +++ b/src/llama_stack_client/lib/__init__.py @@ -3,3 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from .tools.mcp_oauth import get_oauth_token_for_mcp_server + +__all__ = ["get_oauth_token_for_mcp_server"] diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index aa002440..ebdc4abd 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -138,6 +138,7 @@ def __init__( output_shields: Optional[List[str]] = None, response_format: Optional[ResponseFormat] = None, enable_session_persistence: Optional[bool] = None, + extra_headers: Headers | None = None, ): """Construct an Agent with the given parameters. @@ -162,6 +163,7 @@ def __init__( :param output_shields: The output shields for the agent. :param response_format: The response format for the agent. :param enable_session_persistence: Whether to enable session persistence. + :param extra_headers: Extra headers to add to all requests sent by the agent. """ self.client = client @@ -191,21 +193,25 @@ def __init__( self.sessions = [] self.tool_parser = tool_parser self.builtin_tools = {} + self.extra_headers = extra_headers self.initialize() def initialize(self) -> None: agentic_system_create_response = self.client.agents.create( agent_config=self.agent_config, + extra_headers=self.extra_headers, ) self.agent_id = agentic_system_create_response.agent_id for tg in self.agent_config["toolgroups"]: - for tool in self.client.tools.list(toolgroup_id=tg if isinstance(tg, str) else tg.get("name")): + toolgroup_id = tg if isinstance(tg, str) else tg.get("name") + for tool in self.client.tools.list(toolgroup_id=toolgroup_id, extra_headers=self.extra_headers): self.builtin_tools[tool.identifier] = tg.get("args", {}) if isinstance(tg, dict) else {} def create_session(self, session_name: str) -> str: agentic_system_create_session_response = self.client.agents.session.create( agent_id=self.agent_id, session_name=session_name, + extra_headers=self.extra_headers, ) self.session_id = agentic_system_create_session_response.session_id self.sessions.append(self.session_id) @@ -243,6 +249,7 @@ def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: **tool_call.arguments, **self.builtin_tools[tool_call.tool_name], }, + extra_headers=self.extra_headers, ) return ToolResponseParam( call_id=tool_call.call_id, @@ -264,10 +271,13 @@ def create_turn( toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, stream: bool = True, + # TODO: deprecate this extra_headers: Headers | None = None, ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: if stream: - return self._create_turn_streaming(messages, session_id, toolgroups, documents, extra_headers=extra_headers) + return self._create_turn_streaming( + messages, session_id, toolgroups, documents, extra_headers=extra_headers or self.extra_headers + ) else: chunks = [ x @@ -276,7 +286,7 @@ def create_turn( session_id, toolgroups, documents, - extra_headers=extra_headers, + extra_headers=extra_headers or self.extra_headers, ) ] if not chunks: @@ -300,6 +310,7 @@ def _create_turn_streaming( session_id: Optional[str] = None, toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, + # TODO: deprecate this extra_headers: Headers | None = None, ) -> Iterator[AgentTurnResponseStreamChunk]: n_iter = 0 @@ -313,7 +324,7 @@ def _create_turn_streaming( stream=True, documents=documents, toolgroups=toolgroups, - extra_headers=extra_headers, + extra_headers=extra_headers or self.extra_headers, ) # 2. process turn and resume if there's a tool call @@ -350,7 +361,7 @@ def _create_turn_streaming( turn_id=turn_id, tool_responses=tool_responses, stream=True, - extra_headers=extra_headers, + extra_headers=extra_headers or self.extra_headers, ) n_iter += 1 @@ -377,6 +388,7 @@ def __init__( output_shields: Optional[List[str]] = None, response_format: Optional[ResponseFormat] = None, enable_session_persistence: Optional[bool] = None, + extra_headers: Headers | None = None, ): """Construct an Agent with the given parameters. @@ -401,6 +413,7 @@ def __init__( :param output_shields: The output shields for the agent. :param response_format: The response format for the agent. :param enable_session_persistence: Whether to enable session persistence. + :param extra_headers: Extra headers to add to all requests sent by the agent. """ self.client = client @@ -430,6 +443,7 @@ def __init__( self.sessions = [] self.tool_parser = tool_parser self.builtin_tools = {} + self.extra_headers = extra_headers self._agent_id = None if isinstance(client, LlamaStackClient): @@ -450,7 +464,7 @@ async def initialize(self) -> None: ) self._agent_id = agentic_system_create_response.agent_id for tg in self.agent_config["toolgroups"]: - for tool in await self.client.tools.list(toolgroup_id=tg): + for tool in await self.client.tools.list(toolgroup_id=tg, extra_headers=self.extra_headers): self.builtin_tools[tool.identifier] = tg.get("args", {}) if isinstance(tg, dict) else {} async def create_session(self, session_name: str) -> str: @@ -458,6 +472,7 @@ async def create_session(self, session_name: str) -> str: agentic_system_create_session_response = await self.client.agents.session.create( agent_id=self.agent_id, session_name=session_name, + extra_headers=self.extra_headers, ) self.session_id = agentic_system_create_session_response.session_id self.sessions.append(self.session_id) @@ -509,6 +524,7 @@ async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: **tool_call.arguments, **self.builtin_tools[tool_call.tool_name], }, + extra_headers=self.extra_headers, ) return ToolResponseParam( call_id=tool_call.call_id, @@ -541,6 +557,7 @@ async def _create_turn_streaming( stream=True, documents=documents, toolgroups=toolgroups, + extra_headers=self.extra_headers, ) # 2. process turn and resume if there's a tool call @@ -576,6 +593,7 @@ async def _create_turn_streaming( turn_id=turn_id, tool_responses=tool_responses, stream=True, + extra_headers=self.extra_headers, ) n_iter += 1 diff --git a/src/llama_stack_client/lib/tools/mcp_oauth.py b/src/llama_stack_client/lib/tools/mcp_oauth.py new file mode 100644 index 00000000..a3c03416 --- /dev/null +++ b/src/llama_stack_client/lib/tools/mcp_oauth.py @@ -0,0 +1,297 @@ +import asyncio +import base64 +import hashlib +import logging +import os +import socket +import threading +import time +import urllib.parse +import uuid +from http.server import BaseHTTPRequestHandler, HTTPServer + +import fire +import requests + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class McpOAuthHelper: + """A simpler helper for OAuth2 authentication with MCP servers with OAuth discovery.""" + + def __init__(self, server_url): + self.server_url = server_url + self.server_base_url = get_base_url(server_url) + self.access_token = None + + # For PKCE (Proof Key for Code Exchange) + self.code_verifier = None + self.code_challenge = None + + # OAuth client registration + self.client_id = None + self.client_secret = None + self.registered_redirect_uris = [] + + # Callback server + self.callback_port = find_available_port(8000, 8100) + self.redirect_uri = f"http://localhost:{self.callback_port}/callback" + self.auth_code = None + self.auth_error = None + self.http_server = None + self.server_thread = None + + # Software statement for DCR + self.software_statement = { + "software_id": "simple-mcp-client", + "software_version": "1.0.0", + "software_name": "Simple MCP Client Example", + "software_description": "A simple MCP client for demonstration purposes", + "software_uri": "https://github.com/example/simple-mcp-client", + "redirect_uris": [self.redirect_uri], + "client_name": "Simple MCP Client", + "client_uri": "https://example.com/mcp-client", + "token_endpoint_auth_method": "none", # Public client + } + + def discover_auth_endpoints(self): + """ + Discover the OAuth server metadata according to RFC8414. + MCP servers MUST support this discovery mechanism. + """ + well_known_url = f"{self.server_base_url}/.well-known/oauth-authorization-server" + response = requests.get(well_known_url) + if response.status_code == 200: + metadata = response.json() + logger.info("✅ Successfully discovered OAuth metadata") + return metadata + + raise Exception(f"OAuth metadata discovery failed with status: {response.status_code}") + + def register_client(self, registration_endpoint): + headers = {"Content-Type": "application/json"} + + registration_request = { + "client_name": self.software_statement["client_name"], + "redirect_uris": [self.redirect_uri], + "token_endpoint_auth_method": "none", # Public client + "grant_types": ["authorization_code"], + "response_types": ["code"], + "scope": "openid", + "software_id": self.software_statement["software_id"], + "software_version": self.software_statement["software_version"], + } + + response = requests.post(registration_endpoint, headers=headers, json=registration_request) + + if response.status_code in (201, 200): + registration_data = response.json() + self.client_id = registration_data.get("client_id") + self.client_secret = registration_data.get("client_secret") + self.registered_redirect_uris = registration_data.get("redirect_uris", [self.redirect_uri]) + + logger.info(f"Client ID: {self.client_id}") + return registration_data + + raise Exception(f"Client registration failed: {response.status_code}") + + def generate_pkce_values(self): + """Generate PKCE code verifier and challenge.""" + # Generate a random code verifier + code_verifier = base64.urlsafe_b64encode(os.urandom(32)).decode("utf-8").rstrip("=") + + # Generate the code challenge using SHA-256 + code_challenge_digest = hashlib.sha256(code_verifier.encode("utf-8")).digest() + code_challenge = base64.urlsafe_b64encode(code_challenge_digest).decode("utf-8").rstrip("=") + + self.code_verifier = code_verifier + self.code_challenge = code_challenge + + return code_verifier, code_challenge + + def stop_server(self): + time.sleep(1) + if self.http_server: + self.http_server.shutdown() + + def start_callback_server(self): + def auth_callback(auth_code: str | None, error: str | None): + logger.info(f"Authorization callback received: auth_code={auth_code}, error={error}") + self.auth_code = auth_code + self.auth_error = error + threading.Thread(target=self.stop_server).start() + + self.http_server = CallbackServer(("localhost", self.callback_port), auth_callback) + + self.server_thread = threading.Thread(target=self.http_server.serve_forever) + self.server_thread.daemon = True + self.server_thread.start() + + logger.info(f"🌐 Callback server started on port {self.callback_port}") + + def exchange_code_for_token(self, auth_code, token_endpoint): + logger.info("Exchanging authorization code for access token...") + + data = { + "grant_type": "authorization_code", + "code": auth_code, + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + "code_verifier": self.code_verifier, + } + if self.client_secret: + data["client_secret"] = self.client_secret + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + response = requests.post(token_endpoint, data=data, headers=headers) + if response.status_code == 200: + token_data = response.json() + self.access_token = token_data.get("access_token") + logger.info(f"✅ Successfully obtained access token: {self.access_token}") + return self.access_token + + raise Exception(f"Failed to exchange code for token: {response.status_code}") + + def initiate_auth_flow(self): + auth_metadata = self.discover_auth_endpoints() + registration_endpoint = auth_metadata.get("registration_endpoint") + if registration_endpoint and not self.client_id: + self.register_client(registration_endpoint) + + self.generate_pkce_values() + + self.start_callback_server() + + auth_url = auth_metadata.get("authorization_endpoint") + if not auth_url: + raise Exception("No authorization endpoint in metadata") + + token_endpoint = auth_metadata.get("token_endpoint") + if not token_endpoint: + raise Exception("No token endpoint in metadata") + + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "response_type": "code", + "state": str(uuid.uuid4()), # Random state + "code_challenge": self.code_challenge, + "code_challenge_method": "S256", + "scope": "openid", # Add appropriate scopes for Asana + } + + full_auth_url = f"{auth_url}?{urllib.parse.urlencode(params)}" + logger.info(f"Opening browser to authorize URL: {full_auth_url}") + logger.info("Flow will continue after you log in") + + import webbrowser + + webbrowser.open(full_auth_url) + self.server_thread.join(60) # Wait up to 1 minute + + if self.auth_code: + return self.exchange_code_for_token(self.auth_code, token_endpoint) + elif self.auth_error: + logger.error(f"Authorization failed: {self.auth_error}") + return None + else: + logger.error("Timed out waiting for authorization") + return None + + +def get_base_url(url): + parsed_url = urllib.parse.urlparse(url) + return f"{parsed_url.scheme}://{parsed_url.netloc}" + + +def find_available_port(start_port, end_port): + """Find an available port within a range.""" + for port in range(start_port, end_port + 1): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return port + except socket.error: + continue + raise RuntimeError(f"No available ports in range {start_port}-{end_port}") + + +class CallbackServer(HTTPServer): + class OAuthCallbackHandler(BaseHTTPRequestHandler): + def do_GET(self): + parsed_path = urllib.parse.urlparse(self.path) + query_params = urllib.parse.parse_qs(parsed_path.query) + + if parsed_path.path == "/callback": + auth_code = query_params.get("code", [None])[0] + error = query_params.get("error", [None])[0] + + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + + if error: + self.wfile.write(b"Authorization Failed") + self.wfile.write(f"

Authorization Failed

Error: {error}

".encode()) + self.server.auth_code_callback(None, error) + elif auth_code: + self.wfile.write(b"Authorization Successful") + self.wfile.write( + b"

Authorization Successful

You can close this window now.

" + ) + # Call the callback with the auth code + self.server.auth_code_callback(auth_code, None) + else: + self.wfile.write(b"Authorization Failed") + self.wfile.write( + b"

Authorization Failed

No authorization code received.

" + ) + self.server.auth_code_callback(None, "No authorization code received") + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + """Override to suppress HTTP server logs.""" + return + + def __init__(self, server_address, auth_code_callback): + self.auth_code_callback = auth_code_callback + super().__init__(server_address, self.OAuthCallbackHandler) + + +def get_oauth_token_for_mcp_server(url: str) -> str | None: + helper = McpOAuthHelper(url) + return helper.initiate_auth_flow() + + +async def run_main(url: str): + from mcp import ClientSession + from mcp.client.sse import sse_client + + token = get_oauth_token_for_mcp_server(url) + if not token: + return + + headers = { + "Authorization": f"Bearer {token}", + } + + async with sse_client(url, headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + result = await session.list_tools() + + logger.info(f"Tools: {len(result.tools)}, showing first 5:") + for t in result.tools[:5]: + logger.info(f"{t.name}: {t.description}") + + +def main(url: str): + asyncio.run(run_main(url)) + + +if __name__ == "__main__": + fire.Fire(main)