Skip to content

Commit 0f06109

Browse files
committed
Correct the indirection of message propagation
Added tests to verify that messages are propagated to the BeforeAny and AfterAny hooks
1 parent eaf863c commit 0f06109

File tree

3 files changed

+62
-138
lines changed

3 files changed

+62
-138
lines changed

server/hooks.go

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/internal/gen/hooks.go.tmpl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func (c *Hooks) beforeAny(id any, method mcp.MCPMethod, message any) {
129129
return
130130
}
131131
for _, hook := range c.OnBeforeAny {
132-
hook(id, method, &message)
132+
hook(id, method, message)
133133
}
134134
}
135135

@@ -138,7 +138,7 @@ func (c *Hooks) afterAny(id any, method mcp.MCPMethod, message any, result any)
138138
return
139139
}
140140
for _, hook := range c.OnAfterAny {
141-
hook(id, method, &message, result)
141+
hook(id, method, message, result)
142142
}
143143
}
144144

@@ -157,6 +157,7 @@ func (c *Hooks) afterAny(id any, method mcp.MCPMethod, message any, result any)
157157
// - ErrPromptNotFound: When a prompt is not found
158158
// - ErrToolNotFound: When a tool is not found
159159
func (c *Hooks) onError(id any, method mcp.MCPMethod, message any, err error) {
160+
c.afterAny(id, method, message, err)
160161
if c == nil {
161162
return
162163
}

server/server_test.go

Lines changed: 56 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -665,10 +665,10 @@ func TestMCPServer_PromptHandling(t *testing.T) {
665665
}
666666

667667
func TestMCPServer_HandleInvalidMessages(t *testing.T) {
668-
var errChan = make(chan error, 1)
668+
var errs []error
669669
hooks := &Hooks{}
670670
hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
671-
errChan <- err
671+
errs = append(errs, err)
672672
})
673673

674674
server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks))
@@ -710,114 +710,8 @@ func TestMCPServer_HandleInvalidMessages(t *testing.T) {
710710

711711
for _, tt := range tests {
712712
t.Run(tt.name, func(t *testing.T) {
713-
response := server.HandleMessage(
714-
context.Background(),
715-
[]byte(tt.message),
716-
)
717-
assert.NotNil(t, response)
718-
719-
errorResponse, ok := response.(mcp.JSONRPCError)
720-
assert.True(t, ok)
721-
assert.Equal(t, tt.expectedErr, errorResponse.Error.Code)
722-
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
723-
defer cancel()
724-
725-
if tt.validateErr != nil {
726-
select {
727-
case err := <-errChan:
728-
tt.validateErr(t, err)
729-
case <-ctx.Done():
730-
t.Errorf("Error not received")
731-
}
732-
}
733-
})
734-
}
735-
}
736-
737-
func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
738-
var errChan chan error
739-
hooks := &Hooks{}
740-
hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
741-
errChan <- err
742-
})
743-
744-
server := NewMCPServer("test-server", "1.0.0",
745-
WithResourceCapabilities(true, true),
746-
WithPromptCapabilities(true),
747-
WithToolCapabilities(true),
748-
WithHooks(hooks),
749-
)
750-
751-
// Add a test tool to enable tool capabilities
752-
server.AddTool(mcp.Tool{
753-
Name: "test-tool",
754-
Description: "Test tool",
755-
InputSchema: mcp.ToolInputSchema{
756-
Type: "object",
757-
Properties: map[string]interface{}{},
758-
},
759-
}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
760-
return &mcp.CallToolResult{}, nil
761-
})
762-
763-
tests := []struct {
764-
name string
765-
message string
766-
expectedErr int
767-
validateErr func(t *testing.T, err error)
768-
}{
769-
{
770-
name: "Undefined tool",
771-
message: `{
772-
"jsonrpc": "2.0",
773-
"id": 1,
774-
"method": "tools/call",
775-
"params": {
776-
"name": "undefined-tool",
777-
"arguments": {}
778-
}
779-
}`,
780-
expectedErr: mcp.INVALID_PARAMS,
781-
validateErr: func(t *testing.T, err error) {
782-
assert.True(t, errors.Is(err, ErrToolNotFound), "Error should be ErrToolNotFound but was %v", err)
783-
},
784-
},
785-
{
786-
name: "Undefined prompt",
787-
message: `{
788-
"jsonrpc": "2.0",
789-
"id": 1,
790-
"method": "prompts/get",
791-
"params": {
792-
"name": "undefined-prompt",
793-
"arguments": {}
794-
}
795-
}`,
796-
expectedErr: mcp.INVALID_PARAMS,
797-
validateErr: func(t *testing.T, err error) {
798-
assert.True(t, errors.Is(err, ErrPromptNotFound), "Error should be ErrPromptNotFound but was %v", err)
799-
},
800-
},
801-
{
802-
name: "Undefined resource",
803-
message: `{
804-
"jsonrpc": "2.0",
805-
"id": 1,
806-
"method": "resources/read",
807-
"params": {
808-
"uri": "undefined-resource"
809-
}
810-
}`,
811-
expectedErr: mcp.INVALID_PARAMS,
812-
validateErr: func(t *testing.T, err error) {
813-
assert.True(t, errors.Is(err, ErrResourceNotFound), "Error should be ErrResourceNotFound but was %v", err)
814-
},
815-
},
816-
}
713+
errs = nil // Reset errors for each test case
817714

