@@ -349,9 +349,11 @@ async def _perform_authorization(self) -> tuple[str, str]:
349349 # Wait for callback
350350 auth_code , returned_state = await self .context .callback_handler ()
351351
352+ # Validate state parameter for CSRF protection
352353 if returned_state is None or not secrets .compare_digest (returned_state , self .auth_state ):
353354 raise OAuthFlowError (f"State parameter mismatch: { returned_state } != { self .auth_state } " )
354355
356+ # Clear state after validation
355357 self .auth_state = None
356358
357359 if not auth_code :
@@ -487,6 +489,62 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non
487489 if self .context .client_metadata .scope is None and metadata .scopes_supported is not None :
488490 self .context .client_metadata .scope = " " .join (metadata .scopes_supported )
489491
492+ async def _ensure_token (self ) -> None :
493+ """Ensure valid access token, refreshing or re-authenticating as needed."""
494+ # Return early if token is valid
495+ if self .context .is_token_valid ():
496+ return
497+
498+ # Try refreshing existing token
499+ if self .context .can_refresh_token ():
500+ try :
501+ refresh_request = await self ._refresh_token ()
502+ async with httpx .AsyncClient (timeout = self .context .timeout ) as client :
503+ refresh_response = await client .send (refresh_request )
504+ if await self ._handle_refresh_response (refresh_response ):
505+ return # Refresh succeeded
506+ except Exception :
507+ logger .warning ("Token refresh failed, will perform full OAuth flow" )
508+
509+ # Fall back to full OAuth flow
510+ await self ._perform_oauth_flow ()
511+
512+ async def _perform_oauth_flow (self ) -> None :
513+ """Execute OAuth2 authorization code flow with PKCE."""
514+ logger .debug ("Starting OAuth authentication flow" )
515+
516+ async with httpx .AsyncClient (timeout = self .context .timeout ) as client :
517+ # Step 1: Discover protected resource metadata (if we have a 401 response)
518+ # Note: We can't access the 401 response here, so skip PRM discovery
519+
520+ # Step 2: Discover OAuth metadata
521+ discovery_urls = self ._get_discovery_urls ()
522+ for url in discovery_urls :
523+ oauth_metadata_request = self ._create_oauth_metadata_request (url )
524+ oauth_metadata_response = await client .send (oauth_metadata_request )
525+ if oauth_metadata_response .status_code == 200 :
526+ try :
527+ await self ._handle_oauth_metadata_response (oauth_metadata_response )
528+ break
529+ except ValidationError :
530+ continue
531+ elif oauth_metadata_response .status_code < 400 or oauth_metadata_response .status_code >= 500 :
532+ break
533+
534+ # Step 3: Register client if needed
535+ registration_request = await self ._register_client ()
536+ if registration_request :
537+ registration_response = await client .send (registration_request )
538+ await self ._handle_registration_response (registration_response )
539+
540+ # Step 4: Perform authorization
541+ auth_code , code_verifier = await self ._perform_authorization ()
542+
543+ # Step 5: Exchange authorization code for tokens
544+ token_request = await self ._exchange_token (auth_code , code_verifier )
545+ token_response = await client .send (token_request )
546+ await self ._handle_token_response (token_response )
547+
490548 async def async_auth_flow (self , request : httpx .Request ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
491549 """HTTPX auth flow integration."""
492550 async with self .context .lock :
@@ -496,61 +554,32 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
496554 # Capture protocol version from request headers
497555 self .context .protocol_version = request .headers .get (MCP_PROTOCOL_VERSION )
498556
499- if not self .context .is_token_valid () and self .context .can_refresh_token ():
500- # Try to refresh token
501- refresh_request = await self ._refresh_token ()
502- refresh_response = yield refresh_request
503-
504- if not await self ._handle_refresh_response (refresh_response ):
505- # Refresh failed, need full re-authentication
506- self ._initialized = False
507-
508- if self .context .is_token_valid ():
509- self ._add_auth_header (request )
557+ # Ensure we have a valid token
558+ await self ._ensure_token ()
559+
560+ # Add auth header
561+ self ._add_auth_header (request )
510562
563+ # Send the request
511564 response = yield request
512565
566+ # Handle 401 - clear tokens and retry once
513567 if response .status_code == 401 :
514- # Perform full OAuth flow
568+ logger .debug ("Got 401 response, clearing tokens and retrying" )
569+ self .context .clear_tokens ()
570+
571+ # Try to extract protected resource metadata from 401 response
515572 try :
516- # OAuth flow must be inline due to generator constraints
517- # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
518- discovery_request = await self ._discover_protected_resource (response )
519- discovery_response = yield discovery_request
520- await self ._handle_protected_resource_response (discovery_response )
521-
522- # Step 2: Discover OAuth metadata (with fallback for legacy servers)
523- discovery_urls = self ._get_discovery_urls ()
524- for url in discovery_urls :
525- oauth_metadata_request = self ._create_oauth_metadata_request (url )
526- oauth_metadata_response = yield oauth_metadata_request
527-
528- if oauth_metadata_response .status_code == 200 :
529- try :
530- await self ._handle_oauth_metadata_response (oauth_metadata_response )
531- break
532- except ValidationError :
533- continue
534- elif oauth_metadata_response .status_code < 400 or oauth_metadata_response .status_code >= 500 :
535- break # Non-4XX error, stop trying
536-
537- # Step 3: Register client if needed
538- registration_request = await self ._register_client ()
539- if registration_request :
540- registration_response = yield registration_request
541- await self ._handle_registration_response (registration_response )
542-
543- # Step 4: Perform authorization
544- auth_code , code_verifier = await self ._perform_authorization ()
545-
546- # Step 5: Exchange authorization code for tokens
547- token_request = await self ._exchange_token (auth_code , code_verifier )
548- token_response = yield token_request
549- await self ._handle_token_response (token_response )
573+ async with httpx .AsyncClient (timeout = self .context .timeout ) as client :
574+ discovery_request = await self ._discover_protected_resource (response )
575+ discovery_response = await client .send (discovery_request )
576+ await self ._handle_protected_resource_response (discovery_response )
550577 except Exception :
551- logger .exception ("OAuth flow error" )
552- raise
553-
554- # Retry with new tokens
555- self ._add_auth_header (request )
556- yield request
578+ logger .debug ("Failed to discover protected resource metadata" )
579+
580+ # Perform full OAuth flow
581+ await self ._perform_oauth_flow ()
582+
583+ # Retry request with new tokens
584+ self ._add_auth_header (request )
585+ response = yield request
0 commit comments