@@ -588,38 +588,48 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
588588
589589 // Find the corresponding session and deliver the response
590590 // The response is delivered to the specific session identified by sessionID
591- s .deliverSamplingResponse (sessionID , response )
591+ if err := s .deliverSamplingResponse (sessionID , response ); err != nil {
592+ s .logger .Errorf ("Failed to deliver sampling response: %v" , err )
593+ http .Error (w , "Failed to deliver response" , http .StatusInternalServerError )
594+ return err
595+ }
592596
593597 // Acknowledge receipt
594598 w .WriteHeader (http .StatusOK )
595599 return nil
596600}
597601
598602// deliverSamplingResponse delivers a sampling response to the appropriate session
599- func (s * StreamableHTTPServer ) deliverSamplingResponse (sessionID string , response samplingResponseItem ) {
603+ func (s * StreamableHTTPServer ) deliverSamplingResponse (sessionID string , response samplingResponseItem ) error {
600604 // Look up the active session
601- if sessionInterface , ok := s .activeSessions .Load (sessionID ); ok {
602- if session , ok := sessionInterface .(* streamableHttpSession ); ok {
603- // Look up the dedicated response channel for this specific request
604- if responseChannelInterface , exists := session .samplingRequests .Load (response .requestID ); exists {
605- if responseChan , ok := responseChannelInterface .(chan samplingResponseItem ); ok {
606- select {
607- case responseChan <- response :
608- s .logger .Infof ("Delivered sampling response for session %s, request %d" , sessionID , response .requestID )
609- default :
610- s .logger .Errorf ("Failed to deliver sampling response for session %s, request %d: channel full" , sessionID , response .requestID )
611- }
612- } else {
613- s .logger .Errorf ("Invalid response channel type for session %s, request %d" , sessionID , response .requestID )
614- }
615- } else {
616- s .logger .Errorf ("No pending request found for session %s, request %d" , sessionID , response .requestID )
617- }
618- } else {
619- s .logger .Errorf ("Invalid session type for session %s" , sessionID )
620- }
621- } else {
622- s .logger .Errorf ("No active session found for session %s" , sessionID )
605+ sessionInterface , ok := s .activeSessions .Load (sessionID )
606+ if ! ok {
607+ return fmt .Errorf ("no active session found for session %s" , sessionID )
608+ }
609+
610+ session , ok := sessionInterface .(* streamableHttpSession )
611+ if ! ok {
612+ return fmt .Errorf ("invalid session type for session %s" , sessionID )
613+ }
614+
615+ // Look up the dedicated response channel for this specific request
616+ responseChannelInterface , exists := session .samplingRequests .Load (response .requestID )
617+ if ! exists {
618+ return fmt .Errorf ("no pending request found for session %s, request %d" , sessionID , response .requestID )
619+ }
620+
621+ responseChan , ok := responseChannelInterface .(chan samplingResponseItem )
622+ if ! ok {
623+ return fmt .Errorf ("invalid response channel type for session %s, request %d" , sessionID , response .requestID )
624+ }
625+
626+ // Attempt to deliver the response with timeout to prevent indefinite blocking
627+ select {
628+ case responseChan <- response :
629+ s .logger .Infof ("Delivered sampling response for session %s, request %d" , sessionID , response .requestID )
630+ return nil
631+ default :
632+ return fmt .Errorf ("failed to deliver sampling response for session %s, request %d: channel full or blocked" , sessionID , response .requestID )
623633 }
624634}
625635
0 commit comments