Skip to content

Commit 1773958

Browse files
committed
Implement basic rate-limiting based on session ID
Fixes: #22 Signed-off-by: Nimisha Mehta <nimishamehta5@gmail.com>
1 parent 057dd99 commit 1773958

File tree

5 files changed

+289
-5
lines changed

5 files changed

+289
-5
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ MKP is a Model Context Protocol (MCP) server for Kubernetes that allows LLM-powe
1515
- Apply (create or update) clustered resources
1616
- Apply (create or update) namespaced resources
1717
- Generic and pluggable implementation using API Machinery's unstructured client
18+
- Built-in rate limiting for protection against excessive API calls
1819

1920
## Why MKP?
2021

@@ -47,6 +48,7 @@ MKP offers several key advantages as a Model Context Protocol server for Kuberne
4748
### Production-Ready Architecture
4849
- Designed for reliability and performance in production environments
4950
- Proper error handling and resource management
51+
- Built-in rate limiting to protect against excessive API calls
5052
- Testable design with comprehensive unit tests
5153
- Follows Kubernetes development best practices
5254

@@ -291,6 +293,24 @@ By default, MKP operates in read-only mode, meaning it does not allow write oper
291293
./build/mkp-server --kubeconfig=/path/to/kubeconfig --read-write=true
292294
```
293295

296+
### Rate Limiting
297+
298+
MKP includes a built-in rate limiting mechanism to protect the server from excessive API calls, which is particularly important when used with AI agents. The rate limiter uses a token bucket algorithm and applies different limits based on the operation type:
299+
300+
- Read operations (list_resources, get_resource): 120 requests per minute
301+
- Write operations (apply_resource, delete_resource): 30 requests per minute
302+
- Default for other operations: 60 requests per minute
303+
304+
Rate limits are applied per client session, ensuring fair resource allocation across multiple clients. The rate limiting feature can be enabled or disabled via the server configuration.
305+
306+
```bash
307+
# Run with rate limiting enabled (default)
308+
task run
309+
310+
# Run with rate limiting disabled
311+
DISABLE_RATE_LIMITING=true task run
312+
```
313+
294314
## Development
295315

296316
### Running tests

cmd/server/main.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,17 @@ func main() {
9595
shutdownCh := make(chan error, 1)
9696
go func() {
9797
log.Println("Initiating server shutdown...")
98+
99+
// Stop the SSE server
98100
err := sseServer.Shutdown(shutdownCtx)
99101
if err != nil {
100102
log.Printf("Error during shutdown: %v", err)
101103
}
104+
105+
// Stop the MCP server resources (including rate limiter)
106+
log.Println("Stopping MCP server resources...")
107+
mcp.StopServer()
108+
102109
shutdownCh <- err
103110
close(shutdownCh)
104111
}()

pkg/mcp/server.go

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"github.com/mark3labs/mcp-go/server"
88

99
"github.com/StacklokLabs/mkp/pkg/k8s"
10+
"github.com/StacklokLabs/mkp/pkg/ratelimit"
1011
)
1112

1213
// Config holds configuration options for the MCP server
@@ -18,16 +19,29 @@ type Config struct {
1819
// ReadWrite determines whether the MCP server can modify resources in the cluster
1920
// When false, the server operates in read-only mode and does not serve the apply_resource tool
2021
ReadWrite bool
22+
23+
// EnableRateLimiting determines whether to enable rate limiting for tool calls
24+
// When true, a default rate limiter will be used to prevent excessive API calls
25+
EnableRateLimiting bool
2126
}
2227

2328
// DefaultConfig returns a Config with default values
2429
func DefaultConfig() *Config {
2530
return &Config{
26-
ServeResources: true, // Default to serving resources for backward compatibility
27-
ReadWrite: false, // Default to read-only mode
31+
ServeResources: true, // Default to serving resources for backward compatibility
32+
ReadWrite: false, // Default to read-only mode
33+
EnableRateLimiting: true, // Default to enabling rate limiting
2834
}
2935
}
3036

37+
// serverResources holds resources that need to be cleaned up when the server is stopped
38+
type serverResources struct {
39+
rateLimiter *ratelimit.RateLimiter
40+
}
41+
42+
// Global variable to hold server resources
43+
var resources *serverResources
44+
3145
// CreateServer creates a new MCP server for Kubernetes
3246
func CreateServer(k8sClient *k8s.Client, config *Config) *server.MCPServer {
3347
// Use default config if none provided
@@ -37,12 +51,29 @@ func CreateServer(k8sClient *k8s.Client, config *Config) *server.MCPServer {
3751
// Create MCP implementation
3852
impl := NewImplementation(k8sClient)
3953

40-
// Create MCP server
54+
options := []server.ServerOption{
55+
server.WithResourceCapabilities(true, true),
56+
server.WithToolCapabilities(true),
57+
}
58+
59+
// Add rate limiting middleware if enabled
60+
if config.EnableRateLimiting {
61+
// Create and store the rate limiter for later cleanup
62+
limiter := ratelimit.GetDefaultRateLimiter()
63+
64+
// Store the limiter for cleanup when the server is stopped
65+
resources = &serverResources{
66+
rateLimiter: limiter,
67+
}
68+
69+
options = append(options, server.WithToolHandlerMiddleware(limiter.Middleware()))
70+
}
71+
72+
// Create MCP server with all options
4173
mcpServer := server.NewMCPServer(
4274
"kubernetes-mcp-server",
4375
"0.1.0",
44-
server.WithResourceCapabilities(true, true),
45-
server.WithToolCapabilities(true),
76+
options...,
4677
)
4778

4879
// Add tools
@@ -84,6 +115,17 @@ func CreateServer(k8sClient *k8s.Client, config *Config) *server.MCPServer {
84115
return mcpServer
85116
}
86117

118+
// StopServer stops the MCP server and cleans up resources
119+
func StopServer() {
120+
// Clean up resources
121+
if resources != nil {
122+
// Stop the rate limiter if it exists
123+
if resources.rateLimiter != nil {
124+
resources.rateLimiter.Stop()
125+
}
126+
}
127+
}
128+
87129
// CreateSSEServer creates a new SSE server for the MCP server
88130
func CreateSSEServer(mcpServer *server.MCPServer) *server.SSEServer {
89131
return server.NewSSEServer(mcpServer)

pkg/ratelimit/config.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package ratelimit
2+
3+
const defaultLimit = 60
4+
5+
// DefaultConfig defines the default rate limits for different tools
6+
var DefaultConfig = map[string]int{
7+
// Read operations - higher limits
8+
"list_resources": 120, // 120 requests per minute (2 per second)
9+
"get_resource": 120, // 120 requests per minute (2 per second)
10+
"read_resource": 120, // 120 requests per minute (2 per second)
11+
12+
// Write operations - lower limits
13+
"apply_resource": 30, // 30 requests per minute (0.5 per second)
14+
"delete_resource": 30, // 30 requests per minute (0.5 per second)
15+
16+
// Default for any other tool
17+
"default": defaultLimit,
18+
}
19+
20+
// GetDefaultRateLimiter returns a RateLimiter with default configuration
21+
func GetDefaultRateLimiter() *RateLimiter {
22+
options := []RateLimiterOption{
23+
WithDefaultLimit(DefaultConfig["default"]),
24+
}
25+
26+
// Add tool-specific limits
27+
for tool, limit := range DefaultConfig {
28+
if tool != "default" {
29+
options = append(options, WithToolLimit(tool, limit))
30+
}
31+
}
32+
33+
return NewRateLimiter(options...)
34+
}

pkg/ratelimit/ratelimit.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package ratelimit
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"time"
8+
9+
"github.com/mark3labs/mcp-go/mcp"
10+
"github.com/mark3labs/mcp-go/server"
11+
)
12+
13+
const (
14+
cleanupInterval = 10 * time.Minute
15+
bucketTimeout = 30 * time.Minute
16+
)
17+
18+
// RateLimiter implements a rate limiting middleware for MCP server
19+
// It uses a token bucket algorithm to limit the number of requests per minute for each tool
20+
type RateLimiter struct {
21+
mu sync.RWMutex
22+
limits map[string]int // Tool name to requests per minute
23+
defaultLimit int // Default requests per minute
24+
buckets map[string]map[string]*bucket // SessionID:[Tool:Bucket] mapping
25+
cleanupTicker *time.Ticker
26+
}
27+
28+
// bucket represents a token bucket for rate limiting
29+
type bucket struct {
30+
mu sync.Mutex
31+
tokens int // Current number of tokens
32+
lastSeen time.Time // Last time this bucket was accessed
33+
}
34+
35+
// RateLimiterOption is a function that configures a RateLimiter
36+
type RateLimiterOption func(*RateLimiter)
37+
38+
// WithToolLimit sets the rate limit for a specific tool
39+
func WithToolLimit(toolName string, requestsPerMinute int) RateLimiterOption {
40+
return func(rl *RateLimiter) {
41+
rl.limits[toolName] = requestsPerMinute
42+
}
43+
}
44+
45+
// WithDefaultLimit sets the default rate limit for all tools
46+
func WithDefaultLimit(requestsPerMinute int) RateLimiterOption {
47+
return func(rl *RateLimiter) {
48+
rl.defaultLimit = requestsPerMinute
49+
}
50+
}
51+
52+
// NewRateLimiter creates a new rate limiter with the given options
53+
func NewRateLimiter(opts ...RateLimiterOption) *RateLimiter {
54+
rl := &RateLimiter{
55+
limits: make(map[string]int),
56+
defaultLimit: defaultLimit,
57+
buckets: make(map[string]map[string]*bucket),
58+
}
59+
60+
for _, opt := range opts {
61+
opt(rl)
62+
}
63+
64+
// Start a cleanup ticker to remove old buckets
65+
rl.cleanupTicker = time.NewTicker(cleanupInterval)
66+
go func() {
67+
for range rl.cleanupTicker.C {
68+
rl.cleanup()
69+
}
70+
}()
71+
72+
return rl
73+
}
74+
75+
func (rl *RateLimiter) cleanup() {
76+
rl.mu.Lock()
77+
defer rl.mu.Unlock()
78+
79+
now := time.Now()
80+
for sessionID, toolBuckets := range rl.buckets {
81+
for tool, b := range toolBuckets {
82+
b.mu.Lock()
83+
// If bucket hasn't been used for bucketTimeout, remove it
84+
if now.Sub(b.lastSeen) > bucketTimeout {
85+
delete(toolBuckets, tool)
86+
}
87+
b.mu.Unlock()
88+
}
89+
// If no more buckets for this session, remove the session entry
90+
if len(toolBuckets) == 0 {
91+
delete(rl.buckets, sessionID)
92+
}
93+
}
94+
}
95+
96+
// Stop stops the cleanup ticker
97+
func (rl *RateLimiter) Stop() {
98+
if rl.cleanupTicker != nil {
99+
rl.cleanupTicker.Stop()
100+
}
101+
}
102+
103+
// getSessionID extracts the session ID from the request context
104+
func getSessionID(ctx context.Context) string {
105+
// Get the session from the context
106+
if session := server.ClientSessionFromContext(ctx); session != nil {
107+
return session.SessionID()
108+
}
109+
// If no session is available (which shouldn't happen in normal operation),
110+
// return a default identifier
111+
return "unknown"
112+
}
113+
114+
// getBucket gets or creates a bucket for the given session ID and tool
115+
func (rl *RateLimiter) getBucket(sessionID, tool string) *bucket {
116+
rl.mu.Lock()
117+
defer rl.mu.Unlock()
118+
119+
// Create session map if it doesn't exist
120+
if _, ok := rl.buckets[sessionID]; !ok {
121+
rl.buckets[sessionID] = make(map[string]*bucket)
122+
}
123+
124+
// Create bucket if it doesn't exist
125+
if _, ok := rl.buckets[sessionID][tool]; !ok {
126+
rl.buckets[sessionID][tool] = &bucket{
127+
tokens: rl.getLimit(tool), // Initialize with full tokens
128+
lastSeen: time.Now(),
129+
}
130+
}
131+
132+
return rl.buckets[sessionID][tool]
133+
}
134+
135+
// getLimit returns the rate limit for the given tool
136+
func (rl *RateLimiter) getLimit(tool string) int {
137+
rl.mu.RLock()
138+
defer rl.mu.RUnlock()
139+
140+
if limit, ok := rl.limits[tool]; ok {
141+
return limit
142+
}
143+
return rl.defaultLimit
144+
}
145+
146+
// Middleware returns a middleware function for the MCP server
147+
func (rl *RateLimiter) Middleware() server.ToolHandlerMiddleware {
148+
return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
149+
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
150+
sessionID := getSessionID(ctx)
151+
tool := request.Params.Name
152+
153+
b := rl.getBucket(sessionID, tool)
154+
b.mu.Lock()
155+
defer b.mu.Unlock()
156+
157+
now := time.Now()
158+
b.lastSeen = now
159+
160+
// Calculate tokens to add based on time elapsed
161+
limit := rl.getLimit(tool)
162+
tokensPerSecond := float64(limit) / 60.0
163+
elapsed := now.Sub(b.lastSeen).Seconds()
164+
tokensToAdd := int(elapsed * tokensPerSecond)
165+
166+
// Add tokens, but don't exceed the limit
167+
b.tokens = min(b.tokens+tokensToAdd, limit)
168+
169+
// Check if we have enough tokens
170+
if b.tokens <= 0 {
171+
return mcp.NewToolResultError(fmt.Sprintf("Rate limit exceeded for tool '%s'. Try again later.", tool)), nil
172+
}
173+
174+
// Consume a token
175+
b.tokens--
176+
177+
// Call the next handler
178+
return next(ctx, request)
179+
}
180+
}
181+
}

0 commit comments

Comments
 (0)