diff --git a/examples/auth-example/README.md b/examples/auth-example/README.md new file mode 100644 index 000000000..a85f37379 --- /dev/null +++ b/examples/auth-example/README.md @@ -0,0 +1,70 @@ +# MCP Authentication Example + +This example demonstrates how to implement OAuth 2.0 authentication with the Model Context Protocol (MCP) Java SDK. + +## Overview + +The example consists of: + +1. A simple MCP server with OAuth authentication +2. A simple MCP client that authenticates using OAuth +3. A tool that requires authentication to access + +## Running the Example + +### 1. Build the Project + +```bash +cd examples/auth-example +mvn clean package +``` + +### 2. Run the Server + +In one terminal window: + +```bash +cd examples/auth-example +mvn exec:java -Dexec.mainClass="io.modelcontextprotocol.examples.auth.server.SimpleAuthServer" +``` + +### 3. Run the Client + +In another terminal window: + +```bash +cd examples/auth-example +mvn exec:java -Dexec.mainClass="io.modelcontextprotocol.examples.auth.client.SimpleAuthClient" +``` + +## Using the Client + +Once the client is running, you can use these commands: +- `list` - List available tools +- `call get_user_profile` - Call the user profile tool +- `quit` - Exit the client + +## Authentication Flow + +1. Client initiates the OAuth flow +2. Server redirects to the authorization page +3. User approves the authorization +4. Server redirects back to the client with an authorization code +5. Client exchanges the code for access and refresh tokens +6. Client uses the access token for authenticated MCP requests + +## Implementation Details + +### Server + +The server implements the `OAuthAuthorizationServerProvider` interface to provide OAuth authentication. It uses in-memory storage for clients, tokens, and authorization codes. + +### Client + +The client uses the `OAuthClientProvider` class to handle OAuth authentication. It opens a browser for the authorization flow and starts a local server to receive the OAuth callback. + +## Code Structure + +- `SimpleAuthServer.java` - Server implementation +- `SimpleAuthClient.java` - Client implementation +- `Constants.java` - Shared constants \ No newline at end of file diff --git a/examples/auth-example/pom.xml b/examples/auth-example/pom.xml new file mode 100644 index 000000000..3561415e5 --- /dev/null +++ b/examples/auth-example/pom.xml @@ -0,0 +1,58 @@ + + + 4.0.0 + + + io.modelcontextprotocol.sdk + mcp-parent + 0.11.0-SNAPSHOT + ../../pom.xml + + + auth-example + MCP Auth Example + Example of MCP authentication using OAuth 2.0 + + + + + io.modelcontextprotocol.sdk + mcp + ${project.version} + + + + + org.springframework.boot + spring-boot-starter-web + 3.2.3 + + + + + ch.qos.logback + logback-classic + ${logback.version} + + + + + + + org.codehaus.mojo + exec-maven-plugin + 3.1.0 + + false + + + + org.springframework.boot + spring-boot-maven-plugin + 3.2.3 + + + + \ No newline at end of file diff --git a/examples/auth-example/run-client.sh b/examples/auth-example/run-client.sh new file mode 100755 index 000000000..c0d4322b9 --- /dev/null +++ b/examples/auth-example/run-client.sh @@ -0,0 +1,3 @@ +#!/bin/bash +cd "$(dirname "$0")" +mvn exec:java -Dexec.mainClass="io.modelcontextprotocol.examples.auth.client.SimpleAuthClient" \ No newline at end of file diff --git a/examples/auth-example/run-server.sh b/examples/auth-example/run-server.sh new file mode 100755 index 000000000..3f246b517 --- /dev/null +++ b/examples/auth-example/run-server.sh @@ -0,0 +1,3 @@ +#!/bin/bash +cd "$(dirname "$0")" +mvn exec:java -Dexec.mainClass="io.modelcontextprotocol.examples.auth.server.SimpleAuthServer" \ No newline at end of file diff --git a/examples/auth-example/src/main/java/io/modelcontextprotocol/examples/auth/client/SimpleAuthClient.java b/examples/auth-example/src/main/java/io/modelcontextprotocol/examples/auth/client/SimpleAuthClient.java new file mode 100644 index 000000000..848641022 --- /dev/null +++ b/examples/auth-example/src/main/java/io/modelcontextprotocol/examples/auth/client/SimpleAuthClient.java @@ -0,0 +1,320 @@ +package io.modelcontextprotocol.examples.auth.client; + +import java.awt.Desktop; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.URI; +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthClientMetadata; +import io.modelcontextprotocol.auth.OAuthToken; +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClientFactory; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.auth.AuthCallbackResult; +import io.modelcontextprotocol.client.auth.OAuthClientProvider; +import io.modelcontextprotocol.client.auth.TokenStorage; +import io.modelcontextprotocol.examples.auth.shared.Constants; +import io.modelcontextprotocol.spec.McpSchema; + +/** + * Simple MCP client with OAuth authentication. + */ +public class SimpleAuthClient { + + private final McpAsyncClient client; + + private static final int CALLBACK_PORT = 3000; + + private static final Logger logger = LoggerFactory.getLogger(SimpleAuthClient.class); + + /** + * Creates a new SimpleAuthClient. + */ + public SimpleAuthClient() throws Exception { + // Create a simple in-memory token storage + TokenStorage tokenStorage = new InMemoryTokenStorage(); + + System.out.println("šŸ”‘ Initializing OAuth client..."); + + // Create client metadata + OAuthClientMetadata clientMetadata = new OAuthClientMetadata(); + clientMetadata.setRedirectUris(Collections.singletonList(new URI(Constants.REDIRECT_URI))); + clientMetadata.setScope(Constants.SCOPE); + + // Create redirect handler that opens a browser + Function> redirectHandler = url -> { + CompletableFuture future = new CompletableFuture<>(); + try { + System.out.println("Opening browser to: " + url); + Desktop.getDesktop().browse(new URI(url)); + future.complete(null); + } + catch (Exception e) { + System.out.println("Failed to open browser. Please navigate to: " + url); + future.complete(null); + } + return future; + }; + + // Create callback handler with a simple HTTP server + Function> callbackHandler = v -> { + CompletableFuture future = new CompletableFuture<>(); + + new Thread(() -> { + try (ServerSocket serverSocket = new ServerSocket(CALLBACK_PORT)) { + System.out.println("Waiting for callback on port " + CALLBACK_PORT + "..."); + Socket clientSocket = serverSocket.accept(); + + // Read the request + BufferedReader reader = new BufferedReader(new InputStreamReader(clientSocket.getInputStream())); + String line = reader.readLine(); + + // Parse the request line + String[] parts = line.split(" "); + String path = parts[1]; + + // Extract code and state from query parameters + String query = path.substring(path.indexOf('?') + 1); + String[] params = query.split("&"); + String code = null; + String state = null; + + for (String param : params) { + String[] keyValue = param.split("="); + if (keyValue.length == 2) { + if ("code".equals(keyValue[0])) { + code = keyValue[1]; + } + else if ("state".equals(keyValue[0])) { + state = keyValue[1]; + } + } + } + + // Send a simple response + String response = "HTTP/1.1 200 OK\r\n" + "Content-Type: text/html\r\n\r\n" + + "

Authorization successful!

" + "

You can close this window now.

