Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion internal/guard/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package guard

import (
"fmt"
"log"
"sync"

"github.com/githubnext/gh-aw-mcpg/internal/logger"
Expand Down
63 changes: 20 additions & 43 deletions internal/server/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
// TestLoggingResponseWriter_WriteHeader tests the WriteHeader method
func TestLoggingResponseWriter_WriteHeader(t *testing.T) {
tests := []struct {
name string
statusCode int
wantStatusCode int
writeMultipleTimes bool
name string
statusCode int
wantStatusCode int
}{
{
name: "StatusOK",
Expand All @@ -43,35 +42,18 @@
statusCode: http.StatusUnauthorized,
wantStatusCode: http.StatusUnauthorized,
},
{
name: "MultipleWrites_FirstWins",
statusCode: http.StatusOK,
wantStatusCode: http.StatusOK,
writeMultipleTimes: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
lw := &loggingResponseWriter{
ResponseWriter: w,
body: []byte{},
statusCode: 0,
}
lw := newResponseWriter(w)

// Write header
lw.WriteHeader(tt.statusCode)

// Verify status code is captured
assert.Equal(t, tt.wantStatusCode, lw.statusCode, "Status code should be captured")

// If testing multiple writes, try writing again (should be ignored by stdlib)
if tt.writeMultipleTimes {
lw.WriteHeader(http.StatusBadRequest)
// First status code should win
assert.Equal(t, tt.wantStatusCode, lw.statusCode, "First status code should be preserved")
}
assert.Equal(t, tt.wantStatusCode, lw.StatusCode(), "Status code should be captured")
})
}
}
Expand Down Expand Up @@ -119,11 +101,7 @@
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
lw := &loggingResponseWriter{
ResponseWriter: w,
body: []byte{},
statusCode: http.StatusOK,
}
lw := newResponseWriter(w)

totalWritten := 0
for _, data := range tt.writes {
Expand All @@ -135,29 +113,28 @@
// Verify total bytes written
assert.Equal(t, tt.wantWritten, totalWritten, "Total bytes written should match")

// Verify body is captured
assert.Equal(t, tt.wantBody, lw.body, "Body should be captured correctly")

// Verify body is also written to underlying response writer
assert.Equal(t, tt.wantBody, w.Body.Bytes(), "Body should be written to underlying writer")
// Verify body is captured (use Len for empty check to handle nil vs empty slice)
if len(tt.wantBody) == 0 {
assert.Empty(t, lw.Body(), "Body should be empty")
assert.Empty(t, w.Body.Bytes(), "Underlying writer body should be empty")
} else {
assert.Equal(t, tt.wantBody, lw.Body(), "Body should be captured correctly")
assert.Equal(t, tt.wantBody, w.Body.Bytes(), "Body should be written to underlying writer")
}
})
}
}

// TestLoggingResponseWriter_DefaultStatusCode tests that default status code is 200
func TestLoggingResponseWriter_DefaultStatusCode(t *testing.T) {
w := httptest.NewRecorder()
lw := &loggingResponseWriter{
ResponseWriter: w,
body: []byte{},
statusCode: http.StatusOK, // Constructor sets this
}
lw := newResponseWriter(w)

// Write without explicit WriteHeader
lw.Write([]byte("test"))

// Default status code should be 200
assert.Equal(t, http.StatusOK, lw.statusCode, "Default status code should be 200")
assert.Equal(t, http.StatusOK, lw.StatusCode(), "Default status code should be 200")
}

// TestWithResponseLogging tests the withResponseLogging middleware
Expand Down Expand Up @@ -340,9 +317,9 @@

// Check required fields
assert.Contains(t, response, "status", "Response should contain status")
assert.Contains(t, response, "protocolVersion", "Response should contain protocolVersion")
assert.Contains(t, response, "gatewayVersion", "Response should contain gatewayVersion")
assert.Contains(t, response, "specVersion", "Response should contain specVersion")
assert.Contains(t, response, "gatewayVersion", "Response should contain gatewayVersion")
assert.Contains(t, response, "servers", "Response should contain servers")
})
}
}
Expand Down Expand Up @@ -439,7 +416,7 @@

// Verify response based on error expectation
if tt.wantError {
if tt.wantStatusCode == http.StatusMethodNotAllowed {

Check failure on line 419 in internal/server/transport_test.go

View workflow job for this annotation

GitHub Actions / lint

QF1003: could use tagged switch on tt.wantStatusCode (staticcheck)
// http.Error writes plain text for 405
assert.Contains(t, w.Body.String(), "Method not allowed")
} else if tt.wantStatusCode == http.StatusUnauthorized {
Expand Down Expand Up @@ -508,13 +485,13 @@
wantStatusCode int
}{
{
name: "MCP_GET_NotAllowed",
name: "MCP_GET_RequiresSession",
path: "/mcp",
method: "GET",
apiKey: "",
authHeader: "",
body: nil,
wantStatusCode: http.StatusMethodNotAllowed,
wantStatusCode: http.StatusBadRequest, // GET requires active session per MCP streamable HTTP spec
},
{
name: "MCP_POST_NoAuth_WithAPIKey",
Expand Down
Loading