818-
for _, tt := range tests {
819-
t.Run(tt.name, func(t *testing.T) {
820-
errChan = make(chan error, 1)
821715
response := server.HandleMessage(
822716
context.Background(),
823717
[]byte(tt.message),
@@ -827,16 +721,10 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
827721
errorResponse, ok := response.(mcp.JSONRPCError)
828722
assert.True(t, ok)
829723
assert.Equal(t, tt.expectedErr, errorResponse.Error.Code)
830-
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
831-
defer cancel()
832724

833725
if tt.validateErr != nil {
834-
select {
835-
case err := <-errChan:
836-
tt.validateErr(t, err)
837-
case <-ctx.Done():
838-
t.Errorf("Error not received")
839-
}
726+
require.Len(t, errs, 1, "Expected exactly one error")
727+
tt.validateErr(t, errs[0])
840728
}
841729
})
842730
}
@@ -981,10 +869,10 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
981869
}
982870

983871
func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) {
984-
var errChan chan error
872+
var errs []error
985873
hooks := &Hooks{}
986874
hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
987-
errChan <- err
875+
errs = append(errs, err)
988876
})
989877
hooksOption := WithHooks(hooks)
990878

@@ -1041,7 +929,8 @@ func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) {
1041929

1042930
for _, tt := range tests {
1043931
t.Run(tt.name, func(t *testing.T) {
1044-
errChan = make(chan error, 1)
932+
errs = nil // Reset errors for each test case
933+
1045934
server := NewMCPServer("test-server", "1.0.0", tt.options...)
1046935
response := server.HandleMessage(
1047936
context.Background(),
@@ -1052,16 +941,10 @@ func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) {
1052941
errorResponse, ok := response.(mcp.JSONRPCError)
1053942
assert.True(t, ok)
1054943
assert.Equal(t, tt.expectedErr, errorResponse.Error.Code)
1055-
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
1056-
defer cancel()
1057-
var err error
1058-
select {
1059-
case err = <-errChan:
1060-
case <-ctx.Done():
1061-
t.Errorf("Error not received")
1062-
}
1063-
assert.True(t, errors.Is(err, ErrUnsupported), "Error should be ErrUnsupported but was %v", err)
1064-
assert.Contains(t, err.Error(), tt.errString)
944+
945+
require.Len(t, errs, 1, "Expected exactly one error")
946+
assert.True(t, errors.Is(errs[0], ErrUnsupported), "Error should be ErrUnsupported but was %v", errs[0])
947+
assert.Contains(t, errs[0].Error(), tt.errString)
1065948
})
1066949
}
1067950
}
@@ -1290,29 +1173,57 @@ func TestMCPServer_WithHooks(t *testing.T) {
12901173
afterToolsCount int
12911174
)
12921175

1176+
// Collectors for message and result types
1177+
var beforeAnyMessages []any
1178+
var afterAnyData []struct {
1179+
msg any
1180+
res any
1181+
}
1182+
var beforePingMessages []*mcp.PingRequest
1183+
var afterPingData []struct {
1184+
msg *mcp.PingRequest
1185+
res *mcp.EmptyResult
1186+
}
1187+
12931188
// Initialize hook handlers
12941189
hooks := &Hooks{}
12951190

1296-
// Register "any" hooks
1191+
// Register "any" hooks with type verification
12971192
hooks.AddBeforeAny(func(id any, method mcp.MCPMethod, message any) {
12981193
beforeAnyCount++
1194+
// Only collect ping messages for our test
1195+
if method == mcp.MethodPing {
1196+
beforeAnyMessages = append(beforeAnyMessages, message)
1197+
}
12991198
})
13001199

13011200
hooks.AddAfterAny(func(id any, method mcp.MCPMethod, message any, result any) {
13021201
afterAnyCount++
1202+
// Only collect ping responses for our test
1203+
if method == mcp.MethodPing {
1204+
afterAnyData = append(afterAnyData, struct {
1205+
msg any
1206+
res any
1207+
}{message, result})
1208+
}
13031209
})
13041210

13051211
hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
13061212
onErrorCount++
13071213
})
13081214