" + + "" + ""; + + OutputStream output = clientSocket.getOutputStream(); + output.write(response.getBytes()); + output.flush(); + + // Complete the future with the result + future.complete(new AuthCallbackResult(code, state)); + + } + catch (IOException e) { + future.completeExceptionally(e); + } + }).start(); + + return future; + }; + + // Create the OAuth client provider + OAuthClientProvider authProvider = new OAuthClientProvider(Constants.SERVER_URL, clientMetadata, tokenStorage, + redirectHandler, callbackHandler, Duration.ofSeconds(60)); + + // Initialize the auth provider + System.out.println("Initializing auth provider..."); + authProvider.initialize().get(); + tokenStorage.setTokens(null).get(); + authProvider.ensureToken().get(); + System.out.println("Auth provider initialized, access token: " + authProvider.getAccessToken()); + + try { + System.out.println("Creating authenticated client..."); + client = McpClientFactory.createAuthenticatedClient(Constants.SERVER_URL, authProvider); + + System.out.println("Initializing sync client..."); + try { + client.initialize().block(); + System.out.println("Client initialized successfully!"); + logger.info("OAuth client initialized successfully!"); + + // Verify initialization + System.out.println("Testing client initialization..."); + boolean isInitialized = client.isInitialized(); + System.out.println("Client is initialized: " + isInitialized); + } + catch (Exception e) { + System.err.println("Error initializing client: " + e.getMessage()); + e.printStackTrace(); + throw e; + } + } + catch (Exception e) { + System.err.println("Failed to initialize client: " + e.getMessage()); + e.printStackTrace(); + throw e; + } + } + + /** + * Lists available tools from the server. + */ + public List listTools() { + try { + System.out.println("Listing tools..."); + McpSchema.ListToolsResult result = client.listTools().block(); + if (result == null) { + System.err.println("Error: ListToolsResult is null"); + return Collections.emptyList(); + } + return result.tools(); + } + catch (Exception e) { + System.err.println("Error listing tools: " + e.getMessage()); + e.printStackTrace(); + return Collections.emptyList(); + } + } + + /** + * Calls a tool on the server. + */ + public McpSchema.CallToolResult callTool(String toolName, Map arguments) { + return client.callTool(new McpSchema.CallToolRequest(toolName, arguments)).block(); + } + + /** + * Runs an interactive command loop. + */ + public void runInteractiveLoop() throws Exception { + System.out.println("\nšŸŽÆ Interactive MCP Client"); + System.out.println("Commands:"); + System.out.println(" list - List available tools"); + System.out.println(" call - Call a tool"); + System.out.println(" quit - Exit the client"); + System.out.println(); + + BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); + + while (true) { + System.out.print("mcp> "); + // For testing, use a hardcoded command + String command = reader.readLine().trim(); + + if (command.isEmpty()) { + continue; + } + + if (command.equals("quit")) { + break; + } + else if (command.equals("list")) { + List tools = listTools(); + System.out.println("\nšŸ“‹ Available tools:"); + for (int i = 0; i < tools.size(); i++) { + McpSchema.Tool tool = tools.get(i); + System.out.println((i + 1) + ". " + tool.name()); + if (tool.description() != null) { + System.out.println(" Description: " + tool.description()); + } + System.out.println(); + } + } + else if (command.startsWith("call ")) { + String[] parts = command.split(" ", 2); + String toolName = parts.length > 1 ? parts[1] : ""; + + if (toolName.isEmpty()) { + System.out.println("āŒ Please specify a tool name"); + continue; + } + + Map arguments = new HashMap<>(); + + // Prompt for arguments if tool is fetch_url + if (toolName.equals("fetch_url")) { + System.out.print("Enter URL to fetch: "); + String url = reader.readLine().trim(); + arguments.put("url", url); + } + + McpSchema.CallToolResult result = callTool(toolName, arguments); + System.out.println("\nšŸ”§ Tool '" + toolName + "' result:"); + for (McpSchema.Content content : result.content()) { + if (content instanceof McpSchema.TextContent) { + System.out.println(((McpSchema.TextContent) content).text()); + } + else { + System.out.println(content); + } + } + } + else { + System.out.println("āŒ Unknown command. Try 'list', 'call ', or 'quit'"); + } + } + + System.out.println("\nšŸ‘‹ Goodbye!"); + } + + /** + * Simple in-memory token storage implementation. + */ + private static class InMemoryTokenStorage implements TokenStorage { + + private OAuthToken tokens; + + private OAuthClientInformation clientInfo; + + @Override + public CompletableFuture getTokens() { + return CompletableFuture.completedFuture(tokens); + } + + @Override + public CompletableFuture setTokens(OAuthToken tokens) { + this.tokens = tokens; + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture getClientInfo() { + return CompletableFuture.completedFuture(clientInfo); + } + + @Override + public CompletableFuture setClientInfo(OAuthClientInformation clientInfo) { + this.clientInfo = clientInfo; + return CompletableFuture.completedFuture(null); + } + + } + + /** + * Main method to start the client. + */ + public static void main(String[] args) { + try { + System.out.println("šŸš€ Simple MCP Auth Client"); + System.out.println("Connecting to: " + Constants.SERVER_URL); + + SimpleAuthClient client = new SimpleAuthClient(); + client.runInteractiveLoop(); + } + catch (Exception e) { + System.err.println("āŒ Error: " + e.getMessage()); + e.printStackTrace(); + } + } + +} \ No newline at end of file diff --git a/examples/auth-example/src/main/java/io/modelcontextprotocol/examples/auth/server/SimpleAuthServer.java b/examples/auth-example/src/main/java/io/modelcontextprotocol/examples/auth/server/SimpleAuthServer.java new file mode 100644 index 000000000..a69237dee --- /dev/null +++ b/examples/auth-example/src/main/java/io/modelcontextprotocol/examples/auth/server/SimpleAuthServer.java @@ -0,0 +1,346 @@ +package io.modelcontextprotocol.examples.auth.server; + +import java.net.URI; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.web.client.RestClient; + +import io.modelcontextprotocol.auth.AccessToken; +import io.modelcontextprotocol.auth.AuthorizationCode; +import io.modelcontextprotocol.auth.AuthorizationParams; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthToken; +import io.modelcontextprotocol.auth.RefreshToken; +import io.modelcontextprotocol.auth.exception.AuthorizeException; +import io.modelcontextprotocol.auth.exception.RegistrationException; +import io.modelcontextprotocol.auth.exception.TokenException; +import io.modelcontextprotocol.examples.auth.shared.Constants; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.auth.settings.ClientRegistrationOptions; +import io.modelcontextprotocol.server.auth.settings.RevocationOptions; +import io.modelcontextprotocol.server.transport.OAuthHttpServletSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import jakarta.servlet.Servlet; + +/** + * Simple MCP server with OAuth authentication. + */ +@SpringBootApplication +public class SimpleAuthServer { + + private static final Logger logger = LoggerFactory.getLogger(SimpleAuthServer.class); + + private static void startTomcat(Servlet transportProvider) { + var tomcat = new Tomcat(); + tomcat.setPort(9200); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + Context context = tomcat.addContext("", baseDir); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(transportProvider); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + + try { + tomcat.start(); + assert tomcat.getServer().getState().equals(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + /** + * Simple in-memory auth provider implementation. + */ + private static class SimpleAuthProvider implements OAuthAuthorizationServerProvider { + + private final Map clients = new ConcurrentHashMap<>(); + + private final Map authCodes = new ConcurrentHashMap<>(); + + private final Map refreshTokens = new ConcurrentHashMap<>(); + + private final Map accessTokens = new ConcurrentHashMap<>(); + + private final Map stateMapping = new ConcurrentHashMap<>(); + + @Override + public CompletableFuture getClient(String clientId) { + return CompletableFuture.completedFuture(clients.get(clientId)); + } + + @Override + public CompletableFuture registerClient(OAuthClientInformation clientInfo) throws RegistrationException { + clients.put(clientInfo.getClientId(), clientInfo); + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture authorize(OAuthClientInformation client, AuthorizationParams params) + throws AuthorizeException { + // Generate a random authorization code + String code = UUID.randomUUID().toString(); + + // Store state mapping + if (params.getState() != null) { + stateMapping.put(params.getState(), client.getClientId()); + } + + // Create and store the authorization code + AuthorizationCode authCode = new AuthorizationCode(); + authCode.setCode(code); + authCode.setClientId(client.getClientId()); + authCode.setScopes(params.getScopes()); + authCode.setExpiresAt(Instant.now().plusSeconds(600).getEpochSecond()); + authCode.setCodeChallenge(params.getCodeChallenge()); + authCode.setRedirectUri(params.getRedirectUri()); + authCode.setRedirectUriProvidedExplicitly(params.isRedirectUriProvidedExplicitly()); + + authCodes.put(code, authCode); + + // Build the redirect URL with the code + String redirectUri = params.getRedirectUri().toString(); + String state = params.getState(); + String url = redirectUri + "?code=" + code; + if (state != null) { + url += "&state=" + state; + } + + return CompletableFuture.completedFuture(url); + } + + @Override + public CompletableFuture loadAuthorizationCode(OAuthClientInformation client, + String authorizationCode) { + return CompletableFuture.completedFuture(authCodes.get(authorizationCode)); + } + + @Override + public CompletableFuture exchangeAuthorizationCode(OAuthClientInformation client, + AuthorizationCode authorizationCode) throws TokenException { + // Remove the used authorization code + authCodes.remove(authorizationCode.getCode()); + + // Generate tokens + String accessTokenValue = UUID.randomUUID().toString(); + String refreshTokenValue = UUID.randomUUID().toString(); + + // Create access token + AccessToken accessToken = new AccessToken(); + accessToken.setToken(accessTokenValue); + accessToken.setClientId(client.getClientId()); + accessToken.setScopes(authorizationCode.getScopes()); + accessToken.setExpiresAt((int) Instant.now().plusSeconds(3600).getEpochSecond()); + + // Create refresh token + RefreshToken refreshToken = new RefreshToken(); + refreshToken.setToken(refreshTokenValue); + refreshToken.setClientId(client.getClientId()); + refreshToken.setScopes(authorizationCode.getScopes()); + refreshToken.setExpiresAt((int) Instant.now().plusSeconds(86400).getEpochSecond()); + + // Store tokens + accessTokens.put(accessTokenValue, accessToken); + refreshTokens.put(refreshTokenValue, refreshToken); + + // Create OAuth token response + OAuthToken token = new OAuthToken(); + token.setAccessToken(accessTokenValue); + token.setRefreshToken(refreshTokenValue); + token.setExpiresIn(3600); + token.setScope(String.join(" ", authorizationCode.getScopes())); + + return CompletableFuture.completedFuture(token); + } + + @Override + public CompletableFuture loadRefreshToken(OAuthClientInformation client, String refreshToken) { + return CompletableFuture.completedFuture(refreshTokens.get(refreshToken)); + } + + @Override + public CompletableFuture exchangeRefreshToken(OAuthClientInformation client, + RefreshToken refreshToken, List scopes) throws TokenException { + // Remove the used refresh token + refreshTokens.remove(refreshToken.getToken()); + + // Generate new tokens + String accessTokenValue = UUID.randomUUID().toString(); + String refreshTokenValue = UUID.randomUUID().toString(); + + // Create access token + AccessToken accessToken = new AccessToken(); + accessToken.setToken(accessTokenValue); + accessToken.setClientId(client.getClientId()); + accessToken.setScopes(scopes); + accessToken.setExpiresAt((int) Instant.now().plusSeconds(3600).getEpochSecond()); + + // Create refresh token + RefreshToken newRefreshToken = new RefreshToken(); + newRefreshToken.setToken(refreshTokenValue); + newRefreshToken.setClientId(client.getClientId()); + newRefreshToken.setScopes(scopes); + newRefreshToken.setExpiresAt((int) Instant.now().plusSeconds(86400).getEpochSecond()); + + // Store tokens + accessTokens.put(accessTokenValue, accessToken); + refreshTokens.put(refreshTokenValue, newRefreshToken); + + // Create OAuth token response + OAuthToken token = new OAuthToken(); + token.setAccessToken(accessTokenValue); + token.setRefreshToken(refreshTokenValue); + token.setExpiresIn(3600); + token.setScope(String.join(" ", scopes)); + + return CompletableFuture.completedFuture(token); + } + + @Override + public CompletableFuture loadAccessToken(String token) { + return CompletableFuture.completedFuture(accessTokens.get(token)); + } + + @Override + public CompletableFuture revokeToken(Object token) { + if (token instanceof AccessToken) { + accessTokens.remove(((AccessToken) token).getToken()); + } + else if (token instanceof RefreshToken) { + refreshTokens.remove(((RefreshToken) token).getToken()); + } + return CompletableFuture.completedFuture(null); + } + + } + + /** + * Main method to start the server. + */ + public static void main(String[] args) throws Exception { + + // Create a simple auth provider + SimpleAuthProvider authProvider = new SimpleAuthProvider(); + + // Register a default client + OAuthClientInformation clientInfo = new OAuthClientInformation(); + clientInfo.setClientId(Constants.CLIENT_ID); + clientInfo.setClientSecret(Constants.CLIENT_SECRET); + clientInfo.setRedirectUris(Collections.singletonList(new URI(Constants.REDIRECT_URI))); + clientInfo.setTokenEndpointAuthMethod("client_secret_post"); + clientInfo.setGrantTypes(Arrays.asList("authorization_code", "refresh_token")); + clientInfo.setResponseTypes(Collections.singletonList("code")); + clientInfo.setScope(Constants.SCOPE); + + authProvider.registerClient(clientInfo).get(); + + // Create registration options + ClientRegistrationOptions registrationOptions = new ClientRegistrationOptions(); + registrationOptions.setAllowLocalhostRedirect(true); + registrationOptions.setValidScopes(Arrays.asList("read", "write")); + + // Create revocation options + RevocationOptions revocationOptions = new RevocationOptions(); + revocationOptions.setEnabled(true); + + // Create and configure the MCP server using the builder + com.fasterxml.jackson.databind.ObjectMapper objectMapper = new com.fasterxml.jackson.databind.ObjectMapper(); + + // Use the OAuth-enabled transport provider that handles both MCP and OAuth + // routes + OAuthHttpServletSseServerTransportProvider transportProvider = new OAuthHttpServletSseServerTransportProvider( + objectMapper, "/mcp/message", Constants.SERVER_URL, authProvider, new URI(Constants.SERVER_URL), + registrationOptions, revocationOptions); + + startTomcat(transportProvider); + + // String emptyJsonSchema = """ + // { + // "$schema": "http://json-schema.org/draft-07/schema#", + // "type": "object", + // "properties": {} + // } + // """; + // var callResponse = new McpSchema.CallToolResult(List.of(new + // McpSchema.TextContent("CALL RESPONSE")), null); + // McpServerFeatures.SyncToolSpecification tool1 = new + // McpServerFeatures.SyncToolSpecification( + // new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, + // request) -> { + // // perform a blocking call to a remote service + // String response = RestClient.create() + // .get() + // .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + // .retrieve() + // .body(String.class); + // return callResponse; + // }); + + String fetchUrlSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "description": "The URL to fetch" + } + }, + "required": ["url"] + } + """; + + McpServerFeatures.SyncToolSpecification fetchUrlTool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("fetch_url", "Fetches the content of a given URL", fetchUrlSchema), + (exchange, request) -> { + String url = (String) request.get("url"); + try { + String content = RestClient.create().get().uri(url).retrieve().body(String.class); + // Return only the first 500 characters for brevity + String snippet = content.length() > 500 ? content.substring(0, 500) + "..." : content; + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Fetched content:\n" + snippet)), null); + } + catch (Exception e) { + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Error fetching URL: " + e.getMessage())), null); + } + }); + + McpServer.sync(transportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(fetchUrlTool) + .build(); + + logger.info("MCP server is now ready"); + } + +} \ No newline at end of file diff --git a/examples/auth-example/src/main/java/io/modelcontextprotocol/examples/auth/shared/Constants.java b/examples/auth-example/src/main/java/io/modelcontextprotocol/examples/auth/shared/Constants.java new file mode 100644 index 000000000..fcf564a5d --- /dev/null +++ b/examples/auth-example/src/main/java/io/modelcontextprotocol/examples/auth/shared/Constants.java @@ -0,0 +1,18 @@ +package io.modelcontextprotocol.examples.auth.shared; + +/** + * Shared constants for the auth example. + */ +public class Constants { + + public static final String SERVER_URL = "http://localhost:9200"; + + public static final String CLIENT_ID = "example-client"; + + public static final String CLIENT_SECRET = "example-secret"; + + public static final String REDIRECT_URI = "http://localhost:3000/callback"; + + public static final String SCOPE = "read write"; + +} \ No newline at end of file diff --git a/examples/auth-example/src/main/resources/application.properties b/examples/auth-example/src/main/resources/application.properties new file mode 100644 index 000000000..b5baca5a5 --- /dev/null +++ b/examples/auth-example/src/main/resources/application.properties @@ -0,0 +1,9 @@ +# Server configuration +server.port=9200 + +# Spring configuration +spring.main.banner-mode=off + +# Logging configuration +logging.level.root=INFO +logging.level.io.modelcontextprotocol=INFO \ No newline at end of file diff --git a/examples/auth-example/src/main/resources/logback.xml b/examples/auth-example/src/main/resources/logback.xml new file mode 100644 index 000000000..95129a281 --- /dev/null +++ b/examples/auth-example/src/main/resources/logback.xml @@ -0,0 +1,15 @@ + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/AccessToken.java b/mcp/src/main/java/io/modelcontextprotocol/auth/AccessToken.java new file mode 100644 index 000000000..1c9f8954e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/AccessToken.java @@ -0,0 +1,60 @@ +package io.modelcontextprotocol.auth; + +import java.util.List; + +/** + * Represents an OAuth access token. + */ +public class AccessToken { + + private String token; + + private String clientId; + + private List scopes; + + private Integer expiresAt; + + public AccessToken() { + } + + public AccessToken(String token, String clientId, List scopes, Integer expiresAt) { + this.token = token; + this.clientId = clientId; + this.scopes = scopes; + this.expiresAt = expiresAt; + } + + public String getToken() { + return token; + } + + public void setToken(String token) { + this.token = token; + } + + public String getClientId() { + return clientId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public List getScopes() { + return scopes; + } + + public void setScopes(List scopes) { + this.scopes = scopes; + } + + public Integer getExpiresAt() { + return expiresAt; + } + + public void setExpiresAt(Integer expiresAt) { + this.expiresAt = expiresAt; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationCode.java b/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationCode.java new file mode 100644 index 000000000..fdd5c14c2 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationCode.java @@ -0,0 +1,95 @@ +package io.modelcontextprotocol.auth; + +import java.net.URI; +import java.util.List; + +/** + * Represents an OAuth authorization code. + */ +public class AuthorizationCode { + + private String code; + + private List scopes; + + private double expiresAt; + + private String clientId; + + private String codeChallenge; + + private URI redirectUri; + + private boolean redirectUriProvidedExplicitly; + + public AuthorizationCode() { + } + + public AuthorizationCode(String code, List scopes, double expiresAt, String clientId, String codeChallenge, + URI redirectUri, boolean redirectUriProvidedExplicitly) { + this.code = code; + this.scopes = scopes; + this.expiresAt = expiresAt; + this.clientId = clientId; + this.codeChallenge = codeChallenge; + this.redirectUri = redirectUri; + this.redirectUriProvidedExplicitly = redirectUriProvidedExplicitly; + } + + public String getCode() { + return code; + } + + public void setCode(String code) { + this.code = code; + } + + public List getScopes() { + return scopes; + } + + public void setScopes(List scopes) { + this.scopes = scopes; + } + + public double getExpiresAt() { + return expiresAt; + } + + public void setExpiresAt(double expiresAt) { + this.expiresAt = expiresAt; + } + + public String getClientId() { + return clientId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public String getCodeChallenge() { + return codeChallenge; + } + + public void setCodeChallenge(String codeChallenge) { + this.codeChallenge = codeChallenge; + } + + public URI getRedirectUri() { + return redirectUri; + } + + public void setRedirectUri(URI redirectUri) { + this.redirectUri = redirectUri; + } + + public boolean isRedirectUriProvidedExplicitly() { + return redirectUriProvidedExplicitly; + } + + public void setRedirectUriProvidedExplicitly(boolean redirectUriProvidedExplicitly) { + this.redirectUriProvidedExplicitly = redirectUriProvidedExplicitly; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationParams.java b/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationParams.java new file mode 100644 index 000000000..996c48bda --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationParams.java @@ -0,0 +1,73 @@ +package io.modelcontextprotocol.auth; + +import java.net.URI; +import java.util.List; + +/** + * Parameters for an authorization request. + */ +public class AuthorizationParams { + + private String state; + + private List scopes; + + private String codeChallenge; + + private URI redirectUri; + + private boolean redirectUriProvidedExplicitly; + + public AuthorizationParams() { + } + + public AuthorizationParams(String state, List scopes, String codeChallenge, URI redirectUri, + boolean redirectUriProvidedExplicitly) { + this.state = state; + this.scopes = scopes; + this.codeChallenge = codeChallenge; + this.redirectUri = redirectUri; + this.redirectUriProvidedExplicitly = redirectUriProvidedExplicitly; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public List getScopes() { + return scopes; + } + + public void setScopes(List scopes) { + this.scopes = scopes; + } + + public String getCodeChallenge() { + return codeChallenge; + } + + public void setCodeChallenge(String codeChallenge) { + this.codeChallenge = codeChallenge; + } + + public URI getRedirectUri() { + return redirectUri; + } + + public void setRedirectUri(URI redirectUri) { + this.redirectUri = redirectUri; + } + + public boolean isRedirectUriProvidedExplicitly() { + return redirectUriProvidedExplicitly; + } + + public void setRedirectUriProvidedExplicitly(boolean redirectUriProvidedExplicitly) { + this.redirectUriProvidedExplicitly = redirectUriProvidedExplicitly; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidRedirectUriException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidRedirectUriException.java new file mode 100644 index 000000000..91a6fd320 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidRedirectUriException.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.auth; + +/** + * Exception thrown when a redirect URI is invalid. + */ +public class InvalidRedirectUriException extends Exception { + + public InvalidRedirectUriException(String message) { + super(message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidScopeException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidScopeException.java new file mode 100644 index 000000000..620a1f1ee --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidScopeException.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.auth; + +/** + * Exception thrown when a requested scope is invalid. + */ +public class InvalidScopeException extends Exception { + + public InvalidScopeException(String message) { + super(message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthAuthorizationServerProvider.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthAuthorizationServerProvider.java new file mode 100644 index 000000000..cefe1498f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthAuthorizationServerProvider.java @@ -0,0 +1,99 @@ +package io.modelcontextprotocol.auth; + +import io.modelcontextprotocol.auth.exception.AuthorizeException; +import io.modelcontextprotocol.auth.exception.RegistrationException; +import io.modelcontextprotocol.auth.exception.TokenException; + +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * Interface for OAuth authorization server providers. + */ +public interface OAuthAuthorizationServerProvider { + + /** + * Retrieves client information by client ID. + * @param clientId The ID of the client to retrieve. + * @return A CompletableFuture that resolves to the client information, or null if the + * client does not exist. + */ + CompletableFuture getClient(String clientId); + + /** + * Saves client information as part of registering it. + * @param clientInfo The client metadata to register. + * @return A CompletableFuture that completes when the registration is done. + * @throws RegistrationException If the client metadata is invalid. + */ + CompletableFuture registerClient(OAuthClientInformation clientInfo) throws RegistrationException; + + /** + * Called as part of the /authorize endpoint, and returns a URL that the client will + * be redirected to. + * @param client The client requesting authorization. + * @param params The parameters of the authorization request. + * @return A CompletableFuture that resolves to a URL to redirect the client to for + * authorization. + * @throws AuthorizeException If the authorization request is invalid. + */ + CompletableFuture authorize(OAuthClientInformation client, AuthorizationParams params) + throws AuthorizeException; + + /** + * Loads an AuthorizationCode by its code. + * @param client The client that requested the authorization code. + * @param authorizationCode The authorization code to get the challenge for. + * @return A CompletableFuture that resolves to the AuthorizationCode, or null if not + * found. + */ + CompletableFuture loadAuthorizationCode(OAuthClientInformation client, String authorizationCode); + + /** + * Exchanges an authorization code for an access token and refresh token. + * @param client The client exchanging the authorization code. + * @param authorizationCode The authorization code to exchange. + * @return A CompletableFuture that resolves to the OAuth token, containing access and + * refresh tokens. + * @throws TokenException If the request is invalid. + */ + CompletableFuture exchangeAuthorizationCode(OAuthClientInformation client, + AuthorizationCode authorizationCode) throws TokenException; + + /** + * Loads a RefreshToken by its token string. + * @param client The client that is requesting to load the refresh token. + * @param refreshToken The refresh token string to load. + * @return A CompletableFuture that resolves to the RefreshToken object if found, or + * null if not found. + */ + CompletableFuture loadRefreshToken(OAuthClientInformation client, String refreshToken); + + /** + * Exchanges a refresh token for an access token and refresh token. + * @param client The client exchanging the refresh token. + * @param refreshToken The refresh token to exchange. + * @param scopes Optional scopes to request with the new access token. + * @return A CompletableFuture that resolves to the OAuth token, containing access and + * refresh tokens. + * @throws TokenException If the request is invalid. + */ + CompletableFuture exchangeRefreshToken(OAuthClientInformation client, RefreshToken refreshToken, + List scopes) throws TokenException; + + /** + * Loads an access token by its token. + * @param token The access token to verify. + * @return A CompletableFuture that resolves to the AccessToken, or null if the token + * is invalid. + */ + CompletableFuture loadAccessToken(String token); + + /** + * Revokes an access or refresh token. + * @param token The token to revoke. + * @return A CompletableFuture that completes when the token is revoked. + */ + CompletableFuture revokeToken(Object token); + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientInformation.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientInformation.java new file mode 100644 index 000000000..d45761271 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientInformation.java @@ -0,0 +1,61 @@ +package io.modelcontextprotocol.auth; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * RFC 7591 OAuth 2.0 Dynamic Client Registration full response (client information plus + * metadata). + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class OAuthClientInformation extends OAuthClientMetadata { + + @JsonProperty("client_id") + private String clientId; + + @JsonProperty("client_secret") + private String clientSecret; + + @JsonProperty("client_id_issued_at") + private Long clientIdIssuedAt; + + @JsonProperty("client_secret_expires_at") + private Long clientSecretExpiresAt; + + public OAuthClientInformation() { + super(); + } + + public String getClientId() { + return clientId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public String getClientSecret() { + return clientSecret; + } + + public void setClientSecret(String clientSecret) { + this.clientSecret = clientSecret; + } + + public Long getClientIdIssuedAt() { + return clientIdIssuedAt; + } + + public void setClientIdIssuedAt(Long clientIdIssuedAt) { + this.clientIdIssuedAt = clientIdIssuedAt; + } + + public Long getClientSecretExpiresAt() { + return clientSecretExpiresAt; + } + + public void setClientSecretExpiresAt(Long clientSecretExpiresAt) { + this.clientSecretExpiresAt = clientSecretExpiresAt; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientMetadata.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientMetadata.java new file mode 100644 index 000000000..292623f23 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientMetadata.java @@ -0,0 +1,236 @@ +package io.modelcontextprotocol.auth; + +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. See + * https://datatracker.ietf.org/doc/html/rfc7591#section-2 for the full specification. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class OAuthClientMetadata { + + @JsonProperty("redirect_uris") + private List redirectUris; + + @JsonProperty("token_endpoint_auth_method") + private String tokenEndpointAuthMethod; + + @JsonProperty("grant_types") + private List grantTypes; + + @JsonProperty("response_types") + private List responseTypes; + + @JsonProperty("scope") + private String scope; + + // Optional metadata fields + @JsonProperty("client_name") + private String clientName; + + @JsonProperty("client_uri") + private URI clientUri; + + @JsonProperty("logo_uri") + private URI logoUri; + + @JsonProperty("contacts") + private List contacts; + + @JsonProperty("tos_uri") + private URI tosUri; + + @JsonProperty("policy_uri") + private URI policyUri; + + @JsonProperty("jwks_uri") + private URI jwksUri; + + @JsonProperty("jwks") + private Object jwks; + + @JsonProperty("software_id") + private String softwareId; + + @JsonProperty("software_version") + private String softwareVersion; + + public OAuthClientMetadata() { + this.tokenEndpointAuthMethod = "client_secret_post"; + this.grantTypes = Arrays.asList("authorization_code", "refresh_token"); + this.responseTypes = Arrays.asList("code"); + } + + /** + * Validates the requested scope against the client's allowed scopes. + * @param requestedScope The scope requested by the client + * @return List of validated scopes or null if no scope was requested + * @throws InvalidScopeException if the requested scope is not allowed + */ + public List validateScope(String requestedScope) throws InvalidScopeException { + if (requestedScope == null) { + return null; + } + + List requestedScopes = Arrays.asList(requestedScope.split(" ")); + List allowedScopes = scope == null ? new ArrayList<>() : Arrays.asList(scope.split(" ")); + + for (String scope : requestedScopes) { + if (!allowedScopes.contains(scope)) { + throw new InvalidScopeException("Client was not registered with scope " + scope); + } + } + + return requestedScopes; + } + + /** + * Validates the redirect URI against the client's registered redirect URIs. + * @param redirectUri The redirect URI to validate + * @return The validated redirect URI + * @throws InvalidRedirectUriException if the redirect URI is invalid + */ + public URI validateRedirectUri(URI redirectUri) throws InvalidRedirectUriException { + if (redirectUri != null) { + if (!redirectUris.contains(redirectUri)) { + throw new InvalidRedirectUriException("Redirect URI '" + redirectUri + "' not registered for client"); + } + return redirectUri; + } + else if (redirectUris.size() == 1) { + return redirectUris.get(0); + } + else { + throw new InvalidRedirectUriException( + "redirect_uri must be specified when client has multiple registered URIs"); + } + } + + // Getters and setters + public List getRedirectUris() { + return redirectUris; + } + + public void setRedirectUris(List redirectUris) { + this.redirectUris = redirectUris; + } + + public String getTokenEndpointAuthMethod() { + return tokenEndpointAuthMethod; + } + + public void setTokenEndpointAuthMethod(String tokenEndpointAuthMethod) { + this.tokenEndpointAuthMethod = tokenEndpointAuthMethod; + } + + public List getGrantTypes() { + return grantTypes; + } + + public void setGrantTypes(List grantTypes) { + this.grantTypes = grantTypes; + } + + public List getResponseTypes() { + return responseTypes; + } + + public void setResponseTypes(List responseTypes) { + this.responseTypes = responseTypes; + } + + public String getScope() { + return scope; + } + + public void setScope(String scope) { + this.scope = scope; + } + + public String getClientName() { + return clientName; + } + + public void setClientName(String clientName) { + this.clientName = clientName; + } + + public URI getClientUri() { + return clientUri; + } + + public void setClientUri(URI clientUri) { + this.clientUri = clientUri; + } + + public URI getLogoUri() { + return logoUri; + } + + public void setLogoUri(URI logoUri) { + this.logoUri = logoUri; + } + + public List getContacts() { + return contacts; + } + + public void setContacts(List contacts) { + this.contacts = contacts; + } + + public URI getTosUri() { + return tosUri; + } + + public void setTosUri(URI tosUri) { + this.tosUri = tosUri; + } + + public URI getPolicyUri() { + return policyUri; + } + + public void setPolicyUri(URI policyUri) { + this.policyUri = policyUri; + } + + public URI getJwksUri() { + return jwksUri; + } + + public void setJwksUri(URI jwksUri) { + this.jwksUri = jwksUri; + } + + public Object getJwks() { + return jwks; + } + + public void setJwks(Object jwks) { + this.jwks = jwks; + } + + public String getSoftwareId() { + return softwareId; + } + + public void setSoftwareId(String softwareId) { + this.softwareId = softwareId; + } + + public String getSoftwareVersion() { + return softwareVersion; + } + + public void setSoftwareVersion(String softwareVersion) { + this.softwareVersion = softwareVersion; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthMetadata.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthMetadata.java new file mode 100644 index 000000000..6c9c7b595 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthMetadata.java @@ -0,0 +1,230 @@ +package io.modelcontextprotocol.auth; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; + +/** + * RFC 8414 OAuth 2.0 Authorization Server Metadata. See + * https://datatracker.ietf.org/doc/html/rfc8414#section-2 + */ +public class OAuthMetadata { + + private URI issuer; + + private URI authorizationEndpoint; + + private URI tokenEndpoint; + + private URI registrationEndpoint; + + private List scopesSupported; + + private List responseTypesSupported; + + private List responseModesSupported; + + private List grantTypesSupported; + + private List tokenEndpointAuthMethodsSupported; + + private List tokenEndpointAuthSigningAlgValuesSupported; + + private URI serviceDocumentation; + + private List uiLocalesSupported; + + private URI opPolicyUri; + + private URI opTosUri; + + private URI revocationEndpoint; + + private List revocationEndpointAuthMethodsSupported; + + private List revocationEndpointAuthSigningAlgValuesSupported; + + private URI introspectionEndpoint; + + private List introspectionEndpointAuthMethodsSupported; + + private List introspectionEndpointAuthSigningAlgValuesSupported; + + private List codeChallengeMethodsSupported; + + public OAuthMetadata() { + this.responseTypesSupported = Arrays.asList("code"); + } + + // Getters and setters + public URI getIssuer() { + return issuer; + } + + public void setIssuer(URI issuer) { + this.issuer = issuer; + } + + public URI getAuthorizationEndpoint() { + return authorizationEndpoint; + } + + public void setAuthorizationEndpoint(URI authorizationEndpoint) { + this.authorizationEndpoint = authorizationEndpoint; + } + + public URI getTokenEndpoint() { + return tokenEndpoint; + } + + public void setTokenEndpoint(URI tokenEndpoint) { + this.tokenEndpoint = tokenEndpoint; + } + + public URI getRegistrationEndpoint() { + return registrationEndpoint; + } + + public void setRegistrationEndpoint(URI registrationEndpoint) { + this.registrationEndpoint = registrationEndpoint; + } + + public List getScopesSupported() { + return scopesSupported; + } + + public void setScopesSupported(List scopesSupported) { + this.scopesSupported = scopesSupported; + } + + public List getResponseTypesSupported() { + return responseTypesSupported; + } + + public void setResponseTypesSupported(List responseTypesSupported) { + this.responseTypesSupported = responseTypesSupported; + } + + public List getResponseModesSupported() { + return responseModesSupported; + } + + public void setResponseModesSupported(List responseModesSupported) { + this.responseModesSupported = responseModesSupported; + } + + public List getGrantTypesSupported() { + return grantTypesSupported; + } + + public void setGrantTypesSupported(List grantTypesSupported) { + this.grantTypesSupported = grantTypesSupported; + } + + public List getTokenEndpointAuthMethodsSupported() { + return tokenEndpointAuthMethodsSupported; + } + + public void setTokenEndpointAuthMethodsSupported(List tokenEndpointAuthMethodsSupported) { + this.tokenEndpointAuthMethodsSupported = tokenEndpointAuthMethodsSupported; + } + + public List getTokenEndpointAuthSigningAlgValuesSupported() { + return tokenEndpointAuthSigningAlgValuesSupported; + } + + public void setTokenEndpointAuthSigningAlgValuesSupported(List tokenEndpointAuthSigningAlgValuesSupported) { + this.tokenEndpointAuthSigningAlgValuesSupported = tokenEndpointAuthSigningAlgValuesSupported; + } + + public URI getServiceDocumentation() { + return serviceDocumentation; + } + + public void setServiceDocumentation(URI serviceDocumentation) { + this.serviceDocumentation = serviceDocumentation; + } + + public List getUiLocalesSupported() { + return uiLocalesSupported; + } + + public void setUiLocalesSupported(List uiLocalesSupported) { + this.uiLocalesSupported = uiLocalesSupported; + } + + public URI getOpPolicyUri() { + return opPolicyUri; + } + + public void setOpPolicyUri(URI opPolicyUri) { + this.opPolicyUri = opPolicyUri; + } + + public URI getOpTosUri() { + return opTosUri; + } + + public void setOpTosUri(URI opTosUri) { + this.opTosUri = opTosUri; + } + + public URI getRevocationEndpoint() { + return revocationEndpoint; + } + + public void setRevocationEndpoint(URI revocationEndpoint) { + this.revocationEndpoint = revocationEndpoint; + } + + public List getRevocationEndpointAuthMethodsSupported() { + return revocationEndpointAuthMethodsSupported; + } + + public void setRevocationEndpointAuthMethodsSupported(List revocationEndpointAuthMethodsSupported) { + this.revocationEndpointAuthMethodsSupported = revocationEndpointAuthMethodsSupported; + } + + public List getRevocationEndpointAuthSigningAlgValuesSupported() { + return revocationEndpointAuthSigningAlgValuesSupported; + } + + public void setRevocationEndpointAuthSigningAlgValuesSupported( + List revocationEndpointAuthSigningAlgValuesSupported) { + this.revocationEndpointAuthSigningAlgValuesSupported = revocationEndpointAuthSigningAlgValuesSupported; + } + + public URI getIntrospectionEndpoint() { + return introspectionEndpoint; + } + + public void setIntrospectionEndpoint(URI introspectionEndpoint) { + this.introspectionEndpoint = introspectionEndpoint; + } + + public List getIntrospectionEndpointAuthMethodsSupported() { + return introspectionEndpointAuthMethodsSupported; + } + + public void setIntrospectionEndpointAuthMethodsSupported(List introspectionEndpointAuthMethodsSupported) { + this.introspectionEndpointAuthMethodsSupported = introspectionEndpointAuthMethodsSupported; + } + + public List getIntrospectionEndpointAuthSigningAlgValuesSupported() { + return introspectionEndpointAuthSigningAlgValuesSupported; + } + + public void setIntrospectionEndpointAuthSigningAlgValuesSupported( + List introspectionEndpointAuthSigningAlgValuesSupported) { + this.introspectionEndpointAuthSigningAlgValuesSupported = introspectionEndpointAuthSigningAlgValuesSupported; + } + + public List getCodeChallengeMethodsSupported() { + return codeChallengeMethodsSupported; + } + + public void setCodeChallengeMethodsSupported(List codeChallengeMethodsSupported) { + this.codeChallengeMethodsSupported = codeChallengeMethodsSupported; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthToken.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthToken.java new file mode 100644 index 000000000..993010d41 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthToken.java @@ -0,0 +1,71 @@ +package io.modelcontextprotocol.auth; + +/** + * OAuth token as defined in RFC 6749 section 5.1 + * https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 + */ +public class OAuthToken { + + private String accessToken; + + private String tokenType; + + private Integer expiresIn; + + private String scope; + + private String refreshToken; + + public OAuthToken() { + this.tokenType = "bearer"; + } + + public OAuthToken(String accessToken, Integer expiresIn, String scope, String refreshToken) { + this.accessToken = accessToken; + this.tokenType = "bearer"; + this.expiresIn = expiresIn; + this.scope = scope; + this.refreshToken = refreshToken; + } + + public String getAccessToken() { + return accessToken; + } + + public void setAccessToken(String accessToken) { + this.accessToken = accessToken; + } + + public String getTokenType() { + return tokenType; + } + + public void setTokenType(String tokenType) { + this.tokenType = tokenType; + } + + public Integer getExpiresIn() { + return expiresIn; + } + + public void setExpiresIn(Integer expiresIn) { + this.expiresIn = expiresIn; + } + + public String getScope() { + return scope; + } + + public void setScope(String scope) { + this.scope = scope; + } + + public String getRefreshToken() { + return refreshToken; + } + + public void setRefreshToken(String refreshToken) { + this.refreshToken = refreshToken; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/README.md b/mcp/src/main/java/io/modelcontextprotocol/auth/README.md new file mode 100644 index 000000000..e8a77c515 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/README.md @@ -0,0 +1,86 @@ +# Authentication Implementation for Java SDK + +This package provides OAuth 2.0 authentication functionality for the Java SDK, based on the implementation in the Python SDK. + +## Overview + +The authentication implementation follows the OAuth 2.0 specification and includes: + +1. **Core OAuth Models**: + - `OAuthToken`: Represents an OAuth token with access and refresh tokens + - `OAuthClientMetadata`: Client registration metadata + - `OAuthClientInformation`: Client information including credentials + - `OAuthMetadata`: Authorization server metadata + +2. **Token Models**: + - `AccessToken`: Represents an OAuth access token + - `RefreshToken`: Represents an OAuth refresh token + - `AuthorizationCode`: Represents an OAuth authorization code + +3. **Authentication Middleware**: + - `BearerAuthenticator`: Validates Bearer tokens in Authorization headers + - `ClientAuthenticator`: Validates client credentials + - `AuthContext`: Holds authentication context for a request + +4. **Provider Interface**: + - `OAuthAuthorizationServerProvider`: Interface for OAuth authorization server providers + +5. **Exceptions**: + - `RegistrationException`: Thrown during client registration errors + - `AuthorizeException`: Thrown during authorization errors + - `TokenException`: Thrown during token operations errors + - `InvalidScopeException`: Thrown when a requested scope is invalid + - `InvalidRedirectUriException`: Thrown when a redirect URI is invalid + +## Usage + +To use the authentication functionality: + +1. Implement the `OAuthAuthorizationServerProvider` interface +2. Use the `BearerAuthenticator` to validate Bearer tokens +3. Use the `ClientAuthenticator` to validate client credentials + +Example: + +```java +// Create an OAuth provider implementation +OAuthAuthorizationServerProvider provider = new MyOAuthProvider(); + +// Create authenticators +BearerAuthenticator bearerAuth = new BearerAuthenticator(provider); +ClientAuthenticator clientAuth = new ClientAuthenticator(provider); + +// Authenticate a request with a Bearer token +String authHeader = "Bearer abc123"; +bearerAuth.authenticate(authHeader) + .thenAccept(user -> { + if (user != null) { + // User is authenticated + String clientId = user.getClientId(); + // ... + } else { + // Authentication failed + } + }); + +// Authenticate a client +String clientId = "client123"; +String clientSecret = "secret456"; +clientAuth.authenticate(clientId, clientSecret) + .thenAccept(client -> { + // Client is authenticated + // ... + }) + .exceptionally(ex -> { + // Authentication failed + // ... + return null; + }); +``` + +## Implementation Notes + +- The implementation uses CompletableFuture for asynchronous operations +- Token validation includes expiration checks +- Client authentication supports both secret and no-secret modes +- URI utilities are provided for constructing redirect URIs \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/RefreshToken.java b/mcp/src/main/java/io/modelcontextprotocol/auth/RefreshToken.java new file mode 100644 index 000000000..31014515e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/RefreshToken.java @@ -0,0 +1,60 @@ +package io.modelcontextprotocol.auth; + +import java.util.List; + +/** + * Represents an OAuth refresh token. + */ +public class RefreshToken { + + private String token; + + private String clientId; + + private List scopes; + + private Integer expiresAt; + + public RefreshToken() { + } + + public RefreshToken(String token, String clientId, List scopes, Integer expiresAt) { + this.token = token; + this.clientId = clientId; + this.scopes = scopes; + this.expiresAt = expiresAt; + } + + public String getToken() { + return token; + } + + public void setToken(String token) { + this.token = token; + } + + public String getClientId() { + return clientId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public List getScopes() { + return scopes; + } + + public void setScopes(List scopes) { + this.scopes = scopes; + } + + public Integer getExpiresAt() { + return expiresAt; + } + + public void setExpiresAt(Integer expiresAt) { + this.expiresAt = expiresAt; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/exception/AuthorizeException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/AuthorizeException.java new file mode 100644 index 000000000..c16d2759c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/AuthorizeException.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.auth.exception; + +/** + * Exception thrown during authorization. + */ +public class AuthorizeException extends Exception { + + private final String error; + + private final String errorDescription; + + public AuthorizeException(String error, String errorDescription) { + super(errorDescription != null ? errorDescription : error); + this.error = error; + this.errorDescription = errorDescription; + } + + public String getError() { + return error; + } + + public String getErrorDescription() { + return errorDescription; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/exception/RegistrationException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/RegistrationException.java new file mode 100644 index 000000000..0f4a1a109 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/RegistrationException.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.auth.exception; + +/** + * Exception thrown during client registration. + */ +public class RegistrationException extends Exception { + + private final String error; + + private final String errorDescription; + + public RegistrationException(String error, String errorDescription) { + super(errorDescription != null ? errorDescription : error); + this.error = error; + this.errorDescription = errorDescription; + } + + public String getError() { + return error; + } + + public String getErrorDescription() { + return errorDescription; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/exception/TokenException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/TokenException.java new file mode 100644 index 000000000..e0fa8ca6d --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/TokenException.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.auth.exception; + +/** + * Exception thrown during token operations. + */ +public class TokenException extends Exception { + + private final String error; + + private final String errorDescription; + + public TokenException(String error, String errorDescription) { + super(errorDescription != null ? errorDescription : error); + this.error = error; + this.errorDescription = errorDescription; + } + + public String getError() { + return error; + } + + public String getErrorDescription() { + return errorDescription; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFactory.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFactory.java new file mode 100644 index 000000000..0ecab28d4 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFactory.java @@ -0,0 +1,45 @@ +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.auth.OAuthClientProvider; +import io.modelcontextprotocol.client.transport.AuthenticatedTransportBuilder; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; + +/** + * Factory for creating MCP clients with authentication. + */ +public class McpClientFactory { + + /** + * Creates a new MCP client with authentication. + * @param serverUrl The server URL + * @param authProvider The OAuth client provider + * @return The MCP client + */ + public static McpAsyncClient createAuthenticatedClient(String serverUrl, OAuthClientProvider authProvider) { + // Create transport with authentication + HttpClientSseClientTransport transport = AuthenticatedTransportBuilder + .withAuthentication(HttpClientSseClientTransport.builder(serverUrl).sseEndpoint("/sse"), authProvider) + .build(); + + // Create MCP client + return McpClient.async(transport).build(); + } + + /** + * Creates a new MCP client with authentication. + * @param transportBuilder The transport builder to use + * @param authProvider The OAuth client provider + * @return The MCP client + */ + public static McpSyncClient createAuthenticatedClient(HttpClientSseClientTransport.Builder transportBuilder, + OAuthClientProvider authProvider) { + // Create transport with authentication + HttpClientSseClientTransport transport = AuthenticatedTransportBuilder + .withAuthentication(transportBuilder, authProvider) + .build(); + + // Create MCP client + return McpClient.sync(transport).build(); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/AuthCallbackResult.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/AuthCallbackResult.java new file mode 100644 index 000000000..07ac9d2da --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/AuthCallbackResult.java @@ -0,0 +1,38 @@ +package io.modelcontextprotocol.client.auth; + +/** + * Result of an OAuth authorization callback. + */ +public class AuthCallbackResult { + + private final String code; + + private final String state; + + /** + * Creates a new AuthCallbackResult. + * @param code The authorization code. + * @param state The state parameter. + */ + public AuthCallbackResult(String code, String state) { + this.code = code; + this.state = state; + } + + /** + * Get the authorization code. + * @return The authorization code. + */ + public String getCode() { + return code; + } + + /** + * Get the state parameter. + * @return The state parameter. + */ + public String getState() { + return state; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/HttpClientAuthenticator.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/HttpClientAuthenticator.java new file mode 100644 index 000000000..57845ecac --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/HttpClientAuthenticator.java @@ -0,0 +1,50 @@ +package io.modelcontextprotocol.client.auth; + +import java.net.http.HttpRequest; +import java.util.concurrent.CompletableFuture; + +/** + * Authenticator for HTTP requests using OAuth. + */ +public class HttpClientAuthenticator { + + private final OAuthClientProvider oauthProvider; + + /** + * Creates a new HttpClientAuthenticator. + * @param oauthProvider The OAuth client provider. + */ + public HttpClientAuthenticator(OAuthClientProvider oauthProvider) { + this.oauthProvider = oauthProvider; + } + + /** + * Authenticate an HTTP request by adding an Authorization header with the OAuth + * token. + * @param requestBuilder The HTTP request builder. + * @return A CompletableFuture that completes with the authenticated request builder. + */ + public CompletableFuture authenticate(HttpRequest.Builder requestBuilder) { + return oauthProvider.ensureToken().thenApply(v -> { + String accessToken = oauthProvider.getAccessToken(); + if (accessToken != null) { + return requestBuilder.header("Authorization", "Bearer " + accessToken); + } + return requestBuilder; + }); + } + + /** + * Handle an HTTP response, refreshing the token if needed. + * @param statusCode The HTTP status code. + * @return A CompletableFuture that completes when the response is handled. + */ + public CompletableFuture handleResponse(int statusCode) { + if (statusCode == 401) { + // Force token refresh on 401 Unauthorized + return oauthProvider.ensureToken(); + } + return CompletableFuture.completedFuture(null); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/OAuthClientProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/OAuthClientProvider.java new file mode 100644 index 000000000..7d6ad78ef --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/OAuthClientProvider.java @@ -0,0 +1,614 @@ +package io.modelcontextprotocol.client.auth; + +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthClientMetadata; +import io.modelcontextprotocol.auth.OAuthMetadata; +import io.modelcontextprotocol.auth.OAuthToken; + +import java.io.IOException; +import java.net.URI; +import java.net.URLEncoder; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * OAuth client provider that handles the OAuth 2.0 authorization code flow with PKCE. + */ +public class OAuthClientProvider { + + private final String serverUrl; + + private final OAuthClientMetadata clientMetadata; + + private final TokenStorage storage; + + private final Function> redirectHandler; + + private final Function> callbackHandler; + + private final Duration timeout; + + private final HttpClient httpClient; + + private final ObjectMapper objectMapper; + + // Cached authentication state + private OAuthToken currentTokens; + + private OAuthMetadata metadata; + + private OAuthClientInformation clientInfo; + + private Long tokenExpiryTime; + + // PKCE flow parameters + private String codeVerifier; + + private String codeChallenge; + + // State parameter for CSRF protection + private String authState; + + // Thread safety lock + private final ReentrantLock tokenLock = new ReentrantLock(); + + /** + * Creates a new OAuthClientProvider. + * @param serverUrl Base URL of the OAuth server + * @param clientMetadata OAuth client metadata + * @param storage Token storage implementation + * @param redirectHandler Function to handle authorization URL (e.g., opening a + * browser) + * @param callbackHandler Function to wait for callback and return auth code and state + * @param timeout Timeout for OAuth flow + */ + public OAuthClientProvider(String serverUrl, OAuthClientMetadata clientMetadata, TokenStorage storage, + Function> redirectHandler, + Function> callbackHandler, Duration timeout, + HttpClient httpClient) { + + this.serverUrl = serverUrl; + this.clientMetadata = clientMetadata; + this.storage = storage; + this.redirectHandler = redirectHandler; + this.callbackHandler = callbackHandler; + this.timeout = timeout; + this.httpClient = httpClient != null ? httpClient + : HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(30)).build(); + this.objectMapper = new ObjectMapper(); + } + + /** + * Creates a new OAuthClientProvider with default HTTP client. + * @param serverUrl Base URL of the OAuth server + * @param clientMetadata OAuth client metadata + * @param storage Token storage implementation + * @param redirectHandler Function to handle authorization URL (e.g., opening a + * browser) + * @param callbackHandler Function to wait for callback and return auth code and state + * @param timeout Timeout for OAuth flow + */ + public OAuthClientProvider(String serverUrl, OAuthClientMetadata clientMetadata, TokenStorage storage, + Function> redirectHandler, + Function> callbackHandler, Duration timeout) { + + this(serverUrl, clientMetadata, storage, redirectHandler, callbackHandler, timeout, null); + } + + /** + * Initialize the provider by loading stored tokens and client info. + * @return A CompletableFuture that completes when initialization is done. + */ + public CompletableFuture initialize() { + return storage.getTokens() + .thenAccept(tokens -> this.currentTokens = tokens) + .thenCompose(v -> storage.getClientInfo()) + .thenAccept(clientInfo -> this.clientInfo = clientInfo); + } + + /** + * Ensure a valid access token is available, refreshing or re-authenticating as + * needed. + * @return A CompletableFuture that completes when a valid token is available. + */ + public CompletableFuture ensureToken() { + if (hasValidToken()) { + return CompletableFuture.completedFuture(null); + } + + tokenLock.lock(); + try { + // Check again after acquiring lock + if (hasValidToken()) { + return CompletableFuture.completedFuture(null); + } + + // Try refreshing existing token + if (currentTokens != null && currentTokens.getRefreshToken() != null) { + return refreshAccessToken().thenCompose(refreshed -> { + if (Boolean.TRUE.equals(refreshed)) { + return CompletableFuture.completedFuture(null); + } + else { + // Fall back to full OAuth flow if refresh fails + return performOAuthFlow(); + } + }); + } + else { + // No refresh token, perform full OAuth flow + return performOAuthFlow(); + } + } + finally { + tokenLock.unlock(); + } + } + + /** + * Check if the current token is valid. + * @return true if a valid token exists, false otherwise. + */ + private boolean hasValidToken() { + if (currentTokens == null || currentTokens.getAccessToken() == null) { + return false; + } + + // Check expiry time + return tokenExpiryTime == null || System.currentTimeMillis() < tokenExpiryTime; + } + + /** + * Perform the OAuth 2.0 authorization code flow with PKCE. + * @return A CompletableFuture that completes when the flow is done. + */ + private CompletableFuture performOAuthFlow() { + // Discover OAuth metadata + return discoverOAuthMetadata(serverUrl).thenCompose(metadata -> { + this.metadata = metadata; + return getOrRegisterClient(); + }).thenCompose(clientInfo -> { + // Generate PKCE challenge + this.codeVerifier = PkceUtils.generateCodeVerifier(); + this.codeChallenge = PkceUtils.generateCodeChallenge(codeVerifier); + + // Generate state for CSRF protection + byte[] stateBytes = new byte[32]; + new java.security.SecureRandom().nextBytes(stateBytes); + this.authState = java.util.Base64.getUrlEncoder().withoutPadding().encodeToString(stateBytes); + + // Build authorization URL + String authUrl = buildAuthorizationUrl(clientInfo); + + // Redirect user for authorization + return redirectHandler.apply(authUrl) + .thenCompose(v -> callbackHandler.apply(null)) + .thenCompose(callbackResult -> { + // Validate state parameter + if (callbackResult.getState() == null || !callbackResult.getState().equals(authState)) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally( + new SecurityException("State parameter mismatch: possible CSRF attack")); + return future; + } + + // Clear state after validation + authState = null; + + if (callbackResult.getCode() == null) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new IllegalStateException("No authorization code received")); + return future; + } + + // Exchange authorization code for tokens + return exchangeCodeForToken(callbackResult.getCode(), clientInfo); + }); + }); + } + + /** + * Discover OAuth metadata from server's well-known endpoint. + * @param serverUrl The server URL. + * @return A CompletableFuture that resolves to the OAuth metadata. + */ + private CompletableFuture discoverOAuthMetadata(String serverUrl) { + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + String url = authBaseUrl + "/.well-known/oauth-authorization-server"; + + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("MCP-Protocol-Version", "0.1") + .GET() + .build(); + + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenApply(response -> { + if (response.statusCode() == 404) { + return null; + } + if (response.statusCode() != 200) { + throw new RuntimeException("Failed to discover OAuth metadata: " + response.statusCode()); + } + try { + return objectMapper.readValue(response.body(), OAuthMetadata.class); + } + catch (IOException e) { + throw new RuntimeException("Failed to parse OAuth metadata", e); + } + }).exceptionally(ex -> { + // Try again without MCP header + HttpRequest retryRequest = HttpRequest.newBuilder().uri(URI.create(url)).GET().build(); + + try { + HttpResponse response = httpClient.send(retryRequest, HttpResponse.BodyHandlers.ofString()); + if (response.statusCode() == 404) { + return null; + } + if (response.statusCode() != 200) { + return null; + } + return objectMapper.readValue(response.body(), OAuthMetadata.class); + } + catch (Exception e) { + return null; + } + }); + } + + /** + * Get or register client with server. + * @return A CompletableFuture that resolves to the client information. + */ + private CompletableFuture getOrRegisterClient() { + System.out.println("Client info: " + clientInfo); + if (clientInfo != null) { + return CompletableFuture.completedFuture(clientInfo); + } + + return registerOAuthClient(serverUrl, clientMetadata, metadata).thenCompose(registeredClient -> { + this.clientInfo = registeredClient; + return storage.setClientInfo(registeredClient).thenApply(v -> registeredClient); + }); + } + + /** + * Register OAuth client with server. + * @param serverUrl The server URL. + * @param clientMetadata The client metadata. + * @param metadata The OAuth metadata. + * @return A CompletableFuture that resolves to the registered client information. + */ + private CompletableFuture registerOAuthClient(String serverUrl, + OAuthClientMetadata clientMetadata, OAuthMetadata metadata) { + + String registrationUrl; + if (metadata != null && metadata.getRegistrationEndpoint() != null) { + registrationUrl = metadata.getRegistrationEndpoint().toString(); + } + else { + // Use fallback registration endpoint + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + registrationUrl = authBaseUrl + "/register"; + } + + // Handle default scope + if (clientMetadata.getScope() == null && metadata != null && metadata.getScopesSupported() != null + && !metadata.getScopesSupported().isEmpty()) { + clientMetadata.setScope(String.join(" ", metadata.getScopesSupported())); + } + + try { + String requestBody = objectMapper.writeValueAsString(clientMetadata); + + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(registrationUrl)) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(requestBody)) + .build(); + + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenApply(response -> { + if (response.statusCode() != 200 && response.statusCode() != 201) { + throw new RuntimeException("Registration failed: " + response.statusCode()); + } + try { + return objectMapper.readValue(response.body(), OAuthClientInformation.class); + } + catch (IOException e) { + throw new RuntimeException("Failed to parse client information", e); + } + }); + } + catch (Exception e) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(e); + return future; + } + } + + /** + * Build authorization URL for the OAuth flow. + * @param clientInfo The client information. + * @return The authorization URL. + */ + private String buildAuthorizationUrl(OAuthClientInformation clientInfo) { + String authUrlBase; + if (metadata != null && metadata.getAuthorizationEndpoint() != null) { + authUrlBase = metadata.getAuthorizationEndpoint().toString(); + } + else { + // Use fallback authorization endpoint + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + authUrlBase = authBaseUrl + "/authorize"; + } + + Map params = new HashMap<>(); + params.put("response_type", "code"); + params.put("client_id", clientInfo.getClientId()); + params.put("redirect_uri", clientInfo.getRedirectUris().get(0).toString()); + params.put("state", authState); + params.put("code_challenge", codeChallenge); + params.put("code_challenge_method", "S256"); + + // Include explicit scopes only + if (clientMetadata.getScope() != null) { + params.put("scope", clientMetadata.getScope()); + } + + return authUrlBase + "?" + formatQueryParams(params); + } + + /** + * Exchange authorization code for access token. + * @param authCode The authorization code. + * @param clientInfo The client information. + * @return A CompletableFuture that completes when the exchange is done. + */ + private CompletableFuture exchangeCodeForToken(String authCode, OAuthClientInformation clientInfo) { + String tokenUrl; + if (metadata != null && metadata.getTokenEndpoint() != null) { + tokenUrl = metadata.getTokenEndpoint().toString(); + } + else { + // Use fallback token endpoint + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + tokenUrl = authBaseUrl + "/token"; + } + + Map formData = new HashMap<>(); + formData.put("grant_type", "authorization_code"); + formData.put("code", authCode); + formData.put("redirect_uri", clientInfo.getRedirectUris().get(0).toString()); + formData.put("client_id", clientInfo.getClientId()); + formData.put("code_verifier", codeVerifier); + + if (clientInfo.getClientSecret() != null) { + formData.put("client_secret", clientInfo.getClientSecret()); + } + + String requestBody = formatQueryParams(formData); + + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(tokenUrl)) + .header("Content-Type", "application/x-www-form-urlencoded") + .timeout(Duration.ofSeconds(30)) + .POST(HttpRequest.BodyPublishers.ofString(requestBody)) + .build(); + + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenCompose(response -> { + if (response.statusCode() != 200) { + try { + Map errorData = objectMapper.readValue(response.body(), Map.class); + Object errorDesc = errorData.get("error_description"); + if (errorDesc == null) { + errorDesc = errorData.get("error"); + } + if (errorDesc == null) { + errorDesc = "Unknown error"; + } + String errorMsg = errorDesc.toString(); + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new RuntimeException( + "Token exchange failed: " + errorMsg + " (HTTP " + response.statusCode() + ")")); + return future; + } + catch (Exception e) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new RuntimeException( + "Token exchange failed: " + response.statusCode() + " " + response.body())); + return future; + } + } + + try { + OAuthToken tokenResponse = objectMapper.readValue(response.body(), OAuthToken.class); + + // Validate token scopes + validateTokenScopes(tokenResponse); + + // Calculate token expiry + if (tokenResponse.getExpiresIn() != null) { + tokenExpiryTime = System.currentTimeMillis() + (tokenResponse.getExpiresIn() * 1000L); + } + else { + tokenExpiryTime = null; + } + + // Store tokens + currentTokens = tokenResponse; + return storage.setTokens(tokenResponse); + } + catch (Exception e) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(e); + return future; + } + }); + } + + /** + * Refresh access token using refresh token. + * @return A CompletableFuture that resolves to true if refresh was successful, false + * otherwise. + */ + private CompletableFuture refreshAccessToken() { + if (currentTokens == null || currentTokens.getRefreshToken() == null) { + return CompletableFuture.completedFuture(false); + } + + return getOrRegisterClient().thenCompose(clientInfo -> { + String tokenUrl; + if (metadata != null && metadata.getTokenEndpoint() != null) { + tokenUrl = metadata.getTokenEndpoint().toString(); + } + else { + // Use fallback token endpoint + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + tokenUrl = authBaseUrl + "/token"; + } + + Map formData = new HashMap<>(); + formData.put("grant_type", "refresh_token"); + formData.put("refresh_token", currentTokens.getRefreshToken()); + formData.put("client_id", clientInfo.getClientId()); + + if (clientInfo.getClientSecret() != null) { + formData.put("client_secret", clientInfo.getClientSecret()); + } + + String requestBody = formatQueryParams(formData); + + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(tokenUrl)) + .header("Content-Type", "application/x-www-form-urlencoded") + .timeout(Duration.ofSeconds(30)) + .POST(HttpRequest.BodyPublishers.ofString(requestBody)) + .build(); + + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenCompose(response -> { + if (response.statusCode() != 200) { + return CompletableFuture.completedFuture(false); + } + + try { + OAuthToken tokenResponse = objectMapper.readValue(response.body(), OAuthToken.class); + + // Validate token scopes + validateTokenScopes(tokenResponse); + + // Calculate token expiry + if (tokenResponse.getExpiresIn() != null) { + tokenExpiryTime = System.currentTimeMillis() + (tokenResponse.getExpiresIn() * 1000L); + } + else { + tokenExpiryTime = null; + } + + // Store refreshed tokens + currentTokens = tokenResponse; + return storage.setTokens(tokenResponse).thenApply(v -> true); + } + catch (Exception e) { + return CompletableFuture.completedFuture(false); + } + }).exceptionally(ex -> false); + }); + } + + /** + * Validate returned scopes against requested scopes. + * @param tokenResponse The token response. + */ + private void validateTokenScopes(OAuthToken tokenResponse) { + if (tokenResponse.getScope() == null) { + // No scope returned = validation passes + return; + } + + // Check explicitly requested scopes only + if (clientMetadata.getScope() != null) { + // Validate against explicit scope request + String[] requestedScopes = clientMetadata.getScope().split(" "); + String[] returnedScopes = tokenResponse.getScope().split(" "); + + // Check for unauthorized scopes + for (String returnedScope : returnedScopes) { + boolean found = false; + for (String requestedScope : requestedScopes) { + if (returnedScope.equals(requestedScope)) { + found = true; + break; + } + } + + if (!found) { + throw new IllegalStateException("Server granted unauthorized scope: " + returnedScope); + } + } + } + } + + /** + * Extract base URL by removing path component. + * @param serverUrl The server URL. + * @return The base URL. + */ + private String getAuthorizationBaseUrl(String serverUrl) { + try { + URI uri = new URI(serverUrl); + return new URI(uri.getScheme(), uri.getAuthority(), null, null, null).toString(); + } + catch (Exception e) { + throw new IllegalArgumentException("Invalid server URL: " + serverUrl, e); + } + } + + /** + * Format query parameters for URL or form data. + * @param params The parameters. + * @return The formatted query string. + */ + private String formatQueryParams(Map params) { + StringBuilder result = new StringBuilder(); + boolean first = true; + + for (Map.Entry entry : params.entrySet()) { + if (!first) { + result.append("&"); + } + first = false; + + result.append(URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8)); + result.append("="); + result.append(URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8)); + } + + return result.toString(); + } + + /** + * Get the current access token. + * @return The access token, or null if none exists. + */ + public String getAccessToken() { + return currentTokens != null ? currentTokens.getAccessToken() : null; + } + + /** + * Get the current OAuth tokens. + * @return The OAuth tokens, or null if none exist. + */ + public OAuthToken getCurrentTokens() { + return currentTokens; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/PkceUtils.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/PkceUtils.java new file mode 100644 index 000000000..3c33e5992 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/PkceUtils.java @@ -0,0 +1,46 @@ +package io.modelcontextprotocol.client.auth; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.Base64; + +/** + * Utility class for PKCE (Proof Key for Code Exchange) operations. + */ +public class PkceUtils { + + private static final SecureRandom secureRandom = new SecureRandom(); + + private static final String ALLOWED_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"; + + /** + * Generates a cryptographically random code verifier for PKCE. + * @return A random code verifier string. + */ + public static String generateCodeVerifier() { + StringBuilder codeVerifier = new StringBuilder(128); + for (int i = 0; i < 128; i++) { + codeVerifier.append(ALLOWED_CHARS.charAt(secureRandom.nextInt(ALLOWED_CHARS.length()))); + } + return codeVerifier.toString(); + } + + /** + * Generates a code challenge from a code verifier using SHA-256. + * @param codeVerifier The code verifier to hash. + * @return The code challenge string. + */ + public static String generateCodeChallenge(String codeVerifier) { + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(codeVerifier.getBytes(StandardCharsets.UTF_8)); + return Base64.getUrlEncoder().withoutPadding().encodeToString(hash); + } + catch (NoSuchAlgorithmException e) { + throw new RuntimeException("SHA-256 algorithm not available", e); + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/README.md b/mcp/src/main/java/io/modelcontextprotocol/client/auth/README.md new file mode 100644 index 000000000..133ab5c82 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/README.md @@ -0,0 +1,96 @@ +# OAuth 2.0 Client Implementation + +This package provides an OAuth 2.0 client implementation for the MCP Java SDK, supporting the Authorization Code flow with PKCE (Proof Key for Code Exchange). + +## Components + +- `OAuthClientProvider`: Main class that handles the OAuth 2.0 flow +- `TokenStorage`: Interface for storing OAuth tokens and client information +- `HttpClientAuthenticator`: Authenticator for HTTP requests using OAuth +- `PkceUtils`: Utility class for PKCE operations +- `AuthCallbackResult`: Class to hold the result of an OAuth authorization callback + +## Usage Example + +```java +// Create client metadata +OAuthClientMetadata clientMetadata = new OAuthClientMetadata(); +clientMetadata.setRedirectUris(List.of(URI.create("http://localhost:8080/callback"))); +clientMetadata.setScope("read write"); + +// Create token storage +TokenStorage storage = new InMemoryTokenStorage(); + +// Create redirect handler (e.g., open browser) +Function> redirectHandler = url -> { + try { + Desktop.getDesktop().browse(URI.create(url)); + return CompletableFuture.completedFuture(null); + } catch (IOException e) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(e); + return future; + } +}; + +// Create callback handler (e.g., start local server to receive callback) +Function> callbackHandler = v -> { + // Implementation to start a local server and wait for callback + // Return CompletableFuture with code and state +}; + +// Create OAuth client provider +OAuthClientProvider provider = new OAuthClientProvider( + "https://api.example.com", + clientMetadata, + storage, + redirectHandler, + callbackHandler, + Duration.ofMinutes(5) +); + +// Initialize provider +provider.initialize() + .thenCompose(v -> provider.ensureToken()) + .thenRun(() -> { + // Now you have a valid token + String accessToken = provider.getAccessToken(); + System.out.println("Access token: " + accessToken); + }) + .exceptionally(ex -> { + System.err.println("Authentication failed: " + ex.getMessage()); + return null; + }); +``` + +## HTTP Client Integration + +```java +HttpClientAuthenticator authenticator = new HttpClientAuthenticator(provider); + +HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() + .uri(URI.create("https://api.example.com/resource")) + .GET(); + +authenticator.authenticate(requestBuilder) + .thenCompose(builder -> { + HttpRequest request = builder.build(); + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()); + }) + .thenCompose(response -> { + // Handle 401 responses by refreshing the token + return authenticator.handleResponse(response.statusCode()) + .thenApply(v -> response); + }) + .thenAccept(response -> { + System.out.println("Response: " + response.body()); + }); +``` + +## Security Features + +- PKCE support to prevent authorization code interception attacks +- State parameter to prevent CSRF attacks +- Automatic token refresh +- Thread-safe token management +- Scope validation \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/TokenStorage.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/TokenStorage.java new file mode 100644 index 000000000..1b67b6d04 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/TokenStorage.java @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.client.auth; + +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthToken; + +import java.util.concurrent.CompletableFuture; + +/** + * Interface for token storage implementations. + */ +public interface TokenStorage { + + /** + * Get stored tokens. + * @return A CompletableFuture that resolves to the stored tokens, or null if none + * exist. + */ + CompletableFuture getTokens(); + + /** + * Store tokens. + * @param tokens The tokens to store. + * @return A CompletableFuture that completes when the tokens are stored. + */ + CompletableFuture setTokens(OAuthToken tokens); + + /** + * Get stored client information. + * @return A CompletableFuture that resolves to the stored client information, or null + * if none exists. + */ + CompletableFuture getClientInfo(); + + /** + * Store client information. + * @param clientInfo The client information to store. + * @return A CompletableFuture that completes when the client information is stored. + */ + CompletableFuture setClientInfo(OAuthClientInformation clientInfo); + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/AuthenticatedTransportBuilder.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/AuthenticatedTransportBuilder.java new file mode 100644 index 000000000..26af5c972 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/AuthenticatedTransportBuilder.java @@ -0,0 +1,89 @@ +package io.modelcontextprotocol.client.transport; + +import io.modelcontextprotocol.client.auth.HttpClientAuthenticator; +import io.modelcontextprotocol.client.auth.OAuthClientProvider; + +import java.net.http.HttpRequest; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +/** + * Extension methods for transport builders to add authentication support. + */ +public final class AuthenticatedTransportBuilder { + + private AuthenticatedTransportBuilder() { + // Utility class + } + + /** + * Adds authentication to an HttpClientSseClientTransport.Builder. + * @param builder The builder to extend + * @param authProvider The OAuth client provider + * @return The modified builder + */ + public static HttpClientSseClientTransport.Builder withAuthentication(HttpClientSseClientTransport.Builder builder, + OAuthClientProvider authProvider) { + + HttpClientAuthenticator authenticator = new HttpClientAuthenticator(authProvider); + + // First, ensure token is available for the initial SSE connection + try { + authProvider.ensureToken().get(); + } + catch (Exception e) { + System.err.println("Failed to ensure token: " + e.getMessage()); + } + + // Add authentication to initial requests (including SSE connection) + builder = builder.customizeRequest(requestBuilder -> { + String token = authProvider.getAccessToken(); + if (token != null) { + requestBuilder.setHeader("Authorization", "Bearer " + token); + } + }); + + // Add interceptor for dynamic token refresh and retry + return builder.requestInterceptor(new RequestResponseInterceptor() { + @Override + public CompletableFuture interceptRequest(HttpRequest.Builder requestBuilder) { + // Skip setting the header here since it's already set by customizeRequest + return CompletableFuture.completedFuture(requestBuilder.build()); + } + + @Override + public CompletableFuture interceptResponse(HttpRequest request, + Function> responseHandler) { + + return responseHandler.apply(request).thenCompose(response -> { + if (response instanceof java.net.http.HttpResponse) { + int statusCode = ((java.net.http.HttpResponse) response).statusCode(); + if (statusCode == 401) { + // Handle 401 by refreshing token and retrying + return authenticator.handleResponse(statusCode).thenCompose(v -> { + // Rebuild request with new token + HttpRequest.Builder newRequestBuilder = HttpRequest.newBuilder(request.uri()) + .method(request.method(), + request.bodyPublisher().orElse(HttpRequest.BodyPublishers.noBody())); + + // Copy headers except Authorization + request.headers().map().forEach((name, values) -> { + if (!name.equalsIgnoreCase("Authorization")) { + values.forEach(value -> newRequestBuilder.header(name, value)); + } + }); + + // Add new auth header + return authenticator.authenticate(newRequestBuilder) + .thenApply(HttpRequest.Builder::build) + .thenCompose(responseHandler); + }); + } + } + return CompletableFuture.completedFuture(response); + }); + } + }); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 99cf2a625..715238e16 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -16,8 +16,12 @@ import java.util.function.Consumer; import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; + import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; @@ -25,8 +29,6 @@ import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; /** @@ -78,6 +80,8 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** SSE client for handling server-sent events. Uses the /sse endpoint */ private final FlowSseClient sseClient; + private RequestResponseInterceptor requestInterceptor; + /** * HTTP client for sending messages to the server. Uses HTTP POST over the message * endpoint @@ -179,6 +183,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; @@ -188,6 +193,14 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); } + /** + * Sets the request interceptor. + * @param requestInterceptor the request interceptor + */ + void setRequestInterceptor(RequestResponseInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + } + /** * Creates a new builder for {@link HttpClientSseClientTransport}. * @param baseUri the base URI of the MCP server @@ -215,6 +228,8 @@ public static class Builder { private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() .header("Content-Type", "application/json"); + private RequestResponseInterceptor requestInterceptor; + /** * Creates a new builder instance. */ @@ -312,13 +327,29 @@ public Builder objectMapper(ObjectMapper objectMapper) { return this; } + /** + * Sets the request interceptor. + * @param requestInterceptor the request interceptor + * @return this builder + */ + public Builder requestInterceptor(RequestResponseInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + return this; + } + /** * Builds a new {@link HttpClientSseClientTransport} instance. * @return a new transport instance */ public HttpClientSseClientTransport build() { - return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); + HttpClientSseClientTransport transport = new HttpClientSseClientTransport(clientBuilder.build(), + requestBuilder, baseUri, sseEndpoint, objectMapper); + + if (requestInterceptor != null) { + transport.setRequestInterceptor(requestInterceptor); + } + + return transport; } } @@ -342,6 +373,7 @@ public Mono connect(Function, Mono> h connectionFuture.set(future); URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + logger.info("Connecting to {}", clientUri); sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { @@ -415,17 +447,35 @@ public Mono sendMessage(JSONRPCMessage message) { try { String jsonText = this.objectMapper.writeValueAsString(message); URI requestUri = Utils.resolveUri(baseUri, endpoint); - HttpRequest request = this.requestBuilder.uri(requestUri) - .POST(HttpRequest.BodyPublishers.ofString(jsonText)) - .build(); - - return Mono.fromFuture( - httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - })); + HttpRequest.Builder builder = this.requestBuilder.copy() + .uri(requestUri) + .POST(HttpRequest.BodyPublishers.ofString(jsonText)); + + // Apply request interceptor if available + if (requestInterceptor != null) { + return Mono.fromFuture(requestInterceptor.interceptRequest(builder) + .thenCompose(request -> requestInterceptor.interceptResponse(request, + req -> httpClient.sendAsync(req, HttpResponse.BodyHandlers.discarding()) + .thenApply(response -> { + if (response.statusCode() != 200 && response.statusCode() != 201 + && response.statusCode() != 202 && response.statusCode() != 206) { + logger.error("Error sending message: {}", response.statusCode()); + } + return response; + }))) + .thenApply(response -> null)); + } + else { + // Original behavior without interceptor + HttpRequest request = builder.build(); + return Mono.fromFuture( + httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { + if (response.statusCode() != 200 && response.statusCode() != 201 + && response.statusCode() != 202 && response.statusCode() != 206) { + logger.error("Error sending message: {}", response.statusCode()); + } + })); + } } catch (IOException e) { if (!isClosing) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/RequestResponseInterceptor.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/RequestResponseInterceptor.java new file mode 100644 index 000000000..73a1472f5 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/RequestResponseInterceptor.java @@ -0,0 +1,30 @@ +package io.modelcontextprotocol.client.transport; + +import java.net.http.HttpRequest; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +/** + * Interface for intercepting HTTP requests and responses. + */ +public interface RequestResponseInterceptor { + + /** + * Intercept and potentially modify an HTTP request before it is sent. + * @param requestBuilder The request builder + * @return A CompletableFuture that resolves to the modified request + */ + CompletableFuture interceptRequest(HttpRequest.Builder requestBuilder); + + /** + * Intercept the response handling process, allowing for retries or other processing. + * @param request The original request + * @param responseHandler The function that processes the request and returns a + * response + * @param The response type + * @return A CompletableFuture that resolves to the response + */ + CompletableFuture interceptResponse(HttpRequest request, + Function> responseHandler); + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 889dc66d0..13441d276 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -5,6 +5,8 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.server.auth.middleware.AuthContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; @@ -28,6 +30,8 @@ public class McpAsyncServerExchange { private final McpSchema.Implementation clientInfo; + private final AuthContext authContext; + private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { @@ -45,9 +49,23 @@ public class McpAsyncServerExchange { */ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this(session, clientCapabilities, clientInfo, AuthContext.getCurrent()); + } + + /** + * Create a new asynchronous exchange with the client and authentication context. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + * @param authContext The authentication context. + */ + public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo, AuthContext authContext) { this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; + this.authContext = authContext; } /** @@ -145,4 +163,12 @@ private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) { return loggingLevel.level() >= this.minLoggingLevel.level(); } + /** + * Gets the authentication context for the current session. + * @return The authentication context, or null if not authenticated + */ + public AuthContext getAuthContext() { + return authContext; + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d6ec2cc30..896359ce8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -14,6 +14,7 @@ import java.util.function.BiFunction; import com.fasterxml.jackson.databind.ObjectMapper; + import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/OAuthRoutes.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/OAuthRoutes.java new file mode 100644 index 000000000..0e14def6b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/OAuthRoutes.java @@ -0,0 +1,161 @@ +package io.modelcontextprotocol.server.auth; + +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthMetadata; +import io.modelcontextprotocol.server.auth.handlers.AuthorizationHandler; +import io.modelcontextprotocol.server.auth.handlers.MetadataHandler; +import io.modelcontextprotocol.server.auth.handlers.RegistrationHandler; +import io.modelcontextprotocol.server.auth.handlers.RevocationHandler; +import io.modelcontextprotocol.server.auth.handlers.TokenHandler; +import io.modelcontextprotocol.server.auth.middleware.ClientAuthenticator; +import io.modelcontextprotocol.server.auth.settings.ClientRegistrationOptions; +import io.modelcontextprotocol.server.auth.settings.RevocationOptions; +import io.modelcontextprotocol.server.auth.util.UriUtils; + +import java.net.URI; +import java.util.ArrayList; +import java.util.List; + +/** + * Helper class for creating OAuth routes. + */ +public class OAuthRoutes { + + public static final String AUTHORIZATION_PATH = "/authorize"; + + public static final String TOKEN_PATH = "/token"; + + public static final String REGISTRATION_PATH = "/register"; + + public static final String REVOCATION_PATH = "/revoke"; + + public static final String METADATA_PATH = "/.well-known/oauth-authorization-server"; + + /** + * Create OAuth metadata for the server. + * @param issuerUrl The issuer URL + * @param serviceDocumentationUrl The service documentation URL + * @param clientRegistrationOptions The client registration options + * @param revocationOptions The revocation options + * @return The OAuth metadata + */ + public static OAuthMetadata buildMetadata(URI issuerUrl, URI serviceDocumentationUrl, + ClientRegistrationOptions clientRegistrationOptions, RevocationOptions revocationOptions) { + + UriUtils.validateIssuerUrl(issuerUrl); + + URI authorizationUrl = UriUtils.buildEndpointUrl(issuerUrl, AUTHORIZATION_PATH); + URI tokenUrl = UriUtils.buildEndpointUrl(issuerUrl, TOKEN_PATH); + + OAuthMetadata metadata = new OAuthMetadata(); + metadata.setIssuer(issuerUrl); + metadata.setAuthorizationEndpoint(authorizationUrl); + metadata.setTokenEndpoint(tokenUrl); + metadata.setScopesSupported(clientRegistrationOptions.getValidScopes()); + metadata.setResponseTypesSupported(List.of("code")); + metadata.setGrantTypesSupported(List.of("authorization_code", "refresh_token")); + metadata.setTokenEndpointAuthMethodsSupported(List.of("client_secret_post")); + metadata.setServiceDocumentation(serviceDocumentationUrl); + metadata.setCodeChallengeMethodsSupported(List.of("S256")); + + // Add registration endpoint if supported + if (clientRegistrationOptions.isEnabled()) { + metadata.setRegistrationEndpoint(UriUtils.buildEndpointUrl(issuerUrl, REGISTRATION_PATH)); + } + + // Add revocation endpoint if supported + if (revocationOptions.isEnabled()) { + metadata.setRevocationEndpoint(UriUtils.buildEndpointUrl(issuerUrl, REVOCATION_PATH)); + metadata.setRevocationEndpointAuthMethodsSupported(List.of("client_secret_post")); + } + + return metadata; + } + + /** + * Create handlers for OAuth routes. + * @param provider The OAuth authorization server provider + * @param metadata The OAuth metadata + * @param clientRegistrationOptions The client registration options + * @param revocationOptions The revocation options + * @return A map of route handlers + */ + public static OAuthHandlers createHandlers(OAuthAuthorizationServerProvider provider, OAuthMetadata metadata, + ClientRegistrationOptions clientRegistrationOptions, RevocationOptions revocationOptions) { + + ClientAuthenticator clientAuthenticator = new ClientAuthenticator(provider); + + OAuthHandlers handlers = new OAuthHandlers(); + handlers.setMetadataHandler(new MetadataHandler(metadata)); + handlers.setAuthorizationHandler(new AuthorizationHandler(provider)); + handlers.setTokenHandler(new TokenHandler(provider, clientAuthenticator)); + + if (clientRegistrationOptions.isEnabled()) { + handlers.setRegistrationHandler(new RegistrationHandler(provider, clientRegistrationOptions)); + } + + if (revocationOptions.isEnabled()) { + handlers.setRevocationHandler(new RevocationHandler(provider, clientAuthenticator)); + } + + return handlers; + } + + /** + * Container for OAuth route handlers. + */ + public static class OAuthHandlers { + + private MetadataHandler metadataHandler; + + private AuthorizationHandler authorizationHandler; + + private TokenHandler tokenHandler; + + private RegistrationHandler registrationHandler; + + private RevocationHandler revocationHandler; + + public MetadataHandler getMetadataHandler() { + return metadataHandler; + } + + public void setMetadataHandler(MetadataHandler metadataHandler) { + this.metadataHandler = metadataHandler; + } + + public AuthorizationHandler getAuthorizationHandler() { + return authorizationHandler; + } + + public void setAuthorizationHandler(AuthorizationHandler authorizationHandler) { + this.authorizationHandler = authorizationHandler; + } + + public TokenHandler getTokenHandler() { + return tokenHandler; + } + + public void setTokenHandler(TokenHandler tokenHandler) { + this.tokenHandler = tokenHandler; + } + + public RegistrationHandler getRegistrationHandler() { + return registrationHandler; + } + + public void setRegistrationHandler(RegistrationHandler registrationHandler) { + this.registrationHandler = registrationHandler; + } + + public RevocationHandler getRevocationHandler() { + return revocationHandler; + } + + public void setRevocationHandler(RevocationHandler revocationHandler) { + this.revocationHandler = revocationHandler; + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/README.md b/mcp/src/main/java/io/modelcontextprotocol/server/auth/README.md new file mode 100644 index 000000000..855441322 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/README.md @@ -0,0 +1,167 @@ +# OAuth 2.0 Server Implementation + +This package provides an OAuth 2.0 server implementation for the MCP Java SDK, supporting the Authorization Code flow with PKCE (Proof Key for Code Exchange). + +## Components + +### Handlers +- `AuthorizationHandler`: Handles OAuth authorization requests +- `TokenHandler`: Handles OAuth token requests +- `RegistrationHandler`: Handles OAuth client registration requests +- `RevocationHandler`: Handles OAuth token revocation requests +- `MetadataHandler`: Handles OAuth metadata requests + +### Middleware +- `ClientAuthenticator`: Authenticates OAuth clients +- `BearerAuthenticator`: Authenticates requests with bearer tokens + +### Settings +- `ClientRegistrationOptions`: Options for OAuth client registration +- `RevocationOptions`: Options for OAuth token revocation + +### Utilities +- `OAuthRoutes`: Helper class for creating OAuth routes and metadata + +## Usage Example + +```java +// Create provider implementation +OAuthAuthorizationServerProvider provider = new MyOAuthProvider(); + +// Create options +ClientRegistrationOptions registrationOptions = new ClientRegistrationOptions(); +registrationOptions.setValidScopes(List.of("read", "write")); + +RevocationOptions revocationOptions = new RevocationOptions(); + +// Create metadata +URI issuerUrl = URI.create("https://api.example.com"); +URI docsUrl = URI.create("https://docs.example.com"); +OAuthMetadata metadata = OAuthRoutes.buildMetadata( + issuerUrl, + docsUrl, + registrationOptions, + revocationOptions +); + +// Create handlers +OAuthRoutes.OAuthHandlers handlers = OAuthRoutes.createHandlers( + provider, + metadata, + registrationOptions, + revocationOptions +); + +// Use handlers in your web framework +// For example, with Spring MVC: + +@RestController +public class OAuthController { + + private final OAuthRoutes.OAuthHandlers handlers; + + public OAuthController(OAuthRoutes.OAuthHandlers handlers) { + this.handlers = handlers; + } + + @GetMapping("/.well-known/oauth-authorization-server") + public OAuthMetadata getMetadata() { + return handlers.getMetadataHandler().handle().join(); + } + + @GetMapping("/authorize") + public ResponseEntity authorize(@RequestParam Map params) { + try { + String redirectUrl = handlers.getAuthorizationHandler().handle(params).join(); + return ResponseEntity.status(HttpStatus.FOUND) + .header("Location", redirectUrl) + .header("Cache-Control", "no-store") + .build(); + } catch (CompletionException e) { + // Handle errors + return ResponseEntity.badRequest().body("Error: " + e.getCause().getMessage()); + } + } + + @PostMapping("/token") + public ResponseEntity token(@RequestParam Map params) { + try { + OAuthToken token = handlers.getTokenHandler().handle(params).join(); + return ResponseEntity.ok() + .header("Cache-Control", "no-store") + .header("Pragma", "no-cache") + .body(token); + } catch (CompletionException e) { + // Handle errors + return ResponseEntity.badRequest().body(null); + } + } + + // Add other endpoints for registration and revocation +} +``` + +## Provider Implementation + +You need to implement the `OAuthAuthorizationServerProvider` interface to provide the actual OAuth functionality: + +```java +public class MyOAuthProvider implements OAuthAuthorizationServerProvider { + + // Store clients, authorization codes, tokens, etc. + private final Map clients = new ConcurrentHashMap<>(); + private final Map authCodes = new ConcurrentHashMap<>(); + private final Map accessTokens = new ConcurrentHashMap<>(); + private final Map refreshTokens = new ConcurrentHashMap<>(); + + @Override + public CompletableFuture getClient(String clientId) { + return CompletableFuture.completedFuture(clients.get(clientId)); + } + + @Override + public CompletableFuture registerClient(OAuthClientInformation clientInfo) { + clients.put(clientInfo.getClientId(), clientInfo); + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture authorize(OAuthClientInformation client, AuthorizationParams params) { + // In a real implementation, you would show a UI to the user + // and get their consent before generating an authorization code + + // For this example, we'll just generate a code immediately + String code = generateRandomCode(); + + AuthorizationCode authCode = new AuthorizationCode(); + authCode.setClientId(client.getClientId()); + authCode.setCodeChallenge(params.getCodeChallenge()); + authCode.setRedirectUri(params.getRedirectUri()); + authCode.setRedirectUriProvidedExplicitly(params.isRedirectUriProvidedExplicitly()); + authCode.setScopes(params.getScopes()); + authCode.setExpiresAt(Instant.now().plusSeconds(600).getEpochSecond()); // 10 minutes + + authCodes.put(code, authCode); + + // Build redirect URI with code and state + String redirectUri = params.getRedirectUri().toString(); + redirectUri += "?code=" + code; + if (params.getState() != null) { + redirectUri += "&state=" + params.getState(); + } + + return CompletableFuture.completedFuture(redirectUri); + } + + // Implement other methods... +} +``` + +## Security Features + +- PKCE support to prevent authorization code interception attacks +- State parameter to prevent CSRF attacks +- Strict redirect URI validation +- Token expiration +- Scope validation +- HTTPS requirement (with localhost exception for testing) \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/AuthorizationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/AuthorizationHandler.java new file mode 100644 index 000000000..148435c07 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/AuthorizationHandler.java @@ -0,0 +1,179 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import io.modelcontextprotocol.auth.AuthorizationParams; +import io.modelcontextprotocol.auth.InvalidRedirectUriException; +import io.modelcontextprotocol.auth.InvalidScopeException; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.exception.AuthorizeException; +import io.modelcontextprotocol.server.auth.model.AuthorizationErrorResponse; +import io.modelcontextprotocol.server.auth.util.UriUtils; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Handler for OAuth authorization requests. + */ +public class AuthorizationHandler { + + private final OAuthAuthorizationServerProvider provider; + + public AuthorizationHandler(OAuthAuthorizationServerProvider provider) { + this.provider = provider; + } + + /** + * Handle an authorization request. + * @param params The request parameters + * @return A CompletableFuture that resolves to a response object containing either a + * redirect URL or an error + */ + public CompletableFuture handle(Map params) { + String clientId = params.get("client_id"); + String redirectUriStr = params.get("redirect_uri"); + String responseType = params.get("response_type"); + String codeChallenge = params.get("code_challenge"); + String codeChallengeMethod = params.get("code_challenge_method"); + String state = params.get("state"); + String scope = params.get("scope"); + + // Validate required parameters + if (clientId == null || responseType == null || codeChallenge == null) { + return CompletableFuture + .completedFuture(createErrorResponse("invalid_request", "Missing required parameters", state, null)); + } + + // Validate response type + if (!"code".equals(responseType)) { + return CompletableFuture.completedFuture(createErrorResponse("unsupported_response_type", + "Only 'code' response type is supported", state, null)); + } + + // Validate code challenge method + if (codeChallengeMethod != null && !"S256".equals(codeChallengeMethod)) { + return CompletableFuture.completedFuture(createErrorResponse("invalid_request", + "Only 'S256' code challenge method is supported", state, null)); + } + + // Get client information + return provider.getClient(clientId).thenCompose(client -> { + if (client == null) { + return CompletableFuture + .completedFuture(createErrorResponse("invalid_request", "Client ID not found", state, null)); + } + + // Validate redirect URI + URI redirectUri; + try { + URI tempUri = redirectUriStr != null ? URI.create(redirectUriStr) : null; + redirectUri = client.validateRedirectUri(tempUri); + } + catch (InvalidRedirectUriException e) { + return CompletableFuture + .completedFuture(createErrorResponse("invalid_request", e.getMessage(), state, null)); + } + + // Validate scope + List scopes; + try { + scopes = client.validateScope(scope); + } + catch (InvalidScopeException e) { + return CompletableFuture + .completedFuture(createErrorResponse("invalid_scope", e.getMessage(), state, redirectUri)); + } + + // Setup authorization parameters + AuthorizationParams authParams = new AuthorizationParams(); + authParams.setState(state); + authParams.setScopes(scopes); + authParams.setCodeChallenge(codeChallenge); + authParams.setRedirectUri(redirectUri); + authParams.setRedirectUriProvidedExplicitly(redirectUriStr != null); + + // Let the provider handle the authorization + try { + return provider.authorize(client, authParams) + .thenApply(url -> new AuthorizationResponse(url, true, null)) + .exceptionally(ex -> { + if (ex.getCause() instanceof AuthorizeException) { + AuthorizeException authEx = (AuthorizeException) ex.getCause(); + return createErrorResponse(authEx.getError(), authEx.getErrorDescription(), state, + redirectUri); + } + else { + return createErrorResponse("server_error", "An unexpected error occurred", state, + redirectUri); + } + }); + } + catch (AuthorizeException e) { + return CompletableFuture + .completedFuture(createErrorResponse(e.getError(), e.getErrorDescription(), state, redirectUri)); + } + }); + } + + /** + * Create an error response. + * @param error The error code + * @param errorDescription The error description + * @param state The state parameter + * @param redirectUri The redirect URI, or null if not available + * @return An AuthorizationResponse containing the error + */ + private AuthorizationResponse createErrorResponse(String error, String errorDescription, String state, + URI redirectUri) { + + AuthorizationErrorResponse errorResponse = new AuthorizationErrorResponse(error, errorDescription, state); + + if (redirectUri != null) { + // Redirect with error parameters + String redirectUrl = UriUtils.constructRedirectUri(redirectUri.toString(), errorResponse.toQueryParams()); + + return new AuthorizationResponse(redirectUrl, true, errorResponse); + } + else { + // Direct error response + return new AuthorizationResponse(null, false, errorResponse); + } + } + + /** + * Response object for authorization requests. + */ + public static class AuthorizationResponse { + + private final String redirectUrl; + + private final boolean isRedirect; + + private final AuthorizationErrorResponse error; + + public AuthorizationResponse(String redirectUrl, boolean isRedirect, AuthorizationErrorResponse error) { + this.redirectUrl = redirectUrl; + this.isRedirect = isRedirect; + this.error = error; + } + + public String getRedirectUrl() { + return redirectUrl; + } + + public boolean isRedirect() { + return isRedirect; + } + + public AuthorizationErrorResponse getError() { + return error; + } + + public boolean isSuccess() { + return error == null; + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/MetadataHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/MetadataHandler.java new file mode 100644 index 000000000..f9efc2cab --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/MetadataHandler.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import io.modelcontextprotocol.auth.OAuthMetadata; + +import java.util.concurrent.CompletableFuture; + +/** + * Handler for OAuth metadata requests. + */ +public class MetadataHandler { + + private final OAuthMetadata metadata; + + public MetadataHandler(OAuthMetadata metadata) { + this.metadata = metadata; + } + + /** + * Handle a metadata request. + * @return A CompletableFuture that resolves to the OAuth metadata + */ + public CompletableFuture handle() { + return CompletableFuture.completedFuture(metadata); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RegistrationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RegistrationHandler.java new file mode 100644 index 000000000..6ed97efb7 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RegistrationHandler.java @@ -0,0 +1,134 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import java.net.URI; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthClientMetadata; +import io.modelcontextprotocol.auth.exception.RegistrationException; +import io.modelcontextprotocol.server.auth.settings.ClientRegistrationOptions; + +/** + * Handler for OAuth client registration requests. + */ +public class RegistrationHandler { + + private final OAuthAuthorizationServerProvider provider; + + private final ClientRegistrationOptions options; + + public RegistrationHandler(OAuthAuthorizationServerProvider provider, ClientRegistrationOptions options) { + this.provider = provider; + this.options = options; + } + + /** + * Handle a client registration request. + * @param clientMetadata The client metadata + * @return A CompletableFuture that resolves to the registered client information + */ + public CompletableFuture handle(OAuthClientMetadata clientMetadata) { + // Validate client metadata + if (clientMetadata.getRedirectUris() == null || clientMetadata.getRedirectUris().isEmpty()) { + return CompletableFuture.failedFuture( + new RegistrationException("invalid_redirect_uri", "At least one redirect URI is required")); + } + + // Validate redirect URIs + for (URI redirectUri : clientMetadata.getRedirectUris()) { + if (!isValidRedirectUri(redirectUri)) { + return CompletableFuture.failedFuture( + new RegistrationException("invalid_redirect_uri", "Invalid redirect URI: " + redirectUri)); + } + } + + // Validate scopes if provided + if (clientMetadata.getScope() != null && options.getValidScopes() != null) { + String[] requestedScopes = clientMetadata.getScope().split(" "); + for (String scope : requestedScopes) { + if (!options.getValidScopes().contains(scope)) { + return CompletableFuture + .failedFuture(new RegistrationException("invalid_scope", "Invalid scope: " + scope)); + } + } + } + + // Create client information + OAuthClientInformation clientInfo = new OAuthClientInformation(); + + // Copy metadata fields + clientInfo.setRedirectUris(clientMetadata.getRedirectUris()); + clientInfo.setTokenEndpointAuthMethod(clientMetadata.getTokenEndpointAuthMethod()); + clientInfo.setGrantTypes(clientMetadata.getGrantTypes()); + clientInfo.setResponseTypes(clientMetadata.getResponseTypes()); + clientInfo.setScope(clientMetadata.getScope()); + clientInfo.setClientName(clientMetadata.getClientName()); + clientInfo.setClientUri(clientMetadata.getClientUri()); + clientInfo.setLogoUri(clientMetadata.getLogoUri()); + clientInfo.setContacts(clientMetadata.getContacts()); + clientInfo.setTosUri(clientMetadata.getTosUri()); + clientInfo.setPolicyUri(clientMetadata.getPolicyUri()); + clientInfo.setJwksUri(clientMetadata.getJwksUri()); + clientInfo.setJwks(clientMetadata.getJwks()); + clientInfo.setSoftwareId(clientMetadata.getSoftwareId()); + clientInfo.setSoftwareVersion(clientMetadata.getSoftwareVersion()); + + // Generate client ID and secret + clientInfo.setClientId(generateClientId()); + + // Generate client secret if using client_secret_post auth method + if ("client_secret_post".equals(clientMetadata.getTokenEndpointAuthMethod())) { + clientInfo.setClientSecret(generateClientSecret()); + } + + // Set issuance time + clientInfo.setClientIdIssuedAt(System.currentTimeMillis() / 1000); + + // Register client with provider + try { + return provider.registerClient(clientInfo).thenApply(v -> clientInfo); + } + catch (RegistrationException e) { + return CompletableFuture.failedFuture(e); + } + } + + /** + * Validate a redirect URI. + * @param redirectUri The redirect URI to validate + * @return true if the redirect URI is valid, false otherwise + */ + private boolean isValidRedirectUri(URI redirectUri) { + String scheme = redirectUri.getScheme(); + + // Check if localhost is allowed for non-HTTPS URIs + if (options.isAllowLocalhostRedirect() && ("http".equals(scheme) || "custom".equals(scheme))) { + String host = redirectUri.getHost(); + if ("localhost".equals(host) || host.startsWith("127.0.0.1")) { + return true; + } + } + + // Require HTTPS for all other URIs + return "https".equals(scheme); + } + + /** + * Generate a random client ID. + * @return A random client ID + */ + private String generateClientId() { + return UUID.randomUUID().toString(); + } + + /** + * Generate a random client secret. + * @return A random client secret + */ + private String generateClientSecret() { + return UUID.randomUUID().toString() + UUID.randomUUID().toString(); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RevocationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RevocationHandler.java new file mode 100644 index 000000000..d356cf0b7 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RevocationHandler.java @@ -0,0 +1,68 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import io.modelcontextprotocol.auth.AccessToken; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.RefreshToken; +import io.modelcontextprotocol.server.auth.middleware.ClientAuthenticator; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Handler for OAuth token revocation requests. + */ +public class RevocationHandler { + + private final OAuthAuthorizationServerProvider provider; + + private final ClientAuthenticator clientAuthenticator; + + public RevocationHandler(OAuthAuthorizationServerProvider provider, ClientAuthenticator clientAuthenticator) { + this.provider = provider; + this.clientAuthenticator = clientAuthenticator; + } + + /** + * Handle a token revocation request. + * @param params The request parameters + * @return A CompletableFuture that completes when the token is revoked + */ + public CompletableFuture handle(Map params) { + String token = params.get("token"); + String tokenTypeHint = params.get("token_type_hint"); + String clientId = params.get("client_id"); + String clientSecret = params.get("client_secret"); + + if (token == null || clientId == null) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new IllegalArgumentException("Missing required parameters")); + return future; + } + + // Authenticate client + return clientAuthenticator.authenticate(clientId, clientSecret).thenCompose(client -> { + // Try to load token based on token_type_hint + if ("refresh_token".equals(tokenTypeHint)) { + return provider.loadRefreshToken(client, token).thenCompose(refreshToken -> { + if (refreshToken != null && refreshToken.getClientId().equals(client.getClientId())) { + return provider.revokeToken(refreshToken); + } + return CompletableFuture.completedFuture(null); + }); + } + else if ("access_token".equals(tokenTypeHint) || tokenTypeHint == null) { + return provider.loadAccessToken(token).thenCompose(accessToken -> { + if (accessToken != null && accessToken.getClientId().equals(client.getClientId())) { + return provider.revokeToken(accessToken); + } + return CompletableFuture.completedFuture(null); + }); + } + else { + // Unknown token type hint + return CompletableFuture.completedFuture(null); + } + }); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/TokenHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/TokenHandler.java new file mode 100644 index 000000000..6bebe1a04 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/TokenHandler.java @@ -0,0 +1,188 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import io.modelcontextprotocol.auth.AccessToken; +import io.modelcontextprotocol.auth.AuthorizationCode; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthToken; +import io.modelcontextprotocol.auth.RefreshToken; +import io.modelcontextprotocol.auth.exception.TokenException; +import io.modelcontextprotocol.server.auth.middleware.ClientAuthenticator; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Handler for OAuth token requests. + */ +public class TokenHandler { + + private final OAuthAuthorizationServerProvider provider; + + private final ClientAuthenticator clientAuthenticator; + + public TokenHandler(OAuthAuthorizationServerProvider provider, ClientAuthenticator clientAuthenticator) { + this.provider = provider; + this.clientAuthenticator = clientAuthenticator; + } + + /** + * Handle a token request. + * @param params The request parameters + * @return A CompletableFuture that resolves to an OAuth token + */ + public CompletableFuture handle(Map params) { + String grantType = params.get("grant_type"); + String clientId = params.get("client_id"); + String clientSecret = params.get("client_secret"); + + if (grantType == null || clientId == null) { + return CompletableFuture.failedFuture(new TokenException("invalid_request", "Missing required parameters")); + } + + // Authenticate client + return clientAuthenticator.authenticate(clientId, clientSecret).thenCompose(client -> { + // Check if grant type is supported + if (!client.getGrantTypes().contains(grantType)) { + return CompletableFuture.failedFuture(new TokenException("unsupported_grant_type", + "Unsupported grant type (supported grant types are " + client.getGrantTypes() + ")")); + } + + // Handle different grant types + if ("authorization_code".equals(grantType)) { + return handleAuthorizationCode(client, params); + } + else if ("refresh_token".equals(grantType)) { + return handleRefreshToken(client, params); + } + else { + return CompletableFuture + .failedFuture(new TokenException("unsupported_grant_type", "Unsupported grant type")); + } + }); + } + + /** + * Handle authorization code grant type. + * @param client The authenticated client + * @param params The request parameters + * @return A CompletableFuture that resolves to an OAuth token + */ + private CompletableFuture handleAuthorizationCode(OAuthClientInformation client, + Map params) { + + String code = params.get("code"); + String redirectUri = params.get("redirect_uri"); + String codeVerifier = params.get("code_verifier"); + + if (code == null || codeVerifier == null) { + return CompletableFuture.failedFuture(new TokenException("invalid_request", "Missing required parameters")); + } + + // Load authorization code + return provider.loadAuthorizationCode(client, code).thenCompose(authCode -> { + if (authCode == null || !authCode.getClientId().equals(client.getClientId())) { + return CompletableFuture + .failedFuture(new TokenException("invalid_grant", "Authorization code does not exist")); + } + + // Check if code has expired + if (authCode.getExpiresAt() < Instant.now().getEpochSecond()) { + return CompletableFuture + .failedFuture(new TokenException("invalid_grant", "Authorization code has expired")); + } + + // Verify redirect URI matches + if (authCode.isRedirectUriProvidedExplicitly()) { + if (redirectUri == null || !redirectUri.equals(authCode.getRedirectUri().toString())) { + return CompletableFuture.failedFuture(new TokenException("invalid_request", + "Redirect URI did not match the one used when creating auth code")); + } + } + + // Verify PKCE code verifier + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(codeVerifier.getBytes(StandardCharsets.UTF_8)); + String hashedCodeVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(hash); + + if (!hashedCodeVerifier.equals(authCode.getCodeChallenge())) { + return CompletableFuture + .failedFuture(new TokenException("invalid_grant", "Incorrect code_verifier")); + } + } + catch (NoSuchAlgorithmException e) { + return CompletableFuture + .failedFuture(new TokenException("server_error", "Failed to verify code challenge")); + } + + // Exchange authorization code for tokens + try { + return provider.exchangeAuthorizationCode(client, authCode); + } + catch (TokenException e) { + return CompletableFuture.failedFuture(e); + } + }); + } + + /** + * Handle refresh token grant type. + * @param client The authenticated client + * @param params The request parameters + * @return A CompletableFuture that resolves to an OAuth token + */ + private CompletableFuture handleRefreshToken(OAuthClientInformation client, + Map params) { + + String refreshTokenStr = params.get("refresh_token"); + String scope = params.get("scope"); + + if (refreshTokenStr == null) { + return CompletableFuture + .failedFuture(new TokenException("invalid_request", "Missing refresh_token parameter")); + } + + // Load refresh token + return provider.loadRefreshToken(client, refreshTokenStr).thenCompose(refreshToken -> { + if (refreshToken == null || !refreshToken.getClientId().equals(client.getClientId())) { + return CompletableFuture + .failedFuture(new TokenException("invalid_grant", "Refresh token does not exist")); + } + + // Check if token has expired + if (refreshToken.getExpiresAt() != null && refreshToken.getExpiresAt() < Instant.now().getEpochSecond()) { + return CompletableFuture.failedFuture(new TokenException("invalid_grant", "Refresh token has expired")); + } + + // Parse scopes if provided + List scopes = scope != null ? Arrays.asList(scope.split(" ")) : refreshToken.getScopes(); + + // Validate requested scopes against refresh token scopes + if (scopes != null) { + for (String s : scopes) { + if (!refreshToken.getScopes().contains(s)) { + return CompletableFuture.failedFuture(new TokenException("invalid_scope", + "Cannot request scope `" + s + "` not provided by refresh token")); + } + } + } + + // Exchange refresh token for new tokens + try { + return provider.exchangeRefreshToken(client, refreshToken, scopes); + } + catch (TokenException e) { + return CompletableFuture.failedFuture(e); + } + }); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/AuthContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/AuthContext.java new file mode 100644 index 000000000..ee225d682 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/AuthContext.java @@ -0,0 +1,70 @@ +package io.modelcontextprotocol.server.auth.middleware; + +import io.modelcontextprotocol.auth.AccessToken; + +/** + * Holds authentication context for a request. + */ +public class AuthContext { + + private final AccessToken accessToken; + + private static final ThreadLocal currentContext = new ThreadLocal<>(); + + /** + * Creates a new AuthContext. + * @param accessToken The authenticated access token. + */ + public AuthContext(AccessToken accessToken) { + this.accessToken = accessToken; + } + + /** + * Gets the access token. + * @return The access token. + */ + public AccessToken getAccessToken() { + return accessToken; + } + + /** + * Gets the client ID. + * @return The client ID. + */ + public String getClientId() { + return accessToken != null ? accessToken.getClientId() : null; + } + + /** + * Checks if the user has the specified scope. + * @param scope The scope to check. + * @return True if the user has the scope, false otherwise. + */ + public boolean hasScope(String scope) { + return accessToken != null && accessToken.getScopes().contains(scope); + } + + /** + * Sets the current auth context for this thread. + * @param authContext The auth context to set + */ + public static void setCurrent(AuthContext authContext) { + currentContext.set(authContext); + } + + /** + * Gets the current auth context for this thread. + * @return The current auth context, or null if not set + */ + public static AuthContext getCurrent() { + return currentContext.get(); + } + + /** + * Clears the current auth context for this thread. + */ + public static void clearCurrent() { + currentContext.remove(); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/AuthContextProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/AuthContextProvider.java new file mode 100644 index 000000000..ef6b6fa9c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/AuthContextProvider.java @@ -0,0 +1,20 @@ +package io.modelcontextprotocol.server.auth.middleware; + +/** + * Interface for transports that support authentication context. + */ +public interface AuthContextProvider { + + /** + * Gets the authentication context. + * @return The authentication context, or null if not authenticated + */ + AuthContext getAuthContext(); + + /** + * Sets the authentication context. + * @param authContext The authentication context + */ + void setAuthContext(AuthContext authContext); + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/BearerAuthenticator.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/BearerAuthenticator.java new file mode 100644 index 000000000..d33025396 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/BearerAuthenticator.java @@ -0,0 +1,60 @@ +package io.modelcontextprotocol.server.auth.middleware; + +import io.modelcontextprotocol.auth.AccessToken; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; + +import java.util.concurrent.CompletableFuture; + +/** + * Authenticator for OAuth bearer tokens. + */ +public class BearerAuthenticator { + + private final OAuthAuthorizationServerProvider provider; + + public BearerAuthenticator(OAuthAuthorizationServerProvider provider) { + this.provider = provider; + } + + /** + * Authenticate a request using a bearer token. + * @param authHeader The Authorization header value + * @return A CompletableFuture that resolves to the authenticated access token + */ + public CompletableFuture authenticate(String authHeader) { + if (authHeader == null || !authHeader.startsWith("Bearer ")) { + return CompletableFuture + .failedFuture(new AuthenticationException("Missing or invalid Authorization header")); + } + + String token = authHeader.substring("Bearer ".length()).trim(); + if (token.isEmpty()) { + return CompletableFuture.failedFuture(new AuthenticationException("Empty bearer token")); + } + + return provider.loadAccessToken(token).thenCompose(accessToken -> { + if (accessToken == null) { + return CompletableFuture.failedFuture(new AuthenticationException("Invalid access token")); + } + + // Check if token has expired + if (accessToken.getExpiresAt() != null && accessToken.getExpiresAt() < System.currentTimeMillis() / 1000) { + return CompletableFuture.failedFuture(new AuthenticationException("Access token has expired")); + } + + return CompletableFuture.completedFuture(accessToken); + }); + } + + /** + * Exception thrown when bearer authentication fails. + */ + public static class AuthenticationException extends Exception { + + public AuthenticationException(String message) { + super(message); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/ClientAuthenticator.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/ClientAuthenticator.java new file mode 100644 index 000000000..b5198626f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/ClientAuthenticator.java @@ -0,0 +1,61 @@ +package io.modelcontextprotocol.server.auth.middleware; + +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthClientInformation; + +import java.util.concurrent.CompletableFuture; + +/** + * Authenticator for OAuth clients. + */ +public class ClientAuthenticator { + + private final OAuthAuthorizationServerProvider provider; + + public ClientAuthenticator(OAuthAuthorizationServerProvider provider) { + this.provider = provider; + } + + /** + * Authenticate a client using client ID and optional client secret. + * @param clientId The client ID + * @param clientSecret The client secret (may be null) + * @return A CompletableFuture that resolves to the authenticated client information + */ + public CompletableFuture authenticate(String clientId, String clientSecret) { + if (clientId == null) { + return CompletableFuture.failedFuture(new AuthenticationException("Missing client_id parameter")); + } + + return provider.getClient(clientId).thenCompose(client -> { + if (client == null) { + return CompletableFuture.failedFuture(new AuthenticationException("Client not found")); + } + + // If client has a secret, verify it + if (client.getClientSecret() != null) { + if (clientSecret == null) { + return CompletableFuture.failedFuture(new AuthenticationException("Client secret required")); + } + + if (!client.getClientSecret().equals(clientSecret)) { + return CompletableFuture.failedFuture(new AuthenticationException("Invalid client secret")); + } + } + + return CompletableFuture.completedFuture(client); + }); + } + + /** + * Exception thrown when client authentication fails. + */ + public static class AuthenticationException extends Exception { + + public AuthenticationException(String message) { + super(message); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/model/AuthorizationErrorResponse.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/model/AuthorizationErrorResponse.java new file mode 100644 index 000000000..cccfe594e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/model/AuthorizationErrorResponse.java @@ -0,0 +1,117 @@ +package io.modelcontextprotocol.server.auth.model; + +import java.net.URI; + +/** + * OAuth authorization error response as defined in RFC 6749 Section 4.1.2.1. + */ +public class AuthorizationErrorResponse { + + private String error; + + private String errorDescription; + + private URI errorUri; + + private String state; + + /** + * Creates a new AuthorizationErrorResponse. + * @param error The error code + * @param errorDescription The error description + * @param state The state parameter from the request + */ + public AuthorizationErrorResponse(String error, String errorDescription, String state) { + this.error = error; + this.errorDescription = errorDescription; + this.state = state; + } + + /** + * Gets the error code. + * @return The error code + */ + public String getError() { + return error; + } + + /** + * Sets the error code. + * @param error The error code + */ + public void setError(String error) { + this.error = error; + } + + /** + * Gets the error description. + * @return The error description + */ + public String getErrorDescription() { + return errorDescription; + } + + /** + * Sets the error description. + * @param errorDescription The error description + */ + public void setErrorDescription(String errorDescription) { + this.errorDescription = errorDescription; + } + + /** + * Gets the error URI. + * @return The error URI + */ + public URI getErrorUri() { + return errorUri; + } + + /** + * Sets the error URI. + * @param errorUri The error URI + */ + public void setErrorUri(URI errorUri) { + this.errorUri = errorUri; + } + + /** + * Gets the state parameter. + * @return The state parameter + */ + public String getState() { + return state; + } + + /** + * Sets the state parameter. + * @param state The state parameter + */ + public void setState(String state) { + this.state = state; + } + + /** + * Converts the error response to a map of query parameters. + * @return A map of query parameters + */ + public java.util.Map toQueryParams() { + java.util.Map params = new java.util.HashMap<>(); + params.put("error", error); + + if (errorDescription != null) { + params.put("error_description", errorDescription); + } + + if (errorUri != null) { + params.put("error_uri", errorUri.toString()); + } + + if (state != null) { + params.put("state", state); + } + + return params; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/ClientRegistrationOptions.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/ClientRegistrationOptions.java new file mode 100644 index 000000000..1e4e18c81 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/ClientRegistrationOptions.java @@ -0,0 +1,64 @@ +package io.modelcontextprotocol.server.auth.settings; + +import java.util.List; + +/** + * Options for client registration. + */ +public class ClientRegistrationOptions { + + private boolean enabled = true; + + private boolean allowLocalhostRedirect; + + private List validScopes; + + /** + * Check if client registration is enabled. + * @return true if client registration is enabled, false otherwise + */ + public boolean isEnabled() { + return enabled; + } + + /** + * Set whether client registration is enabled. + * @param enabled true to enable client registration, false to disable + */ + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + /** + * Gets whether localhost redirects are allowed. + * @return true if localhost redirects are allowed, false otherwise + */ + public boolean isAllowLocalhostRedirect() { + return allowLocalhostRedirect; + } + + /** + * Sets whether localhost redirects are allowed. + * @param allowLocalhostRedirect true to allow localhost redirects, false otherwise + */ + public void setAllowLocalhostRedirect(boolean allowLocalhostRedirect) { + this.allowLocalhostRedirect = allowLocalhostRedirect; + } + + /** + * Gets the valid scopes. + * @return the valid scopes + */ + public List getValidScopes() { + return validScopes; + } + + /** + * Sets the valid scopes. + * @param validScopes the valid scopes + */ + public void setValidScopes(List validScopes) { + this.validScopes = validScopes; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/RevocationOptions.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/RevocationOptions.java new file mode 100644 index 000000000..6e1805565 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/RevocationOptions.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.server.auth.settings; + +/** + * Options for OAuth token revocation. + */ +public class RevocationOptions { + + private boolean enabled = true; + + /** + * Check if token revocation is enabled. + * @return true if token revocation is enabled, false otherwise + */ + public boolean isEnabled() { + return enabled; + } + + /** + * Set whether token revocation is enabled. + * @param enabled true to enable token revocation, false to disable + */ + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/util/UriUtils.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/util/UriUtils.java new file mode 100644 index 000000000..464246a99 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/util/UriUtils.java @@ -0,0 +1,123 @@ +package io.modelcontextprotocol.server.auth.util; + +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Utility class for URI operations. + */ +public class UriUtils { + + /** + * Constructs a redirect URI with query parameters. + * @param redirectUriBase The base redirect URI. + * @param params The parameters to add to the query string. + * @return The constructed redirect URI. + */ + public static String constructRedirectUri(String redirectUriBase, Map params) { + try { + URI uri = new URI(redirectUriBase); + + // Get existing query + String query = uri.getQuery(); + StringBuilder queryBuilder = new StringBuilder(); + + // Append existing query parameters if any + if (query != null && !query.isEmpty()) { + queryBuilder.append(query); + if (!params.isEmpty()) { + queryBuilder.append("&"); + } + } + + // Append new parameters + if (!params.isEmpty()) { + String newParams = params.entrySet() + .stream() + .filter(entry -> entry.getValue() != null) + .map(entry -> URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8) + "=" + + URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8)) + .collect(Collectors.joining("&")); + queryBuilder.append(newParams); + } + + // Create new URI with updated query + return new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), queryBuilder.toString(), + uri.getFragment()) + .toString(); + } + catch (URISyntaxException e) { + throw new IllegalArgumentException("Invalid redirect URI: " + redirectUriBase, e); + } + } + + /** + * Modify a URI's path using the provided mapper function. + * @param uri The URI to modify + * @param pathMapper Function to transform the path + * @return The modified URI + */ + public static URI modifyUriPath(URI uri, Function pathMapper) { + String path = uri.getPath(); + if (path == null) { + path = ""; + } + + String newPath = pathMapper.apply(path); + + try { + return new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(), newPath, uri.getQuery(), + uri.getFragment()); + } + catch (Exception e) { + throw new IllegalArgumentException("Failed to modify URI path", e); + } + } + + /** + * Validate that the issuer URL meets OAuth 2.0 requirements. + * @param url The issuer URL to validate + */ + public static void validateIssuerUrl(URI url) { + // RFC 8414 requires HTTPS, but we allow localhost HTTP for testing + String scheme = url.getScheme(); + String host = url.getHost(); + + if (!"https".equals(scheme) && !"localhost".equals(host) && !host.startsWith("127.0.0.1")) { + throw new IllegalArgumentException("Issuer URL must be HTTPS"); + } + + // No fragments or query parameters allowed + if (url.getFragment() != null) { + throw new IllegalArgumentException("Issuer URL must not have a fragment"); + } + if (url.getQuery() != null) { + throw new IllegalArgumentException("Issuer URL must not have a query string"); + } + } + + /** + * Build an endpoint URL by appending a path to the issuer URL. + * @param issuerUrl The issuer URL + * @param path The path to append + * @return The endpoint URL + */ + public static URI buildEndpointUrl(URI issuerUrl, String path) { + String baseUrl = issuerUrl.toString(); + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } + + if (!path.startsWith("/")) { + path = "/" + path; + } + + return URI.create(baseUrl + path); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff472..3510992ec 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -13,6 +13,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.auth.middleware.AuthContext; +import io.modelcontextprotocol.server.auth.middleware.AuthContextProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -202,6 +205,11 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } + // Call the authentication hook + if (!authenticateRequest(request, response)) { + return; // Authentication failed, response already set + } + response.setContentType("text/event-stream"); response.setCharacterEncoding(UTF_8); response.setHeader("Cache-Control", "no-cache"); @@ -214,9 +222,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) PrintWriter writer = response.getWriter(); - // Create a new session transport - HttpServletMcpSessionTransport sessionTransport = new HttpServletMcpSessionTransport(sessionId, asyncContext, - writer); + // Create a new session transport using the hook method + HttpServletMcpSessionTransport sessionTransport = createSessionTransport(sessionId, asyncContext, writer); // Create a new session using the session factory McpServerSession session = sessionFactory.create(sessionTransport); @@ -356,11 +363,39 @@ public void destroy() { super.destroy(); } + /** + * Hook method for authentication. Subclasses can override this to provide + * authentication. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @return true if authentication succeeded, false if it failed + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + protected boolean authenticateRequest(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + // Default implementation does no authentication + return true; + } + + /** + * Hook method to create a session transport. Subclasses can override this to provide + * custom transport implementations. + * @param sessionId The session ID + * @param asyncContext The async context + * @param writer The writer + * @return A new session transport + */ + protected HttpServletMcpSessionTransport createSessionTransport(String sessionId, AsyncContext asyncContext, + PrintWriter writer) { + return new HttpServletMcpSessionTransport(sessionId, asyncContext, writer); + } + /** * Implementation of McpServerTransport for HttpServlet SSE sessions. This class * handles the transport-level communication for a specific client session. */ - private class HttpServletMcpSessionTransport implements McpServerTransport { + protected class HttpServletMcpSessionTransport implements McpServerTransport, AuthContextProvider { private final String sessionId; @@ -368,6 +403,8 @@ private class HttpServletMcpSessionTransport implements McpServerTransport { private final PrintWriter writer; + private AuthContext authContext; + /** * Creates a new session transport with the specified ID and SSE writer. * @param sessionId The unique identifier for this session @@ -381,6 +418,16 @@ private class HttpServletMcpSessionTransport implements McpServerTransport { logger.debug("Session transport {} initialized with SSE writer", sessionId); } + @Override + public void setAuthContext(AuthContext authContext) { + this.authContext = authContext; + } + + @Override + public AuthContext getAuthContext() { + return authContext; + } + /** * Sends a JSON-RPC message to the client through the SSE connection. * @param message The JSON-RPC message to send diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/OAuthHttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/OAuthHttpServletSseServerTransportProvider.java new file mode 100644 index 000000000..f892c20bc --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/OAuthHttpServletSseServerTransportProvider.java @@ -0,0 +1,386 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletionException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.auth.AccessToken; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthClientMetadata; +import io.modelcontextprotocol.auth.OAuthMetadata; +import io.modelcontextprotocol.server.auth.handlers.AuthorizationHandler; +import io.modelcontextprotocol.server.auth.handlers.MetadataHandler; +import io.modelcontextprotocol.server.auth.handlers.RegistrationHandler; +import io.modelcontextprotocol.server.auth.handlers.RevocationHandler; +import io.modelcontextprotocol.server.auth.handlers.TokenHandler; +import io.modelcontextprotocol.server.auth.middleware.AuthContext; +import io.modelcontextprotocol.server.auth.middleware.BearerAuthenticator; +import io.modelcontextprotocol.server.auth.middleware.ClientAuthenticator; +import io.modelcontextprotocol.server.auth.settings.ClientRegistrationOptions; +import io.modelcontextprotocol.server.auth.settings.RevocationOptions; +import io.modelcontextprotocol.spec.McpError; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +/** + * Extended transport provider that handles both MCP messages and OAuth routes. This class + * integrates OAuth authentication routes directly into the transport layer It also adds + * authentication middleware to validate requests for SSE and message endpoints. + */ +public class OAuthHttpServletSseServerTransportProvider extends HttpServletSseServerTransportProvider { + + /** Logger for this class */ + private static final Logger logger = LoggerFactory.getLogger(OAuthHttpServletSseServerTransportProvider.class); + + private final MetadataHandler metadataHandler; + + private final AuthorizationHandler authorizationHandler; + + private final TokenHandler tokenHandler; + + private final RegistrationHandler registrationHandler; + + private final RevocationHandler revocationHandler; + + private final ClientAuthenticator clientAuthenticator; + + private final BearerAuthenticator bearerAuthenticator; + + private final ObjectMapper objectMapper; + + /** + * Creates a new OAuthHttpServletSseServerTransportProvider. + * @param objectMapper The JSON object mapper + * @param mcpEndpoint The MCP endpoint path + * @param authProvider The OAuth authorization server provider + * @param issuerUrl The issuer URL for OAuth metadata + * @param registrationOptions The client registration options + * @param revocationOptions The token revocation options + */ + public OAuthHttpServletSseServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, String baseUrl, + OAuthAuthorizationServerProvider authProvider, URI issuerUrl, ClientRegistrationOptions registrationOptions, + RevocationOptions revocationOptions) { + super(objectMapper, baseUrl, mcpEndpoint, "/sse"); + // HttpServletSseServerTransportProvider.builder().baseUrl(baseUrl).messageEndpoint(mcpEndpoint).build(); + this.objectMapper = objectMapper; + + logger.info("Initializing OAuthHttpServletSseServerTransportProvider with base URL: " + baseUrl); + + // Create authenticators + this.clientAuthenticator = new ClientAuthenticator(authProvider); + this.bearerAuthenticator = new BearerAuthenticator(authProvider); + + // Create metadata + OAuthMetadata metadata = new OAuthMetadata(); + metadata.setIssuer(issuerUrl); + metadata.setAuthorizationEndpoint(URI.create(issuerUrl + "/authorize")); + metadata.setTokenEndpoint(URI.create(issuerUrl + "/token")); + metadata.setScopesSupported(registrationOptions.getValidScopes()); + metadata.setResponseTypesSupported(java.util.Arrays.asList("code")); + metadata.setGrantTypesSupported(java.util.Arrays.asList("authorization_code", "refresh_token")); + metadata.setTokenEndpointAuthMethodsSupported(java.util.Arrays.asList("client_secret_post")); + metadata.setCodeChallengeMethodsSupported(java.util.Arrays.asList("S256")); + + if (registrationOptions.isEnabled()) { + metadata.setRegistrationEndpoint(URI.create(issuerUrl + "/register")); + } + + if (revocationOptions.isEnabled()) { + metadata.setRevocationEndpoint(URI.create(issuerUrl + "/revoke")); + metadata.setRevocationEndpointAuthMethodsSupported(java.util.Arrays.asList("client_secret_post")); + } + + // Create handlers + this.metadataHandler = new MetadataHandler(metadata); + this.authorizationHandler = new AuthorizationHandler(authProvider); + this.tokenHandler = new TokenHandler(authProvider, clientAuthenticator); + this.registrationHandler = registrationOptions.isEnabled() + ? new RegistrationHandler(authProvider, registrationOptions) : null; + this.revocationHandler = revocationOptions.isEnabled() + ? new RevocationHandler(authProvider, clientAuthenticator) : null; + + logger.info("OAuthHttpServletSseServerTransportProvider initialized with base URL: " + baseUrl); + } + + /** + * Gets the object mapper. + * @return The object mapper + */ + protected ObjectMapper getObjectMapper() { + return objectMapper; + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + logger.info("Handling OAuth GET request: " + request.getRequestURI()); + String path = request.getRequestURI(); + + // Handle OAuth GET routes + if (path.endsWith("/.well-known/oauth-authorization-server")) { + handleMetadataRequest(request, response); + } + else if (path.endsWith("/authorize")) { + handleAuthorizeRequest(request, response); + } + else { + // Handle other GET requests using the parent class + super.doGet(request, response); + } + } + + /** + * Authenticates a request using the Bearer token in the Authorization header. + * @param request The HTTP request + * @param response The HTTP response + * @return true if authentication succeeded, false otherwise + */ + @Override + protected boolean authenticateRequest(HttpServletRequest request, HttpServletResponse response) throws IOException { + String authHeader = request.getHeader("Authorization"); + + try { + // Use the BearerAuthenticator to validate the token + AccessToken token = bearerAuthenticator.authenticate(authHeader).join(); + + // Create auth context and store it in request attributes and thread-local + AuthContext authContext = new AuthContext(token); + request.setAttribute("authContext", authContext); + AuthContext.setCurrent(authContext); + + return true; + } + catch (Exception e) { + // Clear auth context in case of failure + AuthContext.clearCurrent(); + // Extract the root cause message + String message = e.getCause() != null ? e.getCause().getMessage() : e.getMessage(); + sendAuthError(response, message, HttpServletResponse.SC_UNAUTHORIZED); + return false; + } + } + + @Override + protected HttpServletMcpSessionTransport createSessionTransport(String sessionId, AsyncContext asyncContext, + PrintWriter writer) { + HttpServletMcpSessionTransport transport = super.createSessionTransport(sessionId, asyncContext, writer); + + AuthContext authContext = AuthContext.getCurrent(); + if (authContext != null) { + transport.setAuthContext(authContext); + } + + return transport; + } + + /** + * Sends an authentication error response. + * @param response The HTTP response + * @param message The error message + * @param statusCode The HTTP status code + */ + private void sendAuthError(HttpServletResponse response, String message, int statusCode) throws IOException { + response.setContentType("application/json"); + response.setCharacterEncoding("UTF-8"); + response.setStatus(statusCode); + + McpError error = new McpError(message); + String jsonError = getObjectMapper().writeValueAsString(error); + + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + logger.info("Handling OAuth POST request: " + request.getRequestURI()); + try { + String path = request.getRequestURI(); + + // Handle OAuth POST routes + if (path.endsWith("/token")) { + handleTokenRequest(request, response); + } + else if (path.endsWith("/register") && registrationHandler != null) { + handleRegisterRequest(request, response); + } + else if (path.endsWith("/revoke") && revocationHandler != null) { + handleRevokeRequest(request, response); + } + else if (path.endsWith("/authorize")) { + handleAuthorizeRequest(request, response); + } + else { + // Handle other POST requests using the parent class + super.doPost(request, response); + } + } + finally { + // Clear thread-local auth context after request is processed + AuthContext.clearCurrent(); + } + + } + + private void handleMetadataRequest(HttpServletRequest request, HttpServletResponse response) throws IOException { + try { + OAuthMetadata metadata = metadataHandler.handle().join(); + response.setContentType("application/json"); + response.setStatus(200); + getObjectMapper().writeValue(response.getOutputStream(), metadata); + } + catch (CompletionException ex) { + response.setStatus(500); + } + } + + private void handleAuthorizeRequest(HttpServletRequest request, HttpServletResponse response) throws IOException { + // Extract parameters + Map params = new HashMap<>(); + request.getParameterMap().forEach((key, values) -> { + if (values.length > 0) { + params.put(key, values[0]); + } + }); + + try { + String redirectUrl = authorizationHandler.handle(params).join().getRedirectUrl(); + response.setHeader("Location", redirectUrl); + response.setHeader("Cache-Control", "no-store"); + response.setStatus(302); // Found + } + catch (CompletionException ex) { + response.setStatus(400); + } + } + + // private void handleAuthorizeRequest(HttpServletRequest request, HttpServletResponse + // response) throws IOException { + // if ("GET".equalsIgnoreCase(request.getMethod())) { + // // Render consent page + // String clientId = request.getParameter("client_id"); + // String redirectUri = request.getParameter("redirect_uri"); + // String scope = request.getParameter("scope"); + // String state = request.getParameter("state"); + // String codeChallenge = request.getParameter("code_challenge"); + // String codeChallengeMethod = request.getParameter("code_challenge_method"); + // String responseType = request.getParameter("response_type"); + + // response.setContentType("text/html"); + // response.getWriter() + // .write("" + "

