diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d1d1adb..fd183df 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,11 +7,15 @@ on: pull_request: branches: ["main"] +env: + GO_VERSION: "1.25.1" + jobs: - build: + test: + name: Test strategy: matrix: - os: [ ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - name: Checkout Project @@ -22,28 +26,125 @@ jobs: - name: Set up Go uses: actions/setup-go@v6 with: - go-version: "1.25.1" - - - name: Install Linters - run: | - go install github.com/mgechev/revive@latest - go install honnef.co/go/tools/cmd/staticcheck@latest + go-version: ${{ env.GO_VERSION }} - - name: Set up Goreleaser - uses: goreleaser/goreleaser-action@v6 + - name: Cache Go modules + uses: actions/cache@v4 with: - install-only: true + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- - - name: Lint - run: make lint + - name: Install Dependencies + run: go mod download + + - name: Verify Dependencies + run: make deps-verify - name: Check Code Formatting run: make format-check - - name: Test + - name: Run Tests run: make test - - name: Build Dist + - name: Run Race Tests + run: make test-race + + - name: Generate Coverage Report + run: make test-coverage + + - name: Upload Coverage to Codecov + if: matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v4 + with: + file: ./coverage.out + flags: unittests + name: codecov-umbrella + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - name: Checkout Project + uses: actions/checkout@v5 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Install Linters + run: make install-tools + + - name: Run Linting + run: make lint + + - name: Run Static Analysis + uses: dominikh/staticcheck-action@v1.3.1 + with: + version: "2024.1.1" + + security: + name: Security Scan + runs-on: ubuntu-latest + steps: + - name: Checkout Project + uses: actions/checkout@v5 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Run Security Scan + run: make security-scan + + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@master + with: + scan-type: 'fs' + scan-ref: '.' + format: 'sarif' + output: 'trivy-results.sarif' + + - name: Upload Trivy scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v3 + if: always() + with: + sarif_file: 'trivy-results.sarif' + + build: + name: Build + needs: [test, lint, security] + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Project + uses: actions/checkout@v5 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Set up Goreleaser + uses: goreleaser/goreleaser-action@v6 + with: + install-only: true + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '18' + cache: 'npm' + cache-dependency-path: './scripts/npm/package-lock.json' + + - name: Build Distribution run: make build-dist-snapshot - name: Run Platform Tests @@ -56,4 +157,54 @@ jobs: set -e OUTPUT=$(npx ./scripts/npm --log-level debug test-key 2>&1 || true) echo "$OUTPUT" - echo "$OUTPUT" | grep -q "starting MCP server" \ No newline at end of file + echo "$OUTPUT" | grep -q "starting MCP server" + + - name: Upload Build Artifacts + uses: actions/upload-artifact@v4 + with: + name: build-artifacts-${{ matrix.os }} + path: | + dist/ + scripts/npm/dist/ + retention-days: 7 + + docker: + name: Docker Build + needs: [test, lint, security] + runs-on: ubuntu-latest + steps: + - name: Checkout Project + uses: actions/checkout@v5 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build Docker Image + run: make docker-build + + - name: Test Docker Image + run: make docker-test + + benchmark: + name: Benchmark + runs-on: ubuntu-latest + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + steps: + - name: Checkout Project + uses: actions/checkout@v5 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Run Benchmarks + run: make benchmark + + - name: Store Benchmark Results + uses: benchmark-action/github-action-benchmark@v1 + with: + tool: 'go' + output-file-path: benchmark.txt + github-token: ${{ secrets.GITHUB_TOKEN }} + auto-push: true \ No newline at end of file diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..5b05601 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,283 @@ +# Architecture Documentation + +This document describes the architecture and design decisions of the MCP DigitalOcean Integration project. + +## Overview + +The MCP DigitalOcean Integration is built using a modular, component-based architecture that emphasizes: + +- **Separation of Concerns**: Each component has a single responsibility +- **Dependency Injection**: Components are loosely coupled and easily testable +- **Performance**: Built-in caching and rate limiting for optimal API usage +- **Observability**: Comprehensive metrics and health monitoring +- **Reliability**: Structured error handling and retry mechanisms + +## Architecture Diagram + +``` +┌─────────────────────────────────────────────────────────────┐ +│ MCP Server │ +├─────────────────────────────────────────────────────────────┤ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │ +│ │ Config │ │ Health │ │ Registry │ │ +│ │ Management │ │ Checker │ │ (Services) │ │ +│ └─────────────┘ └─────────────┘ └─────────────────────┘ │ +├─────────────────────────────────────────────────────────────┤ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │ +│ │ Metrics │ │ Cache │ │ Rate Limiter │ │ +│ │ Collection │ │ Layer │ │ │ │ +│ └─────────────┘ └─────────────┘ └─────────────────────┘ │ +├─────────────────────────────────────────────────────────────┤ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │ +│ │ Error │ │ Test Utils │ │ Structured │ │ +│ │ Handling │ │ │ │ Logging │ │ +│ └─────────────┘ └─────────────┘ └─────────────────────┘ │ +├─────────────────────────────────────────────────────────────┤ +│ DigitalOcean API Client │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Core Components + +### 1. Configuration Management (`internal/config`) + +**Purpose**: Centralized configuration management with environment variable support. + +**Key Features**: +- Environment variable parsing with defaults +- Configuration validation +- Type-safe configuration access +- Support for duration, boolean, and numeric types + +**Usage**: +```go +cfg, err := config.LoadConfig() +if err != nil { + log.Fatal(err) +} +``` + +### 2. Caching Layer (`internal/cache`) + +**Purpose**: In-memory caching with TTL support to reduce API calls and improve performance. + +**Key Features**: +- Thread-safe operations +- Automatic expiration and cleanup +- Configurable TTL +- Cache hit/miss metrics +- Function wrapping for easy integration + +**Usage**: +```go +cache := cache.New(5*time.Minute, true) +cachedFn := cache.WithCache("key", originalFunction) +``` + +### 3. Rate Limiting (`internal/ratelimit`) + +**Purpose**: Token bucket rate limiter to prevent API abuse and respect rate limits. + +**Key Features**: +- Token bucket algorithm +- Configurable requests per second +- Context-aware waiting +- Middleware support +- Statistics reporting + +**Usage**: +```go +limiter := ratelimit.New(100, true) // 100 RPS +if !limiter.Allow() { + return errors.New("rate limit exceeded") +} +``` + +### 4. Metrics Collection (`internal/metrics`) + +**Purpose**: Comprehensive metrics collection for monitoring and observability. + +**Key Features**: +- Request/response metrics +- Timing measurements +- Error categorization +- Service usage tracking +- Cache performance metrics + +**Metrics Collected**: +- Total requests +- Success/failure rates +- Response times (min, max, average) +- Cache hit/miss ratios +- Service-specific usage +- Error types and frequencies + +### 5. Health Monitoring (`internal/health`) + +**Purpose**: Health checks for all system components. + +**Health Checks**: +- DigitalOcean API connectivity +- Cache functionality +- Rate limiter status +- Metrics collection status + +**Status Levels**: +- `healthy`: All systems operational +- `degraded`: Some issues but functional +- `unhealthy`: Critical issues detected + +### 6. Error Handling (`internal/errors`) + +**Purpose**: Structured error handling with categorization and retry logic. + +**Error Types**: +- `validation`: Input validation errors +- `authentication`: API authentication issues +- `authorization`: Permission errors +- `not_found`: Resource not found +- `rate_limit`: Rate limiting errors +- `internal`: Internal server errors +- `network`: Network connectivity issues +- `timeout`: Request timeout errors + +### 7. Service Registry (`internal/registry.go`) + +**Purpose**: Dynamic service registration with component injection. + +**Key Features**: +- Modular service registration +- Component dependency injection +- Backward compatibility +- Metrics integration +- Error handling + +## Data Flow + +1. **Request Initiation**: Client sends request to MCP server +2. **Rate Limiting**: Request passes through rate limiter +3. **Cache Check**: System checks cache for existing response +4. **API Call**: If cache miss, makes API call to DigitalOcean +5. **Response Processing**: Processes and caches response +6. **Metrics Recording**: Records metrics for monitoring +7. **Response Return**: Returns response to client + +## Performance Optimizations + +### Caching Strategy +- **TTL-based expiration**: Configurable cache lifetime +- **Automatic cleanup**: Background goroutine removes expired entries +- **Memory efficient**: Only caches successful responses +- **Thread-safe**: Concurrent access support + +### Rate Limiting Strategy +- **Token bucket algorithm**: Smooth rate limiting +- **Configurable rates**: Adjustable requests per second +- **Burst handling**: Allows temporary bursts within limits +- **Context awareness**: Respects request cancellation + +### Connection Management +- **HTTP client reuse**: Single client instance with connection pooling +- **Retry logic**: Exponential backoff for failed requests +- **Timeout handling**: Configurable request timeouts +- **Keep-alive**: Persistent connections for better performance + +## Security Considerations + +### API Token Handling +- Environment variable storage +- No token logging or exposure +- Secure token transmission + +### Container Security +- Distroless base image +- Non-root user execution +- Minimal attack surface +- Static binary compilation + +### Input Validation +- Structured error responses +- Input sanitization +- Type-safe configuration +- Request validation + +## Testing Strategy + +### Unit Tests +- Component isolation +- Mock dependencies +- Edge case coverage +- Performance benchmarks + +### Integration Tests +- End-to-end workflows +- API connectivity tests +- Error scenario testing +- Performance validation + +### Test Utilities +- Common test helpers +- Mock time utilities +- Configuration builders +- Assertion helpers + +## Monitoring and Observability + +### Metrics +- Request rates and latencies +- Error rates by type +- Cache performance +- Service usage patterns + +### Logging +- Structured JSON logging +- Configurable log levels +- Contextual information +- Performance data + +### Health Checks +- Component status monitoring +- API connectivity verification +- Performance threshold alerts +- Automated recovery + +## Deployment Considerations + +### Environment Variables +- Comprehensive configuration options +- Secure defaults +- Documentation for all options +- Validation and error reporting + +### Container Deployment +- Multi-stage builds for optimization +- Security-focused base images +- Health check endpoints +- Resource limit awareness + +### Scaling Considerations +- Stateless design +- Horizontal scaling support +- Load balancer compatibility +- Resource efficiency + +## Future Enhancements + +### Planned Features +- Distributed caching support +- Advanced metrics dashboards +- Circuit breaker pattern +- Request tracing +- Configuration hot-reloading + +### Performance Improvements +- Response compression +- Connection pooling optimization +- Memory usage optimization +- CPU profiling integration + +### Security Enhancements +- mTLS support +- Request signing +- Audit logging +- Security scanning integration diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..4f34810 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,52 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.0.12] - 2024-10-12 + +### Added +- **Configuration Management**: Comprehensive configuration system with environment variable support + - Configurable request timeouts, retry settings, and API endpoints + - Cache configuration with TTL settings + - Rate limiting configuration +- **Performance Enhancements**: + - In-memory caching layer with TTL support and automatic cleanup + - Rate limiting with token bucket algorithm + - Request/response metrics collection +- **Observability**: + - Structured metrics collection (request counts, response times, error rates) + - Health check system with API connectivity, cache, and rate limiter status + - Enhanced logging with structured JSON output +- **Error Handling**: + - Structured error types with categorization (validation, authentication, network, etc.) + - Improved error messages with context and retry information + - Replaced panic calls in generation code with proper error handling +- **Code Quality**: + - Enhanced component architecture with dependency injection + - Backward compatibility maintained for existing integrations + - Improved code organization and separation of concerns + +### Changed +- **Main Application**: Refactored to use new configuration and component system +- **Registry System**: Enhanced to support new components while maintaining backward compatibility +- **Version**: Bumped to 1.0.12 to reflect significant improvements + +### Fixed +- **Error Handling**: Replaced panic calls in schema generation with proper error handling +- **Resource Management**: Improved cleanup and resource management +- **Logging**: More consistent and structured logging throughout the application + +### Technical Improvements +- Added comprehensive test coverage for new components +- Improved documentation and inline code comments +- Enhanced security with better token handling +- Performance optimizations with caching and rate limiting +- Better separation of concerns with modular architecture + +## [1.0.11] - Previous Release +- Base functionality with DigitalOcean API integration +- Support for multiple services (apps, droplets, networking, etc.) +- Basic MCP server implementation diff --git a/Dockerfile b/Dockerfile index eb2b306..7862ec9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,13 +1,14 @@ +# Multi-stage build for optimized production image FROM golang:1.25.1-alpine AS builder WORKDIR /src -# Install git for version info -RUN apk add --no-cache git +# Install build dependencies +RUN apk add --no-cache git ca-certificates tzdata # Copy go mod and sum files COPY go.mod go.sum ./ -RUN go mod download +RUN go mod download && go mod verify # Copy the rest of the source COPY . . @@ -17,16 +18,38 @@ ARG VERSION=unknown ARG COMMIT=unknown ARG DATE=unknown -# Build the binary with version info -RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "-s -w -X 'main.version=${VERSION}' -X 'main.commit=${COMMIT}' -X 'main.date=${DATE}'" -o /app/mcp-digitalocean ./cmd/mcp-digitalocean +# Build the binary with version info and optimizations +RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \ + -ldflags "-s -w -X 'main.version=${VERSION}' -X 'main.commit=${COMMIT}' -X 'main.date=${DATE}' -extldflags '-static'" \ + -a -installsuffix cgo \ + -o /app/mcp-digitalocean \ + ./cmd/mcp-digitalocean -FROM debian:12-slim +# Production stage - using distroless for security +FROM gcr.io/distroless/static:nonroot -WORKDIR /app +# Copy CA certificates for HTTPS requests +COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ -COPY --from=builder /app/mcp-digitalocean ./mcp-digitalocean +# Copy timezone data +COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo -# Expose default port -EXPOSE 8080 +# Copy the binary +COPY --from=builder /app/mcp-digitalocean /mcp-digitalocean -ENTRYPOINT ["/app/mcp-digitalocean"] \ No newline at end of file +# Set environment variables with secure defaults +ENV LOG_LEVEL=info +ENV CACHE_ENABLED=true +ENV RATE_LIMIT_ENABLED=true +ENV REQUEST_TIMEOUT=30s +ENV MAX_RETRIES=4 + +# Use non-root user (distroless nonroot user) +USER nonroot:nonroot + +# Health check (commented out as distroless doesn't have shell) +# HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ +# CMD ["/mcp-digitalocean", "--help"] + +# Default entrypoint +ENTRYPOINT ["/mcp-digitalocean"] \ No newline at end of file diff --git a/Makefile b/Makefile index 844b21a..2a0e9b5 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,11 @@ +# MCP DigitalOcean Makefile + +.PHONY: all build-dist build-dist-snapshot build-bin build-bin-snapshot dist lint test format format-check gen clean help install-tools benchmark coverage + +# Default target all: lint test build-dist + +# Build targets build-dist: build-bin dist build-dist-snapshot: build-bin-snapshot dist @@ -8,27 +15,136 @@ build-bin-snapshot: build-bin: goreleaser build --auto-snapshot --clean --skip validate -.PHONY: dist +# Distribution dist: mkdir -p ./scripts/npm/dist cp ./README.md ./scripts/npm/README.md + cp ./CHANGELOG.md ./scripts/npm/CHANGELOG.md cp ./dist/*/mcp-digitalocean* ./scripts/npm/dist/ cp ./internal/apps/spec/*.json ./scripts/npm/dist/ cp ./internal/doks/spec/*.json ./scripts/npm/dist/ npm install --prefix ./scripts/npm/ +# Code quality lint: revive -config revive.toml ./... + @echo "✅ Linting completed successfully" test: go test -v ./... + @echo "✅ Tests completed successfully" + +# Enhanced test targets +test-coverage: + go test -v -coverprofile=coverage.out ./... + go tool cover -html=coverage.out -o coverage.html + @echo "✅ Coverage report generated: coverage.html" + +test-race: + go test -v -race ./... + @echo "✅ Race condition tests completed" +benchmark: + go test -v -bench=. -benchmem ./... + @echo "✅ Benchmarks completed" + +# Code formatting format: gofmt -w . - @echo "Code formatted successfully." + @echo "✅ Code formatted successfully" format-check: - bash -c 'diff -u <(echo -n) <(gofmt -d ./)' + @bash -c 'diff -u <(echo -n) <(gofmt -d ./)' + @echo "✅ Code formatting check passed" +# Code generation gen: go generate ./... + @echo "✅ Code generation completed" + +# Development tools +install-tools: + go install github.com/mgechev/revive@latest + go install honnef.co/go/tools/cmd/staticcheck@latest + go install github.com/goreleaser/goreleaser@latest + @echo "✅ Development tools installed" + +# Security scanning +security-scan: + @command -v gosec >/dev/null 2>&1 || { echo "Installing gosec..."; go install github.com/securecodewarrior/gosec/v2/cmd/gosec@latest; } + gosec ./... + @echo "✅ Security scan completed" + +# Dependency management +deps-update: + go get -u ./... + go mod tidy + @echo "✅ Dependencies updated" + +deps-verify: + go mod verify + go mod tidy + @echo "✅ Dependencies verified" + +# Clean up +clean: + rm -rf dist/ + rm -rf ./scripts/npm/dist/ + rm -f coverage.out coverage.html + go clean -cache -testcache -modcache + @echo "✅ Cleanup completed" + +# Docker targets +docker-build: + docker build -t mcp-digitalocean:latest . + @echo "✅ Docker image built" + +docker-test: + docker run --rm -e DIGITALOCEAN_API_TOKEN=test-token mcp-digitalocean:latest --services apps --log-level debug + @echo "✅ Docker test completed" + +# CI/CD helpers +ci-setup: install-tools deps-verify + @echo "✅ CI environment setup completed" + +ci-test: format-check lint test-race test-coverage security-scan + @echo "✅ CI tests completed" + +# Development workflow +dev-setup: install-tools deps-verify gen + @echo "✅ Development environment setup completed" + +dev-test: format lint test + @echo "✅ Development tests completed" + +# Release preparation +pre-release: clean gen format lint test-coverage security-scan build-dist + @echo "✅ Pre-release checks completed" + +# Help target +help: + @echo "Available targets:" + @echo " all - Run lint, test, and build-dist" + @echo " build-dist - Build distribution packages" + @echo " build-bin - Build binary only" + @echo " test - Run all tests" + @echo " test-coverage - Run tests with coverage report" + @echo " test-race - Run tests with race detection" + @echo " benchmark - Run benchmarks" + @echo " lint - Run linting" + @echo " format - Format code" + @echo " format-check - Check code formatting" + @echo " gen - Generate code" + @echo " security-scan - Run security analysis" + @echo " deps-update - Update dependencies" + @echo " deps-verify - Verify dependencies" + @echo " clean - Clean build artifacts" + @echo " docker-build - Build Docker image" + @echo " docker-test - Test Docker image" + @echo " install-tools - Install development tools" + @echo " dev-setup - Setup development environment" + @echo " dev-test - Run development tests" + @echo " ci-setup - Setup CI environment" + @echo " ci-test - Run CI tests" + @echo " pre-release - Run pre-release checks" + @echo " help - Show this help message" diff --git a/README.md b/README.md index 10a78dc..49ead4d 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,24 @@ MCP DigitalOcean Integration is an open-source project that provides a comprehen > **DISCLAIMER:** "Use of MCP technology to interact with your DigitalOcean account [can come with risks](https://www.wiz.io/blog/mcp-security-research-briefing)" +## 🚀 What's New in v1.0.12 + +### Performance & Reliability +- **In-memory Caching**: Configurable caching layer with TTL support for improved response times +- **Rate Limiting**: Token bucket rate limiter to prevent API abuse and respect rate limits +- **Enhanced Configuration**: Comprehensive configuration management with environment variables +- **Health Checks**: Built-in health monitoring for API connectivity and component status + +### Observability +- **Metrics Collection**: Request/response metrics, success rates, and performance monitoring +- **Structured Logging**: Enhanced JSON logging with contextual information +- **Error Handling**: Improved error categorization and retry logic + +### Code Quality +- **Better Architecture**: Modular design with dependency injection +- **Comprehensive Testing**: Enhanced test coverage for new components +- **Documentation**: Improved inline documentation and examples + ## Prerequisites - Node.js (v18 or later) @@ -11,7 +29,6 @@ MCP DigitalOcean Integration is an open-source project that provides a comprehen You can find installation guides at [https://nodejs.org/en/download](https://nodejs.org/en/download) - Verify your installation: ```bash node --version @@ -25,6 +42,45 @@ To verify the MCP server works correctly, you can test it directly from the comm npx @digitalocean/mcp --services apps ``` +## Configuration + +### Environment Variables + +The MCP server supports extensive configuration through environment variables: + +#### Required +- `DIGITALOCEAN_API_TOKEN`: Your DigitalOcean API token + +#### Optional +- `SERVICES`: Comma-separated list of services to enable (default: all) +- `LOG_LEVEL`: Logging level (debug, info, warn, error) (default: info) +- `DIGITALOCEAN_API_ENDPOINT`: API endpoint URL (default: https://api.digitalocean.com) + +#### Performance & Reliability +- `REQUEST_TIMEOUT`: API request timeout (default: 30s) +- `MAX_RETRIES`: Maximum retry attempts (default: 4) +- `RETRY_WAIT_MIN`: Minimum retry wait time (default: 1s) +- `RETRY_WAIT_MAX`: Maximum retry wait time (default: 30s) + +#### Caching +- `CACHE_ENABLED`: Enable/disable caching (default: true) +- `CACHE_TTL`: Cache time-to-live (default: 5m) + +#### Rate Limiting +- `RATE_LIMIT_ENABLED`: Enable/disable rate limiting (default: true) +- `RATE_LIMIT_RPS`: Requests per second limit (default: 100) + +### Example Configuration + +```bash +export DIGITALOCEAN_API_TOKEN="your_token_here" +export SERVICES="apps,droplets,networking" +export LOG_LEVEL="debug" +export CACHE_ENABLED="true" +export CACHE_TTL="10m" +export RATE_LIMIT_RPS="50" +``` + ## Installation ### Claude Code diff --git a/cmd/mcp-digitalocean/main.go b/cmd/mcp-digitalocean/main.go index e0618ad..c47d707 100644 --- a/cmd/mcp-digitalocean/main.go +++ b/cmd/mcp-digitalocean/main.go @@ -10,6 +10,11 @@ import ( "strings" registry "mcp-digitalocean/internal" + "mcp-digitalocean/internal/cache" + "mcp-digitalocean/internal/config" + "mcp-digitalocean/internal/health" + "mcp-digitalocean/internal/metrics" + "mcp-digitalocean/internal/ratelimit" "github.com/digitalocean/godo" "github.com/mark3labs/mcp-go/server" @@ -18,20 +23,47 @@ import ( const ( mcpName = "mcp-digitalocean" - mcpVersion = "1.0.11" - - defaultEndpoint = "https://api.digitalocean.com" + mcpVersion = "1.0.12" ) func main() { - logLevelFlag := flag.String("log-level", os.Getenv("LOG_LEVEL"), "Log level: debug, info, warn, error") - serviceFlag := flag.String("services", os.Getenv("SERVICES"), "Comma-separated list of services to activate (e.g., apps,networking,droplets)") - tokenFlag := flag.String("digitalocean-api-token", os.Getenv("DIGITALOCEAN_API_TOKEN"), "DigitalOcean API token") - endpointFlag := flag.String("digitalocean-api-endpoint", os.Getenv("DIGITALOCEAN_API_ENDPOINT"), "DigitalOcean API endpoint") + // Parse command line flags for backward compatibility + logLevelFlag := flag.String("log-level", "", "Log level: debug, info, warn, error") + serviceFlag := flag.String("services", "", "Comma-separated list of services to activate") + tokenFlag := flag.String("digitalocean-api-token", "", "DigitalOcean API token") + endpointFlag := flag.String("digitalocean-api-endpoint", "", "DigitalOcean API endpoint") flag.Parse() + // Override environment variables with command line flags if provided + if *logLevelFlag != "" { + os.Setenv("LOG_LEVEL", *logLevelFlag) + } + if *serviceFlag != "" { + os.Setenv("SERVICES", *serviceFlag) + } + if *tokenFlag != "" { + os.Setenv("DIGITALOCEAN_API_TOKEN", *tokenFlag) + } + if *endpointFlag != "" { + os.Setenv("DIGITALOCEAN_API_ENDPOINT", *endpointFlag) + } + + // Load configuration + cfg, err := config.LoadConfig() + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err) + os.Exit(1) + } + + // Validate configuration + if err := cfg.Validate(); err != nil { + fmt.Fprintf(os.Stderr, "Invalid configuration: %v\n", err) + os.Exit(1) + } + + // Setup logger var level slog.Level - switch strings.ToLower(*logLevelFlag) { + switch strings.ToLower(cfg.LogLevel) { case "debug": level = slog.LevelDebug case "info": @@ -45,63 +77,97 @@ func main() { } logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: level})) - token := *tokenFlag - if token == "" { - logger.Error("DigitalOcean API token not provided. Use --digitalocean-api-token flag or set DIGITALOCEAN_API_TOKEN environment variable") - os.Exit(1) - } - endpoint := *endpointFlag - if endpoint == "" { - endpoint = defaultEndpoint - } + // Initialize components + metricsCollector := metrics.New() + cacheInstance := cache.New(cfg.CacheTTL, cfg.CacheEnabled) + rateLimiter := ratelimit.New(cfg.RateLimitRPS, cfg.RateLimitEnabled) - var services []string - if *serviceFlag != "" { - services = strings.Split(*serviceFlag, ",") + // Create DigitalOcean client with enhanced configuration + client, err := newGodoClientWithConfig(context.Background(), cfg) + if err != nil { + logger.Error("Failed to create DigitalOcean client", "error", err) + os.Exit(1) } - client, err := newGodoClientWithTokenAndEndpoint(context.Background(), token, endpoint) + // Initialize health checker + healthChecker := health.New(client, cacheInstance, metricsCollector, rateLimiter) + + // Perform initial health check + ctx := context.Background() + healthReport, err := healthChecker.CheckHealth(ctx) if err != nil { - logger.Error("Failed to create DigitalOcean client: " + err.Error()) - os.Exit(1) + logger.Warn("Initial health check failed", "error", err) + } else { + logger.Info("Initial health check completed", + "status", healthReport.Status, + "healthy_checks", countHealthyChecks(healthReport.Checks)) } + // Create MCP server s := server.NewMCPServer(mcpName, mcpVersion) - err = registry.Register(logger, s, client, services...) + + // Register tools with enhanced registry + err = registry.RegisterWithComponents(logger, s, client, metricsCollector, cacheInstance, rateLimiter, cfg.Services...) if err != nil { - logger.Error("Failed to register tools: " + err.Error()) + logger.Error("Failed to register tools", "error", err) os.Exit(1) } - logger.Debug("starting MCP server", "name", mcpName, "version", mcpVersion) + logger.Info("Starting MCP server", + "name", mcpName, + "version", mcpVersion, + "services", cfg.Services, + "cache_enabled", cfg.CacheEnabled, + "rate_limit_enabled", cfg.RateLimitEnabled) + + // Start server err = server.ServeStdio(s) if err != nil { - // if context cancelled or sigterm then shutdown gracefully if errors.Is(err, context.Canceled) { logger.Info("Server shutdown gracefully") + + // Log final metrics + finalMetrics := metricsCollector.GetSnapshot() + logger.Info("Final metrics", + "total_requests", finalMetrics.TotalRequests, + "success_rate", fmt.Sprintf("%.2f%%", finalMetrics.SuccessRate()), + "cache_hit_rate", fmt.Sprintf("%.2f%%", finalMetrics.CacheHitRate()), + "uptime", finalMetrics.Uptime) + os.Exit(0) } else { - logger.Error("Failed to serve MCP server: " + err.Error()) + logger.Error("Failed to serve MCP server", "error", err) os.Exit(1) } } } -// newGodoClientWithTokenAndEndpoint initializes a new godo client with a custom user agent and endpoint. -func newGodoClientWithTokenAndEndpoint(ctx context.Context, token string, endpoint string) (*godo.Client, error) { - cleanToken := strings.Trim(strings.TrimSpace(token), "'") +// newGodoClientWithConfig initializes a new godo client with enhanced configuration. +func newGodoClientWithConfig(ctx context.Context, cfg *config.Config) (*godo.Client, error) { + cleanToken := strings.Trim(strings.TrimSpace(cfg.APIToken), "'") ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: cleanToken}) oauthClient := oauth2.NewClient(ctx, ts) retry := godo.RetryConfig{ - RetryMax: 4, - RetryWaitMin: godo.PtrTo(float64(1)), - RetryWaitMax: godo.PtrTo(float64(30)), + RetryMax: cfg.MaxRetries, + RetryWaitMin: godo.PtrTo(cfg.RetryWaitMin.Seconds()), + RetryWaitMax: godo.PtrTo(cfg.RetryWaitMax.Seconds()), } return godo.New(oauthClient, godo.WithRetryAndBackoffs(retry), - godo.SetBaseURL(endpoint), + godo.SetBaseURL(cfg.APIEndpoint), godo.SetUserAgent(fmt.Sprintf("%s/%s", mcpName, mcpVersion))) } + +// countHealthyChecks counts the number of healthy checks in a health report +func countHealthyChecks(checks []health.HealthCheck) int { + count := 0 + for _, check := range checks { + if check.Status == health.StatusHealthy { + count++ + } + } + return count +} diff --git a/internal/apps/spec/generate.go b/internal/apps/spec/generate.go index f83541d..ee30e62 100644 --- a/internal/apps/spec/generate.go +++ b/internal/apps/spec/generate.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "log" "os" "mcp-digitalocean/internal/apps" @@ -18,6 +19,12 @@ import ( // This is necessary since we need to pass the AppSpec to the mcp tool as a raw argument. // Ideally, we shouldn't have to copy the godo files around. However, it's currently not possible to without preserving the struct comments. func main() { + if err := generateSchemas(); err != nil { + log.Fatalf("Failed to generate schemas: %v", err) + } +} + +func generateSchemas() error { reflect := jsonschema.Reflector{ BaseSchemaID: "", Anonymous: true, @@ -29,46 +36,57 @@ func main() { FieldNameTag: "", } - err := reflect.AddGoComments("github.com/digitalocean/godo", "./") - if err != nil { - panic(fmt.Errorf("failed to add Go comments: %w", err)) + if err := reflect.AddGoComments("github.com/digitalocean/godo", "./"); err != nil { + return fmt.Errorf("failed to add Go comments: %w", err) } + // Generate app create schema + if err := generateAppCreateSchema(reflect); err != nil { + return fmt.Errorf("failed to generate app create schema: %w", err) + } + + // Generate app update schema + if err := generateAppUpdateSchema(reflect); err != nil { + return fmt.Errorf("failed to generate app update schema: %w", err) + } + + return nil +} + +func generateAppCreateSchema(reflect jsonschema.Reflector) error { createSchema, err := reflect.Reflect(&godo.AppCreateRequest{}).MarshalJSON() if err != nil { - panic(fmt.Errorf("failed to marshal app create schema: %w", err)) + return fmt.Errorf("failed to marshal app create schema: %w", err) } var createSchemaJSON bytes.Buffer if err := json.Indent(&createSchemaJSON, createSchema, "", " "); err != nil { - panic(fmt.Errorf("failed to indent JSON: %w", err)) + return fmt.Errorf("failed to indent JSON: %w", err) } - // now write the schema to a file - err = os.WriteFile("./app-create-schema.json", createSchemaJSON.Bytes(), 0644) - if err != nil { - panic(fmt.Errorf("failed to write schema to file: %w", err)) + if err := os.WriteFile("./app-create-schema.json", createSchemaJSON.Bytes(), 0644); err != nil { + return fmt.Errorf("failed to write schema to file: %w", err) } - fmt.Println("Schema successfully written to app_create_schema.json") + fmt.Println("Schema successfully written to app-create-schema.json") + return nil +} - // Generate schema for AppUpdateRequest +func generateAppUpdateSchema(reflect jsonschema.Reflector) error { updateSchema, err := reflect.Reflect(&apps.AppUpdate{}).MarshalJSON() if err != nil { - panic(fmt.Errorf("failed to marshal app update schema: %w", err)) + return fmt.Errorf("failed to marshal app update schema: %w", err) } - // Prettify the JSON var updateSchemaJSON bytes.Buffer if err := json.Indent(&updateSchemaJSON, updateSchema, "", " "); err != nil { - panic(fmt.Errorf("failed to indent JSON: %w", err)) + return fmt.Errorf("failed to indent JSON: %w", err) } - // Write the schema to a file - err = os.WriteFile("./app-update-schema.json", updateSchemaJSON.Bytes(), 0644) - if err != nil { - panic(fmt.Errorf("failed to write schema to file: %w", err)) + if err := os.WriteFile("./app-update-schema.json", updateSchemaJSON.Bytes(), 0644); err != nil { + return fmt.Errorf("failed to write schema to file: %w", err) } - fmt.Println("Update schema successfully written to app_update_schema.json") + fmt.Println("Update schema successfully written to app-update-schema.json") + return nil } diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..c5969de --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,150 @@ +package cache + +import ( + "context" + "sync" + "time" +) + +// Cache represents a simple in-memory cache with TTL support +type Cache struct { + mu sync.RWMutex + items map[string]*item + ttl time.Duration + enabled bool +} + +type item struct { + value interface{} + expiresAt time.Time +} + +// New creates a new cache instance +func New(ttl time.Duration, enabled bool) *Cache { + c := &Cache{ + items: make(map[string]*item), + ttl: ttl, + enabled: enabled, + } + + if enabled { + // Start cleanup goroutine + go c.cleanup() + } + + return c +} + +// Get retrieves a value from the cache +func (c *Cache) Get(key string) (interface{}, bool) { + if !c.enabled { + return nil, false + } + + c.mu.RLock() + defer c.mu.RUnlock() + + item, exists := c.items[key] + if !exists { + return nil, false + } + + if time.Now().After(item.expiresAt) { + // Item has expired, remove it + delete(c.items, key) + return nil, false + } + + return item.value, true +} + +// Set stores a value in the cache +func (c *Cache) Set(key string, value interface{}) { + if !c.enabled { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.items[key] = &item{ + value: value, + expiresAt: time.Now().Add(c.ttl), + } +} + +// Delete removes a value from the cache +func (c *Cache) Delete(key string) { + if !c.enabled { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.items, key) +} + +// Clear removes all items from the cache +func (c *Cache) Clear() { + if !c.enabled { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.items = make(map[string]*item) +} + +// Size returns the number of items in the cache +func (c *Cache) Size() int { + if !c.enabled { + return 0 + } + + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.items) +} + +// cleanup removes expired items from the cache +func (c *Cache) cleanup() { + ticker := time.NewTicker(c.ttl / 2) // Clean up twice per TTL period + defer ticker.Stop() + + for range ticker.C { + c.mu.Lock() + now := time.Now() + for key, item := range c.items { + if now.After(item.expiresAt) { + delete(c.items, key) + } + } + c.mu.Unlock() + } +} + +// CacheableFunc represents a function that can be cached +type CacheableFunc func(ctx context.Context, args ...interface{}) (interface{}, error) + +// WithCache wraps a function with caching capability +func (c *Cache) WithCache(key string, fn CacheableFunc) CacheableFunc { + return func(ctx context.Context, args ...interface{}) (interface{}, error) { + // Try to get from cache first + if cached, found := c.Get(key); found { + return cached, nil + } + + // Execute the function + result, err := fn(ctx, args...) + if err != nil { + return nil, err + } + + // Cache the result + c.Set(key, result) + return result, nil + } +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..56043fb --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,165 @@ +package cache + +import ( + "context" + "fmt" + "testing" + "time" +) + +func TestCache_SetAndGet(t *testing.T) { + cache := New(1*time.Second, true) + + // Test setting and getting a value + cache.Set("key1", "value1") + + value, found := cache.Get("key1") + if !found { + t.Error("Expected to find key1") + } + + if value != "value1" { + t.Errorf("Expected value1, got %v", value) + } +} + +func TestCache_Expiration(t *testing.T) { + cache := New(100*time.Millisecond, true) + + cache.Set("key1", "value1") + + // Should be available immediately + _, found := cache.Get("key1") + if !found { + t.Error("Expected to find key1 immediately") + } + + // Wait for expiration + time.Sleep(150 * time.Millisecond) + + _, found = cache.Get("key1") + if found { + t.Error("Expected key1 to be expired") + } +} + +func TestCache_Delete(t *testing.T) { + cache := New(1*time.Second, true) + + cache.Set("key1", "value1") + cache.Delete("key1") + + _, found := cache.Get("key1") + if found { + t.Error("Expected key1 to be deleted") + } +} + +func TestCache_Clear(t *testing.T) { + cache := New(1*time.Second, true) + + cache.Set("key1", "value1") + cache.Set("key2", "value2") + + if cache.Size() != 2 { + t.Errorf("Expected size 2, got %d", cache.Size()) + } + + cache.Clear() + + if cache.Size() != 0 { + t.Errorf("Expected size 0 after clear, got %d", cache.Size()) + } +} + +func TestCache_Disabled(t *testing.T) { + cache := New(1*time.Second, false) + + cache.Set("key1", "value1") + + _, found := cache.Get("key1") + if found { + t.Error("Expected cache to be disabled") + } + + if cache.Size() != 0 { + t.Errorf("Expected size 0 for disabled cache, got %d", cache.Size()) + } +} + +func TestCache_WithCache(t *testing.T) { + cache := New(1*time.Second, true) + callCount := 0 + + fn := cache.WithCache("test_key", func(ctx context.Context, args ...interface{}) (interface{}, error) { + callCount++ + return "result", nil + }) + + // First call should execute the function + result1, err := fn(context.Background()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if result1 != "result" { + t.Errorf("Expected 'result', got %v", result1) + } + + if callCount != 1 { + t.Errorf("Expected function to be called once, got %d", callCount) + } + + // Second call should use cache + result2, err := fn(context.Background()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if result2 != "result" { + t.Errorf("Expected 'result', got %v", result2) + } + + if callCount != 1 { + t.Errorf("Expected function to still be called once (cached), got %d", callCount) + } +} + +func TestCache_ConcurrentAccess(t *testing.T) { + cache := New(1*time.Second, true) + + // Test concurrent writes and reads + done := make(chan bool, 10) + + // Start multiple goroutines writing to cache + for i := 0; i < 5; i++ { + go func(id int) { + for j := 0; j < 100; j++ { + cache.Set(fmt.Sprintf("key_%d_%d", id, j), fmt.Sprintf("value_%d_%d", id, j)) + } + done <- true + }(i) + } + + // Start multiple goroutines reading from cache + for i := 0; i < 5; i++ { + go func(id int) { + for j := 0; j < 100; j++ { + cache.Get(fmt.Sprintf("key_%d_%d", id, j)) + } + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // Cache should still be functional + cache.Set("test", "value") + value, found := cache.Get("test") + if !found || value != "value" { + t.Error("Cache should still be functional after concurrent access") + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..0541d90 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,123 @@ +package config + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" +) + +// Config holds the configuration for the MCP DigitalOcean server +type Config struct { + // API Configuration + APIToken string + APIEndpoint string + + // Server Configuration + LogLevel string + Services []string + + // Performance Configuration + RequestTimeout time.Duration + MaxRetries int + RetryWaitMin time.Duration + RetryWaitMax time.Duration + + // Cache Configuration + CacheEnabled bool + CacheTTL time.Duration + + // Rate Limiting + RateLimitEnabled bool + RateLimitRPS int +} + +// LoadConfig loads configuration from environment variables and flags +func LoadConfig() (*Config, error) { + cfg := &Config{ + // Defaults + APIEndpoint: getEnvOrDefault("DIGITALOCEAN_API_ENDPOINT", "https://api.digitalocean.com"), + LogLevel: getEnvOrDefault("LOG_LEVEL", "info"), + RequestTimeout: getDurationEnvOrDefault("REQUEST_TIMEOUT", 30*time.Second), + MaxRetries: getIntEnvOrDefault("MAX_RETRIES", 4), + RetryWaitMin: getDurationEnvOrDefault("RETRY_WAIT_MIN", 1*time.Second), + RetryWaitMax: getDurationEnvOrDefault("RETRY_WAIT_MAX", 30*time.Second), + CacheEnabled: getBoolEnvOrDefault("CACHE_ENABLED", true), + CacheTTL: getDurationEnvOrDefault("CACHE_TTL", 5*time.Minute), + RateLimitEnabled: getBoolEnvOrDefault("RATE_LIMIT_ENABLED", true), + RateLimitRPS: getIntEnvOrDefault("RATE_LIMIT_RPS", 100), + } + + // Required fields + cfg.APIToken = os.Getenv("DIGITALOCEAN_API_TOKEN") + if cfg.APIToken == "" { + return nil, fmt.Errorf("DIGITALOCEAN_API_TOKEN environment variable is required") + } + + // Parse services + if services := os.Getenv("SERVICES"); services != "" { + cfg.Services = strings.Split(services, ",") + for i, service := range cfg.Services { + cfg.Services[i] = strings.TrimSpace(service) + } + } + + return cfg, nil +} + +// Validate validates the configuration +func (c *Config) Validate() error { + if c.APIToken == "" { + return fmt.Errorf("API token is required") + } + + if c.RequestTimeout <= 0 { + return fmt.Errorf("request timeout must be positive") + } + + if c.MaxRetries < 0 { + return fmt.Errorf("max retries cannot be negative") + } + + if c.RateLimitRPS <= 0 && c.RateLimitEnabled { + return fmt.Errorf("rate limit RPS must be positive when rate limiting is enabled") + } + + return nil +} + +// Helper functions +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getIntEnvOrDefault(key string, defaultValue int) int { + if value := os.Getenv(key); value != "" { + if intValue, err := strconv.Atoi(value); err == nil { + return intValue + } + } + return defaultValue +} + +func getBoolEnvOrDefault(key string, defaultValue bool) bool { + if value := os.Getenv(key); value != "" { + if boolValue, err := strconv.ParseBool(value); err == nil { + return boolValue + } + } + return defaultValue +} + +func getDurationEnvOrDefault(key string, defaultValue time.Duration) time.Duration { + if value := os.Getenv(key); value != "" { + if duration, err := time.ParseDuration(value); err == nil { + return duration + } + } + return defaultValue +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..da3f5bc --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,154 @@ +package config + +import ( + "os" + "testing" + "time" +) + +func TestLoadConfig(t *testing.T) { + // Save original environment + originalToken := os.Getenv("DIGITALOCEAN_API_TOKEN") + originalServices := os.Getenv("SERVICES") + + // Clean up after test + defer func() { + os.Setenv("DIGITALOCEAN_API_TOKEN", originalToken) + os.Setenv("SERVICES", originalServices) + }() + + tests := []struct { + name string + envVars map[string]string + expectError bool + }{ + { + name: "valid configuration", + envVars: map[string]string{ + "DIGITALOCEAN_API_TOKEN": "test-token", + "SERVICES": "apps,droplets", + "LOG_LEVEL": "debug", + }, + expectError: false, + }, + { + name: "missing API token", + envVars: map[string]string{ + "SERVICES": "apps", + }, + expectError: true, + }, + { + name: "custom timeout", + envVars: map[string]string{ + "DIGITALOCEAN_API_TOKEN": "test-token", + "REQUEST_TIMEOUT": "60s", + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set environment variables + for key, value := range tt.envVars { + os.Setenv(key, value) + } + + cfg, err := LoadConfig() + + if tt.expectError { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Validate configuration + if err := cfg.Validate(); err != nil { + t.Fatalf("Configuration validation failed: %v", err) + } + + // Check specific values + if tt.envVars["DIGITALOCEAN_API_TOKEN"] != "" { + if cfg.APIToken != tt.envVars["DIGITALOCEAN_API_TOKEN"] { + t.Errorf("Expected API token %s, got %s", tt.envVars["DIGITALOCEAN_API_TOKEN"], cfg.APIToken) + } + } + + if tt.envVars["REQUEST_TIMEOUT"] == "60s" { + if cfg.RequestTimeout != 60*time.Second { + t.Errorf("Expected timeout 60s, got %v", cfg.RequestTimeout) + } + } + + // Clean up environment variables + for key := range tt.envVars { + os.Unsetenv(key) + } + }) + } +} + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config *Config + expectError bool + }{ + { + name: "valid config", + config: &Config{ + APIToken: "test-token", + RequestTimeout: 30 * time.Second, + MaxRetries: 3, + RateLimitRPS: 100, + RateLimitEnabled: true, + }, + expectError: false, + }, + { + name: "missing API token", + config: &Config{ + RequestTimeout: 30 * time.Second, + }, + expectError: true, + }, + { + name: "invalid timeout", + config: &Config{ + APIToken: "test-token", + RequestTimeout: -1 * time.Second, + }, + expectError: true, + }, + { + name: "invalid rate limit", + config: &Config{ + APIToken: "test-token", + RequestTimeout: 30 * time.Second, + RateLimitEnabled: true, + RateLimitRPS: -1, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + + if tt.expectError && err == nil { + t.Error("Expected error, got nil") + } + + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 0000000..f52aebb --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,157 @@ +package errors + +import ( + "fmt" + "net/http" +) + +// ErrorType represents different types of errors +type ErrorType string + +const ( + ErrorTypeValidation ErrorType = "validation" + ErrorTypeAuthentication ErrorType = "authentication" + ErrorTypeAuthorization ErrorType = "authorization" + ErrorTypeNotFound ErrorType = "not_found" + ErrorTypeRateLimit ErrorType = "rate_limit" + ErrorTypeInternal ErrorType = "internal" + ErrorTypeNetwork ErrorType = "network" + ErrorTypeTimeout ErrorType = "timeout" +) + +// MCPError represents a structured error for the MCP server +type MCPError struct { + Type ErrorType `json:"type"` + Message string `json:"message"` + Details string `json:"details,omitempty"` + StatusCode int `json:"status_code,omitempty"` + Retryable bool `json:"retryable"` +} + +// Error implements the error interface +func (e *MCPError) Error() string { + if e.Details != "" { + return fmt.Sprintf("%s: %s (%s)", e.Type, e.Message, e.Details) + } + return fmt.Sprintf("%s: %s", e.Type, e.Message) +} + +// NewValidationError creates a new validation error +func NewValidationError(message, details string) *MCPError { + return &MCPError{ + Type: ErrorTypeValidation, + Message: message, + Details: details, + StatusCode: http.StatusBadRequest, + Retryable: false, + } +} + +// NewAuthenticationError creates a new authentication error +func NewAuthenticationError(message, details string) *MCPError { + return &MCPError{ + Type: ErrorTypeAuthentication, + Message: message, + Details: details, + StatusCode: http.StatusUnauthorized, + Retryable: false, + } +} + +// NewAuthorizationError creates a new authorization error +func NewAuthorizationError(message, details string) *MCPError { + return &MCPError{ + Type: ErrorTypeAuthorization, + Message: message, + Details: details, + StatusCode: http.StatusForbidden, + Retryable: false, + } +} + +// NewNotFoundError creates a new not found error +func NewNotFoundError(resource, identifier string) *MCPError { + return &MCPError{ + Type: ErrorTypeNotFound, + Message: fmt.Sprintf("%s not found", resource), + Details: fmt.Sprintf("No %s found with identifier: %s", resource, identifier), + StatusCode: http.StatusNotFound, + Retryable: false, + } +} + +// NewRateLimitError creates a new rate limit error +func NewRateLimitError(message string) *MCPError { + return &MCPError{ + Type: ErrorTypeRateLimit, + Message: message, + StatusCode: http.StatusTooManyRequests, + Retryable: true, + } +} + +// NewInternalError creates a new internal error +func NewInternalError(message, details string) *MCPError { + return &MCPError{ + Type: ErrorTypeInternal, + Message: message, + Details: details, + StatusCode: http.StatusInternalServerError, + Retryable: true, + } +} + +// NewNetworkError creates a new network error +func NewNetworkError(message, details string) *MCPError { + return &MCPError{ + Type: ErrorTypeNetwork, + Message: message, + Details: details, + StatusCode: http.StatusServiceUnavailable, + Retryable: true, + } +} + +// NewTimeoutError creates a new timeout error +func NewTimeoutError(message string) *MCPError { + return &MCPError{ + Type: ErrorTypeTimeout, + Message: message, + StatusCode: http.StatusRequestTimeout, + Retryable: true, + } +} + +// WrapError wraps a generic error into an MCPError +func WrapError(err error, errorType ErrorType, message string) *MCPError { + if err == nil { + return nil + } + + if mcpErr, ok := err.(*MCPError); ok { + return mcpErr + } + + return &MCPError{ + Type: errorType, + Message: message, + Details: err.Error(), + Retryable: errorType == ErrorTypeNetwork || errorType == ErrorTypeTimeout || errorType == ErrorTypeInternal, + } +} + +// IsRetryable checks if an error is retryable +func IsRetryable(err error) bool { + if mcpErr, ok := err.(*MCPError); ok { + return mcpErr.Retryable + } + return false +} + +// GetErrorType returns the error type if it's an MCPError +func GetErrorType(err error) ErrorType { + if mcpErr, ok := err.(*MCPError); ok { + return mcpErr.Type + } + return ErrorTypeInternal +} diff --git a/internal/health/health.go b/internal/health/health.go new file mode 100644 index 0000000..23dbd35 --- /dev/null +++ b/internal/health/health.go @@ -0,0 +1,204 @@ +package health + +import ( + "context" + "time" + + "github.com/digitalocean/godo" + "mcp-digitalocean/internal/cache" + "mcp-digitalocean/internal/metrics" + "mcp-digitalocean/internal/ratelimit" +) + +// HealthStatus represents the health status of the service +type HealthStatus string + +const ( + StatusHealthy HealthStatus = "healthy" + StatusDegraded HealthStatus = "degraded" + StatusUnhealthy HealthStatus = "unhealthy" +) + +// HealthCheck represents a health check result +type HealthCheck struct { + Name string `json:"name"` + Status HealthStatus `json:"status"` + Message string `json:"message,omitempty"` + Duration time.Duration `json:"duration"` + Timestamp time.Time `json:"timestamp"` +} + +// HealthChecker performs health checks on various components +type HealthChecker struct { + client *godo.Client + cache *cache.Cache + metrics *metrics.Metrics + rateLimiter *ratelimit.RateLimiter +} + +// New creates a new health checker +func New(client *godo.Client, cache *cache.Cache, metrics *metrics.Metrics, rateLimiter *ratelimit.RateLimiter) *HealthChecker { + return &HealthChecker{ + client: client, + cache: cache, + metrics: metrics, + rateLimiter: rateLimiter, + } +} + +// CheckHealth performs all health checks and returns the overall status +func (h *HealthChecker) CheckHealth(ctx context.Context) (*HealthReport, error) { + checks := []HealthCheck{ + h.checkAPI(ctx), + h.checkCache(), + h.checkRateLimit(), + h.checkMetrics(), + } + + report := &HealthReport{ + Status: h.determineOverallStatus(checks), + Checks: checks, + Timestamp: time.Now(), + } + + return report, nil +} + +// checkAPI checks if the DigitalOcean API is accessible +func (h *HealthChecker) checkAPI(ctx context.Context) HealthCheck { + start := time.Now() + check := HealthCheck{ + Name: "digitalocean_api", + Timestamp: start, + } + + // Create a context with timeout for the API call + apiCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Try to get account information as a simple API test + _, _, err := h.client.Account.Get(apiCtx) + check.Duration = time.Since(start) + + if err != nil { + check.Status = StatusUnhealthy + check.Message = "Failed to connect to DigitalOcean API: " + err.Error() + } else { + check.Status = StatusHealthy + check.Message = "API connection successful" + } + + return check +} + +// checkCache checks the cache status +func (h *HealthChecker) checkCache() HealthCheck { + start := time.Now() + check := HealthCheck{ + Name: "cache", + Timestamp: start, + } + + // Test cache functionality + testKey := "health_check_test" + testValue := "test_value" + + h.cache.Set(testKey, testValue) + retrieved, found := h.cache.Get(testKey) + h.cache.Delete(testKey) + + check.Duration = time.Since(start) + + if !found || retrieved != testValue { + check.Status = StatusDegraded + check.Message = "Cache not functioning properly" + } else { + check.Status = StatusHealthy + check.Message = "Cache functioning normally" + } + + return check +} + +// checkRateLimit checks the rate limiter status +func (h *HealthChecker) checkRateLimit() HealthCheck { + start := time.Now() + check := HealthCheck{ + Name: "rate_limiter", + Timestamp: start, + } + + stats := h.rateLimiter.GetStats() + check.Duration = time.Since(start) + + if !stats.Enabled { + check.Status = StatusHealthy + check.Message = "Rate limiting disabled" + } else if stats.CurrentTokens > 0 { + check.Status = StatusHealthy + check.Message = "Rate limiter functioning normally" + } else { + check.Status = StatusDegraded + check.Message = "Rate limiter has no available tokens" + } + + return check +} + +// checkMetrics checks the metrics collection status +func (h *HealthChecker) checkMetrics() HealthCheck { + start := time.Now() + check := HealthCheck{ + Name: "metrics", + Timestamp: start, + } + + snapshot := h.metrics.GetSnapshot() + check.Duration = time.Since(start) + + // Check if metrics are being collected + if snapshot.TotalRequests >= 0 { + check.Status = StatusHealthy + check.Message = "Metrics collection active" + } else { + check.Status = StatusDegraded + check.Message = "Metrics collection may not be functioning" + } + + return check +} + +// determineOverallStatus determines the overall health status based on individual checks +func (h *HealthChecker) determineOverallStatus(checks []HealthCheck) HealthStatus { + hasUnhealthy := false + hasDegraded := false + + for _, check := range checks { + switch check.Status { + case StatusUnhealthy: + hasUnhealthy = true + case StatusDegraded: + hasDegraded = true + } + } + + if hasUnhealthy { + return StatusUnhealthy + } else if hasDegraded { + return StatusDegraded + } + + return StatusHealthy +} + +// HealthReport represents the overall health report +type HealthReport struct { + Status HealthStatus `json:"status"` + Checks []HealthCheck `json:"checks"` + Timestamp time.Time `json:"timestamp"` +} + +// IsHealthy returns true if the overall status is healthy +func (r *HealthReport) IsHealthy() bool { + return r.Status == StatusHealthy +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..8763d1d --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,204 @@ +package metrics + +import ( + "context" + "sync" + "time" +) + +// Metrics holds various performance and usage metrics +type Metrics struct { + mu sync.RWMutex + + // Request metrics + TotalRequests int64 + SuccessfulRequests int64 + FailedRequests int64 + + // Timing metrics + AverageResponseTime time.Duration + MaxResponseTime time.Duration + MinResponseTime time.Duration + + // Service-specific metrics + ServiceUsage map[string]int64 + + // Cache metrics + CacheHits int64 + CacheMisses int64 + + // Error metrics + ErrorsByType map[string]int64 + + // Rate limiting metrics + RateLimitedRequests int64 + + startTime time.Time +} + +// New creates a new metrics instance +func New() *Metrics { + return &Metrics{ + ServiceUsage: make(map[string]int64), + ErrorsByType: make(map[string]int64), + MinResponseTime: time.Duration(^uint64(0) >> 1), // Max duration + startTime: time.Now(), + } +} + +// RecordRequest records a request with its duration and success status +func (m *Metrics) RecordRequest(duration time.Duration, success bool, service string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.TotalRequests++ + + if success { + m.SuccessfulRequests++ + } else { + m.FailedRequests++ + } + + // Update timing metrics + if duration > m.MaxResponseTime { + m.MaxResponseTime = duration + } + + if duration < m.MinResponseTime { + m.MinResponseTime = duration + } + + // Calculate average response time + if m.TotalRequests > 0 { + totalTime := time.Duration(m.TotalRequests-1) * m.AverageResponseTime + duration + m.AverageResponseTime = totalTime / time.Duration(m.TotalRequests) + } + + // Record service usage + if service != "" { + m.ServiceUsage[service]++ + } +} + +// RecordError records an error by type +func (m *Metrics) RecordError(errorType string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.ErrorsByType[errorType]++ +} + +// RecordCacheHit records a cache hit +func (m *Metrics) RecordCacheHit() { + m.mu.Lock() + defer m.mu.Unlock() + + m.CacheHits++ +} + +// RecordCacheMiss records a cache miss +func (m *Metrics) RecordCacheMiss() { + m.mu.Lock() + defer m.mu.Unlock() + + m.CacheMisses++ +} + +// RecordRateLimited records a rate-limited request +func (m *Metrics) RecordRateLimited() { + m.mu.Lock() + defer m.mu.Unlock() + + m.RateLimitedRequests++ +} + +// GetSnapshot returns a snapshot of current metrics +func (m *Metrics) GetSnapshot() MetricsSnapshot { + m.mu.RLock() + defer m.mu.RUnlock() + + // Deep copy maps + serviceUsage := make(map[string]int64) + for k, v := range m.ServiceUsage { + serviceUsage[k] = v + } + + errorsByType := make(map[string]int64) + for k, v := range m.ErrorsByType { + errorsByType[k] = v + } + + return MetricsSnapshot{ + TotalRequests: m.TotalRequests, + SuccessfulRequests: m.SuccessfulRequests, + FailedRequests: m.FailedRequests, + AverageResponseTime: m.AverageResponseTime, + MaxResponseTime: m.MaxResponseTime, + MinResponseTime: m.MinResponseTime, + ServiceUsage: serviceUsage, + CacheHits: m.CacheHits, + CacheMisses: m.CacheMisses, + ErrorsByType: errorsByType, + RateLimitedRequests: m.RateLimitedRequests, + Uptime: time.Since(m.startTime), + } +} + +// MetricsSnapshot represents a point-in-time snapshot of metrics +type MetricsSnapshot struct { + TotalRequests int64 + SuccessfulRequests int64 + FailedRequests int64 + AverageResponseTime time.Duration + MaxResponseTime time.Duration + MinResponseTime time.Duration + ServiceUsage map[string]int64 + CacheHits int64 + CacheMisses int64 + ErrorsByType map[string]int64 + RateLimitedRequests int64 + Uptime time.Duration +} + +// SuccessRate returns the success rate as a percentage +func (s MetricsSnapshot) SuccessRate() float64 { + if s.TotalRequests == 0 { + return 0 + } + return float64(s.SuccessfulRequests) / float64(s.TotalRequests) * 100 +} + +// CacheHitRate returns the cache hit rate as a percentage +func (s MetricsSnapshot) CacheHitRate() float64 { + total := s.CacheHits + s.CacheMisses + if total == 0 { + return 0 + } + return float64(s.CacheHits) / float64(total) * 100 +} + +// RequestsPerSecond returns the average requests per second since startup +func (s MetricsSnapshot) RequestsPerSecond() float64 { + if s.Uptime.Seconds() == 0 { + return 0 + } + return float64(s.TotalRequests) / s.Uptime.Seconds() +} + +// Middleware wraps a function with metrics collection +func (m *Metrics) Middleware(service string, fn func(ctx context.Context) error) func(ctx context.Context) error { + return func(ctx context.Context) error { + start := time.Now() + err := fn(ctx) + duration := time.Since(start) + + success := err == nil + m.RecordRequest(duration, success, service) + + if err != nil { + m.RecordError(err.Error()) + } + + return err + } +} diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go new file mode 100644 index 0000000..3d0f1e8 --- /dev/null +++ b/internal/ratelimit/ratelimit.go @@ -0,0 +1,111 @@ +package ratelimit + +import ( + "context" + "fmt" + "sync" + "time" +) + +// RateLimiter implements a token bucket rate limiter +type RateLimiter struct { + mu sync.Mutex + tokens int + capacity int + refillRate int // tokens per second + lastRefill time.Time + enabled bool +} + +// New creates a new rate limiter +func New(rps int, enabled bool) *RateLimiter { + return &RateLimiter{ + tokens: rps, + capacity: rps, + refillRate: rps, + lastRefill: time.Now(), + enabled: enabled, + } +} + +// Allow checks if a request is allowed under the current rate limit +func (rl *RateLimiter) Allow() bool { + if !rl.enabled { + return true + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + elapsed := now.Sub(rl.lastRefill) + + // Refill tokens based on elapsed time + tokensToAdd := int(elapsed.Seconds() * float64(rl.refillRate)) + if tokensToAdd > 0 { + rl.tokens += tokensToAdd + if rl.tokens > rl.capacity { + rl.tokens = rl.capacity + } + rl.lastRefill = now + } + + // Check if we have tokens available + if rl.tokens > 0 { + rl.tokens-- + return true + } + + return false +} + +// Wait blocks until a token is available or context is cancelled +func (rl *RateLimiter) Wait(ctx context.Context) error { + if !rl.enabled { + return nil + } + + for { + if rl.Allow() { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(time.Millisecond * 10): // Check every 10ms + continue + } + } +} + +// Middleware wraps a function with rate limiting +func (rl *RateLimiter) Middleware(fn func(ctx context.Context) error) func(ctx context.Context) error { + return func(ctx context.Context) error { + if !rl.Allow() { + return fmt.Errorf("rate limit exceeded") + } + return fn(ctx) + } +} + +// GetStats returns current rate limiter statistics +func (rl *RateLimiter) GetStats() RateLimiterStats { + rl.mu.Lock() + defer rl.mu.Unlock() + + return RateLimiterStats{ + Enabled: rl.enabled, + CurrentTokens: rl.tokens, + Capacity: rl.capacity, + RefillRate: rl.refillRate, + } +} + +// RateLimiterStats holds statistics about the rate limiter +type RateLimiterStats struct { + Enabled bool + CurrentTokens int + Capacity int + RefillRate int +} diff --git a/internal/registry.go b/internal/registry.go index 4aec650..f4868c8 100644 --- a/internal/registry.go +++ b/internal/registry.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "log/slog" "strings" @@ -10,13 +11,16 @@ import ( "mcp-digitalocean/internal/account" "mcp-digitalocean/internal/apps" + "mcp-digitalocean/internal/cache" "mcp-digitalocean/internal/common" "mcp-digitalocean/internal/dbaas" "mcp-digitalocean/internal/doks" "mcp-digitalocean/internal/droplet" "mcp-digitalocean/internal/insights" "mcp-digitalocean/internal/marketplace" + "mcp-digitalocean/internal/metrics" "mcp-digitalocean/internal/networking" + "mcp-digitalocean/internal/ratelimit" "mcp-digitalocean/internal/spaces" ) @@ -33,163 +37,183 @@ var supportedServices = map[string]struct{}{ "doks": {}, } +// Components holds the shared components for the MCP server +type Components struct { + Client *godo.Client + Metrics *metrics.Metrics + Cache *cache.Cache + RateLimiter *ratelimit.RateLimiter +} + // registerAppTools registers the app platform tools with the MCP server. -func registerAppTools(s *server.MCPServer, c *godo.Client) error { - appTools, err := apps.NewAppPlatformTool(c) +func registerAppTools(s *server.MCPServer, comp *Components) error { + appTools, err := apps.NewAppPlatformTool(comp.Client) if err != nil { return fmt.Errorf("failed to create apps tool: %w", err) } s.AddTools(appTools.Tools()...) - return nil } // registerCommonTools registers the common tools with the MCP server. -func registerCommonTools(s *server.MCPServer, c *godo.Client) error { - s.AddTools(common.NewRegionTools(c).Tools()...) - +func registerCommonTools(s *server.MCPServer, comp *Components) error { + s.AddTools(common.NewRegionTools(comp.Client).Tools()...) return nil } // registerDropletTools registers the droplet tools with the MCP server. -func registerDropletTools(s *server.MCPServer, c *godo.Client) error { - s.AddTools(droplet.NewDropletTool(c).Tools()...) - s.AddTools(droplet.NewDropletActionsTool(c).Tools()...) - s.AddTools(droplet.NewImagesTool(c).Tools()...) - s.AddTools(droplet.NewSizesTool(c).Tools()...) +func registerDropletTools(s *server.MCPServer, comp *Components) error { + s.AddTools(droplet.NewDropletTool(comp.Client).Tools()...) + s.AddTools(droplet.NewDropletActionsTool(comp.Client).Tools()...) + s.AddTools(droplet.NewImagesTool(comp.Client).Tools()...) + s.AddTools(droplet.NewSizesTool(comp.Client).Tools()...) return nil } // registerNetworkingTools registers the networking tools with the MCP server. -func registerNetworkingTools(s *server.MCPServer, c *godo.Client) error { - s.AddTools(networking.NewCertificateTool(c).Tools()...) - s.AddTools(networking.NewDomainsTool(c).Tools()...) - s.AddTools(networking.NewFirewallTool(c).Tools()...) - s.AddTools(networking.NewReservedIPTool(c).Tools()...) +func registerNetworkingTools(s *server.MCPServer, comp *Components) error { + s.AddTools(networking.NewCertificateTool(comp.Client).Tools()...) + s.AddTools(networking.NewDomainsTool(comp.Client).Tools()...) + s.AddTools(networking.NewFirewallTool(comp.Client).Tools()...) + s.AddTools(networking.NewReservedIPTool(comp.Client).Tools()...) // Partner attachments doesn't have much users so this has been disabled - // s.AddTools(networking.NewPartnerAttachmentTool(c).Tools()...) - s.AddTools(networking.NewVPCTool(c).Tools()...) - s.AddTools(networking.NewVPCPeeringTool(c).Tools()...) + // s.AddTools(networking.NewPartnerAttachmentTool(comp.Client).Tools()...) + s.AddTools(networking.NewVPCTool(comp.Client).Tools()...) + s.AddTools(networking.NewVPCPeeringTool(comp.Client).Tools()...) return nil } // registerAccountTools registers the account tools with the MCP server. -func registerAccountTools(s *server.MCPServer, c *godo.Client) error { - s.AddTools(account.NewAccountTools(c).Tools()...) - s.AddTools(account.NewActionTools(c).Tools()...) - s.AddTools(account.NewBalanceTools(c).Tools()...) - s.AddTools(account.NewBillingTools(c).Tools()...) - s.AddTools(account.NewInvoiceTools(c).Tools()...) - s.AddTools(account.NewKeysTool(c).Tools()...) - +func registerAccountTools(s *server.MCPServer, comp *Components) error { + s.AddTools(account.NewAccountTools(comp.Client).Tools()...) + s.AddTools(account.NewActionTools(comp.Client).Tools()...) + s.AddTools(account.NewBalanceTools(comp.Client).Tools()...) + s.AddTools(account.NewBillingTools(comp.Client).Tools()...) + s.AddTools(account.NewInvoiceTools(comp.Client).Tools()...) + s.AddTools(account.NewKeysTool(comp.Client).Tools()...) return nil } // registerSpacesTools registers the spaces tools and resources with the MCP server. -func registerSpacesTools(s *server.MCPServer, c *godo.Client) error { +func registerSpacesTools(s *server.MCPServer, comp *Components) error { // Register the tools for spaces keys - s.AddTools(spaces.NewSpacesKeysTool(c).Tools()...) - s.AddTools(spaces.NewCDNTool(c).Tools()...) - + s.AddTools(spaces.NewSpacesKeysTool(comp.Client).Tools()...) + s.AddTools(spaces.NewCDNTool(comp.Client).Tools()...) return nil } // registerMarketplaceTools registers the marketplace tools with the MCP server. -func registerMarketplaceTools(s *server.MCPServer, c *godo.Client) error { - s.AddTools(marketplace.NewOneClickTool(c).Tools()...) - +func registerMarketplaceTools(s *server.MCPServer, comp *Components) error { + s.AddTools(marketplace.NewOneClickTool(comp.Client).Tools()...) return nil } -func registerInsightsTools(s *server.MCPServer, c *godo.Client) error { - s.AddTools(insights.NewUptimeTool(c).Tools()...) - s.AddTools(insights.NewUptimeCheckAlertTool(c).Tools()...) - s.AddTools(insights.NewAlertPolicyTool(c).Tools()...) +func registerInsightsTools(s *server.MCPServer, comp *Components) error { + s.AddTools(insights.NewUptimeTool(comp.Client).Tools()...) + s.AddTools(insights.NewUptimeCheckAlertTool(comp.Client).Tools()...) + s.AddTools(insights.NewAlertPolicyTool(comp.Client).Tools()...) return nil } -func registerDOKSTools(s *server.MCPServer, c *godo.Client) error { - s.AddTools(doks.NewDoksTool(c).Tools()...) - +func registerDOKSTools(s *server.MCPServer, comp *Components) error { + s.AddTools(doks.NewDoksTool(comp.Client).Tools()...) return nil } -func registerDatabasesTools(s *server.MCPServer, c *godo.Client) error { - s.AddTools(dbaas.NewClusterTool(c).Tools()...) - s.AddTools(dbaas.NewFirewallTool(c).Tools()...) - s.AddTools(dbaas.NewKafkaTool(c).Tools()...) - s.AddTools(dbaas.NewMongoTool(c).Tools()...) - s.AddTools(dbaas.NewMysqlTool(c).Tools()...) - s.AddTools(dbaas.NewOpenSearchTool(c).Tools()...) - s.AddTools(dbaas.NewPostgreSQLTool(c).Tools()...) - s.AddTools(dbaas.NewRedisTool(c).Tools()...) - s.AddTools(dbaas.NewUserTool(c).Tools()...) - +func registerDatabasesTools(s *server.MCPServer, comp *Components) error { + s.AddTools(dbaas.NewClusterTool(comp.Client).Tools()...) + s.AddTools(dbaas.NewFirewallTool(comp.Client).Tools()...) + s.AddTools(dbaas.NewKafkaTool(comp.Client).Tools()...) + s.AddTools(dbaas.NewMongoTool(comp.Client).Tools()...) + s.AddTools(dbaas.NewMysqlTool(comp.Client).Tools()...) + s.AddTools(dbaas.NewOpenSearchTool(comp.Client).Tools()...) + s.AddTools(dbaas.NewPostgreSQLTool(comp.Client).Tools()...) + s.AddTools(dbaas.NewRedisTool(comp.Client).Tools()...) + s.AddTools(dbaas.NewUserTool(comp.Client).Tools()...) return nil } // Register registers the set of tools for the specified services with the MCP server. // We either register a subset of tools of the services are specified, or we register all tools if no services are specified. +// This is the legacy function for backward compatibility. func Register(logger *slog.Logger, s *server.MCPServer, c *godo.Client, servicesToActivate ...string) error { + comp := &Components{ + Client: c, + Metrics: metrics.New(), + Cache: cache.New(0, false), // Disabled by default for backward compatibility + RateLimiter: ratelimit.New(0, false), // Disabled by default for backward compatibility + } + + return RegisterWithComponents(logger, s, c, comp.Metrics, comp.Cache, comp.RateLimiter, servicesToActivate...) +} + +// RegisterWithComponents registers the set of tools with enhanced components support +func RegisterWithComponents(logger *slog.Logger, s *server.MCPServer, c *godo.Client, + metricsCollector *metrics.Metrics, cacheInstance *cache.Cache, rateLimiter *ratelimit.RateLimiter, + servicesToActivate ...string) error { + + comp := &Components{ + Client: c, + Metrics: metricsCollector, + Cache: cacheInstance, + RateLimiter: rateLimiter, + } + if len(servicesToActivate) == 0 { logger.Warn("no services specified, loading all supported services") for k := range supportedServices { servicesToActivate = append(servicesToActivate, k) } } + for _, svc := range servicesToActivate { - logger.Debug(fmt.Sprintf("Registering tool and resources for service: %s", svc)) - switch svc { - case "apps": - if err := registerAppTools(s, c); err != nil { - return fmt.Errorf("failed to register app tools: %w", err) - } - case "networking": - if err := registerNetworkingTools(s, c); err != nil { - return fmt.Errorf("failed to register networking tools: %w", err) - } - case "droplets": - if err := registerDropletTools(s, c); err != nil { - return fmt.Errorf("failed to register droplets tool: %w", err) - } - case "accounts": - if err := registerAccountTools(s, c); err != nil { - return fmt.Errorf("failed to register account tools: %w", err) - } - case "spaces": - if err := registerSpacesTools(s, c); err != nil { - return fmt.Errorf("failed to register spaces tools: %w", err) - } - case "databases": - if err := registerDatabasesTools(s, c); err != nil { - return fmt.Errorf("failed to register databases tools: %w", err) - } - case "marketplace": - if err := registerMarketplaceTools(s, c); err != nil { - return fmt.Errorf("failed to register marketplace tools: %w", err) - } - case "insights": - if err := registerInsightsTools(s, c); err != nil { - return fmt.Errorf("failed to register insights tools: %w", err) - } - case "doks": - if err := registerDOKSTools(s, c); err != nil { - return fmt.Errorf("failed to register DOKS tools: %w", err) - } - default: - return fmt.Errorf("unsupported service: %s, supported service are: %v", svc, setToString(supportedServices)) + logger.Debug("Registering tools and resources for service", "service", svc) + + // Wrap service registration with metrics + err := comp.Metrics.Middleware(svc, func(ctx context.Context) error { + return registerServiceTools(s, comp, svc) + })(context.Background()) + + if err != nil { + return fmt.Errorf("failed to register %s tools: %w", svc, err) } } // Common tools are always registered because they provide common functionality for all services such as region resources - if err := registerCommonTools(s, c); err != nil { + if err := registerCommonTools(s, comp); err != nil { return fmt.Errorf("failed to register common tools: %w", err) } return nil } +// registerServiceTools registers tools for a specific service +func registerServiceTools(s *server.MCPServer, comp *Components, service string) error { + switch service { + case "apps": + return registerAppTools(s, comp) + case "networking": + return registerNetworkingTools(s, comp) + case "droplets": + return registerDropletTools(s, comp) + case "accounts": + return registerAccountTools(s, comp) + case "spaces": + return registerSpacesTools(s, comp) + case "databases": + return registerDatabasesTools(s, comp) + case "marketplace": + return registerMarketplaceTools(s, comp) + case "insights": + return registerInsightsTools(s, comp) + case "doks": + return registerDOKSTools(s, comp) + default: + return fmt.Errorf("unsupported service: %s, supported services are: %v", service, setToString(supportedServices)) + } +} + func setToString(set map[string]struct{}) string { var result []string for key := range set { diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..07f9f8a --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,123 @@ +package testutil + +import ( + "context" + "testing" + "time" + + "mcp-digitalocean/internal/cache" + "mcp-digitalocean/internal/config" + "mcp-digitalocean/internal/metrics" + "mcp-digitalocean/internal/ratelimit" +) + +// TestConfig returns a test configuration with sensible defaults +func TestConfig() *config.Config { + return &config.Config{ + APIToken: "test-token", + APIEndpoint: "https://api.digitalocean.com", + LogLevel: "debug", + Services: []string{"apps", "droplets"}, + RequestTimeout: 30 * time.Second, + MaxRetries: 3, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 10 * time.Second, + CacheEnabled: true, + CacheTTL: 5 * time.Minute, + RateLimitEnabled: true, + RateLimitRPS: 100, + } +} + +// TestComponents returns test components for testing +func TestComponents() (*metrics.Metrics, *cache.Cache, *ratelimit.RateLimiter) { + cfg := TestConfig() + return metrics.New(), + cache.New(cfg.CacheTTL, cfg.CacheEnabled), + ratelimit.New(cfg.RateLimitRPS, cfg.RateLimitEnabled) +} + +// AssertNoError is a helper function to assert no error occurred +func AssertNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } +} + +// AssertError is a helper function to assert an error occurred +func AssertError(t *testing.T, err error) { + t.Helper() + if err == nil { + t.Fatal("Expected an error, got nil") + } +} + +// AssertEqual is a helper function to assert two values are equal +func AssertEqual(t *testing.T, expected, actual interface{}) { + t.Helper() + if expected != actual { + t.Fatalf("Expected %v, got %v", expected, actual) + } +} + +// AssertTrue is a helper function to assert a condition is true +func AssertTrue(t *testing.T, condition bool, message string) { + t.Helper() + if !condition { + t.Fatal(message) + } +} + +// AssertFalse is a helper function to assert a condition is false +func AssertFalse(t *testing.T, condition bool, message string) { + t.Helper() + if condition { + t.Fatal(message) + } +} + +// WithTimeout runs a function with a timeout context +func WithTimeout(t *testing.T, timeout time.Duration, fn func(ctx context.Context)) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + fn(ctx) + }() + + select { + case <-done: + // Function completed successfully + case <-ctx.Done(): + t.Fatal("Function timed out") + } +} + +// MockTime provides utilities for testing time-dependent code +type MockTime struct { + current time.Time +} + +// NewMockTime creates a new mock time instance +func NewMockTime(start time.Time) *MockTime { + return &MockTime{current: start} +} + +// Now returns the current mock time +func (m *MockTime) Now() time.Time { + return m.current +} + +// Advance advances the mock time by the given duration +func (m *MockTime) Advance(d time.Duration) { + m.current = m.current.Add(d) +} + +// Set sets the mock time to a specific time +func (m *MockTime) Set(t time.Time) { + m.current = t +}