Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions examples/mcp_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json
import logging
from urllib.parse import urlparse

import fire
import httpx
from llama_stack_client import Agent, AgentEventLogger, LlamaStackClient
from llama_stack_client.lib import get_oauth_token_for_mcp_server
from rich import print as rprint

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


import tempfile
from pathlib import Path

TMP_DIR = Path(tempfile.gettempdir()) / "llama-stack"
TMP_DIR.mkdir(parents=True, exist_ok=True)

CACHE_FILE = TMP_DIR / "mcp_tokens.json"


def main(model_id: str, mcp_servers: str = "https://mcp.asana.com/sse", llama_stack_url: str = "http://localhost:8321"):
"""Run an MCP agent with the specified model and servers.

Args:
model_id: The model to use for the agent.
mcp_servers: Comma-separated list of MCP servers to use for the agent.
llama_stack_url: The URL of the Llama Stack server to use.

Examples:
python mcp_agent.py "meta-llama/Llama-4-Scout-17B-16E-Instruct" \
-m "https://mcp.asana.com/sse" \
-l "http://localhost:8321"
"""
client = LlamaStackClient(base_url=llama_stack_url)
if not check_model_exists(client, model_id):
return

servers = [s.strip() for s in mcp_servers.split(",")]
mcp_headers = get_and_cache_mcp_headers(servers)

toolgroup_ids = []
for server in servers:
# we cannot use "/" in the toolgroup_id because we have some tech debt from earlier which uses
# "/" as a separator for toolgroup_id and tool_name. We should fix this in the future.
group_id = urlparse(server).netloc
toolgroup_ids.append(group_id)
client.toolgroups.register(
toolgroup_id=group_id, mcp_endpoint=dict(uri=server), provider_id="model-context-protocol"
)

agent = Agent(
client=client,
model=model_id,
instructions="You are a helpful assistant who can use tools when necessary to answer questions.",
tools=toolgroup_ids,
extra_headers={
"X-LlamaStack-Provider-Data": json.dumps(
{
"mcp_headers": mcp_headers,
}
),
},
)

session_id = agent.create_session("test-session")

while True:
user_input = input("Enter a question: ")
if user_input.lower() in ("q", "quit", "exit", "bye", ""):
print("Exiting...")
break
response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": user_input}],
stream=True,
)
for log in AgentEventLogger().log(response):
log.print()


def check_model_exists(client: LlamaStackClient, model_id: str) -> bool:
models = [m for m in client.models.list() if m.model_type == "llm"]
if model_id not in [m.identifier for m in models]:
rprint(f"[red]Model {model_id} not found[/red]")
rprint("[yellow]Available models:[/yellow]")
for model in models:
rprint(f" - {model.identifier}")
return False
return True


def get_and_cache_mcp_headers(servers: list[str]) -> dict[str, dict[str, str]]:
mcp_headers = {}

logger.info(f"Using cache file: {CACHE_FILE} for MCP tokens")
tokens = {}
if CACHE_FILE.exists():
with open(CACHE_FILE, "r") as f:
tokens = json.load(f)
for server, token in tokens.items():
mcp_headers[server] = {
"Authorization": f"Bearer {token}",
}

for server in servers:
with httpx.Client() as http_client:
headers = mcp_headers.get(server, {})
try:
response = http_client.get(server, headers=headers, timeout=1.0)
except httpx.TimeoutException:
# timeout means success since we did not get an immediate 40X
continue

if response.status_code in (401, 403):
logger.info(f"Server {server} requires authentication, getting token")
token = get_oauth_token_for_mcp_server(server)
if not token:
logger.error(f"No token obtained for {server}")
return

tokens[server] = token
mcp_headers[server] = {
"Authorization": f"Bearer {token}",
}

with open(CACHE_FILE, "w") as f:
json.dump(tokens, f, indent=2)

return mcp_headers


if __name__ == "__main__":
fire.Fire(main)
81 changes: 1 addition & 80 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ version = "0.2.7"
description = "The official Python library for the llama-stack-client API"
dynamic = ["readme"]
license = "Apache-2.0"
authors = [
{ name = "Llama Stack Client", email = "dev-feedback@llama-stack-client.com" },
]
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
dependencies = [
"httpx>=0.23.0, <1",
"pydantic>=1.9.0, <3",
Expand Down Expand Up @@ -48,52 +46,6 @@ Repository = "https://github.com/meta-llama/llama-stack-client-python"



[tool.rye]
managed = true
# version pins are in requirements-dev.lock
dev-dependencies = [
"pyright>=1.1.359",
"mypy",
"respx",
"pytest",
"pytest-asyncio",
"ruff",
"time-machine",
"nox",
"dirty-equals>=0.6.0",
"importlib-metadata>=6.7.0",
"rich>=13.7.1",
]

[tool.rye.scripts]
format = { chain = [
"format:ruff",
"format:docs",
"fix:ruff",
]}
"format:black" = "black ."
"format:docs" = "python scripts/utils/ruffen-docs.py README.md api.md"
"format:ruff" = "ruff format"
"format:isort" = "isort ."

"lint" = { chain = [
"check:ruff",
"typecheck",
"check:importable",
]}
"check:ruff" = "ruff check ."
"fix:ruff" = "ruff check --fix ."

"check:importable" = "python -c 'import llama_stack_client'"

typecheck = { chain = [
"typecheck:pyright",
"typecheck:mypy"
]}
"typecheck:pyright" = "pyright"
"typecheck:verify-types" = "pyright --verifytypes llama_stack_client --ignoreexternal"
"typecheck:mypy" = "mypy ."

[build-system]
requires = ["hatchling", "hatch-fancy-pypi-readme"]
build-backend = "hatchling.build"
Expand Down Expand Up @@ -132,37 +84,6 @@ path = "README.md"
pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)'
replacement = '[\1](https://github.com/meta-llama/llama-stack-client-python/tree/main/\g<2>)'

[tool.black]
line-length = 120

[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "--tb=short"
xfail_strict = true
asyncio_mode = "auto"
filterwarnings = [
"error"
]

[tool.pyright]
# this enables practically every flag given by pyright.
# there are a couple of flags that are still disabled by
# default in strict mode as they are experimental and niche.
typeCheckingMode = "strict"
pythonVersion = "3.7"

exclude = [
"_dev",
".venv",
".nox",
]

reportImplicitOverride = true

reportImportCycles = false
reportPrivateUsage = false


[tool.ruff]
line-length = 120
output-format = "grouped"
Expand Down
4 changes: 4 additions & 0 deletions src/llama_stack_client/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .tools.mcp_oauth import get_oauth_token_for_mcp_server

__all__ = ["get_oauth_token_for_mcp_server"]
Loading
Loading