Authorize Access

" + "

Client " + clientId + // + " is requesting access with scope: " + scope + "

" + // + "
" + "" + "" + // + "" + // + "" + // + "" + // + "" + // + "" + // + "" + "
" + ""); + // } + // else if ("POST".equalsIgnoreCase(request.getMethod())) { + // // Extract parameters from form + // Map params = new HashMap<>(); + // request.getParameterMap().forEach((key, values) -> { + // if (values.length > 0) { + // params.put(key, values[0]); + // } + // }); + + // try { + // String redirectUrl = authorizationHandler.handle(params).join().getRedirectUrl(); + // response.setHeader("Location", redirectUrl); + // response.setHeader("Cache-Control", "no-store"); + // response.setStatus(302); // Found + // } + // catch (CompletionException ex) { + // response.setStatus(400); + // } + // } + // } + + private void handleTokenRequest(HttpServletRequest request, HttpServletResponse response) throws IOException { + // Extract parameters + Map params = new HashMap<>(); + request.getParameterMap().forEach((key, values) -> { + if (values.length > 0) { + params.put(key, values[0]); + } + }); + + System.out.println("TOKEN REQUEST PARAMS: " + params); + + try { + io.modelcontextprotocol.auth.OAuthToken token = tokenHandler.handle(params).join(); + response.setContentType("application/json"); + response.setHeader("Cache-Control", "no-store"); + response.setHeader("Pragma", "no-cache"); + response.setStatus(200); + getObjectMapper().writeValue(response.getOutputStream(), token); + } + catch (CompletionException ex) { + response.setStatus(400); + } + } + + private void handleRegisterRequest(HttpServletRequest request, HttpServletResponse response) throws IOException { + // Read request body + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + try { + OAuthClientMetadata clientMetadata = getObjectMapper().readValue(body.toString(), + OAuthClientMetadata.class); + + Object clientInfo = registrationHandler.handle(clientMetadata).join(); + response.setContentType("application/json"); + response.setStatus(201); // Created + getObjectMapper().writeValue(response.getOutputStream(), clientInfo); + } + catch (CompletionException ex) { + response.setStatus(400); + } + } + + private void handleRevokeRequest(HttpServletRequest request, HttpServletResponse response) throws IOException { + // Extract parameters + Map params = new HashMap<>(); + request.getParameterMap().forEach((key, values) -> { + if (values.length > 0) { + params.put(key, values[0]); + } + }); + + try { + revocationHandler.handle(params).join(); + response.setStatus(200); + } + catch (CompletionException ex) { + response.setStatus(400); + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..e7d84decc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -9,6 +9,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.auth.middleware.AuthContext; +import io.modelcontextprotocol.server.auth.middleware.AuthContextProvider; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -242,7 +245,16 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); - exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); + + // Get auth context from transport if it supports it + AuthContext authContext = null; + if (transport instanceof AuthContextProvider authProvider) { + authContext = authProvider.getAuthContext(); + } + + exchangeSink.tryEmitValue( + new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get(), authContext)); + return this.initNotificationHandler.handle(); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthClientProviderTest.java b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthClientProviderTest.java new file mode 100644 index 000000000..398cde5ca --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthClientProviderTest.java @@ -0,0 +1,120 @@ +package io.modelcontextprotocol.auth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.net.URI; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.client.auth.AuthCallbackResult; +import io.modelcontextprotocol.client.auth.OAuthClientProvider; +import io.modelcontextprotocol.client.auth.TokenStorage; + +/** + * Tests for the OAuthClientProvider class. + */ +public class OAuthClientProviderTest { + + private OAuthClientMetadata clientMetadata; + + private TokenStorage mockStorage; + + private Function> mockRedirectHandler; + + private Function> mockCallbackHandler; + + private OAuthClientProvider clientProvider; + + private OAuthToken token; + + private OAuthClientInformation clientInfo; + + @SuppressWarnings("unchecked") + @BeforeEach + public void setup() throws Exception { + // Setup client metadata + clientMetadata = new OAuthClientMetadata(); + clientMetadata.setRedirectUris(List.of(new URI("https://example.com/callback"))); + clientMetadata.setScope("read write"); + + // Setup mock storage + mockStorage = mock(TokenStorage.class); + + // Setup mock handlers + mockRedirectHandler = mock(Function.class); + mockCallbackHandler = mock(Function.class); + + // Setup token and client info + token = new OAuthToken(); + token.setAccessToken("test-access-token"); + token.setRefreshToken("test-refresh-token"); + token.setExpiresIn(3600); + token.setScope("read write"); + + clientInfo = new OAuthClientInformation(); + clientInfo.setClientId("test-client-id"); + clientInfo.setClientSecret("test-client-secret"); + clientInfo.setRedirectUris(List.of(new URI("https://example.com/callback"))); + clientInfo.setScope("read write"); + + // Configure mocks + when(mockStorage.getTokens()).thenReturn(CompletableFuture.completedFuture(token)); + when(mockStorage.getClientInfo()).thenReturn(CompletableFuture.completedFuture(clientInfo)); + when(mockStorage.setTokens(any())).thenReturn(CompletableFuture.completedFuture(null)); + when(mockStorage.setClientInfo(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(mockRedirectHandler.apply(anyString())).thenReturn(CompletableFuture.completedFuture(null)); + + AuthCallbackResult callbackResult = new AuthCallbackResult("test-auth-code", "test-state"); + when(mockCallbackHandler.apply(any())).thenReturn(CompletableFuture.completedFuture(callbackResult)); + + // Create client provider + clientProvider = new OAuthClientProvider("https://auth.example.com", clientMetadata, mockStorage, + mockRedirectHandler, mockCallbackHandler, Duration.ofSeconds(30)); + } + + @Test + public void testInitialize() throws Exception { + // Test initialization + CompletableFuture initFuture = clientProvider.initialize(); + initFuture.get(); + + // Test access token retrieval + String accessToken = clientProvider.getAccessToken(); + assertNotNull(accessToken); + assertEquals("test-access-token", accessToken); + + // Test token retrieval + OAuthToken retrievedToken = clientProvider.getCurrentTokens(); + assertNotNull(retrievedToken); + assertEquals(token.getAccessToken(), retrievedToken.getAccessToken()); + assertEquals(token.getRefreshToken(), retrievedToken.getRefreshToken()); + } + + @Test + public void testEnsureToken() throws Exception { + // Initialize first + clientProvider.initialize().get(); + + // Test token validation + CompletableFuture tokenFuture = clientProvider.ensureToken(); + tokenFuture.get(); + + // Token should be valid and accessible + String accessToken = clientProvider.getAccessToken(); + assertNotNull(accessToken); + assertEquals("test-access-token", accessToken); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthFlowTest.java b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthFlowTest.java new file mode 100644 index 000000000..e8de4e9f7 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthFlowTest.java @@ -0,0 +1,113 @@ +package io.modelcontextprotocol.auth; + +import java.net.URI; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for the OAuth authentication flow. + */ +public class OAuthFlowTest { + + private OAuthAuthorizationServerProvider mockProvider; + + private OAuthClientInformation clientInfo; + + private AuthorizationCode authCode; + + private OAuthToken token; + + @BeforeEach + public void setup() throws Exception { + // Setup mock provider + mockProvider = mock(OAuthAuthorizationServerProvider.class); + + // Setup test client + clientInfo = new OAuthClientInformation(); + clientInfo.setClientId("test-client-id"); + clientInfo.setClientSecret("test-client-secret"); + clientInfo.setRedirectUris(List.of(new URI("https://example.com/callback"))); + clientInfo.setScope("read write"); + + // Setup test auth code + authCode = new AuthorizationCode(); + authCode.setCode("test-auth-code"); + authCode.setClientId(clientInfo.getClientId()); + authCode.setScopes(Arrays.asList("read", "write")); + authCode.setExpiresAt(Instant.now().plusSeconds(600).getEpochSecond()); + authCode.setCodeChallenge("test-code-challenge"); + authCode.setRedirectUri(clientInfo.getRedirectUris().get(0)); + authCode.setRedirectUriProvidedExplicitly(true); + + // Setup test token + token = new OAuthToken(); + token.setAccessToken("test-access-token"); + token.setRefreshToken("test-refresh-token"); + token.setExpiresIn(3600); + token.setScope("read write"); + + // Configure mock provider + when(mockProvider.getClient(clientInfo.getClientId())) + .thenReturn(CompletableFuture.completedFuture(clientInfo)); + + when(mockProvider.authorize(any(), any())) + .thenReturn(CompletableFuture.completedFuture("https://example.com/auth?code=test-auth-code")); + + when(mockProvider.loadAuthorizationCode(any(), any())).thenReturn(CompletableFuture.completedFuture(authCode)); + + when(mockProvider.exchangeAuthorizationCode(any(), any())).thenReturn(CompletableFuture.completedFuture(token)); + } + + @Test + public void testAuthorizationCodeFlow() throws Exception { + // Test client lookup + CompletableFuture clientFuture = mockProvider.getClient(clientInfo.getClientId()); + OAuthClientInformation retrievedClient = clientFuture.get(); + + assertNotNull(retrievedClient); + assertEquals(clientInfo.getClientId(), retrievedClient.getClientId()); + + // Test authorization + AuthorizationParams params = new AuthorizationParams(); + params.setState(UUID.randomUUID().toString()); + params.setScopes(Arrays.asList("read", "write")); + params.setCodeChallenge("test-code-challenge"); + params.setRedirectUri(clientInfo.getRedirectUris().get(0)); + params.setRedirectUriProvidedExplicitly(true); + + CompletableFuture authUrlFuture = mockProvider.authorize(clientInfo, params); + String authUrl = authUrlFuture.get(); + + assertNotNull(authUrl); + assertTrue(authUrl.startsWith("https://example.com/auth?code=")); + + // Test code exchange + CompletableFuture codeFuture = mockProvider.loadAuthorizationCode(clientInfo, + "test-auth-code"); + AuthorizationCode retrievedCode = codeFuture.get(); + + assertNotNull(retrievedCode); + assertEquals(authCode.getCode(), retrievedCode.getCode()); + + CompletableFuture tokenFuture = mockProvider.exchangeAuthorizationCode(clientInfo, retrievedCode); + OAuthToken retrievedToken = tokenFuture.get(); + + assertNotNull(retrievedToken); + assertEquals(token.getAccessToken(), retrievedToken.getAccessToken()); + assertEquals(token.getRefreshToken(), retrievedToken.getRefreshToken()); + assertEquals(token.getExpiresIn(), retrievedToken.getExpiresIn()); + } + +} \ No newline at end of file diff --git a/pom.xml b/pom.xml index c2327ee8d..42d1d906d 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ mcp-spring/mcp-spring-webflux mcp-spring/mcp-spring-webmvc mcp-test + examples/auth-example @@ -368,4 +369,4 @@ - + \ No newline at end of file