@@ -665,10 +665,10 @@ func TestMCPServer_PromptHandling(t *testing.T) {
665665}
666666
667667func TestMCPServer_HandleInvalidMessages (t * testing.T ) {
668- var errChan = make ( chan error , 1 )
668+ var errs [] error
669669 hooks := & Hooks {}
670670 hooks .AddOnError (func (id any , method mcp.MCPMethod , message any , err error ) {
671- errChan <- err
671+ errs = append ( errs , err )
672672 })
673673
674674 server := NewMCPServer ("test-server" , "1.0.0" , WithHooks (hooks ))
@@ -710,114 +710,8 @@ func TestMCPServer_HandleInvalidMessages(t *testing.T) {
710710
711711 for _ , tt := range tests {
712712 t .Run (tt .name , func (t * testing.T ) {
713- response := server .HandleMessage (
714- context .Background (),
715- []byte (tt .message ),
716- )
717- assert .NotNil (t , response )
718-
719- errorResponse , ok := response .(mcp.JSONRPCError )
720- assert .True (t , ok )
721- assert .Equal (t , tt .expectedErr , errorResponse .Error .Code )
722- ctx , cancel := context .WithTimeout (context .Background (), 50 * time .Millisecond )
723- defer cancel ()
724-
725- if tt .validateErr != nil {
726- select {
727- case err := <- errChan :
728- tt .validateErr (t , err )
729- case <- ctx .Done ():
730- t .Errorf ("Error not received" )
731- }
732- }
733- })
734- }
735- }
736-
737- func TestMCPServer_HandleUndefinedHandlers (t * testing.T ) {
738- var errChan chan error
739- hooks := & Hooks {}
740- hooks .AddOnError (func (id any , method mcp.MCPMethod , message any , err error ) {
741- errChan <- err
742- })
743-
744- server := NewMCPServer ("test-server" , "1.0.0" ,
745- WithResourceCapabilities (true , true ),
746- WithPromptCapabilities (true ),
747- WithToolCapabilities (true ),
748- WithHooks (hooks ),
749- )
750-
751- // Add a test tool to enable tool capabilities
752- server .AddTool (mcp.Tool {
753- Name : "test-tool" ,
754- Description : "Test tool" ,
755- InputSchema : mcp.ToolInputSchema {
756- Type : "object" ,
757- Properties : map [string ]interface {}{},
758- },
759- }, func (ctx context.Context , request mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
760- return & mcp.CallToolResult {}, nil
761- })
762-
763- tests := []struct {
764- name string
765- message string
766- expectedErr int
767- validateErr func (t * testing.T , err error )
768- }{
769- {
770- name : "Undefined tool" ,
771- message : `{
772- "jsonrpc": "2.0",
773- "id": 1,
774- "method": "tools/call",
775- "params": {
776- "name": "undefined-tool",
777- "arguments": {}
778- }
779- }` ,
780- expectedErr : mcp .INVALID_PARAMS ,
781- validateErr : func (t * testing.T , err error ) {
782- assert .True (t , errors .Is (err , ErrToolNotFound ), "Error should be ErrToolNotFound but was %v" , err )
783- },
784- },
785- {
786- name : "Undefined prompt" ,
787- message : `{
788- "jsonrpc": "2.0",
789- "id": 1,
790- "method": "prompts/get",
791- "params": {
792- "name": "undefined-prompt",
793- "arguments": {}
794- }
795- }` ,
796- expectedErr : mcp .INVALID_PARAMS ,
797- validateErr : func (t * testing.T , err error ) {
798- assert .True (t , errors .Is (err , ErrPromptNotFound ), "Error should be ErrPromptNotFound but was %v" , err )
799- },
800- },
801- {
802- name : "Undefined resource" ,
803- message : `{
804- "jsonrpc": "2.0",
805- "id": 1,
806- "method": "resources/read",
807- "params": {
808- "uri": "undefined-resource"
809- }
810- }` ,
811- expectedErr : mcp .INVALID_PARAMS ,
812- validateErr : func (t * testing.T , err error ) {
813- assert .True (t , errors .Is (err , ErrResourceNotFound ), "Error should be ErrResourceNotFound but was %v" , err )
814- },
815- },
816- }
713+ errs = nil // Reset errors for each test case
817714
818- for _ , tt := range tests {
819- t .Run (tt .name , func (t * testing.T ) {
820- errChan = make (chan error , 1 )
821715 response := server .HandleMessage (
822716 context .Background (),
823717 []byte (tt .message ),
@@ -827,16 +721,10 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
827721 errorResponse , ok := response .(mcp.JSONRPCError )
828722 assert .True (t , ok )
829723 assert .Equal (t , tt .expectedErr , errorResponse .Error .Code )
830- ctx , cancel := context .WithTimeout (context .Background (), 50 * time .Millisecond )
831- defer cancel ()
832724
833725 if tt .validateErr != nil {
834- select {
835- case err := <- errChan :
836- tt .validateErr (t , err )
837- case <- ctx .Done ():
838- t .Errorf ("Error not received" )
839- }
726+ require .Len (t , errs , 1 , "Expected exactly one error" )
727+ tt .validateErr (t , errs [0 ])
840728 }
841729 })
842730 }
@@ -981,10 +869,10 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
981869}
982870
983871func TestMCPServer_HandleMethodsWithoutCapabilities (t * testing.T ) {
984- var errChan chan error
872+ var errs [] error
985873 hooks := & Hooks {}
986874 hooks .AddOnError (func (id any , method mcp.MCPMethod , message any , err error ) {
987- errChan <- err
875+ errs = append ( errs , err )
988876 })
989877 hooksOption := WithHooks (hooks )
990878
@@ -1041,7 +929,8 @@ func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) {
1041929
1042930 for _ , tt := range tests {
1043931 t .Run (tt .name , func (t * testing.T ) {
1044- errChan = make (chan error , 1 )
932+ errs = nil // Reset errors for each test case
933+
1045934 server := NewMCPServer ("test-server" , "1.0.0" , tt .options ... )
1046935 response := server .HandleMessage (
1047936 context .Background (),
@@ -1052,16 +941,10 @@ func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) {
1052941 errorResponse , ok := response .(mcp.JSONRPCError )
1053942 assert .True (t , ok )
1054943 assert .Equal (t , tt .expectedErr , errorResponse .Error .Code )
1055- ctx , cancel := context .WithTimeout (context .Background (), 50 * time .Millisecond )
1056- defer cancel ()
1057- var err error
1058- select {
1059- case err = <- errChan :
1060- case <- ctx .Done ():
1061- t .Errorf ("Error not received" )
1062- }
1063- assert .True (t , errors .Is (err , ErrUnsupported ), "Error should be ErrUnsupported but was %v" , err )
1064- assert .Contains (t , err .Error (), tt .errString )
944+
945+ require .Len (t , errs , 1 , "Expected exactly one error" )
946+ assert .True (t , errors .Is (errs [0 ], ErrUnsupported ), "Error should be ErrUnsupported but was %v" , errs [0 ])
947+ assert .Contains (t , errs [0 ].Error (), tt .errString )
1065948 })
1066949 }
1067950}
@@ -1290,29 +1173,57 @@ func TestMCPServer_WithHooks(t *testing.T) {
12901173 afterToolsCount int
12911174 )
12921175
1176+ // Collectors for message and result types
1177+ var beforeAnyMessages []any
1178+ var afterAnyData []struct {
1179+ msg any
1180+ res any
1181+ }
1182+ var beforePingMessages []* mcp.PingRequest
1183+ var afterPingData []struct {
1184+ msg * mcp.PingRequest
1185+ res * mcp.EmptyResult
1186+ }
1187+
12931188 // Initialize hook handlers
12941189 hooks := & Hooks {}
12951190
1296- // Register "any" hooks
1191+ // Register "any" hooks with type verification
12971192 hooks .AddBeforeAny (func (id any , method mcp.MCPMethod , message any ) {
12981193 beforeAnyCount ++
1194+ // Only collect ping messages for our test
1195+ if method == mcp .MethodPing {
1196+ beforeAnyMessages = append (beforeAnyMessages , message )
1197+ }
12991198 })
13001199
13011200 hooks .AddAfterAny (func (id any , method mcp.MCPMethod , message any , result any ) {
13021201 afterAnyCount ++
1202+ // Only collect ping responses for our test
1203+ if method == mcp .MethodPing {
1204+ afterAnyData = append (afterAnyData , struct {
1205+ msg any
1206+ res any
1207+ }{message , result })
1208+ }
13031209 })
13041210
13051211 hooks .AddOnError (func (id any , method mcp.MCPMethod , message any , err error ) {
13061212 onErrorCount ++
13071213 })
13081214
1309- // Register method-specific hooks
1215+ // Register method-specific hooks with type verification
13101216 hooks .AddBeforePing (func (id any , message * mcp.PingRequest ) {
13111217 beforePingCount ++
1218+ beforePingMessages = append (beforePingMessages , message )
13121219 })
13131220
13141221 hooks .AddAfterPing (func (id any , message * mcp.PingRequest , result * mcp.EmptyResult ) {
13151222 afterPingCount ++
1223+ afterPingData = append (afterPingData , struct {
1224+ msg * mcp.PingRequest
1225+ res * mcp.EmptyResult
1226+ }{message , result })
13161227 })
13171228
13181229 hooks .AddBeforeListTools (func (id any , message * mcp.ListToolsRequest ) {
@@ -1390,9 +1301,20 @@ func TestMCPServer_WithHooks(t *testing.T) {
13901301 // General hooks should be called for all methods
13911302 // beforeAny is called for all 4 methods (initialize, ping, tools/list, tools/call)
13921303 assert .Equal (t , 4 , beforeAnyCount , "beforeAny should be called for each method" )
1393- // afterAny is called only for successful methods (initialize, ping, tools/list)
1394- assert .Equal (t , 3 , afterAnyCount , "afterAny should be called for successful methods only " )
1304+ // afterAny is called for all 3 success methods (initialize, ping, tools/list) plus 1 error (tools/call )
1305+ assert .Equal (t , 4 , afterAnyCount , "afterAny should be called for all methods including errors " )
13951306
13961307 // Error hook should be called once for the failed tools/call
13971308 assert .Equal (t , 1 , onErrorCount , "onError should be called once" )
1309+
1310+ // Verify type matching between BeforeAny and BeforePing
1311+ require .Len (t , beforePingMessages , 1 , "Expected one BeforePing message" )
1312+ require .Len (t , beforeAnyMessages , 1 , "Expected one BeforeAny Ping message" )
1313+ assert .IsType (t , beforePingMessages [0 ], beforeAnyMessages [0 ], "BeforeAny message should be same type as BeforePing message" )
1314+
1315+ // Verify type matching between AfterAny and AfterPing
1316+ require .Len (t , afterPingData , 1 , "Expected one AfterPing message/result pair" )
1317+ require .Len (t , afterAnyData , 1 , "Expected one AfterAny Ping message/result pair" )
1318+ assert .IsType (t , afterPingData [0 ].msg , afterAnyData [0 ].msg , "AfterAny message should be same type as AfterPing message" )
1319+ assert .IsType (t , afterPingData [0 ].res , afterAnyData [0 ].res , "AfterAny result should be same type as AfterPing result" )
13981320}
0 commit comments