diff --git a/internal/guard/registry.go b/internal/guard/registry.go index 47a4a687..e73d5080 100644 --- a/internal/guard/registry.go +++ b/internal/guard/registry.go @@ -2,7 +2,6 @@ package guard import ( "fmt" - "log" "sync" "github.com/githubnext/gh-aw-mcpg/internal/logger" diff --git a/internal/server/transport_test.go b/internal/server/transport_test.go index 164e2dd6..3aca6540 100644 --- a/internal/server/transport_test.go +++ b/internal/server/transport_test.go @@ -18,10 +18,9 @@ import ( // TestLoggingResponseWriter_WriteHeader tests the WriteHeader method func TestLoggingResponseWriter_WriteHeader(t *testing.T) { tests := []struct { - name string - statusCode int - wantStatusCode int - writeMultipleTimes bool + name string + statusCode int + wantStatusCode int }{ { name: "StatusOK", @@ -43,35 +42,18 @@ func TestLoggingResponseWriter_WriteHeader(t *testing.T) { statusCode: http.StatusUnauthorized, wantStatusCode: http.StatusUnauthorized, }, - { - name: "MultipleWrites_FirstWins", - statusCode: http.StatusOK, - wantStatusCode: http.StatusOK, - writeMultipleTimes: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() - lw := &loggingResponseWriter{ - ResponseWriter: w, - body: []byte{}, - statusCode: 0, - } + lw := newResponseWriter(w) // Write header lw.WriteHeader(tt.statusCode) // Verify status code is captured - assert.Equal(t, tt.wantStatusCode, lw.statusCode, "Status code should be captured") - - // If testing multiple writes, try writing again (should be ignored by stdlib) - if tt.writeMultipleTimes { - lw.WriteHeader(http.StatusBadRequest) - // First status code should win - assert.Equal(t, tt.wantStatusCode, lw.statusCode, "First status code should be preserved") - } + assert.Equal(t, tt.wantStatusCode, lw.StatusCode(), "Status code should be captured") }) } } @@ -119,11 +101,7 @@ func TestLoggingResponseWriter_Write(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() - lw := &loggingResponseWriter{ - ResponseWriter: w, - body: []byte{}, - statusCode: http.StatusOK, - } + lw := newResponseWriter(w) totalWritten := 0 for _, data := range tt.writes { @@ -135,11 +113,14 @@ func TestLoggingResponseWriter_Write(t *testing.T) { // Verify total bytes written assert.Equal(t, tt.wantWritten, totalWritten, "Total bytes written should match") - // Verify body is captured - assert.Equal(t, tt.wantBody, lw.body, "Body should be captured correctly") - - // Verify body is also written to underlying response writer - assert.Equal(t, tt.wantBody, w.Body.Bytes(), "Body should be written to underlying writer") + // Verify body is captured (use Len for empty check to handle nil vs empty slice) + if len(tt.wantBody) == 0 { + assert.Empty(t, lw.Body(), "Body should be empty") + assert.Empty(t, w.Body.Bytes(), "Underlying writer body should be empty") + } else { + assert.Equal(t, tt.wantBody, lw.Body(), "Body should be captured correctly") + assert.Equal(t, tt.wantBody, w.Body.Bytes(), "Body should be written to underlying writer") + } }) } } @@ -147,17 +128,13 @@ func TestLoggingResponseWriter_Write(t *testing.T) { // TestLoggingResponseWriter_DefaultStatusCode tests that default status code is 200 func TestLoggingResponseWriter_DefaultStatusCode(t *testing.T) { w := httptest.NewRecorder() - lw := &loggingResponseWriter{ - ResponseWriter: w, - body: []byte{}, - statusCode: http.StatusOK, // Constructor sets this - } + lw := newResponseWriter(w) // Write without explicit WriteHeader lw.Write([]byte("test")) // Default status code should be 200 - assert.Equal(t, http.StatusOK, lw.statusCode, "Default status code should be 200") + assert.Equal(t, http.StatusOK, lw.StatusCode(), "Default status code should be 200") } // TestWithResponseLogging tests the withResponseLogging middleware @@ -340,9 +317,9 @@ func TestCreateHTTPServerForMCP_Health(t *testing.T) { // Check required fields assert.Contains(t, response, "status", "Response should contain status") - assert.Contains(t, response, "protocolVersion", "Response should contain protocolVersion") - assert.Contains(t, response, "gatewayVersion", "Response should contain gatewayVersion") assert.Contains(t, response, "specVersion", "Response should contain specVersion") + assert.Contains(t, response, "gatewayVersion", "Response should contain gatewayVersion") + assert.Contains(t, response, "servers", "Response should contain servers") }) } } @@ -508,13 +485,13 @@ func TestCreateHTTPServerForMCP_MCPEndpoint(t *testing.T) { wantStatusCode int }{ { - name: "MCP_GET_NotAllowed", + name: "MCP_GET_RequiresSession", path: "/mcp", method: "GET", apiKey: "", authHeader: "", body: nil, - wantStatusCode: http.StatusMethodNotAllowed, + wantStatusCode: http.StatusBadRequest, // GET requires active session per MCP streamable HTTP spec }, { name: "MCP_POST_NoAuth_WithAPIKey",