55 "bytes"
66 "context"
77 "encoding/json"
8+ "errors"
89 "fmt"
910 "io"
1011 "net/http"
@@ -36,6 +37,9 @@ type SSE struct {
3637 started atomic.Bool
3738 closed atomic.Bool
3839 cancelSSEStream context.CancelFunc
40+
41+ // OAuth support
42+ oauthHandler * OAuthHandler
3943}
4044
4145type ClientOption func (* SSE )
@@ -58,6 +62,12 @@ func WithHTTPClient(httpClient *http.Client) ClientOption {
5862 }
5963}
6064
65+ func WithOAuth (config OAuthConfig ) ClientOption {
66+ return func (sc * SSE ) {
67+ sc .oauthHandler = NewOAuthHandler (config )
68+ }
69+ }
70+
6171// NewSSE creates a new SSE-based MCP client with the given base URL.
6272// Returns an error if the URL is invalid.
6373func NewSSE (baseURL string , options ... ClientOption ) (* SSE , error ) {
@@ -78,6 +88,13 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
7888 opt (smc )
7989 }
8090
91+ // If OAuth is configured, set the base URL for metadata discovery
92+ if smc .oauthHandler != nil {
93+ // Extract base URL from server URL for metadata discovery
94+ baseURL := fmt .Sprintf ("%s://%s" , parsedURL .Scheme , parsedURL .Host )
95+ smc .oauthHandler .SetBaseURL (baseURL )
96+ }
97+
8198 return smc , nil
8299}
83100
@@ -112,13 +129,34 @@ func (c *SSE) Start(ctx context.Context) error {
112129 }
113130 }
114131
132+ // Add OAuth authorization if configured
133+ if c .oauthHandler != nil {
134+ authHeader , err := c .oauthHandler .GetAuthorizationHeader (ctx )
135+ if err != nil {
136+ // If we get an authorization error, return a specific error that can be handled by the client
137+ if err .Error () == "no valid token available, authorization required" {
138+ return & OAuthAuthorizationRequiredError {
139+ Handler : c .oauthHandler ,
140+ }
141+ }
142+ return fmt .Errorf ("failed to get authorization header: %w" , err )
143+ }
144+ req .Header .Set ("Authorization" , authHeader )
145+ }
146+
115147 resp , err := c .httpClient .Do (req )
116148 if err != nil {
117149 return fmt .Errorf ("failed to connect to SSE stream: %w" , err )
118150 }
119151
120152 if resp .StatusCode != http .StatusOK {
121153 resp .Body .Close ()
154+ // Handle OAuth unauthorized error
155+ if resp .StatusCode == http .StatusUnauthorized && c .oauthHandler != nil {
156+ return & OAuthAuthorizationRequiredError {
157+ Handler : c .oauthHandler ,
158+ }
159+ }
122160 return fmt .Errorf ("unexpected status code: %d" , resp .StatusCode )
123161 }
124162
@@ -281,6 +319,22 @@ func (c *SSE) SendRequest(
281319 for k , v := range c .headers {
282320 req .Header .Set (k , v )
283321 }
322+
323+ // Add OAuth authorization if configured
324+ if c .oauthHandler != nil {
325+ authHeader , err := c .oauthHandler .GetAuthorizationHeader (ctx )
326+ if err != nil {
327+ // If we get an authorization error, return a specific error that can be handled by the client
328+ if err .Error () == "no valid token available, authorization required" {
329+ return nil , & OAuthAuthorizationRequiredError {
330+ Handler : c .oauthHandler ,
331+ }
332+ }
333+ return nil , fmt .Errorf ("failed to get authorization header: %w" , err )
334+ }
335+ req .Header .Set ("Authorization" , authHeader )
336+ }
337+
284338 if c .headerFunc != nil {
285339 for k , v := range c .headerFunc (ctx ) {
286340 req .Header .Set (k , v )
@@ -320,6 +374,13 @@ func (c *SSE) SendRequest(
320374 if resp .StatusCode != http .StatusOK && resp .StatusCode != http .StatusAccepted {
321375 deleteResponseChan ()
322376
377+ // Handle OAuth unauthorized error
378+ if resp .StatusCode == http .StatusUnauthorized && c .oauthHandler != nil {
379+ return nil , & OAuthAuthorizationRequiredError {
380+ Handler : c .oauthHandler ,
381+ }
382+ }
383+
323384 return nil , fmt .Errorf ("request failed with status %d: %s" , resp .StatusCode , body )
324385 }
325386
@@ -385,6 +446,22 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
385446 for k , v := range c .headers {
386447 req .Header .Set (k , v )
387448 }
449+
450+ // Add OAuth authorization if configured
451+ if c .oauthHandler != nil {
452+ authHeader , err := c .oauthHandler .GetAuthorizationHeader (ctx )
453+ if err != nil {
454+ // If we get an authorization error, return a specific error that can be handled by the client
455+ if errors .Is (err , ErrOAuthAuthorizationRequired ) {
456+ return & OAuthAuthorizationRequiredError {
457+ Handler : c .oauthHandler ,
458+ }
459+ }
460+ return fmt .Errorf ("failed to get authorization header: %w" , err )
461+ }
462+ req .Header .Set ("Authorization" , authHeader )
463+ }
464+
388465 if c .headerFunc != nil {
389466 for k , v := range c .headerFunc (ctx ) {
390467 req .Header .Set (k , v )
@@ -398,6 +475,13 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
398475 defer resp .Body .Close ()
399476
400477 if resp .StatusCode != http .StatusOK && resp .StatusCode != http .StatusAccepted {
478+ // Handle OAuth unauthorized error
479+ if resp .StatusCode == http .StatusUnauthorized && c .oauthHandler != nil {
480+ return & OAuthAuthorizationRequiredError {
481+ Handler : c .oauthHandler ,
482+ }
483+ }
484+
401485 body , _ := io .ReadAll (resp .Body )
402486 return fmt .Errorf (
403487 "notification failed with status %d: %s" ,
@@ -418,3 +502,13 @@ func (c *SSE) GetEndpoint() *url.URL {
418502func (c * SSE ) GetBaseURL () * url.URL {
419503 return c .baseURL
420504}
505+
506+ // GetOAuthHandler returns the OAuth handler if configured
507+ func (c * SSE ) GetOAuthHandler () * OAuthHandler {
508+ return c .oauthHandler
509+ }
510+
511+ // IsOAuthEnabled returns true if OAuth is enabled
512+ func (c * SSE ) IsOAuthEnabled () bool {
513+ return c .oauthHandler != nil
514+ }
0 commit comments