params = new HashMap<>();
+ request.getParameterMap().forEach((key, values) -> {
+ if (values.length > 0) {
+ params.put(key, values[0]);
+ }
+ });
+
+ try {
+ String redirectUrl = authorizationHandler.handle(params).join().getRedirectUrl();
+ response.setHeader("Location", redirectUrl);
+ response.setHeader("Cache-Control", "no-store");
+ response.setStatus(302); // Found
+ }
+ catch (CompletionException ex) {
+ response.setStatus(400);
+ }
+ }
+
+ // private void handleAuthorizeRequest(HttpServletRequest request, HttpServletResponse
+ // response) throws IOException {
+ // if ("GET".equalsIgnoreCase(request.getMethod())) {
+ // // Render consent page
+ // String clientId = request.getParameter("client_id");
+ // String redirectUri = request.getParameter("redirect_uri");
+ // String scope = request.getParameter("scope");
+ // String state = request.getParameter("state");
+ // String codeChallenge = request.getParameter("code_challenge");
+ // String codeChallengeMethod = request.getParameter("code_challenge_method");
+ // String responseType = request.getParameter("response_type");
+
+ // response.setContentType("text/html");
+ // response.getWriter()
+ // .write("" + "Authorize Access
" + "Client " + clientId
+ // + " is requesting access with scope: " + scope + "
"
+ // + "" + "");
+ // }
+ // else if ("POST".equalsIgnoreCase(request.getMethod())) {
+ // // Extract parameters from form
+ // Map params = new HashMap<>();
+ // request.getParameterMap().forEach((key, values) -> {
+ // if (values.length > 0) {
+ // params.put(key, values[0]);
+ // }
+ // });
+
+ // try {
+ // String redirectUrl = authorizationHandler.handle(params).join().getRedirectUrl();
+ // response.setHeader("Location", redirectUrl);
+ // response.setHeader("Cache-Control", "no-store");
+ // response.setStatus(302); // Found
+ // }
+ // catch (CompletionException ex) {
+ // response.setStatus(400);
+ // }
+ // }
+ // }
+
+ private void handleTokenRequest(HttpServletRequest request, HttpServletResponse response) throws IOException {
+ // Extract parameters
+ Map params = new HashMap<>();
+ request.getParameterMap().forEach((key, values) -> {
+ if (values.length > 0) {
+ params.put(key, values[0]);
+ }
+ });
+
+ System.out.println("TOKEN REQUEST PARAMS: " + params);
+
+ try {
+ io.modelcontextprotocol.auth.OAuthToken token = tokenHandler.handle(params).join();
+ response.setContentType("application/json");
+ response.setHeader("Cache-Control", "no-store");
+ response.setHeader("Pragma", "no-cache");
+ response.setStatus(200);
+ getObjectMapper().writeValue(response.getOutputStream(), token);
+ }
+ catch (CompletionException ex) {
+ response.setStatus(400);
+ }
+ }
+
+ private void handleRegisterRequest(HttpServletRequest request, HttpServletResponse response) throws IOException {
+ // Read request body
+ BufferedReader reader = request.getReader();
+ StringBuilder body = new StringBuilder();
+ String line;
+ while ((line = reader.readLine()) != null) {
+ body.append(line);
+ }
+
+ try {
+ OAuthClientMetadata clientMetadata = getObjectMapper().readValue(body.toString(),
+ OAuthClientMetadata.class);
+
+ Object clientInfo = registrationHandler.handle(clientMetadata).join();
+ response.setContentType("application/json");
+ response.setStatus(201); // Created
+ getObjectMapper().writeValue(response.getOutputStream(), clientInfo);
+ }
+ catch (CompletionException ex) {
+ response.setStatus(400);
+ }
+ }
+
+ private void handleRevokeRequest(HttpServletRequest request, HttpServletResponse response) throws IOException {
+ // Extract parameters
+ Map params = new HashMap<>();
+ request.getParameterMap().forEach((key, values) -> {
+ if (values.length > 0) {
+ params.put(key, values[0]);
+ }
+ });
+
+ try {
+ revocationHandler.handle(params).join();
+ response.setStatus(200);
+ }
+ catch (CompletionException ex) {
+ response.setStatus(400);
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java
index 86906d859..e7d84decc 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java
@@ -9,6 +9,9 @@
import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
+import io.modelcontextprotocol.server.auth.middleware.AuthContext;
+import io.modelcontextprotocol.server.auth.middleware.AuthContextProvider;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
@@ -242,7 +245,16 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti
return Mono.defer(() -> {
if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) {
this.state.lazySet(STATE_INITIALIZED);
- exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get()));
+
+ // Get auth context from transport if it supports it
+ AuthContext authContext = null;
+ if (transport instanceof AuthContextProvider authProvider) {
+ authContext = authProvider.getAuthContext();
+ }
+
+ exchangeSink.tryEmitValue(
+ new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get(), authContext));
+
return this.initNotificationHandler.handle();
}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthClientProviderTest.java b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthClientProviderTest.java
new file mode 100644
index 000000000..398cde5ca
--- /dev/null
+++ b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthClientProviderTest.java
@@ -0,0 +1,120 @@
+package io.modelcontextprotocol.auth;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.net.URI;
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import io.modelcontextprotocol.client.auth.AuthCallbackResult;
+import io.modelcontextprotocol.client.auth.OAuthClientProvider;
+import io.modelcontextprotocol.client.auth.TokenStorage;
+
+/**
+ * Tests for the OAuthClientProvider class.
+ */
+public class OAuthClientProviderTest {
+
+ private OAuthClientMetadata clientMetadata;
+
+ private TokenStorage mockStorage;
+
+ private Function> mockRedirectHandler;
+
+ private Function> mockCallbackHandler;
+
+ private OAuthClientProvider clientProvider;
+
+ private OAuthToken token;
+
+ private OAuthClientInformation clientInfo;
+
+ @SuppressWarnings("unchecked")
+ @BeforeEach
+ public void setup() throws Exception {
+ // Setup client metadata
+ clientMetadata = new OAuthClientMetadata();
+ clientMetadata.setRedirectUris(List.of(new URI("https://example.com/callback")));
+ clientMetadata.setScope("read write");
+
+ // Setup mock storage
+ mockStorage = mock(TokenStorage.class);
+
+ // Setup mock handlers
+ mockRedirectHandler = mock(Function.class);
+ mockCallbackHandler = mock(Function.class);
+
+ // Setup token and client info
+ token = new OAuthToken();
+ token.setAccessToken("test-access-token");
+ token.setRefreshToken("test-refresh-token");
+ token.setExpiresIn(3600);
+ token.setScope("read write");
+
+ clientInfo = new OAuthClientInformation();
+ clientInfo.setClientId("test-client-id");
+ clientInfo.setClientSecret("test-client-secret");
+ clientInfo.setRedirectUris(List.of(new URI("https://example.com/callback")));
+ clientInfo.setScope("read write");
+
+ // Configure mocks
+ when(mockStorage.getTokens()).thenReturn(CompletableFuture.completedFuture(token));
+ when(mockStorage.getClientInfo()).thenReturn(CompletableFuture.completedFuture(clientInfo));
+ when(mockStorage.setTokens(any())).thenReturn(CompletableFuture.completedFuture(null));
+ when(mockStorage.setClientInfo(any())).thenReturn(CompletableFuture.completedFuture(null));
+
+ when(mockRedirectHandler.apply(anyString())).thenReturn(CompletableFuture.completedFuture(null));
+
+ AuthCallbackResult callbackResult = new AuthCallbackResult("test-auth-code", "test-state");
+ when(mockCallbackHandler.apply(any())).thenReturn(CompletableFuture.completedFuture(callbackResult));
+
+ // Create client provider
+ clientProvider = new OAuthClientProvider("https://auth.example.com", clientMetadata, mockStorage,
+ mockRedirectHandler, mockCallbackHandler, Duration.ofSeconds(30));
+ }
+
+ @Test
+ public void testInitialize() throws Exception {
+ // Test initialization
+ CompletableFuture initFuture = clientProvider.initialize();
+ initFuture.get();
+
+ // Test access token retrieval
+ String accessToken = clientProvider.getAccessToken();
+ assertNotNull(accessToken);
+ assertEquals("test-access-token", accessToken);
+
+ // Test token retrieval
+ OAuthToken retrievedToken = clientProvider.getCurrentTokens();
+ assertNotNull(retrievedToken);
+ assertEquals(token.getAccessToken(), retrievedToken.getAccessToken());
+ assertEquals(token.getRefreshToken(), retrievedToken.getRefreshToken());
+ }
+
+ @Test
+ public void testEnsureToken() throws Exception {
+ // Initialize first
+ clientProvider.initialize().get();
+
+ // Test token validation
+ CompletableFuture tokenFuture = clientProvider.ensureToken();
+ tokenFuture.get();
+
+ // Token should be valid and accessible
+ String accessToken = clientProvider.getAccessToken();
+ assertNotNull(accessToken);
+ assertEquals("test-access-token", accessToken);
+ }
+
+}
\ No newline at end of file
diff --git a/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthFlowTest.java b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthFlowTest.java
new file mode 100644
index 000000000..e8de4e9f7
--- /dev/null
+++ b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthFlowTest.java
@@ -0,0 +1,113 @@
+package io.modelcontextprotocol.auth;
+
+import java.net.URI;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for the OAuth authentication flow.
+ */
+public class OAuthFlowTest {
+
+ private OAuthAuthorizationServerProvider mockProvider;
+
+ private OAuthClientInformation clientInfo;
+
+ private AuthorizationCode authCode;
+
+ private OAuthToken token;
+
+ @BeforeEach
+ public void setup() throws Exception {
+ // Setup mock provider
+ mockProvider = mock(OAuthAuthorizationServerProvider.class);
+
+ // Setup test client
+ clientInfo = new OAuthClientInformation();
+ clientInfo.setClientId("test-client-id");
+ clientInfo.setClientSecret("test-client-secret");
+ clientInfo.setRedirectUris(List.of(new URI("https://example.com/callback")));
+ clientInfo.setScope("read write");
+
+ // Setup test auth code
+ authCode = new AuthorizationCode();
+ authCode.setCode("test-auth-code");
+ authCode.setClientId(clientInfo.getClientId());
+ authCode.setScopes(Arrays.asList("read", "write"));
+ authCode.setExpiresAt(Instant.now().plusSeconds(600).getEpochSecond());
+ authCode.setCodeChallenge("test-code-challenge");
+ authCode.setRedirectUri(clientInfo.getRedirectUris().get(0));
+ authCode.setRedirectUriProvidedExplicitly(true);
+
+ // Setup test token
+ token = new OAuthToken();
+ token.setAccessToken("test-access-token");
+ token.setRefreshToken("test-refresh-token");
+ token.setExpiresIn(3600);
+ token.setScope("read write");
+
+ // Configure mock provider
+ when(mockProvider.getClient(clientInfo.getClientId()))
+ .thenReturn(CompletableFuture.completedFuture(clientInfo));
+
+ when(mockProvider.authorize(any(), any()))
+ .thenReturn(CompletableFuture.completedFuture("https://example.com/auth?code=test-auth-code"));
+
+ when(mockProvider.loadAuthorizationCode(any(), any())).thenReturn(CompletableFuture.completedFuture(authCode));
+
+ when(mockProvider.exchangeAuthorizationCode(any(), any())).thenReturn(CompletableFuture.completedFuture(token));
+ }
+
+ @Test
+ public void testAuthorizationCodeFlow() throws Exception {
+ // Test client lookup
+ CompletableFuture clientFuture = mockProvider.getClient(clientInfo.getClientId());
+ OAuthClientInformation retrievedClient = clientFuture.get();
+
+ assertNotNull(retrievedClient);
+ assertEquals(clientInfo.getClientId(), retrievedClient.getClientId());
+
+ // Test authorization
+ AuthorizationParams params = new AuthorizationParams();
+ params.setState(UUID.randomUUID().toString());
+ params.setScopes(Arrays.asList("read", "write"));
+ params.setCodeChallenge("test-code-challenge");
+ params.setRedirectUri(clientInfo.getRedirectUris().get(0));
+ params.setRedirectUriProvidedExplicitly(true);
+
+ CompletableFuture authUrlFuture = mockProvider.authorize(clientInfo, params);
+ String authUrl = authUrlFuture.get();
+
+ assertNotNull(authUrl);
+ assertTrue(authUrl.startsWith("https://example.com/auth?code="));
+
+ // Test code exchange
+ CompletableFuture codeFuture = mockProvider.loadAuthorizationCode(clientInfo,
+ "test-auth-code");
+ AuthorizationCode retrievedCode = codeFuture.get();
+
+ assertNotNull(retrievedCode);
+ assertEquals(authCode.getCode(), retrievedCode.getCode());
+
+ CompletableFuture tokenFuture = mockProvider.exchangeAuthorizationCode(clientInfo, retrievedCode);
+ OAuthToken retrievedToken = tokenFuture.get();
+
+ assertNotNull(retrievedToken);
+ assertEquals(token.getAccessToken(), retrievedToken.getAccessToken());
+ assertEquals(token.getRefreshToken(), retrievedToken.getRefreshToken());
+ assertEquals(token.getExpiresIn(), retrievedToken.getExpiresIn());
+ }
+
+}
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index c2327ee8d..42d1d906d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -104,6 +104,7 @@
mcp-spring/mcp-spring-webflux
mcp-spring/mcp-spring-webmvc
mcp-test
+ examples/auth-example
@@ -368,4 +369,4 @@
-
+
\ No newline at end of file