|
| 1 | +package ratelimit |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "sync" |
| 7 | + "time" |
| 8 | + |
| 9 | + "github.com/mark3labs/mcp-go/mcp" |
| 10 | + "github.com/mark3labs/mcp-go/server" |
| 11 | +) |
| 12 | + |
| 13 | +const ( |
| 14 | + cleanupInterval = 10 * time.Minute |
| 15 | + bucketTimeout = 30 * time.Minute |
| 16 | +) |
| 17 | + |
| 18 | +// RateLimiter implements a rate limiting middleware for MCP server |
| 19 | +// It uses a token bucket algorithm to limit the number of requests per minute for each tool |
| 20 | +type RateLimiter struct { |
| 21 | + mu sync.RWMutex |
| 22 | + limits map[string]int // Tool name to requests per minute |
| 23 | + defaultLimit int // Default requests per minute |
| 24 | + buckets map[string]map[string]*bucket // SessionID:[Tool:Bucket] mapping |
| 25 | + cleanupTicker *time.Ticker |
| 26 | +} |
| 27 | + |
| 28 | +// bucket represents a token bucket for rate limiting |
| 29 | +type bucket struct { |
| 30 | + mu sync.Mutex |
| 31 | + tokens int // Current number of tokens |
| 32 | + lastSeen time.Time // Last time this bucket was accessed |
| 33 | +} |
| 34 | + |
| 35 | +// RateLimiterOption is a function that configures a RateLimiter |
| 36 | +type RateLimiterOption func(*RateLimiter) |
| 37 | + |
| 38 | +// WithToolLimit sets the rate limit for a specific tool |
| 39 | +func WithToolLimit(toolName string, requestsPerMinute int) RateLimiterOption { |
| 40 | + return func(rl *RateLimiter) { |
| 41 | + rl.limits[toolName] = requestsPerMinute |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +// WithDefaultLimit sets the default rate limit for all tools |
| 46 | +func WithDefaultLimit(requestsPerMinute int) RateLimiterOption { |
| 47 | + return func(rl *RateLimiter) { |
| 48 | + rl.defaultLimit = requestsPerMinute |
| 49 | + } |
| 50 | +} |
| 51 | + |
| 52 | +// NewRateLimiter creates a new rate limiter with the given options |
| 53 | +func NewRateLimiter(opts ...RateLimiterOption) *RateLimiter { |
| 54 | + rl := &RateLimiter{ |
| 55 | + limits: make(map[string]int), |
| 56 | + defaultLimit: defaultLimit, |
| 57 | + buckets: make(map[string]map[string]*bucket), |
| 58 | + } |
| 59 | + |
| 60 | + for _, opt := range opts { |
| 61 | + opt(rl) |
| 62 | + } |
| 63 | + |
| 64 | + // Start a cleanup ticker to remove old buckets |
| 65 | + rl.cleanupTicker = time.NewTicker(cleanupInterval) |
| 66 | + go func() { |
| 67 | + for range rl.cleanupTicker.C { |
| 68 | + rl.cleanup() |
| 69 | + } |
| 70 | + }() |
| 71 | + |
| 72 | + return rl |
| 73 | +} |
| 74 | + |
| 75 | +func (rl *RateLimiter) cleanup() { |
| 76 | + rl.mu.Lock() |
| 77 | + defer rl.mu.Unlock() |
| 78 | + |
| 79 | + now := time.Now() |
| 80 | + for sessionID, toolBuckets := range rl.buckets { |
| 81 | + for tool, b := range toolBuckets { |
| 82 | + b.mu.Lock() |
| 83 | + // If bucket hasn't been used for bucketTimeout, remove it |
| 84 | + if now.Sub(b.lastSeen) > bucketTimeout { |
| 85 | + delete(toolBuckets, tool) |
| 86 | + } |
| 87 | + b.mu.Unlock() |
| 88 | + } |
| 89 | + // If no more buckets for this session, remove the session entry |
| 90 | + if len(toolBuckets) == 0 { |
| 91 | + delete(rl.buckets, sessionID) |
| 92 | + } |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +// Stop stops the cleanup ticker |
| 97 | +func (rl *RateLimiter) Stop() { |
| 98 | + if rl.cleanupTicker != nil { |
| 99 | + rl.cleanupTicker.Stop() |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +// getSessionID extracts the session ID from the request context |
| 104 | +func getSessionID(ctx context.Context) string { |
| 105 | + // Get the session from the context |
| 106 | + if session := server.ClientSessionFromContext(ctx); session != nil { |
| 107 | + return session.SessionID() |
| 108 | + } |
| 109 | + // If no session is available (which shouldn't happen in normal operation), |
| 110 | + // return a default identifier |
| 111 | + return "unknown" |
| 112 | +} |
| 113 | + |
| 114 | +// getBucket gets or creates a bucket for the given session ID and tool |
| 115 | +func (rl *RateLimiter) getBucket(sessionID, tool string) *bucket { |
| 116 | + rl.mu.Lock() |
| 117 | + defer rl.mu.Unlock() |
| 118 | + |
| 119 | + // Create session map if it doesn't exist |
| 120 | + if _, ok := rl.buckets[sessionID]; !ok { |
| 121 | + rl.buckets[sessionID] = make(map[string]*bucket) |
| 122 | + } |
| 123 | + |
| 124 | + // Create bucket if it doesn't exist |
| 125 | + if _, ok := rl.buckets[sessionID][tool]; !ok { |
| 126 | + rl.buckets[sessionID][tool] = &bucket{ |
| 127 | + tokens: rl.getLimit(tool), // Initialize with full tokens |
| 128 | + lastSeen: time.Now(), |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + return rl.buckets[sessionID][tool] |
| 133 | +} |
| 134 | + |
| 135 | +// getLimit returns the rate limit for the given tool |
| 136 | +func (rl *RateLimiter) getLimit(tool string) int { |
| 137 | + rl.mu.RLock() |
| 138 | + defer rl.mu.RUnlock() |
| 139 | + |
| 140 | + if limit, ok := rl.limits[tool]; ok { |
| 141 | + return limit |
| 142 | + } |
| 143 | + return rl.defaultLimit |
| 144 | +} |
| 145 | + |
| 146 | +// Middleware returns a middleware function for the MCP server |
| 147 | +func (rl *RateLimiter) Middleware() server.ToolHandlerMiddleware { |
| 148 | + return func(next server.ToolHandlerFunc) server.ToolHandlerFunc { |
| 149 | + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { |
| 150 | + sessionID := getSessionID(ctx) |
| 151 | + tool := request.Params.Name |
| 152 | + |
| 153 | + b := rl.getBucket(sessionID, tool) |
| 154 | + b.mu.Lock() |
| 155 | + defer b.mu.Unlock() |
| 156 | + |
| 157 | + now := time.Now() |
| 158 | + b.lastSeen = now |
| 159 | + |
| 160 | + // Calculate tokens to add based on time elapsed |
| 161 | + limit := rl.getLimit(tool) |
| 162 | + tokensPerSecond := float64(limit) / 60.0 |
| 163 | + elapsed := now.Sub(b.lastSeen).Seconds() |
| 164 | + tokensToAdd := int(elapsed * tokensPerSecond) |
| 165 | + |
| 166 | + // Add tokens, but don't exceed the limit |
| 167 | + b.tokens = min(b.tokens+tokensToAdd, limit) |
| 168 | + |
| 169 | + // Check if we have enough tokens |
| 170 | + if b.tokens <= 0 { |
| 171 | + return mcp.NewToolResultError(fmt.Sprintf("Rate limit exceeded for tool '%s'. Try again later.", tool)), nil |
| 172 | + } |
| 173 | + |
| 174 | + // Consume a token |
| 175 | + b.tokens-- |
| 176 | + |
| 177 | + // Call the next handler |
| 178 | + return next(ctx, request) |
| 179 | + } |
| 180 | + } |
| 181 | +} |
0 commit comments