1309-
// Register method-specific hooks
1215+
// Register method-specific hooks with type verification
13101216
hooks.AddBeforePing(func(id any, message *mcp.PingRequest) {
13111217
beforePingCount++
1218+
beforePingMessages = append(beforePingMessages, message)
13121219
})
13131220

13141221
hooks.AddAfterPing(func(id any, message *mcp.PingRequest, result *mcp.EmptyResult) {
13151222
afterPingCount++
1223+
afterPingData = append(afterPingData, struct {
1224+
msg *mcp.PingRequest
1225+
res *mcp.EmptyResult
1226+
}{message, result})
13161227
})
13171228

13181229
hooks.AddBeforeListTools(func(id any, message *mcp.ListToolsRequest) {
@@ -1390,9 +1301,20 @@ func TestMCPServer_WithHooks(t *testing.T) {
13901301
// General hooks should be called for all methods
13911302
// beforeAny is called for all 4 methods (initialize, ping, tools/list, tools/call)
13921303
assert.Equal(t, 4, beforeAnyCount, "beforeAny should be called for each method")
1393-
// afterAny is called only for successful methods (initialize, ping, tools/list)
1394-
assert.Equal(t, 3, afterAnyCount, "afterAny should be called for successful methods only")
1304+
// afterAny is called for all 3 success methods (initialize, ping, tools/list) plus 1 error (tools/call)
1305+
assert.Equal(t, 4, afterAnyCount, "afterAny should be called for all methods including errors")
13951306

13961307
// Error hook should be called once for the failed tools/call
13971308
assert.Equal(t, 1, onErrorCount, "onError should be called once")
1309+
1310+
// Verify type matching between BeforeAny and BeforePing
1311+
require.Len(t, beforePingMessages, 1, "Expected one BeforePing message")
1312+
require.Len(t, beforeAnyMessages, 1, "Expected one BeforeAny Ping message")
1313+
assert.IsType(t, beforePingMessages[0], beforeAnyMessages[0], "BeforeAny message should be same type as BeforePing message")
1314+
1315+
// Verify type matching between AfterAny and AfterPing
1316+
require.Len(t, afterPingData, 1, "Expected one AfterPing message/result pair")
1317+
require.Len(t, afterAnyData, 1, "Expected one AfterAny Ping message/result pair")
1318+
assert.IsType(t, afterPingData[0].msg, afterAnyData[0].msg, "AfterAny message should be same type as AfterPing message")
1319+
assert.IsType(t, afterPingData[0].res, afterAnyData[0].res, "AfterAny result should be same type as AfterPing result")
13981320
}

0 commit comments

Comments
 (0)