Skip to content

Commit 3529090

Browse files
revert yield pattern from modelcontextprotocol#982
1 parent c210339 commit 3529090

File tree

1 file changed

+81
-52
lines changed

1 file changed

+81
-52
lines changed

src/mcp/client/auth.py

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,11 @@ async def _perform_authorization(self) -> tuple[str, str]:
349349
# Wait for callback
350350
auth_code, returned_state = await self.context.callback_handler()
351351

352+
# Validate state parameter for CSRF protection
352353
if returned_state is None or not secrets.compare_digest(returned_state, self.auth_state):
353354
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {self.auth_state}")
354355

356+
# Clear state after validation
355357
self.auth_state = None
356358

357359
if not auth_code:
@@ -487,6 +489,62 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non
487489
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
488490
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
489491

492+
async def _ensure_token(self) -> None:
493+
"""Ensure valid access token, refreshing or re-authenticating as needed."""
494+
# Return early if token is valid
495+
if self.context.is_token_valid():
496+
return
497+
498+
# Try refreshing existing token
499+
if self.context.can_refresh_token():
500+
try:
501+
refresh_request = await self._refresh_token()
502+
async with httpx.AsyncClient(timeout=self.context.timeout) as client:
503+
refresh_response = await client.send(refresh_request)
504+
if await self._handle_refresh_response(refresh_response):
505+
return # Refresh succeeded
506+
except Exception:
507+
logger.warning("Token refresh failed, will perform full OAuth flow")
508+
509+
# Fall back to full OAuth flow
510+
await self._perform_oauth_flow()
511+
512+
async def _perform_oauth_flow(self) -> None:
513+
"""Execute OAuth2 authorization code flow with PKCE."""
514+
logger.debug("Starting OAuth authentication flow")
515+
516+
async with httpx.AsyncClient(timeout=self.context.timeout) as client:
517+
# Step 1: Discover protected resource metadata (if we have a 401 response)
518+
# Note: We can't access the 401 response here, so skip PRM discovery
519+
520+
# Step 2: Discover OAuth metadata
521+
discovery_urls = self._get_discovery_urls()
522+
for url in discovery_urls:
523+
oauth_metadata_request = self._create_oauth_metadata_request(url)
524+
oauth_metadata_response = await client.send(oauth_metadata_request)
525+
if oauth_metadata_response.status_code == 200:
526+
try:
527+
await self._handle_oauth_metadata_response(oauth_metadata_response)
528+
break
529+
except ValidationError:
530+
continue
531+
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
532+
break
533+
534+
# Step 3: Register client if needed
535+
registration_request = await self._register_client()
536+
if registration_request:
537+
registration_response = await client.send(registration_request)
538+
await self._handle_registration_response(registration_response)
539+
540+
# Step 4: Perform authorization
541+
auth_code, code_verifier = await self._perform_authorization()
542+
543+
# Step 5: Exchange authorization code for tokens
544+
token_request = await self._exchange_token(auth_code, code_verifier)
545+
token_response = await client.send(token_request)
546+
await self._handle_token_response(token_response)
547+
490548
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
491549
"""HTTPX auth flow integration."""
492550
async with self.context.lock:
@@ -496,61 +554,32 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
496554
# Capture protocol version from request headers
497555
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
498556

499-
if not self.context.is_token_valid() and self.context.can_refresh_token():
500-
# Try to refresh token
501-
refresh_request = await self._refresh_token()
502-
refresh_response = yield refresh_request
503-
504-
if not await self._handle_refresh_response(refresh_response):
505-
# Refresh failed, need full re-authentication
506-
self._initialized = False
507-
508-
if self.context.is_token_valid():
509-
self._add_auth_header(request)
557+
# Ensure we have a valid token
558+
await self._ensure_token()
559+
560+
# Add auth header
561+
self._add_auth_header(request)
510562

563+
# Send the request
511564
response = yield request
512565

566+
# Handle 401 - clear tokens and retry once
513567
if response.status_code == 401:
514-
# Perform full OAuth flow
568+
logger.debug("Got 401 response, clearing tokens and retrying")
569+
self.context.clear_tokens()
570+
571+
# Try to extract protected resource metadata from 401 response
515572
try:
516-
# OAuth flow must be inline due to generator constraints
517-
# Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
518-
discovery_request = await self._discover_protected_resource(response)
519-
discovery_response = yield discovery_request
520-
await self._handle_protected_resource_response(discovery_response)
521-
522-
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
523-
discovery_urls = self._get_discovery_urls()
524-
for url in discovery_urls:
525-
oauth_metadata_request = self._create_oauth_metadata_request(url)
526-
oauth_metadata_response = yield oauth_metadata_request
527-
528-
if oauth_metadata_response.status_code == 200:
529-
try:
530-
await self._handle_oauth_metadata_response(oauth_metadata_response)
531-
break
532-
except ValidationError:
533-
continue
534-
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
535-
break # Non-4XX error, stop trying
536-
537-
# Step 3: Register client if needed
538-
registration_request = await self._register_client()
539-
if registration_request:
540-
registration_response = yield registration_request
541-
await self._handle_registration_response(registration_response)
542-
543-
# Step 4: Perform authorization
544-
auth_code, code_verifier = await self._perform_authorization()
545-
546-
# Step 5: Exchange authorization code for tokens
547-
token_request = await self._exchange_token(auth_code, code_verifier)
548-
token_response = yield token_request
549-
await self._handle_token_response(token_response)
573+
async with httpx.AsyncClient(timeout=self.context.timeout) as client:
574+
discovery_request = await self._discover_protected_resource(response)
575+
discovery_response = await client.send(discovery_request)
576+
await self._handle_protected_resource_response(discovery_response)
550577
except Exception:
551-
logger.exception("OAuth flow error")
552-
raise
553-
554-
# Retry with new tokens
555-
self._add_auth_header(request)
556-
yield request
578+
logger.debug("Failed to discover protected resource metadata")
579+
580+
# Perform full OAuth flow
581+
await self._perform_oauth_flow()
582+
583+
# Retry request with new tokens
584+
self._add_auth_header(request)
585+
response = yield request

0 commit comments

Comments
 (0)