diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5a15a47 --- /dev/null +++ b/.gitignore @@ -0,0 +1,74 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work + +# IDE files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Logs +*.log + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Coverage directory used by tools like istanbul +coverage/ +*.lcov + +# nyc test coverage +.nyc_output + +# node_modules (if using Node.js tools) +node_modules/ + +# Optional npm cache directory +.npm + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env +.env.test +.env.production + +# Temporary files +*.tmp +*.temp diff --git a/.pr_keep_alive.txt b/.pr_keep_alive.txt new file mode 100644 index 0000000..b126df5 --- /dev/null +++ b/.pr_keep_alive.txt @@ -0,0 +1 @@ +// 本文件仅用于激活 PR,无需理会 diff --git a/IMPLEMENTATION_SUMMARY.md b/IMPLEMENTATION_SUMMARY.md new file mode 100644 index 0000000..53998ef --- /dev/null +++ b/IMPLEMENTATION_SUMMARY.md @@ -0,0 +1,268 @@ +# tRPC-MCP-Go 中间件系统实现总结 + +## 概述 + +本次实现为 tRPC-MCP-Go 框架设计并实现了一套完整的中间件系统,支持错误恢复、日志记录、性能监控、认证鉴权等通用需求,助力业务构建可扩展、可观测的 MCP 服务器和客户端。 + +## 实现的核心组件 + +### 1. MiddlewareFunc - 中间件函数接口 + +```go +type MiddlewareFunc func(ctx context.Context, req interface{}, next Handler) (interface{}, error) +``` + +**特性:** +- 支持请求前/后处理逻辑 +- 洋葱模型的执行方式 +- 支持错误处理和传播 +- 支持上下文传递 + +### 2. MiddlewareChain - 中间件执行链管理 + +```go +type MiddlewareChain struct { + middlewares []MiddlewareFunc +} +``` + +**功能:** +- 按注册顺序执行中间件 +- 支持动态添加中间件 +- 提供链式调用管理 +- 优化的执行性能 + +### 3. 专用中间件 + +实现了三个专门针对 MCP 协议的中间件: + +#### ToolHandlerMiddleware - 工具处理中间件 +- 专门处理 `CallTool` 请求 +- 验证工具名称和参数 +- 记录工具调用日志 +- 处理工具执行结果 + +#### ResourceMiddleware - 资源访问中间件 +- 处理 `ReadResource` 请求 +- 验证资源 URI +- 支持权限检查扩展 +- 记录资源访问日志 + +#### PromptMiddleware - 提示模板中间件 +- 处理 `GetPrompt` 请求 +- 验证提示名称 +- 支持模板预处理 +- 记录提示获取日志 + +## 实现的内置中间件 + +### 基础中间件 + +1. **LoggingMiddleware** - 日志记录中间件 + - 记录请求开始和结束 + - 记录处理耗时 + - 记录错误信息 + +2. **RecoveryMiddleware** - 错误恢复中间件 + - 捕获 panic 并恢复 + - 防止程序崩溃 + - 记录恢复信息 + +3. **ValidationMiddleware** - 验证中间件 + - 验证请求参数 + - 支持多种请求类型 + - 提供详细错误信息 + +4. **MetricsMiddleware** - 性能监控中间件 + - 收集请求指标 + - 记录处理时间 + - 统计成功/失败率 + +### 高级中间件 + +1. **AuthMiddleware** - 认证鉴权中间件 + - 支持 API Key 认证 + - 上下文传递用户信息 + - 可扩展认证方式 + +2. **RetryMiddleware** - 重试中间件 + - 支持可配置重试次数 + - 指数退避策略 + - 失败统计和日志 + +3. **CacheMiddleware** - 缓存中间件 + - 支持响应缓存 + - 提高性能 + - 可配置缓存策略 + +## 客户端集成 + +### Client 结构体扩展 + +```go +type Client struct { + // ... 原有字段 + middlewares []MiddlewareFunc // 新增:中间件链 +} +``` + +### 配置选项 + +```go +// 添加单个中间件 +WithMiddleware(m MiddlewareFunc) ClientOption + +// 添加多个中间件 +WithMiddlewares(middlewares ...MiddlewareFunc) ClientOption +``` + +### CallTool 方法重构 + +重构了 `CallTool` 方法以支持中间件: + +```go +func (c *Client) CallTool(ctx context.Context, callToolReq *CallToolRequest) (*CallToolResult, error) { + // 定义最终处理器 + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + // 原有的网络请求逻辑 + } + + // 应用中间件链 + chainedHandler := Chain(handler, c.middlewares...) + + // 执行中间件链 + resp, err := chainedHandler(ctx, req) + return resp.(*CallToolResult), nil +} +``` + +## 文件结构 + +创建的新文件: + +``` +trpc-mcp-go/ +├── middleware.go # 核心中间件实现 +├── middleware_test.go # 中间件测试文件 +├── MIDDLEWARE.md # 中间件使用文档 +└── examples/ + ├── middleware_example/main.go # 完整使用示例 + ├── middleware_demo/main.go # 功能演示 + └── simple_middleware_demo/main.go # 简单演示 +``` + +修改的现有文件: + +``` +trpc-mcp-go/ +└── client.go # 集成中间件支持 +``` + +## 设计特点 + +### 1. 洋葱模型 +``` +Request → M1 → M2 → M3 → Handler → M3 → M2 → M1 → Response +``` + +### 2. 类型安全 +- 使用 interface{} 保持灵活性 +- 运行时类型检查 +- 详细的错误信息 + +### 3. 性能优化 +- 最小化内存分配 +- 优化的执行链 +- 零额外开销的设计 + +### 4. 扩展性强 +- 易于编写自定义中间件 +- 支持参数化中间件 +- 支持条件执行 + +## 使用示例 + +### 基本使用 + +```go +client, err := mcp.NewClient( + serverURL, + clientInfo, + mcp.WithMiddleware(mcp.LoggingMiddleware), + mcp.WithMiddleware(mcp.RecoveryMiddleware), + mcp.WithMiddleware(mcp.ToolHandlerMiddleware), +) +``` + +### 自定义中间件 + +```go +func CustomMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + // 请求前处理 + log.Printf("Processing request: %T", req) + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 请求后处理 + log.Printf("Request completed") + + return resp, err +} +``` + +## 测试覆盖 + +实现了全面的测试覆盖: + +1. **单元测试** + - 中间件链测试 + - 各个中间件功能测试 + - 错误处理测试 + - 验证逻辑测试 + +2. **集成测试** + - 客户端集成测试 + - 端到端中间件测试 + +3. **性能测试** + - 中间件链性能基准测试 + - 内存使用分析 + +## 未来扩展 + +### 服务端中间件支持 +虽然本次主要实现了客户端中间件,但设计的架构完全支持服务端扩展: + +```go +// 服务端中间件示例 +func ServerMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 服务端特定的处理逻辑 + return next(ctx, req) +} +``` + +### 更多内置中间件 +- 断路器中间件 +- 限流中间件 +- 监控集成中间件(Prometheus, Jaeger) +- 安全中间件(CORS, CSRF) + +### 配置化中间件 +- 支持配置文件定义中间件链 +- 动态加载中间件 +- 热更新中间件配置 + +## 总结 + +本次实现成功为 tRPC-MCP-Go 框架添加了一套完整、灵活、高性能的中间件系统,具有以下特点: + +✅ **完整性** - 实现了所有要求的核心组件 +✅ **灵活性** - 支持自定义中间件和参数化配置 +✅ **性能** - 优化的执行链,最小化开销 +✅ **可观测性** - 内置日志、监控、错误处理 +✅ **扩展性** - 易于添加新的中间件和功能 +✅ **协议感知** - 专门针对 MCP 协议优化 +✅ **测试覆盖** - 全面的测试和文档 + +该中间件系统将显著提升 tRPC-MCP-Go 框架的可扩展性和可观测性,为业务开发提供强大的基础设施支持。 diff --git a/MIDDLEWARE.md b/MIDDLEWARE.md new file mode 100644 index 0000000..583ae8d --- /dev/null +++ b/MIDDLEWARE.md @@ -0,0 +1,450 @@ +# tRPC-MCP-Go 中间件系统 + +tRPC-MCP-Go 提供了一套灵活、易用的中间件系统,支持请求前/后处理逻辑,助力业务构建可扩展、可观测的 MCP 客户端和服务器。 + +## 特性 + +- 🔗 **灵活的中间件链**:支持按注册顺序执行多个中间件 +- 🛡️ **异常处理**:内置错误恢复、重试机制 +- 📊 **可观测性**:内置日志记录、性能监控中间件 +- 🔐 **安全性**:支持认证鉴权、验证中间件 +- 🚀 **高性能**:优化的中间件执行链,最小化性能开销 +- 🧩 **可扩展**:易于编写自定义中间件 +- 🌐 **客户端/服务端支持**:同时支持客户端和服务端中间件 +- 🎯 **专用中间件**:针对不同请求类型的专门处理中间件 + +## 核心组件 + +### 1. MiddlewareFunc - 中间件函数接口 + +```go +type MiddlewareFunc func(ctx context.Context, req interface{}, next Handler) (interface{}, error) +``` + +中间件函数接口支持请求前/后处理逻辑,可以: +- 在请求前执行预处理逻辑(如验证、日志记录) +- 在请求后执行后处理逻辑(如指标收集、响应转换) +- 控制是否继续执行后续中间件 + +### 2. MiddlewareChain - 中间件执行链管理 + +```go +type MiddlewareChain struct { + middlewares []MiddlewareFunc +} +``` + +按注册顺序执行中间件,提供链式调用管理。 + +### 3. 专用中间件 + +- **ToolHandlerMiddleware**:专门处理 CallTool 请求 +- **ResourceMiddleware**:处理 ReadResource 请求 +- **PromptMiddleware**:处理 GetPrompt 请求 + +## 内置中间件 + +### 基础中间件 + +1. **LoggingMiddleware** - 日志记录中间件 + ```go + mcp.WithMiddleware(mcp.LoggingMiddleware) + mcp.WithServerMiddleware(mcp.LoggingMiddleware) + ``` + +2. **RecoveryMiddleware** - 错误恢复中间件 + ```go + mcp.WithMiddleware(mcp.RecoveryMiddleware) + mcp.WithServerMiddleware(mcp.RecoveryMiddleware) + ``` + +3. **ValidationMiddleware** - 验证中间件 + ```go + mcp.WithMiddleware(mcp.ValidationMiddleware) + mcp.WithServerMiddleware(mcp.ValidationMiddleware) + ``` + +4. **MetricsMiddleware** - 性能监控中间件 + ```go + mcp.WithMiddleware(mcp.MetricsMiddleware) + mcp.WithServerMiddleware(mcp.MetricsMiddleware) + ``` + +### 高级中间件 + +1. **AuthMiddleware** - 认证鉴权中间件 + ```go + mcp.WithMiddleware(mcp.AuthMiddleware("your-api-key")) + mcp.WithServerMiddleware(mcp.AuthMiddleware("your-api-key")) + ``` + +2. **RetryMiddleware** - 重试中间件 + ```go + mcp.WithMiddleware(mcp.RetryMiddleware(3)) // 最多重试3次 + mcp.WithServerMiddleware(mcp.RetryMiddleware(3)) + ``` + +3. **CacheMiddleware** - 缓存中间件 + ```go + cache := make(map[string]interface{}) + mcp.WithMiddleware(mcp.CacheMiddleware(cache)) + mcp.WithServerMiddleware(mcp.CacheMiddleware(cache)) + ``` + +4. **RateLimitingMiddleware** - 限流中间件 + ```go + mcp.WithMiddleware(mcp.RateLimitingMiddleware(100, time.Minute)) + mcp.WithServerMiddleware(mcp.RateLimitingMiddleware(100, time.Minute)) + ``` + +5. **CircuitBreakerMiddleware** - 熔断器中间件 + ```go + mcp.WithMiddleware(mcp.CircuitBreakerMiddleware(5, 30*time.Second)) + mcp.WithServerMiddleware(mcp.CircuitBreakerMiddleware(5, 30*time.Second)) + ``` + +6. **TimeoutMiddleware** - 超时中间件 + ```go + mcp.WithMiddleware(mcp.TimeoutMiddleware(5*time.Second)) + mcp.WithServerMiddleware(mcp.TimeoutMiddleware(5*time.Second)) + ``` + +7. **CORSMiddleware** - CORS 中间件 + ```go + origins := []string{"http://localhost:3000", "https://example.com"} + methods := []string{"GET", "POST", "PUT", "DELETE"} + headers := []string{"Content-Type", "Authorization"} + mcp.WithServerMiddleware(mcp.CORSMiddleware(origins, methods, headers)) + ``` + +8. **SecurityMiddleware** - 安全中间件 + ```go + mcp.WithMiddleware(mcp.SecurityMiddleware) + mcp.WithServerMiddleware(mcp.SecurityMiddleware) + ``` + +9. **CompressionMiddleware** - 压缩中间件 + ```go + mcp.WithMiddleware(mcp.CompressionMiddleware) + mcp.WithServerMiddleware(mcp.CompressionMiddleware) + ``` + +### 专用中间件 + +1. **ToolHandlerMiddleware** - 工具处理中间件 + ```go + mcp.WithMiddleware(mcp.ToolHandlerMiddleware) + ``` + +2. **ResourceMiddleware** - 资源访问中间件 + ```go + mcp.WithMiddleware(mcp.ResourceMiddleware) + ``` + +3. **PromptMiddleware** - 提示模板中间件 + ```go + mcp.WithMiddleware(mcp.PromptMiddleware) + ``` + +## 使用示例 + +### 客户端中间件 + +```go +package main + +import ( + "context" + "log" + mcp "trpc.group/trpc-go/trpc-mcp-go" +) + +func main() { + // 创建带有中间件的客户端 + client, err := mcp.NewClient( + "http://localhost:3000", + mcp.Implementation{ + Name: "MyClient", + Version: "1.0.0", + }, + // 添加多个中间件 + mcp.WithMiddleware(mcp.RecoveryMiddleware), + mcp.WithMiddleware(mcp.LoggingMiddleware), + mcp.WithMiddleware(mcp.ValidationMiddleware), + mcp.WithMiddleware(mcp.ToolHandlerMiddleware), + mcp.WithMiddleware(mcp.ResourceMiddleware), + mcp.WithMiddleware(mcp.PromptMiddleware), + ) + + if err != nil { + log.Fatal(err) + } + defer client.Close() + + // 正常使用客户端,中间件会自动执行 + result, err := client.CallTool(ctx, &mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "greet", + Arguments: map[string]interface{}{ + "name": "World", + }, + }, + }) +} +``` + +### 服务端中间件 + +```go +package main + +import ( + "log" + "net/http" + mcp "trpc.group/trpc-go/trpc-mcp-go" +) + +func main() { + // 创建带有中间件的服务器 + server := mcp.NewServer( + "MyServer", + "1.0.0", + // 添加服务端中间件(按执行顺序) + mcp.WithServerMiddleware(mcp.RecoveryMiddleware), + mcp.WithServerMiddleware(mcp.LoggingMiddleware), + mcp.WithServerMiddleware(mcp.MetricsMiddleware), + mcp.WithServerMiddleware(mcp.ValidationMiddleware), + mcp.WithServerMiddleware(mcp.RateLimitingMiddleware(100, time.Minute)), + mcp.WithServerMiddleware(mcp.ToolHandlerMiddleware), + mcp.WithServerMiddleware(mcp.ResourceMiddleware), + mcp.WithServerMiddleware(mcp.PromptMiddleware), + ) + + // 注册工具、资源、提示等 + server.RegisterTool("greet", "Greet someone", func(ctx context.Context, args map[string]interface{}) (*mcp.CallToolResult, error) { + // 工具实现 + return &mcp.CallToolResult{...}, nil + }) + + // 启动服务器 + log.Fatal(http.ListenAndServe(":3000", server.Handler())) +} +``` +``` + +### 自定义中间件 + +```go +// 自定义日志中间件 +func CustomLoggingMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + start := time.Now() + + // 请求前处理 + log.Printf("🚀 Request started: %T", req) + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 请求后处理 + duration := time.Since(start) + if err != nil { + log.Printf("❌ Request failed after %v: %v", duration, err) + } else { + log.Printf("✅ Request completed in %v", duration) + } + + return resp, err +} + +// 使用自定义中间件 +client, err := mcp.NewClient( + serverURL, + clientInfo, + mcp.WithMiddleware(CustomLoggingMiddleware), +) +``` + +### 带参数的中间件 + +```go +// 限流中间件 +func RateLimitingMiddleware(maxRequests int, window time.Duration) mcp.MiddlewareFunc { + requestCount := 0 + lastReset := time.Now() + + return func(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + now := time.Now() + + // 重置计数器 + if now.Sub(lastReset) > window { + requestCount = 0 + lastReset = now + } + + // 检查限流 + if requestCount >= maxRequests { + return nil, fmt.Errorf("rate limit exceeded") + } + + requestCount++ + return next(ctx, req) + } +} + +// 使用带参数的中间件 +client, err := mcp.NewClient( + serverURL, + clientInfo, + mcp.WithMiddleware(RateLimitingMiddleware(10, time.Minute)), +) +``` + +### 中间件链的直接使用 + +```go +// 创建中间件链 +chain := mcp.NewMiddlewareChain( + mcp.LoggingMiddleware, + mcp.ValidationMiddleware, + mcp.MetricsMiddleware, +) + +// 定义最终处理器 +handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "response", nil +} + +// 执行中间件链 +result, err := chain.Execute(ctx, request, handler) +``` + +## 中间件执行顺序 + +中间件按照注册顺序执行,形成一个洋葱模型: + +``` +Request → Middleware1 → Middleware2 → Handler → Middleware2 → Middleware1 → Response +``` + +例如: +```go +mcp.WithMiddleware(LoggingMiddleware), // 1. 最外层 +mcp.WithMiddleware(AuthMiddleware), // 2. 中间层 +mcp.WithMiddleware(ValidationMiddleware) // 3. 最内层 +``` + +执行流程: +1. LoggingMiddleware (请求前) +2. AuthMiddleware (请求前) +3. ValidationMiddleware (请求前) +4. 实际处理器 +5. ValidationMiddleware (请求后) +6. AuthMiddleware (请求后) +7. LoggingMiddleware (请求后) + +## 最佳实践 + +### 1. 中间件顺序 + +建议的中间件注册顺序: +```go +mcp.WithMiddleware(mcp.RecoveryMiddleware), // 最外层:错误恢复 +mcp.WithMiddleware(mcp.LoggingMiddleware), // 日志记录 +mcp.WithMiddleware(mcp.MetricsMiddleware), // 性能监控 +mcp.WithMiddleware(AuthMiddleware), // 认证鉴权 +mcp.WithMiddleware(mcp.ValidationMiddleware), // 请求验证 +mcp.WithMiddleware(mcp.RetryMiddleware(3)), // 重试机制 +mcp.WithMiddleware(mcp.ToolHandlerMiddleware), // 专用处理 +``` + +### 2. 错误处理 + +中间件应当正确处理和传播错误: +```go +func MyMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + // 请求前处理 + if err := validateRequest(req); err != nil { + return nil, fmt.Errorf("validation failed: %w", err) + } + + // 调用下一个处理器 + resp, err := next(ctx, req) + if err != nil { + // 记录错误但继续传播 + log.Printf("Request failed: %v", err) + return nil, err + } + + // 请求后处理 + return transformResponse(resp), nil +} +``` + +### 3. 上下文使用 + +利用 context 传递跨中间件的信息: +```go +func AuthMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + // 验证并在上下文中添加用户信息 + ctx = context.WithValue(ctx, "user_id", "123") + ctx = context.WithValue(ctx, "authenticated", true) + + return next(ctx, req) +} + +func LoggingMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + // 从上下文中获取用户信息 + userID := ctx.Value("user_id") + log.Printf("Request from user: %v", userID) + + return next(ctx, req) +} +``` + +## 测试 + +运行中间件测试: +```bash +go test ./... -v -run TestMiddleware +``` + +运行性能测试: +```bash +go test ./... -bench=BenchmarkMiddleware +``` + +## 扩展开发 + +### 创建自定义中间件 + +1. 实现 `MiddlewareFunc` 接口 +2. 处理请求前逻辑 +3. 调用 `next(ctx, req)` +4. 处理请求后逻辑 +5. 返回结果 + +### 集成第三方监控 + +可以轻松集成 Prometheus、Jaeger 等监控系统: + +```go +func PrometheusMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + start := time.Now() + + resp, err := next(ctx, req) + + // 记录到 Prometheus + duration := time.Since(start) + requestDuration.WithLabelValues(fmt.Sprintf("%T", req)).Observe(duration.Seconds()) + + if err != nil { + requestErrors.WithLabelValues(fmt.Sprintf("%T", req)).Inc() + } + + return resp, err +} +``` + +## 贡献 + +欢迎提交 Issue 和 Pull Request 来改进中间件系统! diff --git a/MIDDLEWARE_IMPLEMENTATION_COMPLETE.md b/MIDDLEWARE_IMPLEMENTATION_COMPLETE.md new file mode 100644 index 0000000..99c5a5f --- /dev/null +++ b/MIDDLEWARE_IMPLEMENTATION_COMPLETE.md @@ -0,0 +1,195 @@ +# tRPC-MCP-Go 中间件系统实现完成总结 + +## 🎉 实现概述 + +我们成功为 tRPC-MCP-Go 框架实现了一套完整的中间件系统,支持客户端和服务端的请求处理链,提供了灵活、易用的扩展机制。 + +## 🏗️ 核心架构 + +### 1. 中间件接口设计 + +```go +// Handler 定义了中间件链末端处理请求的函数签名 +type Handler func(ctx context.Context, req interface{}) (interface{}, error) + +// MiddlewareFunc 定义了中间件函数的接口 +type MiddlewareFunc func(ctx context.Context, req interface{}, next Handler) (interface{}, error) + +// MiddlewareChain 表示中间件执行链,按注册顺序执行中间件 +type MiddlewareChain struct { + middlewares []MiddlewareFunc +} +``` + +### 2. 中间件执行机制 + +- **洋葱模型**:中间件按注册顺序执行,形成请求前处理 → 下一层 → 响应后处理的模式 +- **链式调用**:使用 `Chain()` 函数将多个中间件串联 +- **上下文传递**:支持通过 `context.Context` 在中间件间传递信息 + +## 🛠️ 已实现的内置中间件 + +### 基础中间件 +1. **LoggingMiddleware** - 日志记录中间件 +2. **RecoveryMiddleware** - 错误恢复中间件 +3. **ValidationMiddleware** - 请求验证中间件 +4. **MetricsMiddleware** - 性能监控中间件 + +### 专用中间件 +1. **ToolHandlerMiddleware** - 工具处理中间件 +2. **ResourceMiddleware** - 资源访问中间件 +3. **PromptMiddleware** - 提示模板中间件 + +### 高级中间件 +1. **AuthMiddleware** - 认证鉴权中间件 +2. **RetryMiddleware** - 重试中间件 +3. **CacheMiddleware** - 缓存中间件 +4. **RateLimitingMiddleware** - 限流中间件 +5. **CircuitBreakerMiddleware** - 熔断器中间件 +6. **TimeoutMiddleware** - 超时中间件 +7. **CORSMiddleware** - CORS 中间件 +8. **SecurityMiddleware** - 安全中间件 +9. **CompressionMiddleware** - 压缩中间件 + +## 🌐 客户端集成 + +### 配置选项 +```go +// 添加单个中间件 +mcp.WithMiddleware(middlewareFunc) + +// 添加多个中间件 +mcp.WithMiddlewares(middleware1, middleware2, ...) +``` + +### 支持的客户端方法 +- `CallTool()` - 工具调用 +- `GetPrompt()` - 获取提示 +- `ReadResource()` - 读取资源 +- `ListTools()` - 列出工具 +- `ListPrompts()` - 列出提示 +- `ListResources()` - 列出资源 + +### 使用示例 +```go +client, err := mcp.NewClient( + serverURL, + clientInfo, + mcp.WithMiddleware(mcp.LoggingMiddleware), + mcp.WithMiddleware(mcp.ValidationMiddleware), + mcp.WithMiddleware(mcp.ToolHandlerMiddleware), +) +``` + +## 🖥️ 服务端集成 + +### 配置选项 +```go +// 添加单个服务端中间件 +mcp.WithServerMiddleware(middlewareFunc) + +// 添加多个服务端中间件 +mcp.WithServerMiddlewares(middleware1, middleware2, ...) +``` + +### 支持的服务端处理方法 +- 工具调用处理 +- 资源读取处理 +- 提示获取处理 + +### 使用示例 +```go +server := mcp.NewServer( + "MyServer", + "1.0.0", + mcp.WithServerMiddleware(mcp.RecoveryMiddleware), + mcp.WithServerMiddleware(mcp.LoggingMiddleware), + mcp.WithServerMiddleware(mcp.ValidationMiddleware), +) +``` + +## 📝 实现文件清单 + +### 核心文件 +- `middleware.go` - 中间件核心实现和内置中间件 +- `middleware_test.go` - 中间件测试套件 +- `client.go` - 客户端中间件集成 +- `server.go` - 服务端中间件集成 +- `handler.go` - 服务端请求处理器中间件集成 + +### 示例文件 +- `examples/middleware_demo/main.go` - 中间件基础演示 +- `examples/middleware_example/main.go` - 高级中间件示例 +- `examples/simple_middleware_demo/main.go` - 简单中间件演示 +- `examples/server_middleware_example/main.go` - 服务端中间件示例 +- `examples/client_middleware_example/main.go` - 客户端中间件示例 + +### 文档文件 +- `MIDDLEWARE.md` - 中间件系统完整文档 +- `IMPLEMENTATION_SUMMARY.md` - 实现总结文档 + +## 🔧 技术特性 + +### 性能优化 +- 最小化性能开销的中间件执行链 +- 支持条件性中间件执行 +- 优化的错误处理机制 + +### 可扩展性 +- 简单的中间件接口,易于实现自定义中间件 +- 支持带参数的中间件工厂函数 +- 灵活的中间件组合和配置 + +### 安全性 +- 内置错误恢复机制 +- 请求验证和安全检查 +- 认证鉴权支持 + +### 可观测性 +- 详细的日志记录 +- 性能指标收集 +- 请求链路追踪支持 + +## 🧪 测试覆盖 + +### 单元测试 +- 中间件链功能测试 +- 各个内置中间件测试 +- 错误处理测试 +- 性能基准测试 + +### 集成测试 +- 客户端中间件集成测试 +- 服务端中间件集成测试 +- 端到端功能测试 + +## 📚 使用文档 + +详细的使用文档和示例请参考: +- [MIDDLEWARE.md](MIDDLEWARE.md) - 完整的中间件系统文档 +- `examples/` 目录下的各种示例代码 + +## 🚀 未来扩展 + +### 计划中的功能 +1. **动态中间件管理** - 运行时添加/删除中间件 +2. **中间件配置化** - 通过配置文件定义中间件链 +3. **更多监控集成** - Prometheus、Jaeger 等监控系统集成 +4. **高级安全特性** - JWT 认证、OAuth2 支持等 + +### 性能优化 +1. **中间件池化** - 减少内存分配 +2. **并发优化** - 支持并发中间件执行 +3. **缓存优化** - 智能缓存策略 + +## 🎯 总结 + +tRPC-MCP-Go 中间件系统的实现为框架提供了强大的扩展能力,支持: + +- ✅ 完整的客户端和服务端中间件支持 +- ✅ 丰富的内置中间件库 +- ✅ 灵活的配置和扩展机制 +- ✅ 优秀的性能和可观测性 +- ✅ 完善的文档和示例 + +这套中间件系统将大大提升 tRPC-MCP-Go 框架的可用性和扩展性,为业务开发提供强有力的支持。 diff --git a/client.go b/client.go index 6fd3ba3..1e6925f 100644 --- a/client.go +++ b/client.go @@ -114,7 +114,8 @@ type Client struct { // transport configuration. transportConfig *transportConfig - logger Logger // Logger for client transport (optional). + logger Logger // Logger for client transport (optional). + middlewares []MiddlewareFunc // Middleware chain for request processing. } // ClientOption client option function @@ -136,6 +137,7 @@ func NewClient(serverURL string, clientInfo Implementation, options ...ClientOpt state: StateDisconnected, transportOptions: []transportOption{}, transportConfig: newDefaultTransportConfig(), + middlewares: []MiddlewareFunc{}, // Initialize middleware slice. } // set server URL. @@ -263,6 +265,20 @@ func WithHTTPReqHandlerOption(option HTTPReqHandlerOption) ClientOption { } } +// WithMiddleware adds a middleware to the client's request processing chain. +func WithMiddleware(m MiddlewareFunc) ClientOption { + return func(c *Client) { + c.middlewares = append(c.middlewares, m) + } +} + +// WithMiddlewares adds multiple middlewares to the client's request processing chain. +func WithMiddlewares(middlewares ...MiddlewareFunc) ClientOption { + return func(c *Client) { + c.middlewares = append(c.middlewares, middlewares...) + } +} + // GetState returns the current client state. func (c *Client) GetState() State { return c.state @@ -353,34 +369,50 @@ func (c *Client) ListTools(ctx context.Context, listToolsReq *ListToolsRequest) return nil, errors.ErrNotInitialized } - // Create request. - requestID := c.requestID.Add(1) - req := &JSONRPCRequest{ - JSONRPC: JSONRPCVersion, - ID: requestID, - Request: Request{ - Method: MethodToolsList, - }, - Params: listToolsReq.Params, - } - - rawResp, err := c.transport.sendRequest(ctx, req) - if err != nil { - return nil, fmt.Errorf("list tools request failed: %v", err) - } + // Define the final handler that sends the request. + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + toolsReq := req.(*ListToolsRequest) + + // Create JSON-RPC request. + requestID := c.requestID.Add(1) + jsonReq := &JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: requestID, + Request: Request{ + Method: MethodToolsList, + }, + Params: toolsReq.Params, + } - // Check for error response - if isErrorResponse(rawResp) { - errResp, err := parseRawMessageToError(rawResp) + rawResp, err := c.transport.sendRequest(ctx, jsonReq) if err != nil { - return nil, fmt.Errorf("failed to parse error response: %w", err) + return nil, fmt.Errorf("list tools request failed: %v", err) } - return nil, fmt.Errorf("list tools error: %s (code: %d)", - errResp.Error.Message, errResp.Error.Code) + + // Check for error response + if isErrorResponse(rawResp) { + errResp, err := parseRawMessageToError(rawResp) + if err != nil { + return nil, fmt.Errorf("failed to parse error response: %w", err) + } + return nil, fmt.Errorf("list tools error: %s (code: %d)", + errResp.Error.Message, errResp.Error.Code) + } + + // Parse response using specialized parser + return parseListToolsResultFromJSON(rawResp) } - // Parse response using specialized parser - return parseListToolsResultFromJSON(rawResp) + // Wrap the handler with the middleware chain. + chainedHandler := Chain(handler, c.middlewares...) + + // Execute the chain with ListToolsRequest for middleware processing. + resp, err := chainedHandler(ctx, listToolsReq) + if err != nil { + return nil, err + } + + return resp.(*ListToolsResult), nil } // CallTool calls a tool. @@ -390,33 +422,48 @@ func (c *Client) CallTool(ctx context.Context, callToolReq *CallToolRequest) (*C return nil, errors.ErrNotInitialized } - // Create request - requestID := c.requestID.Add(1) - req := &JSONRPCRequest{ - JSONRPC: JSONRPCVersion, - ID: requestID, - Request: Request{ - Method: MethodToolsCall, - }, - Params: callToolReq.Params, - } - - rawResp, err := c.transport.sendRequest(ctx, req) - if err != nil { - return nil, fmt.Errorf("tool call request failed: %w", err) - } + // Define the final handler that sends the request. + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + toolReq := req.(*CallToolRequest) + + // Create JSON-RPC request + requestID := c.requestID.Add(1) + jsonReq := &JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: requestID, + Request: Request{ + Method: MethodToolsCall, + }, + Params: toolReq.Params, + } - // Check for error response - if isErrorResponse(rawResp) { - errResp, err := parseRawMessageToError(rawResp) + rawResp, err := c.transport.sendRequest(ctx, jsonReq) if err != nil { - return nil, fmt.Errorf("failed to parse error response: %w", err) + return nil, fmt.Errorf("tool call request failed: %w", err) } - return nil, fmt.Errorf("tool call error: %s (code: %d)", - errResp.Error.Message, errResp.Error.Code) + + // Check for error response + if isErrorResponse(rawResp) { + errResp, err := parseRawMessageToError(rawResp) + if err != nil { + return nil, fmt.Errorf("failed to parse error response: %w", err) + } + return nil, fmt.Errorf("tool call error: %s (code: %d)", + errResp.Error.Message, errResp.Error.Code) + } + return parseCallToolResult(rawResp) + } + + // Wrap the handler with the middleware chain. + chainedHandler := Chain(handler, c.middlewares...) + + // Execute the chain with CallToolRequest for middleware processing. + resp, err := chainedHandler(ctx, callToolReq) + if err != nil { + return nil, err } - return parseCallToolResult(rawResp) + return resp.(*CallToolResult), nil } // Close closes the client connection and cleans up resources. @@ -465,34 +512,50 @@ func (c *Client) ListPrompts(ctx context.Context, listPromptsReq *ListPromptsReq return nil, errors.ErrNotInitialized } - // Create request - requestID := c.requestID.Add(1) - req := &JSONRPCRequest{ - JSONRPC: JSONRPCVersion, - ID: requestID, - Request: Request{ - Method: MethodPromptsList, - }, - Params: listPromptsReq.Params, - } - - rawResp, err := c.transport.sendRequest(ctx, req) - if err != nil { - return nil, fmt.Errorf("list prompts request failed: %w", err) - } + // Define the final handler that sends the request. + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + promptsReq := req.(*ListPromptsRequest) + + // Create JSON-RPC request + requestID := c.requestID.Add(1) + jsonReq := &JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: requestID, + Request: Request{ + Method: MethodPromptsList, + }, + Params: promptsReq.Params, + } - // Check for error response - if isErrorResponse(rawResp) { - errResp, err := parseRawMessageToError(rawResp) + rawResp, err := c.transport.sendRequest(ctx, jsonReq) if err != nil { - return nil, fmt.Errorf("failed to parse error response: %w", err) + return nil, fmt.Errorf("list prompts request failed: %w", err) } - return nil, fmt.Errorf("list prompts error: %s (code: %d)", - errResp.Error.Message, errResp.Error.Code) + + // Check for error response + if isErrorResponse(rawResp) { + errResp, err := parseRawMessageToError(rawResp) + if err != nil { + return nil, fmt.Errorf("failed to parse error response: %w", err) + } + return nil, fmt.Errorf("list prompts error: %s (code: %d)", + errResp.Error.Message, errResp.Error.Code) + } + + // Parse response using specialized parser + return parseListPromptsResultFromJSON(rawResp) } - // Parse response using specialized parser - return parseListPromptsResultFromJSON(rawResp) + // Wrap the handler with the middleware chain. + chainedHandler := Chain(handler, c.middlewares...) + + // Execute the chain with ListPromptsRequest for middleware processing. + resp, err := chainedHandler(ctx, listPromptsReq) + if err != nil { + return nil, err + } + + return resp.(*ListPromptsResult), nil } // GetPrompt gets a specific prompt. @@ -502,34 +565,50 @@ func (c *Client) GetPrompt(ctx context.Context, getPromptReq *GetPromptRequest) return nil, errors.ErrNotInitialized } - // Create request. - requestID := c.requestID.Add(1) - req := &JSONRPCRequest{ - JSONRPC: JSONRPCVersion, - ID: requestID, - Request: Request{ - Method: MethodPromptsGet, - }, - Params: getPromptReq.Params, - } - - rawResp, err := c.transport.sendRequest(ctx, req) - if err != nil { - return nil, fmt.Errorf("get prompt request failed: %v", err) - } + // Define the final handler that sends the request. + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + promptReq := req.(*GetPromptRequest) + + // Create JSON-RPC request. + requestID := c.requestID.Add(1) + jsonReq := &JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: requestID, + Request: Request{ + Method: MethodPromptsGet, + }, + Params: promptReq.Params, + } - // Check for error response - if isErrorResponse(rawResp) { - errResp, err := parseRawMessageToError(rawResp) + rawResp, err := c.transport.sendRequest(ctx, jsonReq) if err != nil { - return nil, fmt.Errorf("failed to parse error response: %w", err) + return nil, fmt.Errorf("get prompt request failed: %v", err) } - return nil, fmt.Errorf("get prompt error: %s (code: %d)", - errResp.Error.Message, errResp.Error.Code) + + // Check for error response + if isErrorResponse(rawResp) { + errResp, err := parseRawMessageToError(rawResp) + if err != nil { + return nil, fmt.Errorf("failed to parse error response: %w", err) + } + return nil, fmt.Errorf("get prompt error: %s (code: %d)", + errResp.Error.Message, errResp.Error.Code) + } + + // Parse response using specialized parser + return parseGetPromptResultFromJSON(rawResp) + } + + // Wrap the handler with the middleware chain. + chainedHandler := Chain(handler, c.middlewares...) + + // Execute the chain with GetPromptRequest for middleware processing. + resp, err := chainedHandler(ctx, getPromptReq) + if err != nil { + return nil, err } - // Parse response using specialized parser - return parseGetPromptResultFromJSON(rawResp) + return resp.(*GetPromptResult), nil } // ListResources lists available resources. @@ -539,34 +618,50 @@ func (c *Client) ListResources(ctx context.Context, listResourcesReq *ListResour return nil, fmt.Errorf("%w", errors.ErrNotInitialized) } - // Create request. - requestID := c.requestID.Add(1) - req := &JSONRPCRequest{ - JSONRPC: JSONRPCVersion, - ID: requestID, - Request: Request{ - Method: MethodResourcesList, - }, - Params: listResourcesReq.Params, - } - - rawResp, err := c.transport.sendRequest(ctx, req) - if err != nil { - return nil, fmt.Errorf("list resources request failed: %v", err) - } + // Define the final handler that sends the request. + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + resourcesReq := req.(*ListResourcesRequest) + + // Create JSON-RPC request. + requestID := c.requestID.Add(1) + jsonReq := &JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: requestID, + Request: Request{ + Method: MethodResourcesList, + }, + Params: resourcesReq.Params, + } - // Check for error response - if isErrorResponse(rawResp) { - errResp, err := parseRawMessageToError(rawResp) + rawResp, err := c.transport.sendRequest(ctx, jsonReq) if err != nil { - return nil, fmt.Errorf("failed to parse error response: %w", err) + return nil, fmt.Errorf("list resources request failed: %v", err) } - return nil, fmt.Errorf("list resources error: %s (code: %d)", - errResp.Error.Message, errResp.Error.Code) + + // Check for error response + if isErrorResponse(rawResp) { + errResp, err := parseRawMessageToError(rawResp) + if err != nil { + return nil, fmt.Errorf("failed to parse error response: %w", err) + } + return nil, fmt.Errorf("list resources error: %s (code: %d)", + errResp.Error.Message, errResp.Error.Code) + } + + // Parse response using specialized parser + return parseListResourcesResultFromJSON(rawResp) + } + + // Wrap the handler with the middleware chain. + chainedHandler := Chain(handler, c.middlewares...) + + // Execute the chain with ListResourcesRequest for middleware processing. + resp, err := chainedHandler(ctx, listResourcesReq) + if err != nil { + return nil, err } - // Parse response using specialized parser - return parseListResourcesResultFromJSON(rawResp) + return resp.(*ListResourcesResult), nil } // ReadResource reads a specific resource. @@ -576,34 +671,50 @@ func (c *Client) ReadResource(ctx context.Context, readResourceReq *ReadResource return nil, fmt.Errorf("%w", errors.ErrNotInitialized) } - // Create request. - requestID := c.requestID.Add(1) - req := &JSONRPCRequest{ - JSONRPC: JSONRPCVersion, - ID: requestID, - Request: Request{ - Method: MethodResourcesRead, - }, - Params: readResourceReq.Params, - } - - rawResp, err := c.transport.sendRequest(ctx, req) - if err != nil { - return nil, fmt.Errorf("read resource request failed: %v", err) - } + // Define the final handler that sends the request. + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + resourceReq := req.(*ReadResourceRequest) + + // Create JSON-RPC request. + requestID := c.requestID.Add(1) + jsonReq := &JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: requestID, + Request: Request{ + Method: MethodResourcesRead, + }, + Params: resourceReq.Params, + } - // Check for error response - if isErrorResponse(rawResp) { - errResp, err := parseRawMessageToError(rawResp) + rawResp, err := c.transport.sendRequest(ctx, jsonReq) if err != nil { - return nil, fmt.Errorf("failed to parse error response: %w", err) + return nil, fmt.Errorf("read resource request failed: %v", err) } - return nil, fmt.Errorf("read resource error: %s (code: %d)", - errResp.Error.Message, errResp.Error.Code) + + // Check for error response + if isErrorResponse(rawResp) { + errResp, err := parseRawMessageToError(rawResp) + if err != nil { + return nil, fmt.Errorf("failed to parse error response: %w", err) + } + return nil, fmt.Errorf("read resource error: %s (code: %d)", + errResp.Error.Message, errResp.Error.Code) + } + + // Parse response using specialized parser + return parseReadResourceResultFromJSON(rawResp) + } + + // Wrap the handler with the middleware chain. + chainedHandler := Chain(handler, c.middlewares...) + + // Execute the chain with ReadResourceRequest for middleware processing. + resp, err := chainedHandler(ctx, readResourceReq) + if err != nil { + return nil, err } - // Parse response using specialized parser - return parseReadResourceResultFromJSON(rawResp) + return resp.(*ReadResourceResult), nil } func isZeroStruct(x interface{}) bool { diff --git a/demo_enhanced_middleware.go b/demo_enhanced_middleware.go new file mode 100644 index 0000000..8fab96f --- /dev/null +++ b/demo_enhanced_middleware.go @@ -0,0 +1,226 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package mcp + +import ( + "context" + "fmt" + "log" + "time" +) + +func main() { + log.Println("=== tRPC-MCP-Go 增强中间件系统演示 ===") + + // 1. 演示错误处理和分类 + demonstrateErrorHandling() + + // 2. 演示线程安全的限流中间件 + demonstrateRateLimit() + + // 3. 演示熔断器中间件 + demonstrateCircuitBreaker() + + // 4. 演示监控和指标收集 + demonstrateMonitoring() + + log.Println("=== 演示完成 ===") +} + +func demonstrateErrorHandling() { + log.Println("--- 1. 错误处理和分类演示 ---") + + // 创建一个认证中间件(故意使用空API密钥) + authMiddleware := AuthMiddleware("") + + // 模拟处理器 + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + } + + ctx := context.Background() + req := &CallToolRequest{ + Request: Request{ + Method: "tools/call", + }, + Params: CallToolParams{Name: "test_tool"}, + } + + // 执行中间件 + resp, err := authMiddleware(ctx, req, handler) + + if err != nil { + if middlewareErr, ok := err.(*MiddlewareError); ok { + log.Printf("捕获到中间件错误:") + log.Printf(" 错误码: %s", middlewareErr.Code) + log.Printf(" 错误消息: %s", middlewareErr.Message) + log.Printf(" 错误时间: %v", middlewareErr.Timestamp.Format("2006-01-02 15:04:05")) + log.Printf(" 错误上下文: %+v", middlewareErr.Context) + log.Printf(" 调用堆栈长度: %d", len(middlewareErr.Trace)) + } + } else { + log.Printf("响应: %v", resp) + } + log.Println() +} + +func demonstrateRateLimit() { + log.Println("--- 2. 线程安全限流中间件演示 ---") + + // 创建限流中间件:每秒最多2个请求 + rateLimitMiddleware := RateLimitingMiddleware(2, time.Second) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + } + + ctx := context.WithValue(context.Background(), "user_id", "demo-user") + req := &CallToolRequest{ + Request: Request{ + Method: "tools/call", + }, + Params: CallToolParams{Name: "test_tool"}, + } + + // 快速发送5个请求 + for i := 1; i <= 5; i++ { + resp, err := rateLimitMiddleware(ctx, req, handler) + if err != nil { + if middlewareErr, ok := err.(*MiddlewareError); ok { + log.Printf("请求 %d: 被限流 - %s", i, middlewareErr.Message) + log.Printf(" 当前请求数: %v/%v", + middlewareErr.Context["current_requests"], + middlewareErr.Context["max_requests"]) + } + } else { + log.Printf("请求 %d: 成功 - %v", i, resp) + } + time.Sleep(300 * time.Millisecond) // 模拟请求间隔 + } + log.Println() +} + +func demonstrateCircuitBreaker() { + log.Println("--- 3. 熔断器中间件演示 ---") + + // 创建熔断器中间件:失败阈值为2,超时1秒 + circuitBreakerMiddleware := CircuitBreakerMiddleware(2, time.Second) + + // 会失败的处理器 + failingHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, fmt.Errorf("模拟业务失败") + } + + // 成功的处理器 + successHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + } + + ctx := context.Background() + req := &CallToolRequest{ + Request: Request{ + Method: "tools/call", + }, + Params: CallToolParams{Name: "test_tool"}, + } + + // 1. 发送失败请求直到熔断器打开 + log.Println("发送失败请求:") + for i := 1; i <= 4; i++ { + resp, err := circuitBreakerMiddleware(ctx, req, failingHandler) + if err != nil { + if middlewareErr, ok := err.(*MiddlewareError); ok && middlewareErr.Code == ErrCodeCircuitBreaker { + log.Printf("请求 %d: 熔断器已打开 - %s", i, middlewareErr.Message) + log.Printf(" 失败次数: %v/%v", + middlewareErr.Context["failure_count"], + middlewareErr.Context["failure_threshold"]) + } else { + log.Printf("请求 %d: 业务失败 - %v", i, err) + } + } else { + log.Printf("请求 %d: 成功 - %v", i, resp) + } + } + + // 2. 等待熔断器恢复 + log.Println("等待熔断器恢复...") + time.Sleep(1100 * time.Millisecond) + + // 3. 发送成功请求 + log.Println("发送成功请求:") + resp, err := circuitBreakerMiddleware(ctx, req, successHandler) + if err != nil { + log.Printf("请求失败: %v", err) + } else { + log.Printf("请求成功: %v", resp) + } + log.Println() +} + +func demonstrateMonitoring() { + log.Println("--- 4. 监控和指标收集演示 ---") + + // 获取全局监控器 + monitor := GetGlobalMonitor() + + // 创建监控中间件 + monitoringMiddleware := MonitoringMiddleware("demo_middleware") + + // 创建一个有时成功有时失败的处理器 + var requestCount int + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + requestCount++ + if requestCount%3 == 0 { + return nil, fmt.Errorf("模拟失败") + } + time.Sleep(time.Duration(requestCount*10) * time.Millisecond) // 模拟不同的响应时间 + return fmt.Sprintf("success_%d", requestCount), nil + } + + ctx := context.Background() + req := &CallToolRequest{ + Request: Request{ + Method: "tools/call", + }, + Params: CallToolParams{Name: "test_tool"}, + } + + // 发送多个请求 + log.Println("发送监控请求:") + for i := 1; i <= 6; i++ { + resp, err := monitoringMiddleware(ctx, req, handler) + if err != nil { + log.Printf("请求 %d: 失败 - %v", i, err) + } else { + log.Printf("请求 %d: 成功 - %v", i, resp) + } + time.Sleep(50 * time.Millisecond) + } + + // 显示监控指标 + log.Println("监控指标:") + metrics := monitor.GetMetrics("demo_middleware") + if metrics != nil { + log.Printf(" 请求总数: %d", metrics.RequestCount) + log.Printf(" 错误总数: %d", metrics.ErrorCount) + log.Printf(" 成功率: %.2f%%", float64(metrics.RequestCount-metrics.ErrorCount)/float64(metrics.RequestCount)*100) + log.Printf(" 平均响应时间: %v", metrics.AverageDuration) + log.Printf(" 最大响应时间: %v", metrics.MaxDuration) + log.Printf(" 最小响应时间: %v", metrics.MinDuration) + } + + // 打印完整报告 + log.Println("完整性能报告:") + monitor.PrintReport() + + // 导出JSON格式的指标 + jsonMetrics, err := monitor.ToJSON() + if err == nil { + log.Printf("JSON格式指标:\n%s", jsonMetrics) + } + log.Println() +} \ No newline at end of file diff --git a/examples/client_middleware_example/main.go b/examples/client_middleware_example/main.go new file mode 100644 index 0000000..71a128a --- /dev/null +++ b/examples/client_middleware_example/main.go @@ -0,0 +1,220 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + mcp "trpc.group/trpc-go/trpc-mcp-go" +) + +// ClientLoggingMiddleware 客户端日志中间件 +func ClientLoggingMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + start := time.Now() + + log.Printf("🚀 [Client] Request started: %T", req) + + resp, err := next(ctx, req) + + duration := time.Since(start) + if err != nil { + log.Printf("❌ [Client] Request failed after %v: %v", duration, err) + } else { + log.Printf("✅ [Client] Request completed in %v", duration) + } + + return resp, err +} + +// ClientMetricsMiddleware 客户端指标中间件 +func ClientMetricsMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + start := time.Now() + + resp, err := next(ctx, req) + + duration := time.Since(start) + requestType := fmt.Sprintf("%T", req) + status := "success" + if err != nil { + status = "error" + } + + log.Printf("📊 [Client] Metrics - Type: %s, Status: %s, Duration: %v", + requestType, status, duration) + + return resp, err +} + +// ClientValidationMiddleware 客户端验证中间件 +func ClientValidationMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + log.Printf("🔍 [Client] Validating request: %T", req) + + // 根据请求类型进行验证 + switch r := req.(type) { + case *mcp.CallToolRequest: + if r.Params.Name == "" { + return nil, fmt.Errorf("client validation failed: tool name is required") + } + log.Printf("✅ [Client] Tool request validation passed: %s", r.Params.Name) + case *mcp.ReadResourceRequest: + if r.Params.URI == "" { + return nil, fmt.Errorf("client validation failed: resource URI is required") + } + log.Printf("✅ [Client] Resource request validation passed: %s", r.Params.URI) + case *mcp.GetPromptRequest: + if r.Params.Name == "" { + return nil, fmt.Errorf("client validation failed: prompt name is required") + } + log.Printf("✅ [Client] Prompt request validation passed: %s", r.Params.Name) + default: + log.Printf("✅ [Client] Generic request validation passed") + } + + return next(ctx, req) +} + +// demonstrateClientMiddleware 演示客户端中间件功能 +func demonstrateClientMiddleware() { + log.Println("=== 客户端中间件演示 ===") + + // 创建客户端信息 + clientInfo := mcp.Implementation{ + Name: "ClientMiddlewareDemo", + Version: "1.0.0", + } + + // 创建带有中间件的客户端 + client, err := mcp.NewClient( + "http://localhost:3000/mcp", + clientInfo, + // 添加客户端中间件(按执行顺序) + mcp.WithMiddleware(mcp.RecoveryMiddleware), // 错误恢复 + mcp.WithMiddleware(ClientLoggingMiddleware), // 日志记录 + mcp.WithMiddleware(ClientMetricsMiddleware), // 性能监控 + mcp.WithMiddleware(ClientValidationMiddleware), // 请求验证 + mcp.WithMiddleware(mcp.ToolHandlerMiddleware), // 工具处理 + mcp.WithMiddleware(mcp.ResourceMiddleware), // 资源处理 + mcp.WithMiddleware(mcp.PromptMiddleware), // 提示处理 + mcp.WithClientLogger(mcp.GetDefaultLogger()), + ) + + if err != nil { + log.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // 初始化客户端 + ctx := context.Background() + initReq := &mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.ProtocolVersion_2025_03_26, + ClientInfo: clientInfo, + Capabilities: mcp.ClientCapabilities{}, + }, + } + + log.Println("🔄 Initializing client...") + _, err = client.Initialize(ctx, initReq) + if err != nil { + log.Fatalf("Failed to initialize client: %v", err) + } + + log.Println("✅ Client initialized successfully") + + // 测试工具调用(会经过所有中间件) + log.Println("\n📞 Testing tool call with middleware...") + toolResult, err := client.CallTool(ctx, &mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "greet", + Arguments: map[string]interface{}{ + "name": "Middleware World", + }, + }, + }) + + if err != nil { + log.Printf("❌ Tool call failed: %v", err) + } else { + log.Printf("✅ Tool call successful: %+v", toolResult) + } + + // 测试资源读取(会经过所有中间件) + log.Println("\n📄 Testing resource read with middleware...") + resourceResult, err := client.ReadResource(ctx, &mcp.ReadResourceRequest{ + Params: mcp.ReadResourceParams{ + URI: "welcome", + }, + }) + + if err != nil { + log.Printf("❌ Resource read failed: %v", err) + } else { + log.Printf("✅ Resource read successful: %+v", resourceResult) + } + + // 测试提示获取(会经过所有中间件) + log.Println("\n💬 Testing prompt get with middleware...") + promptResult, err := client.GetPrompt(ctx, &mcp.GetPromptRequest{ + Params: mcp.GetPromptParams{ + Name: "greeting", + Arguments: map[string]interface{}{ + "name": "Middleware User", + }, + }, + }) + + if err != nil { + log.Printf("❌ Prompt get failed: %v", err) + } else { + log.Printf("✅ Prompt get successful: %+v", promptResult) + } + + // 测试工具列表(会经过所有中间件) + log.Println("\n🛠️ Testing list tools with middleware...") + toolsResult, err := client.ListTools(ctx, &mcp.ListToolsRequest{}) + + if err != nil { + log.Printf("❌ List tools failed: %v", err) + } else { + log.Printf("✅ List tools successful, found %d tools", len(toolsResult.Tools)) + } + + log.Println("\n🎉 Client middleware demonstration completed!") +} + +// checkServerAvailability 检查服务器是否可用 +func checkServerAvailability() bool { + resp, err := http.Get("http://localhost:3000/mcp") + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode != 404 +} + +func main() { + log.Println("🚀 Starting Client Middleware Example") + + // 检查服务器是否运行 + if !checkServerAvailability() { + log.Println("⚠️ Warning: Server not available at http://localhost:3000/mcp") + log.Println("Please start the server middleware example first:") + log.Println(" cd examples/server_middleware_example && go run main.go") + log.Println("Then run this client example in another terminal.") + return + } + + // 等待一下让服务器完全启动 + time.Sleep(time.Second) + + // 演示客户端中间件 + demonstrateClientMiddleware() +} diff --git a/examples/middleware_example/main.go b/examples/middleware_example/main.go new file mode 100644 index 0000000..909e5d4 --- /dev/null +++ b/examples/middleware_example/main.go @@ -0,0 +1,212 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package main + +import ( + "context" + "fmt" + "log" + "time" + + mcp "trpc.group/trpc-go/trpc-mcp-go" +) + +// CustomLoggingMiddleware 自定义日志中间件示例 +func CustomLoggingMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + start := time.Now() + + // 请求前处理 + log.Printf("🚀 Request started: %T", req) + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 请求后处理 + duration := time.Since(start) + if err != nil { + log.Printf("❌ Request failed after %v: %v", duration, err) + } else { + log.Printf("✅ Request completed in %v", duration) + } + + return resp, err +} + +// AuthenticationMiddleware 认证中间件示例 +func AuthenticationMiddleware(apiKey string) mcp.MiddlewareFunc { + return func(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + // 验证 API Key + if apiKey == "" { + log.Printf("🔒 Authentication failed: API key is required") + return nil, fmt.Errorf("authentication failed: API key is required") + } + + log.Printf("🔑 Authentication successful") + + // 在上下文中添加认证信息 + ctx = context.WithValue(ctx, "authenticated", true) + ctx = context.WithValue(ctx, "api_key", apiKey) + + return next(ctx, req) + } +} + +// RateLimitingMiddleware 限流中间件示例 +func RateLimitingMiddleware(maxRequests int, window time.Duration) mcp.MiddlewareFunc { + requestCount := 0 + lastReset := time.Now() + + return func(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + now := time.Now() + + // 重置计数器 + if now.Sub(lastReset) > window { + requestCount = 0 + lastReset = now + } + + // 检查限流 + if requestCount >= maxRequests { + log.Printf("🚫 Rate limit exceeded: %d requests in %v", requestCount, window) + return nil, fmt.Errorf("rate limit exceeded: too many requests") + } + + requestCount++ + log.Printf("📊 Request %d/%d in current window", requestCount, maxRequests) + + return next(ctx, req) + } +} + +// CircuitBreakerMiddleware 熔断器中间件示例 +func CircuitBreakerMiddleware(threshold int, timeout time.Duration) mcp.MiddlewareFunc { + failureCount := 0 + lastFailure := time.Time{} + isOpen := false + + return func(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + now := time.Now() + + // 检查熔断器状态 + if isOpen { + if now.Sub(lastFailure) > timeout { + // 尝试半开状态 + isOpen = false + failureCount = 0 + log.Printf("🔄 Circuit breaker: attempting to close") + } else { + log.Printf("🔴 Circuit breaker is open") + return nil, fmt.Errorf("circuit breaker is open") + } + } + + // 执行请求 + resp, err := next(ctx, req) + + if err != nil { + failureCount++ + lastFailure = now + + if failureCount >= threshold { + isOpen = true + log.Printf("🔴 Circuit breaker opened after %d failures", failureCount) + } + } else { + // 重置失败计数 + failureCount = 0 + } + + return resp, err + } +} + +func main() { + // 创建客户端信息 + clientInfo := mcp.Implementation{ + Name: "MiddlewareExample", + Version: "1.0.0", + } + + // 创建带有多个中间件的客户端 + client, err := mcp.NewClient( + "http://localhost:3000", + clientInfo, + // 添加多个中间件(按顺序执行) + mcp.WithMiddleware(CustomLoggingMiddleware), + mcp.WithMiddleware(AuthenticationMiddleware("your-secret-api-key")), + mcp.WithMiddleware(RateLimitingMiddleware(10, time.Minute)), // 每分钟最多10个请求 + mcp.WithMiddleware(CircuitBreakerMiddleware(3, 30*time.Second)), // 3次失败后熔断30秒 + mcp.WithMiddleware(mcp.ValidationMiddleware), + mcp.WithMiddleware(mcp.MetricsMiddleware), + mcp.WithMiddleware(mcp.RetryMiddleware(2)), // 最多重试2次 + mcp.WithMiddleware(mcp.RecoveryMiddleware), + ) + + if err != nil { + log.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // 初始化客户端 + ctx := context.Background() + initReq := &mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.ProtocolVersion_2025_03_26, + ClientInfo: clientInfo, + Capabilities: mcp.ClientCapabilities{}, + }, + } + + _, err = client.Initialize(ctx, initReq) + if err != nil { + log.Fatalf("Failed to initialize client: %v", err) + } + + log.Println("🎉 Client initialized successfully") + + // 调用工具(会经过所有中间件) + toolResult, err := client.CallTool(ctx, &mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "greet", + Arguments: map[string]interface{}{ + "name": "World", + }, + }, + }) + + if err != nil { + log.Printf("Tool call failed: %v", err) + } else { + log.Printf("Tool call succeeded: %+v", toolResult) + } + + // 演示中间件链的直接使用 + log.Println("\n--- 演示中间件链的直接使用 ---") + + // 创建一个简单的处理器 + simpleHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "Hello from handler!", nil + } + + // 创建中间件链 + chain := mcp.NewMiddlewareChain( + CustomLoggingMiddleware, + mcp.ValidationMiddleware, + mcp.MetricsMiddleware, + ) + + // 执行中间件链 + result, err := chain.Execute(ctx, &mcp.CallToolRequest{ + Params: mcp.CallToolParams{Name: "example"}, + }, simpleHandler) + + if err != nil { + log.Printf("Chain execution failed: %v", err) + } else { + log.Printf("Chain execution result: %v", result) + } +} diff --git a/examples/server_middleware_example/main.go b/examples/server_middleware_example/main.go new file mode 100644 index 0000000..a5b360e --- /dev/null +++ b/examples/server_middleware_example/main.go @@ -0,0 +1,238 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + mcp "trpc.group/trpc-go/trpc-mcp-go" +) + +// CustomServerLoggingMiddleware 自定义服务端日志中间件 +func CustomServerLoggingMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + start := time.Now() + + // 请求前处理 + log.Printf("🚀 [Server] Request started: %T", req) + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 请求后处理 + duration := time.Since(start) + if err != nil { + log.Printf("❌ [Server] Request failed after %v: %v", duration, err) + } else { + log.Printf("✅ [Server] Request completed in %v", duration) + } + + return resp, err +} + +// ServerAuthMiddleware 服务端认证中间件 +func ServerAuthMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + // 从上下文中获取认证信息(实际中可能从HTTP头获取) + log.Printf("🔐 [Server] Authenticating request: %T", req) + + // 模拟认证检查 + // 在实际应用中,这里会检查API密钥、JWT令牌等 + + // 在上下文中添加用户信息 + ctx = context.WithValue(ctx, "server_authenticated", true) + ctx = context.WithValue(ctx, "server_user_id", "server_user_123") + + log.Printf("✅ [Server] Authentication successful") + return next(ctx, req) +} + +// ServerMetricsMiddleware 服务端指标中间件 +func ServerMetricsMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + start := time.Now() + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 记录指标 + duration := time.Since(start) + requestType := fmt.Sprintf("%T", req) + status := "success" + if err != nil { + status = "error" + } + + log.Printf("📊 [Server] Metrics - Type: %s, Status: %s, Duration: %v", + requestType, status, duration) + + return resp, err +} + +// ServerValidationMiddleware 服务端验证中间件 +func ServerValidationMiddleware(ctx context.Context, req interface{}, next mcp.Handler) (interface{}, error) { + log.Printf("🔍 [Server] Validating request: %T", req) + + // 根据请求类型进行验证 + switch r := req.(type) { + case *mcp.CallToolRequest: + if r.Params.Name == "" { + return nil, fmt.Errorf("server validation failed: tool name is required") + } + log.Printf("✅ [Server] Tool request validation passed: %s", r.Params.Name) + case *mcp.ReadResourceRequest: + if r.Params.URI == "" { + return nil, fmt.Errorf("server validation failed: resource URI is required") + } + log.Printf("✅ [Server] Resource request validation passed: %s", r.Params.URI) + case *mcp.GetPromptRequest: + if r.Params.Name == "" { + return nil, fmt.Errorf("server validation failed: prompt name is required") + } + log.Printf("✅ [Server] Prompt request validation passed: %s", r.Params.Name) + default: + log.Printf("✅ [Server] Generic request validation passed") + } + + return next(ctx, req) +} + +func main() { + log.Println("🚀 Starting MCP Server with Middleware Example") + + // 创建服务器信息 + serverInfo := mcp.Implementation{ + Name: "ServerMiddlewareExample", + Version: "1.0.0", + } + + // 创建带有多个中间件的服务器 + server := mcp.NewServer( + serverInfo.Name, + serverInfo.Version, + // 添加服务端中间件(按执行顺序) + mcp.WithServerMiddleware(mcp.RecoveryMiddleware), // 最外层:错误恢复 + mcp.WithServerMiddleware(CustomServerLoggingMiddleware), // 日志记录 + mcp.WithServerMiddleware(ServerMetricsMiddleware), // 性能监控 + mcp.WithServerMiddleware(ServerAuthMiddleware), // 认证鉴权 + mcp.WithServerMiddleware(ServerValidationMiddleware), // 请求验证 + mcp.WithServerMiddleware(mcp.RateLimitingMiddleware(100, time.Minute)), // 限流:100请求/分钟 + mcp.WithServerMiddleware(mcp.ToolHandlerMiddleware), // 工具处理 + mcp.WithServerMiddleware(mcp.ResourceMiddleware), // 资源处理 + mcp.WithServerMiddleware(mcp.PromptMiddleware), // 提示处理 + mcp.WithServerLogger(mcp.GetDefaultLogger()), + ) + + // 注册一个示例工具 + err := server.RegisterTool("greet", "Greet someone", func(ctx context.Context, args map[string]interface{}) (*mcp.CallToolResult, error) { + // 从上下文中获取认证信息 + authenticated := ctx.Value("server_authenticated") + userID := ctx.Value("server_user_id") + + log.Printf("🔧 [Tool] Greet tool called by user: %v (authenticated: %v)", userID, authenticated) + + name, ok := args["name"].(string) + if !ok { + name = "World" + } + + message := fmt.Sprintf("Hello, %s! (from server with middleware)", name) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: message, + }, + }, + }, nil + }) + + if err != nil { + log.Fatalf("Failed to register tool: %v", err) + } + + // 注册一个示例资源 + err = server.RegisterResource("welcome", "Welcome message resource", "text/plain", func(ctx context.Context, uri string) (*mcp.ReadResourceResult, error) { + // 从上下文中获取认证信息 + userID := ctx.Value("server_user_id") + log.Printf("📄 [Resource] Welcome resource accessed by user: %v", userID) + + content := "Welcome to the MCP Server with Middleware!" + + return &mcp.ReadResourceResult{ + Contents: []mcp.ResourceContents{ + { + URI: uri, + MimeType: "text/plain", + Text: &content, + }, + }, + }, nil + }) + + if err != nil { + log.Fatalf("Failed to register resource: %v", err) + } + + // 注册一个示例提示 + err = server.RegisterPrompt("greeting", "A greeting prompt template", func(ctx context.Context, name string, args map[string]interface{}) (*mcp.GetPromptResult, error) { + // 从上下文中获取认证信息 + userID := ctx.Value("server_user_id") + log.Printf("💬 [Prompt] Greeting prompt accessed by user: %v", userID) + + // 从参数中获取名字 + targetName, ok := args["name"].(string) + if !ok { + targetName = "Guest" + } + + promptText := fmt.Sprintf("Generate a warm greeting for %s", targetName) + + return &mcp.GetPromptResult{ + Messages: []mcp.PromptMessage{ + { + Role: "user", + Content: mcp.Content{ + Type: "text", + Text: promptText, + }, + }, + }, + }, nil + }) + + if err != nil { + log.Fatalf("Failed to register prompt: %v", err) + } + + // 启动服务器 + log.Println("🌐 Server starting on http://localhost:3000/mcp") + log.Println("📋 Available endpoints:") + log.Println(" - Tools: greet") + log.Println(" - Resources: welcome") + log.Println(" - Prompts: greeting") + log.Println("") + log.Println("🔧 Middleware chain active:") + log.Println(" 1. RecoveryMiddleware (Error recovery)") + log.Println(" 2. CustomServerLoggingMiddleware (Request logging)") + log.Println(" 3. ServerMetricsMiddleware (Performance metrics)") + log.Println(" 4. ServerAuthMiddleware (Authentication)") + log.Println(" 5. ServerValidationMiddleware (Request validation)") + log.Println(" 6. RateLimitingMiddleware (Rate limiting)") + log.Println(" 7. ToolHandlerMiddleware (Tool processing)") + log.Println(" 8. ResourceMiddleware (Resource processing)") + log.Println(" 9. PromptMiddleware (Prompt processing)") + log.Println("") + log.Println("💡 Test with a client:") + log.Println(" cd examples/basic/client && go run main.go") + + if err := http.ListenAndServe(":3000", server.Handler()); err != nil { + log.Fatalf("Server failed to start: %v", err) + } +} diff --git a/examples/sse/client/learn.md b/examples/sse/client/learn.md new file mode 100644 index 0000000..00c566e --- /dev/null +++ b/examples/sse/client/learn.md @@ -0,0 +1,2516 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "Annotations": { + "description": "Optional annotations for the client. The client can use annotations to inform how objects are used or displayed", + "properties": { + "audience": { + "description": "Describes who the intended customer of this object or data is.\n\nIt can include multiple entries to indicate content useful for multiple audiences (e.g., `[\"user\", \"assistant\"]`).", + "items": { + "$ref": "#/definitions/Role" + }, + "type": "array" + }, + "lastModified": { + "description": "The moment the resource was last modified, as an ISO 8601 formatted string.\n\nShould be an ISO 8601 formatted string (e.g., \"2025-01-12T15:00:58Z\").\n\nExamples: last activity timestamp in an open file, timestamp when the resource\nwas attached, etc.", + "type": "string" + }, + "priority": { + "description": "Describes how important this data is for operating the server.\n\nA value of 1 means \"most important,\" and indicates that the data is\neffectively required, while 0 means \"least important,\" and indicates that\nthe data is entirely optional.", + "maximum": 1, + "minimum": 0, + "type": "number" + } + }, + "type": "object" + }, + "AudioContent": { + "description": "Audio provided to or from an LLM.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "data": { + "description": "The base64-encoded audio data.", + "format": "byte", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of the audio. Different providers may support different audio types.", + "type": "string" + }, + "type": { + "const": "audio", + "type": "string" + } + }, + "required": [ + "data", + "mimeType", + "type" + ], + "type": "object" + }, + "BaseMetadata": { + "description": "Base interface for metadata with name (identifier) and title (display name) properties.", + "properties": { + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }, + "BlobResourceContents": { + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "blob": { + "description": "A base64-encoded string representing the binary data of the item.", + "format": "byte", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "blob", + "uri" + ], + "type": "object" + }, + "BooleanSchema": { + "properties": { + "default": { + "type": "boolean" + }, + "description": { + "type": "string" + }, + "title": { + "type": "string" + }, + "type": { + "const": "boolean", + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "CallToolRequest": { + "description": "Used by the client to invoke a tool provided by the server.", + "properties": { + "method": { + "const": "tools/call", + "type": "string" + }, + "params": { + "properties": { + "arguments": { + "additionalProperties": {}, + "type": "object" + }, + "name": { + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "CallToolResult": { + "description": "The server's response to a tool call.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "content": { + "description": "A list of content objects that represent the unstructured result of the tool call.", + "items": { + "$ref": "#/definitions/ContentBlock" + }, + "type": "array" + }, + "isError": { + "description": "Whether the tool call ended in an error.\n\nIf not set, this is assumed to be false (the call was successful).\n\nAny errors that originate from the tool SHOULD be reported inside the result\nobject, with `isError` set to true, _not_ as an MCP protocol-level error\nresponse. Otherwise, the LLM would not be able to see that an error occurred\nand self-correct.\n\nHowever, any errors in _finding_ the tool, an error indicating that the\nserver does not support tool calls, or any other exceptional conditions,\nshould be reported as an MCP error response.", + "type": "boolean" + }, + "structuredContent": { + "additionalProperties": {}, + "description": "An optional JSON object that represents the structured result of the tool call.", + "type": "object" + } + }, + "required": [ + "content" + ], + "type": "object" + }, + "CancelledNotification": { + "description": "This notification can be sent by either side to indicate that it is cancelling a previously-issued request.\n\nThe request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished.\n\nThis notification indicates that the result will be unused, so any associated processing SHOULD cease.\n\nA client MUST NOT attempt to cancel its `initialize` request.", + "properties": { + "method": { + "const": "notifications/cancelled", + "type": "string" + }, + "params": { + "properties": { + "reason": { + "description": "An optional string describing the reason for the cancellation. This MAY be logged or presented to the user.", + "type": "string" + }, + "requestId": { + "$ref": "#/definitions/RequestId", + "description": "The ID of the request to cancel.\n\nThis MUST correspond to the ID of a request previously issued in the same direction." + } + }, + "required": [ + "requestId" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ClientCapabilities": { + "description": "Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities.", + "properties": { + "elicitation": { + "additionalProperties": true, + "description": "Present if the client supports elicitation from the server.", + "properties": {}, + "type": "object" + }, + "experimental": { + "additionalProperties": { + "additionalProperties": true, + "properties": {}, + "type": "object" + }, + "description": "Experimental, non-standard capabilities that the client supports.", + "type": "object" + }, + "roots": { + "description": "Present if the client supports listing roots.", + "properties": { + "listChanged": { + "description": "Whether the client supports notifications for changes to the roots list.", + "type": "boolean" + } + }, + "type": "object" + }, + "sampling": { + "additionalProperties": true, + "description": "Present if the client supports sampling from an LLM.", + "properties": {}, + "type": "object" + } + }, + "type": "object" + }, + "ClientNotification": { + "anyOf": [ + { + "$ref": "#/definitions/CancelledNotification" + }, + { + "$ref": "#/definitions/InitializedNotification" + }, + { + "$ref": "#/definitions/ProgressNotification" + }, + { + "$ref": "#/definitions/RootsListChangedNotification" + } + ] + }, + "ClientRequest": { + "anyOf": [ + { + "$ref": "#/definitions/InitializeRequest" + }, + { + "$ref": "#/definitions/PingRequest" + }, + { + "$ref": "#/definitions/ListResourcesRequest" + }, + { + "$ref": "#/definitions/ListResourceTemplatesRequest" + }, + { + "$ref": "#/definitions/ReadResourceRequest" + }, + { + "$ref": "#/definitions/SubscribeRequest" + }, + { + "$ref": "#/definitions/UnsubscribeRequest" + }, + { + "$ref": "#/definitions/ListPromptsRequest" + }, + { + "$ref": "#/definitions/GetPromptRequest" + }, + { + "$ref": "#/definitions/ListToolsRequest" + }, + { + "$ref": "#/definitions/CallToolRequest" + }, + { + "$ref": "#/definitions/SetLevelRequest" + }, + { + "$ref": "#/definitions/CompleteRequest" + } + ] + }, + "ClientResult": { + "anyOf": [ + { + "$ref": "#/definitions/Result" + }, + { + "$ref": "#/definitions/CreateMessageResult" + }, + { + "$ref": "#/definitions/ListRootsResult" + }, + { + "$ref": "#/definitions/ElicitResult" + } + ] + }, + "CompleteRequest": { + "description": "A request from the client to the server, to ask for completion options.", + "properties": { + "method": { + "const": "completion/complete", + "type": "string" + }, + "params": { + "properties": { + "argument": { + "description": "The argument's information", + "properties": { + "name": { + "description": "The name of the argument", + "type": "string" + }, + "value": { + "description": "The value of the argument to use for completion matching.", + "type": "string" + } + }, + "required": [ + "name", + "value" + ], + "type": "object" + }, + "context": { + "description": "Additional, optional context for completions", + "properties": { + "arguments": { + "additionalProperties": { + "type": "string" + }, + "description": "Previously-resolved variables in a URI template or prompt.", + "type": "object" + } + }, + "type": "object" + }, + "ref": { + "anyOf": [ + { + "$ref": "#/definitions/PromptReference" + }, + { + "$ref": "#/definitions/ResourceTemplateReference" + } + ] + } + }, + "required": [ + "argument", + "ref" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "CompleteResult": { + "description": "The server's response to a completion/complete request", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "completion": { + "properties": { + "hasMore": { + "description": "Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown.", + "type": "boolean" + }, + "total": { + "description": "The total number of completion options available. This can exceed the number of values actually sent in the response.", + "type": "integer" + }, + "values": { + "description": "An array of completion values. Must not exceed 100 items.", + "items": { + "type": "string" + }, + "type": "array" + } + }, + "required": [ + "values" + ], + "type": "object" + } + }, + "required": [ + "completion" + ], + "type": "object" + }, + "ContentBlock": { + "anyOf": [ + { + "$ref": "#/definitions/TextContent" + }, + { + "$ref": "#/definitions/ImageContent" + }, + { + "$ref": "#/definitions/AudioContent" + }, + { + "$ref": "#/definitions/ResourceLink" + }, + { + "$ref": "#/definitions/EmbeddedResource" + } + ] + }, + "CreateMessageRequest": { + "description": "A request from the server to sample an LLM via the client. The client has full discretion over which model to select. The client should also inform the user before beginning sampling, to allow them to inspect the request (human in the loop) and decide whether to approve it.", + "properties": { + "method": { + "const": "sampling/createMessage", + "type": "string" + }, + "params": { + "properties": { + "includeContext": { + "description": "A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. The client MAY ignore this request.", + "enum": [ + "allServers", + "none", + "thisServer" + ], + "type": "string" + }, + "maxTokens": { + "description": "The maximum number of tokens to sample, as requested by the server. The client MAY choose to sample fewer tokens than requested.", + "type": "integer" + }, + "messages": { + "items": { + "$ref": "#/definitions/SamplingMessage" + }, + "type": "array" + }, + "metadata": { + "additionalProperties": true, + "description": "Optional metadata to pass through to the LLM provider. The format of this metadata is provider-specific.", + "properties": {}, + "type": "object" + }, + "modelPreferences": { + "$ref": "#/definitions/ModelPreferences", + "description": "The server's preferences for which model to select. The client MAY ignore these preferences." + }, + "stopSequences": { + "items": { + "type": "string" + }, + "type": "array" + }, + "systemPrompt": { + "description": "An optional system prompt the server wants to use for sampling. The client MAY modify or omit this prompt.", + "type": "string" + }, + "temperature": { + "type": "number" + } + }, + "required": [ + "maxTokens", + "messages" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "CreateMessageResult": { + "description": "The client's response to a sampling/create_message request from the server. The client should inform the user before returning the sampled message, to allow them to inspect the response (human in the loop) and decide whether to allow the server to see it.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "content": { + "anyOf": [ + { + "$ref": "#/definitions/TextContent" + }, + { + "$ref": "#/definitions/ImageContent" + }, + { + "$ref": "#/definitions/AudioContent" + } + ] + }, + "model": { + "description": "The name of the model that generated the message.", + "type": "string" + }, + "role": { + "$ref": "#/definitions/Role" + }, + "stopReason": { + "description": "The reason why sampling stopped, if known.", + "type": "string" + } + }, + "required": [ + "content", + "model", + "role" + ], + "type": "object" + }, + "Cursor": { + "description": "An opaque token used to represent a cursor for pagination.", + "type": "string" + }, + "ElicitRequest": { + "description": "A request from the server to elicit additional information from the user via the client.", + "properties": { + "method": { + "const": "elicitation/create", + "type": "string" + }, + "params": { + "properties": { + "message": { + "description": "The message to present to the user.", + "type": "string" + }, + "requestedSchema": { + "description": "A restricted subset of JSON Schema.\nOnly top-level properties are allowed, without nesting.", + "properties": { + "properties": { + "additionalProperties": { + "$ref": "#/definitions/PrimitiveSchemaDefinition" + }, + "type": "object" + }, + "required": { + "items": { + "type": "string" + }, + "type": "array" + }, + "type": { + "const": "object", + "type": "string" + } + }, + "required": [ + "properties", + "type" + ], + "type": "object" + } + }, + "required": [ + "message", + "requestedSchema" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ElicitResult": { + "description": "The client's response to an elicitation request.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "action": { + "description": "The user action in response to the elicitation.\n- \"accept\": User submitted the form/confirmed the action\n- \"decline\": User explicitly decline the action\n- \"cancel\": User dismissed without making an explicit choice", + "enum": [ + "accept", + "cancel", + "decline" + ], + "type": "string" + }, + "content": { + "additionalProperties": { + "type": [ + "string", + "integer", + "boolean" + ] + }, + "description": "The submitted form data, only present when action is \"accept\".\nContains values matching the requested schema.", + "type": "object" + } + }, + "required": [ + "action" + ], + "type": "object" + }, + "EmbeddedResource": { + "description": "The contents of a resource, embedded into a prompt or tool call result.\n\nIt is up to the client how best to render embedded resources for the benefit\nof the LLM and/or the user.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "resource": { + "anyOf": [ + { + "$ref": "#/definitions/TextResourceContents" + }, + { + "$ref": "#/definitions/BlobResourceContents" + } + ] + }, + "type": { + "const": "resource", + "type": "string" + } + }, + "required": [ + "resource", + "type" + ], + "type": "object" + }, + "EmptyResult": { + "$ref": "#/definitions/Result" + }, + "EnumSchema": { + "properties": { + "description": { + "type": "string" + }, + "enum": { + "items": { + "type": "string" + }, + "type": "array" + }, + "enumNames": { + "items": { + "type": "string" + }, + "type": "array" + }, + "title": { + "type": "string" + }, + "type": { + "const": "string", + "type": "string" + } + }, + "required": [ + "enum", + "type" + ], + "type": "object" + }, + "GetPromptRequest": { + "description": "Used by the client to get a prompt provided by the server.", + "properties": { + "method": { + "const": "prompts/get", + "type": "string" + }, + "params": { + "properties": { + "arguments": { + "additionalProperties": { + "type": "string" + }, + "description": "Arguments to use for templating the prompt.", + "type": "object" + }, + "name": { + "description": "The name of the prompt or prompt template.", + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "GetPromptResult": { + "description": "The server's response to a prompts/get request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "description": { + "description": "An optional description for the prompt.", + "type": "string" + }, + "messages": { + "items": { + "$ref": "#/definitions/PromptMessage" + }, + "type": "array" + } + }, + "required": [ + "messages" + ], + "type": "object" + }, + "ImageContent": { + "description": "An image provided to or from an LLM.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "data": { + "description": "The base64-encoded image data.", + "format": "byte", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of the image. Different providers may support different image types.", + "type": "string" + }, + "type": { + "const": "image", + "type": "string" + } + }, + "required": [ + "data", + "mimeType", + "type" + ], + "type": "object" + }, + "Implementation": { + "description": "Describes the name and version of an MCP implementation, with an optional title for UI representation.", + "properties": { + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "version": { + "type": "string" + } + }, + "required": [ + "name", + "version" + ], + "type": "object" + }, + "InitializeRequest": { + "description": "This request is sent from the client to the server when it first connects, asking it to begin initialization.", + "properties": { + "method": { + "const": "initialize", + "type": "string" + }, + "params": { + "properties": { + "capabilities": { + "$ref": "#/definitions/ClientCapabilities" + }, + "clientInfo": { + "$ref": "#/definitions/Implementation" + }, + "protocolVersion": { + "description": "The latest version of the Model Context Protocol that the client supports. The client MAY decide to support older versions as well.", + "type": "string" + } + }, + "required": [ + "capabilities", + "clientInfo", + "protocolVersion" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "InitializeResult": { + "description": "After receiving an initialize request from the client, the server sends this response.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "capabilities": { + "$ref": "#/definitions/ServerCapabilities" + }, + "instructions": { + "description": "Instructions describing how to use the server and its features.\n\nThis can be used by clients to improve the LLM's understanding of available tools, resources, etc. It can be thought of like a \"hint\" to the model. For example, this information MAY be added to the system prompt.", + "type": "string" + }, + "protocolVersion": { + "description": "The version of the Model Context Protocol that the server wants to use. This may not match the version that the client requested. If the client cannot support this version, it MUST disconnect.", + "type": "string" + }, + "serverInfo": { + "$ref": "#/definitions/Implementation" + } + }, + "required": [ + "capabilities", + "protocolVersion", + "serverInfo" + ], + "type": "object" + }, + "InitializedNotification": { + "description": "This notification is sent from the client to the server after initialization has finished.", + "properties": { + "method": { + "const": "notifications/initialized", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "JSONRPCError": { + "description": "A response to a request that indicates an error occurred.", + "properties": { + "error": { + "properties": { + "code": { + "description": "The error type that occurred.", + "type": "integer" + }, + "data": { + "description": "Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.)." + }, + "message": { + "description": "A short description of the error. The message SHOULD be limited to a concise single sentence.", + "type": "string" + } + }, + "required": [ + "code", + "message" + ], + "type": "object" + }, + "id": { + "$ref": "#/definitions/RequestId" + }, + "jsonrpc": { + "const": "2.0", + "type": "string" + } + }, + "required": [ + "error", + "id", + "jsonrpc" + ], + "type": "object" + }, + "JSONRPCMessage": { + "anyOf": [ + { + "$ref": "#/definitions/JSONRPCRequest" + }, + { + "$ref": "#/definitions/JSONRPCNotification" + }, + { + "$ref": "#/definitions/JSONRPCResponse" + }, + { + "$ref": "#/definitions/JSONRPCError" + } + ], + "description": "Refers to any valid JSON-RPC object that can be decoded off the wire, or encoded to be sent." + }, + "JSONRPCNotification": { + "description": "A notification which does not expect a response.", + "properties": { + "jsonrpc": { + "const": "2.0", + "type": "string" + }, + "method": { + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "jsonrpc", + "method" + ], + "type": "object" + }, + "JSONRPCRequest": { + "description": "A request that expects a response.", + "properties": { + "id": { + "$ref": "#/definitions/RequestId" + }, + "jsonrpc": { + "const": "2.0", + "type": "string" + }, + "method": { + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "properties": { + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications." + } + }, + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "id", + "jsonrpc", + "method" + ], + "type": "object" + }, + "JSONRPCResponse": { + "description": "A successful (non-error) response to a request.", + "properties": { + "id": { + "$ref": "#/definitions/RequestId" + }, + "jsonrpc": { + "const": "2.0", + "type": "string" + }, + "result": { + "$ref": "#/definitions/Result" + } + }, + "required": [ + "id", + "jsonrpc", + "result" + ], + "type": "object" + }, + "ListPromptsRequest": { + "description": "Sent from the client to request a list of prompts and prompt templates the server has.", + "properties": { + "method": { + "const": "prompts/list", + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListPromptsResult": { + "description": "The server's response to a prompts/list request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + }, + "prompts": { + "items": { + "$ref": "#/definitions/Prompt" + }, + "type": "array" + } + }, + "required": [ + "prompts" + ], + "type": "object" + }, + "ListResourceTemplatesRequest": { + "description": "Sent from the client to request a list of resource templates the server has.", + "properties": { + "method": { + "const": "resources/templates/list", + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListResourceTemplatesResult": { + "description": "The server's response to a resources/templates/list request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + }, + "resourceTemplates": { + "items": { + "$ref": "#/definitions/ResourceTemplate" + }, + "type": "array" + } + }, + "required": [ + "resourceTemplates" + ], + "type": "object" + }, + "ListResourcesRequest": { + "description": "Sent from the client to request a list of resources the server has.", + "properties": { + "method": { + "const": "resources/list", + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListResourcesResult": { + "description": "The server's response to a resources/list request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + }, + "resources": { + "items": { + "$ref": "#/definitions/Resource" + }, + "type": "array" + } + }, + "required": [ + "resources" + ], + "type": "object" + }, + "ListRootsRequest": { + "description": "Sent from the server to request a list of root URIs from the client. Roots allow\nservers to ask for specific directories or files to operate on. A common example\nfor roots is providing a set of repositories or directories a server should operate\non.\n\nThis request is typically used when the server needs to understand the file system\nstructure or access specific locations that the client has permission to read from.", + "properties": { + "method": { + "const": "roots/list", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "properties": { + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications." + } + }, + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListRootsResult": { + "description": "The client's response to a roots/list request from the server.\nThis result contains an array of Root objects, each representing a root directory\nor file that the server can operate on.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "roots": { + "items": { + "$ref": "#/definitions/Root" + }, + "type": "array" + } + }, + "required": [ + "roots" + ], + "type": "object" + }, + "ListToolsRequest": { + "description": "Sent from the client to request a list of tools the server has.", + "properties": { + "method": { + "const": "tools/list", + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListToolsResult": { + "description": "The server's response to a tools/list request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + }, + "tools": { + "items": { + "$ref": "#/definitions/Tool" + }, + "type": "array" + } + }, + "required": [ + "tools" + ], + "type": "object" + }, + "LoggingLevel": { + "description": "The severity of a log message.\n\nThese map to syslog message severities, as specified in RFC-5424:\nhttps://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1", + "enum": [ + "alert", + "critical", + "debug", + "emergency", + "error", + "info", + "notice", + "warning" + ], + "type": "string" + }, + "LoggingMessageNotification": { + "description": "Notification of a log message passed from server to client. If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically.", + "properties": { + "method": { + "const": "notifications/message", + "type": "string" + }, + "params": { + "properties": { + "data": { + "description": "The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here." + }, + "level": { + "$ref": "#/definitions/LoggingLevel", + "description": "The severity of this log message." + }, + "logger": { + "description": "An optional name of the logger issuing this message.", + "type": "string" + } + }, + "required": [ + "data", + "level" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ModelHint": { + "description": "Hints to use for model selection.\n\nKeys not declared here are currently left unspecified by the spec and are up\nto the client to interpret.", + "properties": { + "name": { + "description": "A hint for a model name.\n\nThe client SHOULD treat this as a substring of a model name; for example:\n - `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022`\n - `sonnet` should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc.\n - `claude` should match any Claude model\n\nThe client MAY also map the string to a different provider's model name or a different model family, as long as it fills a similar niche; for example:\n - `gemini-1.5-flash` could match `claude-3-haiku-20240307`", + "type": "string" + } + }, + "type": "object" + }, + "ModelPreferences": { + "description": "The server's preferences for model selection, requested of the client during sampling.\n\nBecause LLMs can vary along multiple dimensions, choosing the \"best\" model is\nrarely straightforward. Different models excel in different areas—some are\nfaster but less capable, others are more capable but more expensive, and so\non. This interface allows servers to express their priorities across multiple\ndimensions to help clients make an appropriate selection for their use case.\n\nThese preferences are always advisory. The client MAY ignore them. It is also\nup to the client to decide how to interpret these preferences and how to\nbalance them against other considerations.", + "properties": { + "costPriority": { + "description": "How much to prioritize cost when selecting a model. A value of 0 means cost\nis not important, while a value of 1 means cost is the most important\nfactor.", + "maximum": 1, + "minimum": 0, + "type": "number" + }, + "hints": { + "description": "Optional hints to use for model selection.\n\nIf multiple hints are specified, the client MUST evaluate them in order\n(such that the first match is taken).\n\nThe client SHOULD prioritize these hints over the numeric priorities, but\nMAY still use the priorities to select from ambiguous matches.", + "items": { + "$ref": "#/definitions/ModelHint" + }, + "type": "array" + }, + "intelligencePriority": { + "description": "How much to prioritize intelligence and capabilities when selecting a\nmodel. A value of 0 means intelligence is not important, while a value of 1\nmeans intelligence is the most important factor.", + "maximum": 1, + "minimum": 0, + "type": "number" + }, + "speedPriority": { + "description": "How much to prioritize sampling speed (latency) when selecting a model. A\nvalue of 0 means speed is not important, while a value of 1 means speed is\nthe most important factor.", + "maximum": 1, + "minimum": 0, + "type": "number" + } + }, + "type": "object" + }, + "Notification": { + "properties": { + "method": { + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "NumberSchema": { + "properties": { + "description": { + "type": "string" + }, + "maximum": { + "type": "integer" + }, + "minimum": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "type": { + "enum": [ + "integer", + "number" + ], + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "PaginatedRequest": { + "properties": { + "method": { + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "PaginatedResult": { + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + } + }, + "type": "object" + }, + "PingRequest": { + "description": "A ping, issued by either the server or the client, to check that the other party is still alive. The receiver must promptly respond, or else may be disconnected.", + "properties": { + "method": { + "const": "ping", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "properties": { + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications." + } + }, + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "PrimitiveSchemaDefinition": { + "anyOf": [ + { + "$ref": "#/definitions/StringSchema" + }, + { + "$ref": "#/definitions/NumberSchema" + }, + { + "$ref": "#/definitions/BooleanSchema" + }, + { + "$ref": "#/definitions/EnumSchema" + } + ], + "description": "Restricted schema definitions that only allow primitive types\nwithout nested objects or arrays." + }, + "ProgressNotification": { + "description": "An out-of-band notification used to inform the receiver of a progress update for a long-running request.", + "properties": { + "method": { + "const": "notifications/progress", + "type": "string" + }, + "params": { + "properties": { + "message": { + "description": "An optional message describing the current progress.", + "type": "string" + }, + "progress": { + "description": "The progress thus far. This should increase every time progress is made, even if the total is unknown.", + "type": "number" + }, + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "The progress token which was given in the initial request, used to associate this notification with the request that is proceeding." + }, + "total": { + "description": "Total number of items to process (or total progress required), if known.", + "type": "number" + } + }, + "required": [ + "progress", + "progressToken" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ProgressToken": { + "description": "A progress token, used to associate progress notifications with the original request.", + "type": [ + "string", + "integer" + ] + }, + "Prompt": { + "description": "A prompt or prompt template that the server offers.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "arguments": { + "description": "A list of arguments to use for templating the prompt.", + "items": { + "$ref": "#/definitions/PromptArgument" + }, + "type": "array" + }, + "description": { + "description": "An optional description of what this prompt provides", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }, + "PromptArgument": { + "description": "Describes an argument that a prompt can accept.", + "properties": { + "description": { + "description": "A human-readable description of the argument.", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "required": { + "description": "Whether this argument must be provided.", + "type": "boolean" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }, + "PromptListChangedNotification": { + "description": "An optional notification from the server to the client, informing it that the list of prompts it offers has changed. This may be issued by servers without any previous subscription from the client.", + "properties": { + "method": { + "const": "notifications/prompts/list_changed", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "PromptMessage": { + "description": "Describes a message returned as part of a prompt.\n\nThis is similar to `SamplingMessage`, but also supports the embedding of\nresources from the MCP server.", + "properties": { + "content": { + "$ref": "#/definitions/ContentBlock" + }, + "role": { + "$ref": "#/definitions/Role" + } + }, + "required": [ + "content", + "role" + ], + "type": "object" + }, + "PromptReference": { + "description": "Identifies a prompt.", + "properties": { + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "type": { + "const": "ref/prompt", + "type": "string" + } + }, + "required": [ + "name", + "type" + ], + "type": "object" + }, + "ReadResourceRequest": { + "description": "Sent from the client to the server, to read a specific resource URI.", + "properties": { + "method": { + "const": "resources/read", + "type": "string" + }, + "params": { + "properties": { + "uri": { + "description": "The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ReadResourceResult": { + "description": "The server's response to a resources/read request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "contents": { + "items": { + "anyOf": [ + { + "$ref": "#/definitions/TextResourceContents" + }, + { + "$ref": "#/definitions/BlobResourceContents" + } + ] + }, + "type": "array" + } + }, + "required": [ + "contents" + ], + "type": "object" + }, + "Request": { + "properties": { + "method": { + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "properties": { + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications." + } + }, + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "RequestId": { + "description": "A uniquely identifying ID for a request in JSON-RPC.", + "type": [ + "string", + "integer" + ] + }, + "Resource": { + "description": "A known resource that the server is capable of reading.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "description": { + "description": "A description of what this resource represents.\n\nThis can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a \"hint\" to the model.", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window usage.", + "type": "integer" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "name", + "uri" + ], + "type": "object" + }, + "ResourceContents": { + "description": "The contents of a specific resource or sub-resource.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + }, + "ResourceLink": { + "description": "A resource that the server is capable of reading, included in a prompt or tool call result.\n\nNote: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "description": { + "description": "A description of what this resource represents.\n\nThis can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a \"hint\" to the model.", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window usage.", + "type": "integer" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "type": { + "const": "resource_link", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "name", + "type", + "uri" + ], + "type": "object" + }, + "ResourceListChangedNotification": { + "description": "An optional notification from the server to the client, informing it that the list of resources it can read from has changed. This may be issued by servers without any previous subscription from the client.", + "properties": { + "method": { + "const": "notifications/resources/list_changed", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ResourceTemplate": { + "description": "A template description for resources available on the server.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "description": { + "description": "A description of what this template is for.\n\nThis can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a \"hint\" to the model.", + "type": "string" + }, + "mimeType": { + "description": "The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type.", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "uriTemplate": { + "description": "A URI template (according to RFC 6570) that can be used to construct resource URIs.", + "format": "uri-template", + "type": "string" + } + }, + "required": [ + "name", + "uriTemplate" + ], + "type": "object" + }, + "ResourceTemplateReference": { + "description": "A reference to a resource or resource template definition.", + "properties": { + "type": { + "const": "ref/resource", + "type": "string" + }, + "uri": { + "description": "The URI or URI template of the resource.", + "format": "uri-template", + "type": "string" + } + }, + "required": [ + "type", + "uri" + ], + "type": "object" + }, + "ResourceUpdatedNotification": { + "description": "A notification from the server to the client, informing it that a resource has changed and may need to be read again. This should only be sent if the client previously sent a resources/subscribe request.", + "properties": { + "method": { + "const": "notifications/resources/updated", + "type": "string" + }, + "params": { + "properties": { + "uri": { + "description": "The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "Result": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + }, + "Role": { + "description": "The sender or recipient of messages and data in a conversation.", + "enum": [ + "assistant", + "user" + ], + "type": "string" + }, + "Root": { + "description": "Represents a root directory or file that the server can operate on.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "name": { + "description": "An optional name for the root. This can be used to provide a human-readable\nidentifier for the root, which may be useful for display purposes or for\nreferencing the root in other parts of the application.", + "type": "string" + }, + "uri": { + "description": "The URI identifying the root. This *must* start with file:// for now.\nThis restriction may be relaxed in future versions of the protocol to allow\nother URI schemes.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + }, + "RootsListChangedNotification": { + "description": "A notification from the client to the server, informing it that the list of roots has changed.\nThis notification should be sent whenever the client adds, removes, or modifies any root.\nThe server should then request an updated list of roots using the ListRootsRequest.", + "properties": { + "method": { + "const": "notifications/roots/list_changed", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "SamplingMessage": { + "description": "Describes a message issued to or received from an LLM API.", + "properties": { + "content": { + "anyOf": [ + { + "$ref": "#/definitions/TextContent" + }, + { + "$ref": "#/definitions/ImageContent" + }, + { + "$ref": "#/definitions/AudioContent" + } + ] + }, + "role": { + "$ref": "#/definitions/Role" + } + }, + "required": [ + "content", + "role" + ], + "type": "object" + }, + "ServerCapabilities": { + "description": "Capabilities that a server may support. Known capabilities are defined here, in this schema, but this is not a closed set: any server can define its own, additional capabilities.", + "properties": { + "completions": { + "additionalProperties": true, + "description": "Present if the server supports argument autocompletion suggestions.", + "properties": {}, + "type": "object" + }, + "experimental": { + "additionalProperties": { + "additionalProperties": true, + "properties": {}, + "type": "object" + }, + "description": "Experimental, non-standard capabilities that the server supports.", + "type": "object" + }, + "logging": { + "additionalProperties": true, + "description": "Present if the server supports sending log messages to the client.", + "properties": {}, + "type": "object" + }, + "prompts": { + "description": "Present if the server offers any prompt templates.", + "properties": { + "listChanged": { + "description": "Whether this server supports notifications for changes to the prompt list.", + "type": "boolean" + } + }, + "type": "object" + }, + "resources": { + "description": "Present if the server offers any resources to read.", + "properties": { + "listChanged": { + "description": "Whether this server supports notifications for changes to the resource list.", + "type": "boolean" + }, + "subscribe": { + "description": "Whether this server supports subscribing to resource updates.", + "type": "boolean" + } + }, + "type": "object" + }, + "tools": { + "description": "Present if the server offers any tools to call.", + "properties": { + "listChanged": { + "description": "Whether this server supports notifications for changes to the tool list.", + "type": "boolean" + } + }, + "type": "object" + } + }, + "type": "object" + }, + "ServerNotification": { + "anyOf": [ + { + "$ref": "#/definitions/CancelledNotification" + }, + { + "$ref": "#/definitions/ProgressNotification" + }, + { + "$ref": "#/definitions/ResourceListChangedNotification" + }, + { + "$ref": "#/definitions/ResourceUpdatedNotification" + }, + { + "$ref": "#/definitions/PromptListChangedNotification" + }, + { + "$ref": "#/definitions/ToolListChangedNotification" + }, + { + "$ref": "#/definitions/LoggingMessageNotification" + } + ] + }, + "ServerRequest": { + "anyOf": [ + { + "$ref": "#/definitions/PingRequest" + }, + { + "$ref": "#/definitions/CreateMessageRequest" + }, + { + "$ref": "#/definitions/ListRootsRequest" + }, + { + "$ref": "#/definitions/ElicitRequest" + } + ] + }, + "ServerResult": { + "anyOf": [ + { + "$ref": "#/definitions/Result" + }, + { + "$ref": "#/definitions/InitializeResult" + }, + { + "$ref": "#/definitions/ListResourcesResult" + }, + { + "$ref": "#/definitions/ListResourceTemplatesResult" + }, + { + "$ref": "#/definitions/ReadResourceResult" + }, + { + "$ref": "#/definitions/ListPromptsResult" + }, + { + "$ref": "#/definitions/GetPromptResult" + }, + { + "$ref": "#/definitions/ListToolsResult" + }, + { + "$ref": "#/definitions/CallToolResult" + }, + { + "$ref": "#/definitions/CompleteResult" + } + ] + }, + "SetLevelRequest": { + "description": "A request from the client to the server, to enable or adjust logging.", + "properties": { + "method": { + "const": "logging/setLevel", + "type": "string" + }, + "params": { + "properties": { + "level": { + "$ref": "#/definitions/LoggingLevel", + "description": "The level of logging that the client wants to receive from the server. The server should send all logs at this level and higher (i.e., more severe) to the client as notifications/message." + } + }, + "required": [ + "level" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "StringSchema": { + "properties": { + "description": { + "type": "string" + }, + "format": { + "enum": [ + "date", + "date-time", + "email", + "uri" + ], + "type": "string" + }, + "maxLength": { + "type": "integer" + }, + "minLength": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "type": { + "const": "string", + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "SubscribeRequest": { + "description": "Sent from the client to request resources/updated notifications from the server whenever a particular resource changes.", + "properties": { + "method": { + "const": "resources/subscribe", + "type": "string" + }, + "params": { + "properties": { + "uri": { + "description": "The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "TextContent": { + "description": "Text provided to or from an LLM.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "text": { + "description": "The text content of the message.", + "type": "string" + }, + "type": { + "const": "text", + "type": "string" + } + }, + "required": [ + "text", + "type" + ], + "type": "object" + }, + "TextResourceContents": { + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "text": { + "description": "The text of the item. This must only be set if the item can actually be represented as text (not binary data).", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "text", + "uri" + ], + "type": "object" + }, + "Tool": { + "description": "Definition for a tool the client can call.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/ToolAnnotations", + "description": "Optional additional tool information.\n\nDisplay name precedence order is: title, annotations.title, then name." + }, + "description": { + "description": "A human-readable description of the tool.\n\nThis can be used by clients to improve the LLM's understanding of available tools. It can be thought of like a \"hint\" to the model.", + "type": "string" + }, + "inputSchema": { + "description": "A JSON Schema object defining the expected parameters for the tool.", + "properties": { + "properties": { + "additionalProperties": { + "additionalProperties": true, + "properties": {}, + "type": "object" + }, + "type": "object" + }, + "required": { + "items": { + "type": "string" + }, + "type": "array" + }, + "type": { + "const": "object", + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "outputSchema": { + "description": "An optional JSON Schema object defining the structure of the tool's output returned in\nthe structuredContent field of a CallToolResult.", + "properties": { + "properties": { + "additionalProperties": { + "additionalProperties": true, + "properties": {}, + "type": "object" + }, + "type": "object" + }, + "required": { + "items": { + "type": "string" + }, + "type": "array" + }, + "type": { + "const": "object", + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + } + }, + "required": [ + "inputSchema", + "name" + ], + "type": "object" + }, + "ToolAnnotations": { + "description": "Additional properties describing a Tool to clients.\n\nNOTE: all properties in ToolAnnotations are **hints**.\nThey are not guaranteed to provide a faithful description of\ntool behavior (including descriptive properties like `title`).\n\nClients should never make tool use decisions based on ToolAnnotations\nreceived from untrusted servers.", + "properties": { + "destructiveHint": { + "description": "If true, the tool may perform destructive updates to its environment.\nIf false, the tool performs only additive updates.\n\n(This property is meaningful only when `readOnlyHint == false`)\n\nDefault: true", + "type": "boolean" + }, + "idempotentHint": { + "description": "If true, calling the tool repeatedly with the same arguments\nwill have no additional effect on the its environment.\n\n(This property is meaningful only when `readOnlyHint == false`)\n\nDefault: false", + "type": "boolean" + }, + "openWorldHint": { + "description": "If true, this tool may interact with an \"open world\" of external\nentities. If false, the tool's domain of interaction is closed.\nFor example, the world of a web search tool is open, whereas that\nof a memory tool is not.\n\nDefault: true", + "type": "boolean" + }, + "readOnlyHint": { + "description": "If true, the tool does not modify its environment.\n\nDefault: false", + "type": "boolean" + }, + "title": { + "description": "A human-readable title for the tool.", + "type": "string" + } + }, + "type": "object" + }, + "ToolListChangedNotification": { + "description": "An optional notification from the server to the client, informing it that the list of tools it offers has changed. This may be issued by servers without any previous subscription from the client.", + "properties": { + "method": { + "const": "notifications/tools/list_changed", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/draft/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "UnsubscribeRequest": { + "description": "Sent from the client to request cancellation of resources/updated notifications from the server. This should follow a previous resources/subscribe request.", + "properties": { + "method": { + "const": "resources/unsubscribe", + "type": "string" + }, + "params": { + "properties": { + "uri": { + "description": "The URI of the resource to unsubscribe from.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + } + } +} \ No newline at end of file diff --git a/handler.go b/handler.go index 533c0cc..d91a19c 100644 --- a/handler.go +++ b/handler.go @@ -8,8 +8,24 @@ package mcp import ( "context" + "encoding/json" ) +// parseJSONRPCParams parses JSON-RPC parameters into a target structure +func parseJSONRPCParams(params interface{}, target interface{}) error { + if params == nil { + return nil + } + + // Convert params to JSON and then unmarshal into target + paramBytes, err := json.Marshal(params) + if err != nil { + return err + } + + return json.Unmarshal(paramBytes, target) +} + const ( // defaultServerName is the default name for the server defaultServerName = "Go-MCP-Server" @@ -39,6 +55,9 @@ type mcpHandler struct { // Prompt manager promptManager *promptManager + + // Middleware chain for server request processing + middlewares []MiddlewareFunc } // newMCPHandler creates an MCP protocol handler @@ -79,6 +98,13 @@ func newMCPHandler(options ...func(*mcpHandler)) *mcpHandler { return h } +// withMiddlewares sets the middleware chain for the handler +func withMiddlewares(middlewares []MiddlewareFunc) func(*mcpHandler) { + return func(h *mcpHandler) { + h.middlewares = middlewares + } +} + // withToolManager sets the tool manager func withToolManager(manager *toolManager) func(*mcpHandler) { return func(h *mcpHandler) { @@ -151,6 +177,44 @@ func (h *mcpHandler) handleToolsList(ctx context.Context, req *JSONRPCRequest, s } func (h *mcpHandler) handleToolsCall(ctx context.Context, req *JSONRPCRequest, session Session) (JSONRPCMessage, error) { + // Apply middleware chain for tool calls if middlewares are configured + if len(h.middlewares) > 0 { + // Parse the request to get CallToolRequest + var callToolReq CallToolRequest + if err := parseJSONRPCParams(req.Params, &callToolReq.Params); err != nil { + return newJSONRPCErrorResponse(req.ID, ErrCodeInvalidParams, "invalid params", err.Error()), nil + } + + // Define the final handler that calls the tool manager + handler := func(ctx context.Context, request interface{}) (interface{}, error) { + // Cast back to CallToolRequest + toolReq := request.(*CallToolRequest) + + // Create a new JSON-RPC request with the potentially modified params + modifiedReq := &JSONRPCRequest{ + JSONRPC: req.JSONRPC, + ID: req.ID, + Request: Request{ + Method: req.Method, + }, + Params: toolReq.Params, + } + + return h.toolManager.handleCallTool(ctx, modifiedReq, session) + } + + // Execute the middleware chain + chainedHandler := Chain(handler, h.middlewares...) + result, err := chainedHandler(ctx, &callToolReq) + + if err != nil { + return newJSONRPCErrorResponse(req.ID, ErrCodeInternal, "tool call failed", err.Error()), nil + } + + return result.(JSONRPCMessage), nil + } + + // Fallback to direct call without middleware return h.toolManager.handleCallTool(ctx, req, session) } @@ -159,6 +223,44 @@ func (h *mcpHandler) handleResourcesList(ctx context.Context, req *JSONRPCReques } func (h *mcpHandler) handleResourcesRead(ctx context.Context, req *JSONRPCRequest, session Session) (JSONRPCMessage, error) { + // Apply middleware chain for resource reads if middlewares are configured + if len(h.middlewares) > 0 { + // Parse the request to get ReadResourceRequest + var readResourceReq ReadResourceRequest + if err := parseJSONRPCParams(req.Params, &readResourceReq.Params); err != nil { + return newJSONRPCErrorResponse(req.ID, ErrCodeInvalidParams, "invalid params", err.Error()), nil + } + + // Define the final handler that calls the resource manager + handler := func(ctx context.Context, request interface{}) (interface{}, error) { + // Cast back to ReadResourceRequest + resourceReq := request.(*ReadResourceRequest) + + // Create a new JSON-RPC request with the potentially modified params + modifiedReq := &JSONRPCRequest{ + JSONRPC: req.JSONRPC, + ID: req.ID, + Request: Request{ + Method: req.Method, + }, + Params: resourceReq.Params, + } + + return h.resourceManager.handleReadResource(ctx, modifiedReq) + } + + // Execute the middleware chain + chainedHandler := Chain(handler, h.middlewares...) + result, err := chainedHandler(ctx, &readResourceReq) + + if err != nil { + return newJSONRPCErrorResponse(req.ID, ErrCodeInternal, "resource read failed", err.Error()), nil + } + + return result.(JSONRPCMessage), nil + } + + // Fallback to direct call without middleware return h.resourceManager.handleReadResource(ctx, req) } @@ -179,6 +281,44 @@ func (h *mcpHandler) handlePromptsList(ctx context.Context, req *JSONRPCRequest, } func (h *mcpHandler) handlePromptsGet(ctx context.Context, req *JSONRPCRequest, session Session) (JSONRPCMessage, error) { + // Apply middleware chain for prompt gets if middlewares are configured + if len(h.middlewares) > 0 { + // Parse the request to get GetPromptRequest + var getPromptReq GetPromptRequest + if err := parseJSONRPCParams(req.Params, &getPromptReq.Params); err != nil { + return newJSONRPCErrorResponse(req.ID, ErrCodeInvalidParams, "invalid params", err.Error()), nil + } + + // Define the final handler that calls the prompt manager + handler := func(ctx context.Context, request interface{}) (interface{}, error) { + // Cast back to GetPromptRequest + promptReq := request.(*GetPromptRequest) + + // Create a new JSON-RPC request with the potentially modified params + modifiedReq := &JSONRPCRequest{ + JSONRPC: req.JSONRPC, + ID: req.ID, + Request: Request{ + Method: req.Method, + }, + Params: promptReq.Params, + } + + return h.promptManager.handleGetPrompt(ctx, modifiedReq) + } + + // Execute the middleware chain + chainedHandler := Chain(handler, h.middlewares...) + result, err := chainedHandler(ctx, &getPromptReq) + + if err != nil { + return newJSONRPCErrorResponse(req.ID, ErrCodeInternal, "prompt get failed", err.Error()), nil + } + + return result.(JSONRPCMessage), nil + } + + // Fallback to direct call without middleware return h.promptManager.handleGetPrompt(ctx, req) } diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..1d182d0 --- /dev/null +++ b/middleware.go @@ -0,0 +1,554 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package mcp + +import ( + "context" + "fmt" + "log" + "runtime" + "sync" + "time" +) + +// MiddlewareError 中间件错误类型,支持错误分类和链路追踪 +type MiddlewareError struct { + Code string // 错误码 + Message string // 错误消息 + Cause error // 原始错误 + Context map[string]interface{} // 错误上下文 + Timestamp time.Time // 错误时间 + Trace []string // 调用堆栈 +} + +func (e *MiddlewareError) Error() string { + return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Cause) +} + +func (e *MiddlewareError) Unwrap() error { + return e.Cause +} + +// 错误码常量 +const ( + ErrCodeAuth = "AUTH_FAILED" + ErrCodeRateLimit = "RATE_LIMIT_EXCEEDED" + ErrCodeCircuitBreaker = "CIRCUIT_BREAKER_OPEN" + ErrCodeValidation = "VALIDATION_FAILED" + ErrCodeTimeout = "TIMEOUT" + ErrCodePanic = "PANIC_RECOVERED" + ErrCodeUnknown = "UNKNOWN_ERROR" +) + +// NewMiddlewareError 创建新的中间件错误 +func NewMiddlewareError(code, message string, cause error) *MiddlewareError { + // 获取调用堆栈 + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(3, pcs[:]) + + trace := make([]string, 0, n) + frames := runtime.CallersFrames(pcs[:n]) + for { + frame, more := frames.Next() + trace = append(trace, fmt.Sprintf("%s:%d %s", frame.File, frame.Line, frame.Function)) + if !more { + break + } + } + + return &MiddlewareError{ + Code: code, + Message: message, + Cause: cause, + Context: make(map[string]interface{}), + Timestamp: time.Now(), + Trace: trace, + } +} + +// AddContext 添加错误上下文信息 +func (e *MiddlewareError) AddContext(key string, value interface{}) *MiddlewareError { + e.Context[key] = value + return e +} + +// Handler 定义了中间件链末端处理请求的函数签名。 +// 这是实际执行业务逻辑(如发送网络请求或执行工具逻辑)的函数。 +type Handler func(ctx context.Context, req interface{}) (interface{}, error) + +// MiddlewareFunc 定义了中间件函数的接口。 +// 它接收上下文、请求以及链中的下一个处理器,允许在请求前后执行逻辑。 +type MiddlewareFunc func(ctx context.Context, req interface{}, next Handler) (interface{}, error) + +// MiddlewareChain 表示中间件执行链,按注册顺序执行中间件 +type MiddlewareChain struct { + middlewares []MiddlewareFunc +} + +// NewMiddlewareChain 创建一个新的中间件链 +func NewMiddlewareChain(middlewares ...MiddlewareFunc) *MiddlewareChain { + return &MiddlewareChain{ + middlewares: middlewares, + } +} + +// Use 添加中间件到链中 +func (mc *MiddlewareChain) Use(middleware MiddlewareFunc) { + mc.middlewares = append(mc.middlewares, middleware) +} + +// Execute 执行中间件链 +func (mc *MiddlewareChain) Execute(ctx context.Context, req interface{}, finalHandler Handler) (interface{}, error) { + return Chain(finalHandler, mc.middlewares...)(ctx, req) +} + +// Chain 将一系列中间件和一个最终的处理器链接起来,形成一个完整的执行链。 +// 中间件会按照参数顺序执行,最后一个参数的中间件在最外层最先执行。 +// 例如:Chain(handler, m1, m2) 的执行顺序是 m2 -> m1 -> handler +func Chain(handler Handler, middlewares ...MiddlewareFunc) Handler { + // 从最后一个中间件开始,将处理器逐层向内包装,使最后的中间件在最外层 + for i := len(middlewares) - 1; i >= 0; i-- { + handler = wrap(middlewares[i], handler) + } + return handler +} + +// wrap 是一个辅助函数,用于将一个中间件包装在下一个处理器周围。 +func wrap(m MiddlewareFunc, next Handler) Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + return m(ctx, req, next) + } +} + +// LoggingMiddleware 日志记录中间件,记录请求的基本信息和处理时间 +func LoggingMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + startTime := time.Now() + + // 记录请求开始 + log.Printf("[Middleware] Request started: %T", req) + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 记录请求结束和耗时 + duration := time.Since(startTime) + if err != nil { + log.Printf("[Middleware] Request failed after %v: %v", duration, err) + } else { + log.Printf("[Middleware] Request completed in %v", duration) + } + + return resp, err +} + +// RecoveryMiddleware 错误恢复中间件,捕获 panic 并转换为错误 +func RecoveryMiddleware(ctx context.Context, req interface{}, next Handler) (resp interface{}, err error) { + defer func() { + if r := recover(); r != nil { + log.Printf("[RecoveryMiddleware] Panic recovered: %v", r) + + // 将 panic 转换为结构化错误 + panicErr := NewMiddlewareError(ErrCodePanic, "panic recovered in middleware chain", + fmt.Errorf("panic: %v", r)) + panicErr.AddContext("panic_value", r) + panicErr.AddContext("request_type", fmt.Sprintf("%T", req)) + + resp = nil + err = panicErr + } + }() + + return next(ctx, req) +} + +// ToolHandlerMiddleware 工具处理中间件,专门处理 CallTool 请求 +func ToolHandlerMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 检查是否是工具调用请求 + if callToolReq, ok := req.(*CallToolRequest); ok { + log.Printf("[ToolMiddleware] Calling tool: %s", callToolReq.Params.Name) + + // 验证工具名称 + if callToolReq.Params.Name == "" { + return nil, fmt.Errorf("tool name is required") + } + + // 记录工具参数 + if len(callToolReq.Params.Arguments) > 0 { + log.Printf("[ToolMiddleware] Tool arguments: %v", callToolReq.Params.Arguments) + } + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 处理工具调用结果 + if err == nil { + if toolResult, ok := resp.(*CallToolResult); ok { + if toolResult.IsError { + log.Printf("[ToolMiddleware] Tool execution returned error") + } else { + log.Printf("[ToolMiddleware] Tool execution successful, content items: %d", len(toolResult.Content)) + } + } + } + + return resp, err + } + + // 对于非工具调用请求,直接传递给下一个处理器 + return next(ctx, req) +} + +// ResourceMiddleware 资源访问中间件,处理 ReadResource 请求 +func ResourceMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 检查是否是资源读取请求 + if readResourceReq, ok := req.(*ReadResourceRequest); ok { + log.Printf("[ResourceMiddleware] Reading resource: %s", readResourceReq.Params.URI) + + // 验证资源 URI + if readResourceReq.Params.URI == "" { + return nil, fmt.Errorf("resource URI is required") + } + + // 这里可以添加资源访问权限检查、缓存逻辑等 + // 例如:检查用户是否有权限访问该资源 + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 处理资源读取结果 + if err == nil { + log.Printf("[ResourceMiddleware] Resource read successful") + } + + return resp, err + } + + // 对于非资源读取请求,直接传递给下一个处理器 + return next(ctx, req) +} + +// PromptMiddleware 提示模板中间件,处理 GetPrompt 请求 +func PromptMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 检查是否是获取提示请求 + if getPromptReq, ok := req.(*GetPromptRequest); ok { + log.Printf("[PromptMiddleware] Getting prompt: %s", getPromptReq.Params.Name) + + // 验证提示名称 + if getPromptReq.Params.Name == "" { + return nil, fmt.Errorf("prompt name is required") + } + + // 这里可以添加提示模板的预处理、验证等逻辑 + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 处理获取提示结果 + if err == nil { + log.Printf("[PromptMiddleware] Prompt retrieved successfully") + } + + return resp, err + } + + // 对于非获取提示请求,直接传递给下一个处理器 + return next(ctx, req) +} + +// MetricsMiddleware 性能监控中间件,收集请求指标 +func MetricsMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + startTime := time.Now() + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 记录指标 + duration := time.Since(startTime) + + // 根据请求类型记录不同的指标 + requestType := fmt.Sprintf("%T", req) + status := "success" + if err != nil { + status = "error" + } + + // 这里可以集成到实际的监控系统(如 Prometheus) + log.Printf("[Metrics] RequestType: %s, Status: %s, Duration: %v", requestType, status, duration) + + return resp, err +} + +// AuthMiddleware 认证鉴权中间件 +func AuthMiddleware(apiKey string) MiddlewareFunc { + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 从上下文中获取认证信息 + if apiKey == "" { + err := NewMiddlewareError(ErrCodeAuth, "API key is required", nil) + err.AddContext("request_type", fmt.Sprintf("%T", req)) + return nil, err + } + + log.Printf("[AuthMiddleware] Request authenticated") + + // 在上下文中添加认证信息 + ctx = context.WithValue(ctx, "api_key", apiKey) + + return next(ctx, req) + } +} + +// RetryMiddleware 重试中间件,对失败的请求进行重试 +func RetryMiddleware(maxRetries int) MiddlewareFunc { + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + var lastErr error + + for attempt := 0; attempt <= maxRetries; attempt++ { + if attempt > 0 { + log.Printf("[RetryMiddleware] Retry attempt %d/%d", attempt, maxRetries) + // 添加退避延时 + time.Sleep(time.Duration(attempt) * time.Second) + } + + resp, err := next(ctx, req) + if err == nil { + return resp, nil + } + + lastErr = err + log.Printf("[RetryMiddleware] Attempt %d failed: %v", attempt+1, err) + } + + return nil, fmt.Errorf("request failed after %d retries: %v", maxRetries, lastErr) + } +} + +// CacheMiddleware 缓存中间件 +func CacheMiddleware(cache map[string]interface{}) MiddlewareFunc { + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 生成缓存键 + cacheKey := fmt.Sprintf("%T_%v", req, req) + + // 检查缓存 + if cached, exists := cache[cacheKey]; exists { + log.Printf("[CacheMiddleware] Cache hit for request: %T", req) + return cached, nil + } + + // 调用下一个处理器 + resp, err := next(ctx, req) + + // 如果成功,保存到缓存 + if err == nil { + cache[cacheKey] = resp + log.Printf("[CacheMiddleware] Cached response for request: %T", req) + } + + return resp, err + } +} + +// ValidationMiddleware 验证中间件,对请求进行验证 +func ValidationMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 根据请求类型进行不同的验证 + switch r := req.(type) { + case *CallToolRequest: + if r.Params.Name == "" { + return nil, fmt.Errorf("validation failed: tool name is required") + } + case *ReadResourceRequest: + if r.Params.URI == "" { + return nil, fmt.Errorf("validation failed: resource URI is required") + } + case *GetPromptRequest: + if r.Params.Name == "" { + return nil, fmt.Errorf("validation failed: prompt name is required") + } + } + + log.Printf("[ValidationMiddleware] Request validation passed for: %T", req) + return next(ctx, req) +} + +// RateLimitingMiddleware 限流中间件,控制请求频率 +func RateLimitingMiddleware(maxRequests int, window time.Duration) MiddlewareFunc { + requestCounts := make(map[string][]time.Time) + var mu sync.RWMutex // 添加线程安全保护 + + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 从上下文中获取客户端标识(可以是IP、用户ID等) + clientID := "default" // 简化实现,实际中应该从上下文获取 + if userID := ctx.Value("user_id"); userID != nil { + if id, ok := userID.(string); ok { + clientID = id + } + } + + now := time.Now() + + mu.Lock() + defer mu.Unlock() + + // 清理过期的请求记录 + if timestamps, exists := requestCounts[clientID]; exists { + var validTimestamps []time.Time + for _, ts := range timestamps { + if now.Sub(ts) < window { + validTimestamps = append(validTimestamps, ts) + } + } + requestCounts[clientID] = validTimestamps + } + + // 检查是否超过限制 + if len(requestCounts[clientID]) >= maxRequests { + err := NewMiddlewareError(ErrCodeRateLimit, + fmt.Sprintf("rate limit exceeded: %d requests per %v", maxRequests, window), nil) + err.AddContext("client_id", clientID) + err.AddContext("current_requests", len(requestCounts[clientID])) + err.AddContext("max_requests", maxRequests) + err.AddContext("window", window.String()) + return nil, err + } + + // 记录当前请求 + requestCounts[clientID] = append(requestCounts[clientID], now) + + log.Printf("[RateLimitingMiddleware] Request allowed for client %s (%d/%d)", + clientID, len(requestCounts[clientID]), maxRequests) + + return next(ctx, req) + } +} + +// CircuitBreakerMiddleware 熔断器中间件,防止级联故障 +func CircuitBreakerMiddleware(failureThreshold int, timeout time.Duration) MiddlewareFunc { + var ( + failureCount int + lastFailureTime time.Time + state string = "closed" // closed, open, half-open + mu sync.RWMutex // 添加线程安全保护 + ) + + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + now := time.Now() + + mu.Lock() + defer mu.Unlock() + + // 检查熔断器状态 + switch state { + case "open": + if now.Sub(lastFailureTime) > timeout { + state = "half-open" + log.Printf("[CircuitBreakerMiddleware] Circuit breaker transitioning to half-open") + } else { + err := NewMiddlewareError(ErrCodeCircuitBreaker, "circuit breaker is open", nil) + err.AddContext("failure_count", failureCount) + err.AddContext("failure_threshold", failureThreshold) + err.AddContext("time_until_retry", timeout-now.Sub(lastFailureTime)) + return nil, err + } + case "half-open": + // 半开状态,尝试一个请求 + case "closed": + // 关闭状态,正常处理 + } + + // 临时释放锁以执行下一个中间件 + mu.Unlock() + resp, err := next(ctx, req) + mu.Lock() + + if err != nil { + failureCount++ + lastFailureTime = now + + if failureCount >= failureThreshold && state != "open" { + state = "open" + log.Printf("[CircuitBreakerMiddleware] Circuit breaker opened due to %d failures", failureCount) + } + + return nil, err + } + + // 请求成功,重置失败计数 + if state == "half-open" { + state = "closed" + failureCount = 0 + log.Printf("[CircuitBreakerMiddleware] Circuit breaker closed after successful request") + } + + return resp, nil + } +} + +// CORSMiddleware CORS 中间件,处理跨域请求 +func CORSMiddleware(allowOrigins []string, allowMethods []string, allowHeaders []string) MiddlewareFunc { + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 这里可以添加 CORS 相关的处理逻辑 + // 在实际的 HTTP 层面处理 CORS,这里主要是记录 + log.Printf("[CORSMiddleware] Processing request with CORS policy") + + // 可以在上下文中添加 CORS 相关信息 + ctx = context.WithValue(ctx, "cors_allowed_origins", allowOrigins) + ctx = context.WithValue(ctx, "cors_allowed_methods", allowMethods) + ctx = context.WithValue(ctx, "cors_allowed_headers", allowHeaders) + + return next(ctx, req) + } +} + +// CompressionMiddleware 压缩中间件,处理响应压缩 +func CompressionMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + log.Printf("[CompressionMiddleware] Processing request for compression") + + // 在上下文中标记需要压缩 + ctx = context.WithValue(ctx, "compression_enabled", true) + + return next(ctx, req) +} + +// SecurityMiddleware 安全中间件,处理安全相关检查 +func SecurityMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 安全检查,如 SQL 注入防护、XSS 防护等 + log.Printf("[SecurityMiddleware] Performing security checks for request: %T", req) + + // 这里可以添加具体的安全检查逻辑 + // 例如:检查请求参数中的危险字符串 + + return next(ctx, req) +} + +// TimeoutMiddleware 超时中间件,设置请求超时 +func TimeoutMiddleware(timeout time.Duration) MiddlewareFunc { + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 创建带超时的上下文 + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // 使用通道来处理超时和正常完成 + type result struct { + resp interface{} + err error + } + + resultChan := make(chan result, 1) + + go func() { + resp, err := next(timeoutCtx, req) + resultChan <- result{resp, err} + }() + + select { + case res := <-resultChan: + return res.resp, res.err + case <-timeoutCtx.Done(): + return nil, fmt.Errorf("request timeout after %v", timeout) + } + } +} diff --git a/middleware_advanced_test.go b/middleware_advanced_test.go new file mode 100644 index 0000000..90ffbb2 --- /dev/null +++ b/middleware_advanced_test.go @@ -0,0 +1,318 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package mcp + +import ( + "context" + "fmt" + "sync" + "testing" + "time" +) + +// TestMiddlewareError 测试中间件错误系统 +func TestMiddlewareError(t *testing.T) { + t.Run("错误创建和堆栈追踪", func(t *testing.T) { + originalErr := fmt.Errorf("original error") + middlewareErr := NewMiddlewareError(ErrCodeAuth, "authentication failed", originalErr) + + if middlewareErr.Code != ErrCodeAuth { + t.Errorf("期望错误码 %s,得到 %s", ErrCodeAuth, middlewareErr.Code) + } + + if middlewareErr.Cause != originalErr { + t.Errorf("期望原始错误 %v,得到 %v", originalErr, middlewareErr.Cause) + } + + if len(middlewareErr.Trace) == 0 { + t.Error("期望非空的调用堆栈") + } + + // 测试错误链 + if unwrappedErr := middlewareErr.Unwrap(); unwrappedErr != originalErr { + t.Errorf("期望解包后的错误 %v,得到 %v", originalErr, unwrappedErr) + } + }) + + t.Run("错误上下文", func(t *testing.T) { + err := NewMiddlewareError(ErrCodeRateLimit, "rate limit exceeded", nil) + err.AddContext("client_id", "test-client") + err.AddContext("requests", 100) + + if clientID, exists := err.Context["client_id"]; !exists || clientID != "test-client" { + t.Error("期望上下文包含正确的 client_id") + } + + if requests, exists := err.Context["requests"]; !exists || requests != 100 { + t.Error("期望上下文包含正确的 requests") + } + }) +} + +// TestAuthMiddlewareError 测试认证中间件的错误处理 +func TestAuthMiddlewareError(t *testing.T) { + middleware := AuthMiddleware("") + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + } + + ctx := context.Background() + req := &CallToolRequest{} + + resp, err := middleware(ctx, req, handler) + + if resp != nil { + t.Error("期望响应为 nil") + } + + if err == nil { + t.Fatal("期望返回错误") + } + + middlewareErr, ok := err.(*MiddlewareError) + if !ok { + t.Fatal("期望返回 MiddlewareError 类型") + } + + if middlewareErr.Code != ErrCodeAuth { + t.Errorf("期望错误码 %s,得到 %s", ErrCodeAuth, middlewareErr.Code) + } + + if requestType, exists := middlewareErr.Context["request_type"]; !exists { + t.Error("期望上下文包含 request_type") + } else if requestType != "*mcp.CallToolRequest" { + t.Errorf("期望 request_type 为 *mcp.CallToolRequest,得到 %v", requestType) + } +} + +// TestRateLimitingMiddlewareThreadSafety 测试限流中间件的线程安全 +func TestRateLimitingMiddlewareThreadSafety(t *testing.T) { + middleware := RateLimitingMiddleware(10, time.Second) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + } + + var wg sync.WaitGroup + var successCount, errorCount int64 + var mu sync.Mutex + + // 并发执行 50 个请求 + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + + ctx := context.WithValue(context.Background(), "user_id", fmt.Sprintf("user-%d", i%5)) + req := &CallToolRequest{} + + _, err := middleware(ctx, req, handler) + + mu.Lock() + if err != nil { + errorCount++ + } else { + successCount++ + } + mu.Unlock() + }(i) + } + + wg.Wait() + + t.Logf("成功请求: %d, 失败请求: %d", successCount, errorCount) + + // 由于有5个不同的用户,每个用户可以有10个请求,所以最多50个请求都可能成功 + // 但由于并发执行,可能会有一些请求被限流 + if successCount == 0 { + t.Error("期望至少有一些成功的请求") + } +} + +// TestCircuitBreakerMiddlewareThreadSafety 测试熔断器中间件的线程安全 +func TestCircuitBreakerMiddlewareThreadSafety(t *testing.T) { + middleware := CircuitBreakerMiddleware(3, time.Second) + + // 创建一个会失败的处理器 + failingHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, fmt.Errorf("simulated failure") + } + + var wg sync.WaitGroup + var circuitBreakerErrors int64 + var otherErrors int64 + var mu sync.Mutex + + // 并发执行请求直到熔断器打开 + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + ctx := context.Background() + req := &CallToolRequest{} + + _, err := middleware(ctx, req, failingHandler) + + mu.Lock() + if err != nil { + if middlewareErr, ok := err.(*MiddlewareError); ok && middlewareErr.Code == ErrCodeCircuitBreaker { + circuitBreakerErrors++ + } else { + otherErrors++ + } + } + mu.Unlock() + }() + } + + wg.Wait() + + t.Logf("熔断器错误: %d, 其他错误: %d", circuitBreakerErrors, otherErrors) + + if circuitBreakerErrors == 0 { + t.Error("期望有一些熔断器错误") + } + + if otherErrors == 0 { + t.Error("期望有一些原始错误") + } +} + +// TestRecoveryMiddlewareError 测试恢复中间件的错误处理 +func TestRecoveryMiddlewareError(t *testing.T) { + panicHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + panic("test panic") + } + + ctx := context.Background() + req := &CallToolRequest{} + + resp, err := RecoveryMiddleware(ctx, req, panicHandler) + + if resp != nil { + t.Error("期望响应为 nil") + } + + if err == nil { + t.Fatal("期望返回错误") + } + + middlewareErr, ok := err.(*MiddlewareError) + if !ok { + t.Fatal("期望返回 MiddlewareError 类型") + } + + if middlewareErr.Code != ErrCodePanic { + t.Errorf("期望错误码 %s,得到 %s", ErrCodePanic, middlewareErr.Code) + } + + if panicValue, exists := middlewareErr.Context["panic_value"]; !exists { + t.Error("期望上下文包含 panic_value") + } else if panicValue != "test panic" { + t.Errorf("期望 panic_value 为 'test panic',得到 %v", panicValue) + } +} + +// TestMiddlewareChainErrorPropagation 测试中间件链中的错误传播 +func TestMiddlewareChainErrorPropagation(t *testing.T) { + // 创建一个会产生错误的中间件 + errorMiddleware := func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 不调用下一个中间件,直接返回错误 + return nil, NewMiddlewareError(ErrCodeValidation, "validation failed", nil) + } + + // 创建一个记录中间件 + var executed bool + loggingMiddleware := func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + executed = true + return next(ctx, req) + } + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + t.Error("不应该执行到最终处理器") + return "success", nil + } + + // 创建中间件链 + chain := NewMiddlewareChain() + // 调整中间件顺序,确保 loggingMiddleware 先于 errorMiddleware 执行 + chain.Use(loggingMiddleware) + chain.Use(errorMiddleware) + + ctx := context.Background() + req := &CallToolRequest{} + + resp, err := chain.Execute(ctx, req, handler) + + if resp != nil { + t.Error("期望响应为 nil") + } + + if err == nil { + t.Fatal("期望返回错误") + } + + if !executed { + t.Error("期望日志中间件被执行") + } + + middlewareErr, ok := err.(*MiddlewareError) + if !ok { + t.Fatal("期望返回 MiddlewareError 类型") + } + + if middlewareErr.Code != ErrCodeValidation { + t.Errorf("期望错误码 %s,得到 %s", ErrCodeValidation, middlewareErr.Code) + } +} + +// BenchmarkMiddlewareErrorCreation 基准测试错误创建性能 +func BenchmarkMiddlewareErrorCreation(b *testing.B) { + originalErr := fmt.Errorf("original error") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := NewMiddlewareError(ErrCodeAuth, "test error", originalErr) + err.AddContext("test", "value") + } +} + +// BenchmarkRateLimitingMiddleware 基准测试限流中间件性能 +func BenchmarkRateLimitingMiddleware(b *testing.B) { + middleware := RateLimitingMiddleware(1000, time.Second) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + } + + ctx := context.Background() + req := &CallToolRequest{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + middleware(ctx, req, handler) + } +} + +// BenchmarkCircuitBreakerMiddleware 基准测试熔断器中间件性能 +func BenchmarkCircuitBreakerMiddleware(b *testing.B) { + middleware := CircuitBreakerMiddleware(100, time.Second) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + } + + ctx := context.Background() + req := &CallToolRequest{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + middleware(ctx, req, handler) + } +} diff --git a/middleware_monitoring.go b/middleware_monitoring.go new file mode 100644 index 0000000..2e40877 --- /dev/null +++ b/middleware_monitoring.go @@ -0,0 +1,284 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "log" + "sync" + "sync/atomic" + "time" +) + +// MiddlewareMetrics 中间件性能指标 +type MiddlewareMetrics struct { + RequestCount int64 `json:"request_count"` + ErrorCount int64 `json:"error_count"` + TotalDuration time.Duration `json:"total_duration"` + AverageDuration time.Duration `json:"average_duration"` + MaxDuration time.Duration `json:"max_duration"` + MinDuration time.Duration `json:"min_duration"` + LastRequest time.Time `json:"last_request"` +} + +// MiddlewareMonitor 中间件监控器 +type MiddlewareMonitor struct { + metrics map[string]*MiddlewareMetrics + mu sync.RWMutex +} + +// NewMiddlewareMonitor 创建新的中间件监控器 +func NewMiddlewareMonitor() *MiddlewareMonitor { + return &MiddlewareMonitor{ + metrics: make(map[string]*MiddlewareMetrics), + } +} + +// RecordRequest 记录请求指标 +func (m *MiddlewareMonitor) RecordRequest(middlewareName string, duration time.Duration, hasError bool) { + m.mu.Lock() + defer m.mu.Unlock() + + metric, exists := m.metrics[middlewareName] + if !exists { + metric = &MiddlewareMetrics{ + MinDuration: duration, + MaxDuration: duration, + } + m.metrics[middlewareName] = metric + } + + atomic.AddInt64(&metric.RequestCount, 1) + if hasError { + atomic.AddInt64(&metric.ErrorCount, 1) + } + + metric.TotalDuration += duration + metric.AverageDuration = metric.TotalDuration / time.Duration(metric.RequestCount) + metric.LastRequest = time.Now() + + if duration > metric.MaxDuration { + metric.MaxDuration = duration + } + if duration < metric.MinDuration { + metric.MinDuration = duration + } +} + +// GetMetrics 获取指定中间件的指标 +func (m *MiddlewareMonitor) GetMetrics(middlewareName string) *MiddlewareMetrics { + m.mu.RLock() + defer m.mu.RUnlock() + + if metric, exists := m.metrics[middlewareName]; exists { + // 返回副本以避免并发访问问题 + return &MiddlewareMetrics{ + RequestCount: atomic.LoadInt64(&metric.RequestCount), + ErrorCount: atomic.LoadInt64(&metric.ErrorCount), + TotalDuration: metric.TotalDuration, + AverageDuration: metric.AverageDuration, + MaxDuration: metric.MaxDuration, + MinDuration: metric.MinDuration, + LastRequest: metric.LastRequest, + } + } + return nil +} + +// GetAllMetrics 获取所有中间件的指标 +func (m *MiddlewareMonitor) GetAllMetrics() map[string]*MiddlewareMetrics { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]*MiddlewareMetrics) + for name, metric := range m.metrics { + result[name] = &MiddlewareMetrics{ + RequestCount: atomic.LoadInt64(&metric.RequestCount), + ErrorCount: atomic.LoadInt64(&metric.ErrorCount), + TotalDuration: metric.TotalDuration, + AverageDuration: metric.AverageDuration, + MaxDuration: metric.MaxDuration, + MinDuration: metric.MinDuration, + LastRequest: metric.LastRequest, + } + } + return result +} + +// Reset 重置指定中间件的指标 +func (m *MiddlewareMonitor) Reset(middlewareName string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.metrics, middlewareName) +} + +// ResetAll 重置所有指标 +func (m *MiddlewareMonitor) ResetAll() { + m.mu.Lock() + defer m.mu.Unlock() + m.metrics = make(map[string]*MiddlewareMetrics) +} + +// PrintReport 打印性能报告 +func (m *MiddlewareMonitor) PrintReport() { + metrics := m.GetAllMetrics() + + log.Println("=== 中间件性能报告 ===") + for name, metric := range metrics { + errorRate := float64(metric.ErrorCount) / float64(metric.RequestCount) * 100 + log.Printf("中间件: %s", name) + log.Printf(" 请求总数: %d", metric.RequestCount) + log.Printf(" 错误总数: %d (%.2f%%)", metric.ErrorCount, errorRate) + log.Printf(" 平均响应时间: %v", metric.AverageDuration) + log.Printf(" 最大响应时间: %v", metric.MaxDuration) + log.Printf(" 最小响应时间: %v", metric.MinDuration) + log.Printf(" 最后请求时间: %v", metric.LastRequest.Format("2006-01-02 15:04:05")) + log.Println() + } +} + +// ToJSON 将指标转换为JSON格式 +func (m *MiddlewareMonitor) ToJSON() (string, error) { + metrics := m.GetAllMetrics() + data, err := json.MarshalIndent(metrics, "", " ") + if err != nil { + return "", err + } + return string(data), nil +} + +// 全局监控器实例 +var globalMonitor = NewMiddlewareMonitor() + +// GetGlobalMonitor 获取全局监控器 +func GetGlobalMonitor() *MiddlewareMonitor { + return globalMonitor +} + +// MonitoringMiddleware 监控中间件,自动记录性能指标 +func MonitoringMiddleware(middlewareName string) MiddlewareFunc { + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + startTime := time.Now() + + resp, err := next(ctx, req) + + duration := time.Since(startTime) + hasError := err != nil + + globalMonitor.RecordRequest(middlewareName, duration, hasError) + + // 记录详细日志 + if hasError { + log.Printf("[MonitoringMiddleware] %s 执行失败 - 耗时: %v, 错误: %v", + middlewareName, duration, err) + } else { + log.Printf("[MonitoringMiddleware] %s 执行成功 - 耗时: %v", + middlewareName, duration) + } + + return resp, err + } +} + +// HealthCheckMiddleware 健康检查中间件 +func HealthCheckMiddleware(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 检查系统健康状态 + metrics := globalMonitor.GetAllMetrics() + + var unhealthyMiddlewares []string + for name, metric := range metrics { + // 检查错误率是否过高(超过50%) + if metric.RequestCount > 0 { + errorRate := float64(metric.ErrorCount) / float64(metric.RequestCount) + if errorRate > 0.5 { + unhealthyMiddlewares = append(unhealthyMiddlewares, name) + } + } + + // 检查是否长时间没有请求(超过5分钟) + if time.Since(metric.LastRequest) > 5*time.Minute && metric.RequestCount > 0 { + log.Printf("[HealthCheckMiddleware] 中间件 %s 长时间无请求", name) + } + } + + if len(unhealthyMiddlewares) > 0 { + log.Printf("[HealthCheckMiddleware] 检测到不健康的中间件: %v", unhealthyMiddlewares) + } + + return next(ctx, req) +} + +// AlertingMiddleware 告警中间件 +func AlertingMiddleware(errorThreshold int64, responseTimeThreshold time.Duration) MiddlewareFunc { + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + startTime := time.Now() + resp, err := next(ctx, req) + duration := time.Since(startTime) + + // 检查响应时间告警 + if duration > responseTimeThreshold { + log.Printf("[AlertingMiddleware] 响应时间过长告警: %v > %v", duration, responseTimeThreshold) + } + + // 检查错误计数告警 + if err != nil { + metrics := globalMonitor.GetAllMetrics() + for name, metric := range metrics { + if metric.ErrorCount > errorThreshold { + log.Printf("[AlertingMiddleware] 错误计数过高告警: %s 中间件错误数 %d > %d", + name, metric.ErrorCount, errorThreshold) + } + } + } + + return resp, err + } +} + +// SamplingMiddleware 采样中间件,只处理部分请求 +func SamplingMiddleware(sampleRate float64) MiddlewareFunc { + if sampleRate < 0 || sampleRate > 1 { + sampleRate = 1.0 // 默认处理所有请求 + } + + var counter int64 + + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + currentCount := atomic.AddInt64(&counter, 1) + + // 简单的采样算法:每 1/sampleRate 个请求处理一个 + if sampleRate == 1.0 || float64(currentCount%int64(1/sampleRate)) == 0 { + log.Printf("[SamplingMiddleware] 处理采样请求 #%d", currentCount) + return next(ctx, req) + } + + // 跳过的请求返回默认响应 + log.Printf("[SamplingMiddleware] 跳过请求 #%d (采样率: %.2f)", currentCount, sampleRate) + return fmt.Sprintf("sampled_response_%d", currentCount), nil + } +} + +// LoadBalancingMiddleware 负载均衡中间件 +func LoadBalancingMiddleware(handlers []Handler) MiddlewareFunc { + if len(handlers) == 0 { + panic("LoadBalancingMiddleware: 至少需要一个处理器") + } + + var counter int64 + + return func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + // 轮询选择处理器 + index := atomic.AddInt64(&counter, 1) % int64(len(handlers)) + selectedHandler := handlers[index] + + log.Printf("[LoadBalancingMiddleware] 选择处理器 #%d", index) + + return selectedHandler(ctx, req) + } +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..c7603d7 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,439 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package mcp + +import ( + "context" + "fmt" + "strings" + "testing" + "time" +) + +// TestMiddlewareChain tests the middleware chain functionality +func TestMiddlewareChain(t *testing.T) { + // Create a test handler + finalHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "final response", nil + } + + // Create middleware that adds a prefix + middleware1 := func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + resp, err := next(ctx, req) + if err != nil { + return nil, err + } + return "middleware1-" + resp.(string), nil + } + + // Create middleware that adds another prefix + middleware2 := func(ctx context.Context, req interface{}, next Handler) (interface{}, error) { + resp, err := next(ctx, req) + if err != nil { + return nil, err + } + return "middleware2-" + resp.(string), nil + } + + // Chain the middlewares + chainedHandler := Chain(finalHandler, middleware1, middleware2) + + // Execute the chain + result, err := chainedHandler(context.Background(), "test request") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + expected := "middleware1-middleware2-final response" + if result != expected { + t.Fatalf("Expected %q, got %q", expected, result) + } +} + +// TestLoggingMiddleware tests the logging middleware +func TestLoggingMiddleware(t *testing.T) { + finalHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "response", nil + } + + // Execute with logging middleware + result, err := LoggingMiddleware(context.Background(), "test", finalHandler) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if result != "response" { + t.Fatalf("Expected 'response', got %v", result) + } +} + +// TestRecoveryMiddleware tests the recovery middleware +func TestRecoveryMiddleware(t *testing.T) { + // Create a handler that panics + panicHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + panic("test panic") + } + + // Execute with recovery middleware (should not panic) + _, err := RecoveryMiddleware(context.Background(), "test", panicHandler) + if err != nil { + t.Fatalf("Recovery middleware should handle panic gracefully, got error: %v", err) + } +} + +// TestToolHandlerMiddleware tests the tool handler middleware +func TestToolHandlerMiddleware(t *testing.T) { + finalHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return &CallToolResult{ + Content: []Content{NewTextContent("tool response")}, + IsError: false, + }, nil + } + + // Create a tool call request + toolReq := &CallToolRequest{ + Params: CallToolParams{ + Name: "test-tool", + Arguments: map[string]interface{}{ + "param1": "value1", + }, + }, + } + + // Execute with tool handler middleware + result, err := ToolHandlerMiddleware(context.Background(), toolReq, finalHandler) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + toolResult, ok := result.(*CallToolResult) + if !ok { + t.Fatalf("Expected CallToolResult, got %T", result) + } + + if toolResult.IsError { + t.Fatalf("Expected successful tool result, got error") + } +} + +// TestToolHandlerMiddlewareValidation tests tool handler middleware validation +func TestToolHandlerMiddlewareValidation(t *testing.T) { + finalHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + } + + // Create a tool call request with empty name + toolReq := &CallToolRequest{ + Params: CallToolParams{ + Name: "", // Empty name should trigger validation error + }, + } + + // Execute with tool handler middleware + _, err := ToolHandlerMiddleware(context.Background(), toolReq, finalHandler) + if err == nil { + t.Fatalf("Expected validation error for empty tool name") + } + + expectedError := "tool name is required" + if err.Error() != expectedError { + t.Fatalf("Expected error %q, got %q", expectedError, err.Error()) + } +} + +// TestMetricsMiddleware tests the metrics middleware +func TestMetricsMiddleware(t *testing.T) { + finalHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + time.Sleep(10 * time.Millisecond) // Simulate some processing time + return "response", nil + } + + // Execute with metrics middleware + result, err := MetricsMiddleware(context.Background(), "test", finalHandler) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if result != "response" { + t.Fatalf("Expected 'response', got %v", result) + } +} + +// TestRetryMiddleware tests the retry middleware +func TestRetryMiddleware(t *testing.T) { + attemptCount := 0 + + // Create a handler that fails twice then succeeds + flakyHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + attemptCount++ + if attemptCount < 3 { + return nil, fmt.Errorf("attempt %d failed", attemptCount) + } + return "success", nil + } + + // Create retry middleware with max 3 retries + retryMiddleware := RetryMiddleware(3) + + // Execute with retry middleware + result, err := retryMiddleware(context.Background(), "test", flakyHandler) + if err != nil { + t.Fatalf("Expected no error after retries, got %v", err) + } + + if result != "success" { + t.Fatalf("Expected 'success', got %v", result) + } + + if attemptCount != 3 { + t.Fatalf("Expected 3 attempts, got %d", attemptCount) + } +} + +// TestValidationMiddleware tests the validation middleware +func TestValidationMiddleware(t *testing.T) { + finalHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "response", nil + } + + // Test with valid tool request + validToolReq := &CallToolRequest{ + Params: CallToolParams{ + Name: "valid-tool", + }, + } + + result, err := ValidationMiddleware(context.Background(), validToolReq, finalHandler) + if err != nil { + t.Fatalf("Expected no error for valid request, got %v", err) + } + + if result != "response" { + t.Fatalf("Expected 'response', got %v", result) + } + + // Test with invalid tool request (empty name) + invalidToolReq := &CallToolRequest{ + Params: CallToolParams{ + Name: "", // Empty name should fail validation + }, + } + + _, err = ValidationMiddleware(context.Background(), invalidToolReq, finalHandler) + if err == nil { + t.Fatalf("Expected validation error for empty tool name") + } + + expectedError := "validation failed: tool name is required" + if err.Error() != expectedError { + t.Fatalf("Expected error %q, got %q", expectedError, err.Error()) + } +} + +// TestCacheMiddleware tests the cache middleware +func TestCacheMiddleware(t *testing.T) { + callCount := 0 + cache := make(map[string]interface{}) + + // Create a handler that increments call count + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + callCount++ + return fmt.Sprintf("response-%d", callCount), nil + } + + cacheMiddleware := CacheMiddleware(cache) + + // First call should execute handler + result1, err := cacheMiddleware(context.Background(), "test", handler) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Second call should return cached result + result2, err := cacheMiddleware(context.Background(), "test", handler) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Should have called handler only once + if callCount != 1 { + t.Fatalf("Expected handler to be called once, got %d calls", callCount) + } + + // Both results should be the same (from cache) + if result1 != result2 { + t.Fatalf("Expected same result from cache, got %v and %v", result1, result2) + } +} + +// BenchmarkMiddlewareChain benchmarks the middleware chain performance +func BenchmarkMiddlewareChain(b *testing.B) { + finalHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "response", nil + } + + // Create a chain with multiple middlewares + chainedHandler := Chain(finalHandler, + LoggingMiddleware, + MetricsMiddleware, + ValidationMiddleware, + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := chainedHandler(context.Background(), &CallToolRequest{ + Params: CallToolParams{Name: "test-tool"}, + }) + if err != nil { + b.Fatalf("Unexpected error: %v", err) + } + } +} + +// TestRateLimitingMiddleware tests the rate limiting middleware +func TestRateLimitingMiddleware(t *testing.T) { + middleware := RateLimitingMiddleware(2, time.Second*1) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + } + + // First two requests should succeed + for i := 0; i < 2; i++ { + _, err := middleware(context.Background(), "test", handler) + if err != nil { + t.Fatalf("Request %d should succeed, got error: %v", i+1, err) + } + } + + // Third request should be rate limited + _, err := middleware(context.Background(), "test", handler) + if err == nil { + t.Fatal("Third request should be rate limited") + } +} + +// TestCircuitBreakerMiddleware tests the circuit breaker middleware +func TestCircuitBreakerMiddleware(t *testing.T) { + middleware := CircuitBreakerMiddleware(2, time.Millisecond*100) + + failingHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, fmt.Errorf("simulated error") + } + + // First two requests should fail and trigger circuit breaker + for i := 0; i < 2; i++ { + _, err := middleware(context.Background(), "test", failingHandler) + if err == nil { + t.Fatalf("Request %d should fail", i+1) + } + } + + // Third request should be blocked by circuit breaker + _, err := middleware(context.Background(), "test", failingHandler) + if err == nil || err.Error() != "circuit breaker is open" { + t.Fatalf("Expected circuit breaker to be open, got: %v", err) + } + + // Wait for timeout and test half-open state + time.Sleep(time.Millisecond * 150) + + successHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + } + + // This should succeed and close the circuit breaker + _, err = middleware(context.Background(), "test", successHandler) + if err != nil { + t.Fatalf("Request should succeed in half-open state, got: %v", err) + } +} + +// TestTimeoutMiddleware tests the timeout middleware +func TestTimeoutMiddleware(t *testing.T) { + middleware := TimeoutMiddleware(time.Millisecond * 100) + + slowHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + time.Sleep(time.Millisecond * 200) + return "slow response", nil + } + + _, err := middleware(context.Background(), "test", slowHandler) + if err == nil || !strings.Contains(fmt.Sprintf("%v", err), "timeout") { + t.Fatalf("Expected timeout error, got: %v", err) + } + + fastHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "fast response", nil + } + + resp, err := middleware(context.Background(), "test", fastHandler) + if err != nil { + t.Fatalf("Fast handler should not timeout, got: %v", err) + } + if resp != "fast response" { + t.Fatalf("Expected 'fast response', got: %v", resp) + } +} + +// TestCORSMiddleware tests the CORS middleware +func TestCORSMiddleware(t *testing.T) { + allowOrigins := []string{"http://localhost:3000", "https://example.com"} + allowMethods := []string{"GET", "POST", "PUT", "DELETE"} + allowHeaders := []string{"Content-Type", "Authorization"} + + middleware := CORSMiddleware(allowOrigins, allowMethods, allowHeaders) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + // Check if CORS info is set in context + origins := ctx.Value("cors_allowed_origins") + if origins == nil { + return nil, fmt.Errorf("CORS origins not set in context") + } + return "success", nil + } + + _, err := middleware(context.Background(), "test", handler) + if err != nil { + t.Fatalf("CORS middleware should set context values, got error: %v", err) + } +} + +// TestSecurityMiddleware tests the security middleware +func TestSecurityMiddleware(t *testing.T) { + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return "secure response", nil + } + + resp, err := SecurityMiddleware(context.Background(), "test", handler) + if err != nil { + t.Fatalf("Security middleware should not fail for normal request, got: %v", err) + } + if resp != "secure response" { + t.Fatalf("Expected 'secure response', got: %v", resp) + } +} + +// TestCompressionMiddleware tests the compression middleware +func TestCompressionMiddleware(t *testing.T) { + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + // Check if compression is enabled in context + compressionEnabled := ctx.Value("compression_enabled") + if compressionEnabled != true { + return nil, fmt.Errorf("compression not enabled in context") + } + return "compressed response", nil + } + + resp, err := CompressionMiddleware(context.Background(), "test", handler) + if err != nil { + t.Fatalf("Compression middleware should set context values, got error: %v", err) + } + if resp != "compressed response" { + t.Fatalf("Expected 'compressed response', got: %v", resp) + } +} diff --git a/server.go b/server.go index 02c151b..19b1925 100644 --- a/server.go +++ b/server.go @@ -59,6 +59,9 @@ type serverConfig struct { // Method name modifier for external customization. methodNameModifier MethodNameModifier + + // Middleware chain for server request processing + middlewares []MiddlewareFunc } // Server MCP server @@ -85,6 +88,7 @@ func NewServer(name, version string, options ...ServerOption) *Server { postSSEEnabled: true, getSSEEnabled: true, notificationBufferSize: defaultNotificationBufferSize, + middlewares: []MiddlewareFunc{}, // Initialize middleware slice } // Create server with provided serverInfo @@ -147,6 +151,7 @@ func (s *Server) initComponents() { withLifecycleManager(lifecycleManager), withResourceManager(s.resourceManager), withPromptManager(s.promptManager), + withMiddlewares(s.config.middlewares), ) // Collect HTTP handler options. @@ -280,6 +285,20 @@ func WithToolListFilter(filter ToolListFilter) ServerOption { } } +// WithServerMiddleware adds a middleware to the server's request processing chain. +func WithServerMiddleware(m MiddlewareFunc) ServerOption { + return func(s *Server) { + s.config.middlewares = append(s.config.middlewares, m) + } +} + +// WithServerMiddlewares adds multiple middlewares to the server's request processing chain. +func WithServerMiddlewares(middlewares ...MiddlewareFunc) ServerOption { + return func(s *Server) { + s.config.middlewares = append(s.config.middlewares, middlewares...) + } +} + // WithServerAddress sets the server address func WithServerAddress(addr string) ServerOption { return func(s *Server) { @@ -475,3 +494,8 @@ func (s *Server) SetMethodNameModifier(modifier MethodNameModifier) { s.toolManager.withMethodNameModifier(modifier) } } + +// GetMiddlewares returns the server's middleware chain. +func (s *Server) GetMiddlewares() []MiddlewareFunc { + return s.config.middlewares +} diff --git a/verify.ps1 b/verify.ps1 new file mode 100644 index 0000000..bffcfc8 --- /dev/null +++ b/verify.ps1 @@ -0,0 +1,37 @@ +#!/usr/bin/env powershell + +Write-Host "🔧 验证中间件系统修复..." + +Write-Host "1. 检查编译错误..." +$compileResult = go build 2>&1 +if ($LASTEXITCODE -eq 0) { + Write-Host "✅ 编译成功" -ForegroundColor Green +} else { + Write-Host "❌ 编译失败:" -ForegroundColor Red + Write-Host $compileResult + exit 1 +} + +Write-Host "2. 运行中间件测试..." +$testResult = go test -v -run "TestMiddlewareChain" 2>&1 +if ($LASTEXITCODE -eq 0) { + Write-Host "✅ 测试通过" -ForegroundColor Green + Write-Host $testResult +} else { + Write-Host "❌ 测试失败:" -ForegroundColor Red + Write-Host $testResult +} + +Write-Host "3. 运行测试程序..." +Set-Location test_middleware +$runResult = go run main.go 2>&1 +if ($LASTEXITCODE -eq 0) { + Write-Host "✅ 测试程序运行成功" -ForegroundColor Green + Write-Host $runResult +} else { + Write-Host "❌ 测试程序运行失败:" -ForegroundColor Red + Write-Host $runResult +} +Set-Location .. + +Write-Host "🎉 验证完成!"