diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9eed688a..719f9c7a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -41,6 +41,7 @@ jobs: clean \ ktlintCheck \ build \ + -x :conformance-test:test \ koverLog koverHtmlReport \ publishToMavenLocal \ -Pversion=0.0.1-SNAPSHOT diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml new file mode 100644 index 00000000..c736ad3d --- /dev/null +++ b/.github/workflows/conformance.yml @@ -0,0 +1,50 @@ +name: Conformance Tests + +on: + workflow_dispatch: + pull_request: + branches: [ main ] + push: + branches: [ main ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + # Cancel only when the run is NOT on `main` branch + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +jobs: + run-conformance: + runs-on: macos-latest-xlarge + name: Run Conformance Tests + timeout-minutes: 20 + env: + JAVA_OPTS: "-Xmx8g -Dfile.encoding=UTF-8 -Djava.awt.headless=true -Dkotlin.daemon.jvm.options=-Xmx6g" + steps: + - uses: actions/checkout@v6 + + - name: Set up JDK 21 + uses: actions/setup-java@v5 + with: + java-version: '21' + distribution: 'temurin' + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '24.11.1' + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v5 + with: + add-job-summary: 'always' + cache-read-only: true + + - name: Run Conformance Tests + run: ./gradlew --no-daemon :conformance-test:test + + - name: Upload Conformance Results + if: ${{ !cancelled() }} + uses: actions/upload-artifact@v5 + with: + name: conformance-results + path: conformance-test/results/ diff --git a/.github/workflows/gradle-publish.yml b/.github/workflows/gradle-publish.yml index d9b28012..bb81dbe5 100644 --- a/.github/workflows/gradle-publish.yml +++ b/.github/workflows/gradle-publish.yml @@ -34,7 +34,7 @@ jobs: uses: gradle/actions/setup-gradle@v5 - name: Clean Build with Gradle - run: ./gradlew clean build + run: ./gradlew clean build -x :conformance-test:test - name: Publish to Maven Central Portal id: publish diff --git a/.gitignore b/.gitignore index ea62a956..a8c18ddb 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,6 @@ dist ### SWE agents ### .claude/ .junie/ + +### Conformance test results ### +conformance-test/results/ diff --git a/conformance-test/build.gradle.kts b/conformance-test/build.gradle.kts new file mode 100644 index 00000000..66e57d15 --- /dev/null +++ b/conformance-test/build.gradle.kts @@ -0,0 +1,38 @@ +import org.gradle.api.tasks.testing.logging.TestExceptionFormat + +plugins { + kotlin("jvm") +} + +dependencies { + testImplementation(project(":kotlin-sdk")) + testImplementation(kotlin("test")) + testImplementation(libs.kotlin.logging) + testImplementation(libs.ktor.client.cio) + testImplementation(libs.ktor.server.cio) + testImplementation(libs.ktor.server.websockets) + testRuntimeOnly(libs.slf4j.simple) +} + +tasks.test { + useJUnitPlatform() + + testLogging { + events("passed", "skipped", "failed") + showStandardStreams = true + showExceptions = true + showCauses = true + showStackTraces = true + exceptionFormat = TestExceptionFormat.FULL + } + + doFirst { + systemProperty("test.classpath", classpath.asPath) + + println("\n" + "=".repeat(60)) + println("MCP CONFORMANCE TESTS") + println("=".repeat(60)) + println("These tests validate compliance with the MCP specification.") + println("=".repeat(60) + "\n") + } +} diff --git a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceClient.kt b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceClient.kt new file mode 100644 index 00000000..09ea50bf --- /dev/null +++ b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceClient.kt @@ -0,0 +1,85 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject + +private val logger = KotlinLogging.logger {} + +fun main(args: Array) { + require(args.isNotEmpty()) { + "Server URL must be provided as an argument" + } + + val serverUrl = args.last() + logger.info { "Connecting to test server at: $serverUrl" } + + val httpClient = HttpClient(CIO) { + install(SSE) + } + val transport: Transport = StreamableHttpClientTransport(httpClient, serverUrl) + + val client = Client( + clientInfo = Implementation( + name = "kotlin-conformance-client", + version = "1.0.0", + ), + ) + + var exitCode = 0 + + runBlocking { + try { + client.connect(transport) + logger.info { "✅ Connected to server successfully" } + + try { + val tools = client.listTools() + logger.info { "Available tools: ${tools.tools.map { it.name }}" } + + if (tools.tools.isNotEmpty()) { + val toolName = tools.tools.first().name + logger.info { "Calling tool: $toolName" } + + val result = client.callTool( + CallToolRequest( + params = CallToolRequestParams( + name = toolName, + arguments = buildJsonObject { + put("input", JsonPrimitive("test")) + }, + ), + ), + ) + logger.info { "Tool result: ${result.content}" } + } + } catch (e: Exception) { + logger.debug(e) { "Error during tool operations (may be expected for some scenarios)" } + } + + logger.info { "✅ Client operations completed successfully" } + } catch (e: Exception) { + logger.error(e) { "❌ Client failed" } + exitCode = 1 + } finally { + try { + transport.close() + } catch (e: Exception) { + logger.warn(e) { "Error closing transport" } + } + httpClient.close() + } + } + + kotlin.system.exitProcess(exitCode) +} diff --git a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt new file mode 100644 index 00000000..01bdd403 --- /dev/null +++ b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt @@ -0,0 +1,422 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.http.ContentType +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.ApplicationCall +import io.ktor.server.application.install +import io.ktor.server.cio.CIO +import io.ktor.server.engine.embeddedServer +import io.ktor.server.request.header +import io.ktor.server.request.receiveText +import io.ktor.server.response.header +import io.ktor.server.response.respond +import io.ktor.server.response.respondText +import io.ktor.server.response.respondTextWriter +import io.ktor.server.routing.delete +import io.ktor.server.routing.get +import io.ktor.server.routing.post +import io.ktor.server.routing.routing +import io.ktor.server.websocket.WebSockets +import io.ktor.server.websocket.webSocket +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.WebSocketMcpServerTransport +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions +import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.types.RequestId +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.launch +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.put +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +private val logger = KotlinLogging.logger {} +private val serverTransports = ConcurrentHashMap() +private val jsonFormat = Json { ignoreUnknownKeys = true } + +private const val SESSION_CREATION_TIMEOUT_MS = 2000L +private const val REQUEST_TIMEOUT_MS = 10_000L +private const val MESSAGE_QUEUE_CAPACITY = 256 + +private fun isInitializeRequest(json: JsonElement): Boolean = + json is JsonObject && json["method"]?.jsonPrimitive?.contentOrNull == "initialize" + +fun main(args: Array) { + val port = args.getOrNull(0)?.toIntOrNull() ?: 3000 + + logger.info { "Starting MCP Conformance Server on port $port" } + + embeddedServer(CIO, port = port, host = "127.0.0.1") { + install(WebSockets) + + routing { + webSocket("/ws") { + logger.info { "WebSocket connection established" } + val transport = WebSocketMcpServerTransport(this) + val server = createConformanceServer() + + try { + server.createSession(transport) + } catch (e: Exception) { + logger.error(e) { "Error in WebSocket session" } + throw e + } + } + + get("/mcp") { + val sessionId = call.request.header("mcp-session-id") + ?: run { + call.respond(HttpStatusCode.BadRequest, "Missing mcp-session-id header") + return@get + } + val transport = serverTransports[sessionId] + ?: run { + call.respond(HttpStatusCode.BadRequest, "Invalid mcp-session-id") + return@get + } + transport.stream(call) + } + + post("/mcp") { + val sessionId = call.request.header("mcp-session-id") + val requestBody = call.receiveText() + + logger.debug { "Received request with sessionId: $sessionId" } + logger.trace { "Request body: $requestBody" } + + val jsonElement = try { + jsonFormat.parseToJsonElement(requestBody) + } catch (e: Exception) { + logger.error(e) { "Failed to parse request body as JSON" } + call.respond( + HttpStatusCode.BadRequest, + jsonFormat.encodeToString( + JsonObject.serializer(), + buildJsonObject { + put("jsonrpc", "2.0") + put( + "error", + buildJsonObject { + put("code", -32700) + put("message", "Parse error: ${e.message}") + }, + ) + put("id", JsonNull) + }, + ), + ) + return@post + } + + val transport = sessionId?.let { serverTransports[it] } + if (transport != null) { + logger.debug { "Using existing transport for session: $sessionId" } + transport.handleRequest(call, jsonElement) + } else { + if (isInitializeRequest(jsonElement)) { + val newSessionId = UUID.randomUUID().toString() + logger.info { "Creating new session with ID: $newSessionId" } + + val newTransport = HttpServerTransport(newSessionId) + serverTransports[newSessionId] = newTransport + + val mcpServer = createConformanceServer() + call.response.header("mcp-session-id", newSessionId) + + val sessionReady = CompletableDeferred() + CoroutineScope(Dispatchers.IO).launch { + try { + mcpServer.createSession(newTransport) + sessionReady.complete(Unit) + } catch (e: Exception) { + logger.error(e) { "Failed to create session" } + serverTransports.remove(newSessionId) + newTransport.close() + sessionReady.completeExceptionally(e) + } + } + + val sessionCreated = withTimeoutOrNull(SESSION_CREATION_TIMEOUT_MS) { + sessionReady.await() + } + + if (sessionCreated == null) { + logger.error { "Session creation timed out" } + serverTransports.remove(newSessionId) + call.respond( + HttpStatusCode.InternalServerError, + jsonFormat.encodeToString( + JsonObject.serializer(), + buildJsonObject { + put("jsonrpc", "2.0") + put( + "error", + buildJsonObject { + put("code", -32000) + put("message", "Session creation timed out") + }, + ) + put("id", JsonNull) + }, + ), + ) + return@post + } + + newTransport.handleRequest(call, jsonElement) + } else { + logger.warn { "Invalid request: no session ID or not an initialization request" } + call.respond( + HttpStatusCode.BadRequest, + jsonFormat.encodeToString( + JsonObject.serializer(), + buildJsonObject { + put("jsonrpc", "2.0") + put( + "error", + buildJsonObject { + put("code", -32000) + put("message", "Bad Request: No valid session ID provided") + }, + ) + put("id", JsonNull) + }, + ), + ) + } + } + } + + delete("/mcp") { + val sessionId = call.request.header("mcp-session-id") + val transport = sessionId?.let { serverTransports[it] } + if (transport != null) { + logger.info { "Terminating session: $sessionId" } + serverTransports.remove(sessionId) + transport.close() + call.respond(HttpStatusCode.OK) + } else { + logger.warn { "Invalid session termination request: $sessionId" } + call.respond(HttpStatusCode.BadRequest, "Invalid or missing session ID") + } + } + } + }.start(wait = true) +} + +private fun createConformanceServer(): Server { + val server = Server( + Implementation( + name = "kotlin-conformance-server", + version = "1.0.0", + ), + ServerOptions( + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = true), + resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), + prompts = ServerCapabilities.Prompts(listChanged = true), + ), + ), + ) + + server.addTool( + name = "test-tool", + description = "A test tool for conformance testing", + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "input", + buildJsonObject { + put("type", "string") + put("description", "Test input parameter") + }, + ) + }, + required = listOf("input"), + ), + ) { request -> + val input = (request.params.arguments?.get("input") as? JsonPrimitive)?.content ?: "no input" + CallToolResult( + content = listOf(TextContent("Tool executed with input: $input")), + ) + } + + server.addResource( + uri = "test://test-resource", + name = "Test Resource", + description = "A test resource for conformance testing", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents("Test resource content", request.params.uri, "text/plain"), + ), + ) + } + + server.addPrompt( + name = "test-prompt", + description = "A test prompt for conformance testing", + arguments = listOf( + PromptArgument( + name = "arg", + description = "Test argument", + required = false, + ), + ), + ) { + GetPromptResult( + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent("Test prompt content"), + ), + ), + description = "Test prompt description", + ) + } + + return server +} + +private class HttpServerTransport(private val sessionId: String) : AbstractTransport() { + private val logger = KotlinLogging.logger {} + private val pendingResponses = ConcurrentHashMap>() + private val messageQueue = Channel(MESSAGE_QUEUE_CAPACITY) + + suspend fun stream(call: ApplicationCall) { + logger.debug { "Starting SSE stream for session $sessionId" } + call.response.apply { + header("Cache-Control", "no-cache") + header("Connection", "keep-alive") + } + call.respondTextWriter(ContentType.Text.EventStream) { + try { + while (true) { + val msg = messageQueue.receiveCatching().getOrNull() ?: break + write("event: message\ndata: ${McpJson.encodeToString(msg)}\n\n") + flush() + } + } catch (e: Exception) { + logger.warn(e) { "SSE stream terminated for session $sessionId" } + } finally { + logger.debug { "SSE stream closed for session $sessionId" } + } + } + } + + suspend fun handleRequest(call: ApplicationCall, requestBody: JsonElement) { + try { + val message = McpJson.decodeFromJsonElement(requestBody) + logger.debug { "Handling ${message::class.simpleName}: $requestBody" } + + when (message) { + is JSONRPCRequest -> { + val idKey = when (val id = message.id) { + is RequestId.NumberId -> id.value.toString() + is RequestId.StringId -> id.value + } + val responseDeferred = CompletableDeferred() + pendingResponses[idKey] = responseDeferred + + _onMessage.invoke(message) + + val response = withTimeoutOrNull(REQUEST_TIMEOUT_MS) { responseDeferred.await() } + if (response != null) { + call.respondText(McpJson.encodeToString(response), ContentType.Application.Json) + } else { + pendingResponses.remove(idKey) + logger.warn { "Timeout for request $idKey" } + call.respondText( + McpJson.encodeToString( + JSONRPCError( + message.id, + RPCError(RPCError.ErrorCode.REQUEST_TIMEOUT, "Request timed out"), + ), + ), + ContentType.Application.Json, + ) + } + } + + else -> call.respond(HttpStatusCode.Accepted) + } + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + logger.error(e) { "Error handling request" } + if (!call.response.isCommitted) { + call.respondText( + McpJson.encodeToString( + JSONRPCError( + RequestId(0), + RPCError(RPCError.ErrorCode.INTERNAL_ERROR, "Internal error: ${e.message}"), + ), + ), + ContentType.Application.Json, + HttpStatusCode.InternalServerError, + ) + } + } + } + + override suspend fun start() { + logger.debug { "Started transport for session $sessionId" } + } + + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + when (message) { + is JSONRPCResponse -> { + val idKey = when (val id = message.id) { + is RequestId.NumberId -> id.value.toString() + is RequestId.StringId -> id.value + } + pendingResponses.remove(idKey)?.complete(message) ?: run { + logger.warn { "No pending response for ID $idKey, queueing" } + messageQueue.send(message) + } + } + + else -> messageQueue.send(message) + } + } + + override suspend fun close() { + logger.debug { "Closing transport for session $sessionId" } + messageQueue.close() + pendingResponses.clear() + _onClose.invoke() + } +} diff --git a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTest.kt b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTest.kt new file mode 100644 index 00000000..4137943e --- /dev/null +++ b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTest.kt @@ -0,0 +1,328 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.github.oshai.kotlinlogging.KotlinLogging +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.DynamicTest +import org.junit.jupiter.api.TestFactory +import org.junit.jupiter.api.TestInstance +import java.io.BufferedReader +import java.io.InputStreamReader +import java.lang.management.ManagementFactory +import java.net.HttpURLConnection +import java.net.ServerSocket +import java.net.URI +import java.util.concurrent.TimeUnit +import kotlin.properties.Delegates + +private val logger = KotlinLogging.logger {} + +enum class TransportType { + SSE, + WEBSOCKET, +} + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class ConformanceTest { + + private var serverProcess: Process? = null + private var serverPort: Int by Delegates.notNull() + private val serverErrorOutput = mutableListOf() + private val maxErrorLines = 500 + + companion object { + private val SERVER_SCENARIOS = listOf( + "server-initialize", + "tools-list", + "tools-call-simple-text", + "resources-list", + "prompts-list", + // TODO: Fix + // - resources-read-text + // - prompts-get-simple + ) + + private val CLIENT_SCENARIOS = listOf( + "initialize", + // TODO: Fix + // "tools-call", + ) + + private val SERVER_TRANSPORT_TYPES = listOf( + TransportType.SSE, + // TODO: Fix +// TransportType.WEBSOCKET, + ) + + private val CLIENT_TRANSPORT_TYPES = listOf( + TransportType.SSE, + TransportType.WEBSOCKET, + ) + + private const val DEFAULT_TEST_TIMEOUT_SECONDS = 30L + private const val DEFAULT_SERVER_STARTUP_TIMEOUT_SECONDS = 10 + private const val INITIAL_BACKOFF_MS = 50L + private const val MAX_BACKOFF_MS = 500L + private const val BACKOFF_MULTIPLIER = 1.5 + private const val CONNECTION_TIMEOUT_MS = 500 + private const val GRACEFUL_SHUTDOWN_SECONDS = 5L + private const val FORCE_SHUTDOWN_SECONDS = 2L + + private fun findFreePort(): Int = ServerSocket(0).use { it.localPort } + + private fun getRuntimeClasspath(): String = ManagementFactory.getRuntimeMXBean().classPath + + private fun getTestClasspath(): String = System.getProperty("test.classpath") ?: getRuntimeClasspath() + + private fun waitForServerReady( + url: String, + timeoutSeconds: Int = DEFAULT_SERVER_STARTUP_TIMEOUT_SECONDS, + ): Boolean { + val deadline = System.currentTimeMillis() + (timeoutSeconds * 1000) + var lastError: Exception? = null + var backoffMs = INITIAL_BACKOFF_MS + + while (System.currentTimeMillis() < deadline) { + try { + val connection = URI(url).toURL().openConnection() as HttpURLConnection + connection.requestMethod = "GET" + connection.connectTimeout = CONNECTION_TIMEOUT_MS + connection.readTimeout = CONNECTION_TIMEOUT_MS + connection.connect() + + val responseCode = connection.responseCode + connection.disconnect() + logger.debug { "Server responded with code: $responseCode" } + return true + } catch (e: Exception) { + lastError = e + Thread.sleep(backoffMs) + backoffMs = (backoffMs * BACKOFF_MULTIPLIER).toLong().coerceAtMost(MAX_BACKOFF_MS) + } + } + + logger.error { "Server did not start within $timeoutSeconds seconds. Last error: ${lastError?.message}" } + return false + } + } + + @BeforeAll + fun startServer() { + serverPort = findFreePort() + val serverUrl = "http://127.0.0.1:$serverPort/mcp" + + logger.info { "Starting conformance test server on port $serverPort" } + + val processBuilder = ProcessBuilder( + "java", + "-cp", + getRuntimeClasspath(), + "io.modelcontextprotocol.kotlin.sdk.conformance.ConformanceServerKt", + serverPort.toString(), + ) + + val process = processBuilder.start() + serverProcess = process + + // capture stderr in the background + Thread { + try { + BufferedReader(InputStreamReader(process.errorStream)).use { reader -> + reader.lineSequence().forEach { line -> + synchronized(serverErrorOutput) { + if (serverErrorOutput.size >= maxErrorLines) { + serverErrorOutput.removeAt(0) + } + serverErrorOutput.add(line) + } + logger.debug { "Server stderr: $line" } + } + } + } catch (e: Exception) { + logger.trace(e) { "Error reading server stderr" } + } + }.apply { + name = "server-stderr-reader" + isDaemon = true + }.start() + + logger.info { "Waiting for server to start..." } + val serverReady = waitForServerReady(serverUrl) + + if (!serverReady) { + val errorInfo = synchronized(serverErrorOutput) { + if (serverErrorOutput.isNotEmpty()) { + "\n\nServer error output:\n${serverErrorOutput.joinToString("\n")}" + } else { + "" + } + } + serverProcess?.destroyForcibly() + throw IllegalStateException( + "Server failed to start within $DEFAULT_SERVER_STARTUP_TIMEOUT_SECONDS seconds. " + + "Check if port $serverPort is available.$errorInfo", + ) + } + + logger.info { "Server started successfully at $serverUrl" } + } + + @AfterAll + fun stopServer() { + serverProcess?.also { process -> + logger.info { "Stopping conformance test server (PID: ${process.pid()})" } + + try { + process.destroy() + val terminated = process.waitFor(GRACEFUL_SHUTDOWN_SECONDS, TimeUnit.SECONDS) + + if (!terminated) { + logger.warn { "Server did not terminate gracefully, forcing shutdown..." } + process.destroyForcibly() + process.waitFor(FORCE_SHUTDOWN_SECONDS, TimeUnit.SECONDS) + } else { + logger.info { "Server stopped gracefully" } + } + } catch (e: Exception) { + logger.error(e) { "Error stopping server process" } + } finally { + serverProcess = null + } + } ?: logger.debug { "No server process to stop" } + } + + @TestFactory + fun `MCP Server Conformance Tests`(): List = SERVER_TRANSPORT_TYPES.flatMap { transportType -> + SERVER_SCENARIOS.map { scenario -> + DynamicTest.dynamicTest("Server [$transportType]: $scenario") { + runServerConformanceTest(scenario, transportType) + } + } + } + + @TestFactory + fun `MCP Client Conformance Tests`(): List = CLIENT_TRANSPORT_TYPES.flatMap { transportType -> + CLIENT_SCENARIOS.map { scenario -> + DynamicTest.dynamicTest("Client [$transportType]: $scenario") { + runClientConformanceTest(scenario, transportType) + } + } + } + + private fun runServerConformanceTest(scenario: String, transportType: TransportType) { + val processBuilder = when (transportType) { + TransportType.SSE -> { + val serverUrl = "http://127.0.0.1:$serverPort/mcp" + ProcessBuilder( + "npx", + "@modelcontextprotocol/conformance", + "server", + "--url", + serverUrl, + "--scenario", + scenario, + ).apply { + inheritIO() + } + } + + TransportType.WEBSOCKET -> { + val serverUrl = "ws://127.0.0.1:$serverPort/ws" + ProcessBuilder( + "npx", + "@modelcontextprotocol/conformance", + "server", + "--url", + serverUrl, + "--scenario", + scenario, + ).apply { + inheritIO() + } + } + } + + runConformanceTest("server", scenario, processBuilder, transportType) + } + + private fun runClientConformanceTest(scenario: String, transportType: TransportType) { + val testClasspath = getTestClasspath() + + val clientCommand = when (transportType) { + TransportType.SSE -> { + val serverUrl = "http://127.0.0.1:$serverPort/mcp" + listOf( + "java", + "-cp", + testClasspath, + "io.modelcontextprotocol.kotlin.sdk.conformance.ConformanceClientKt", + serverUrl, + ) + } + + TransportType.WEBSOCKET -> { + val serverUrl = "ws://127.0.0.1:$serverPort/ws" + listOf( + "java", + "-cp", + testClasspath, + "io.modelcontextprotocol.kotlin.sdk.conformance.WebSocketConformanceClientKt", + serverUrl, + ) + } + } + + val processBuilder = ProcessBuilder( + "npx", + "@modelcontextprotocol/conformance", + "client", + "--command", + clientCommand.joinToString(" "), + "--scenario", + scenario, + ).apply { + inheritIO() + } + + runConformanceTest("client", scenario, processBuilder, transportType) + } + + private fun runConformanceTest( + type: String, + scenario: String, + processBuilder: ProcessBuilder, + transportType: TransportType, + ) { + val capitalizedType = type.replaceFirstChar { it.uppercase() } + logger.info { "Running $type conformance test [$transportType]: $scenario" } + + val timeoutSeconds = System.getenv("CONFORMANCE_TEST_TIMEOUT_SECONDS")?.toLongOrNull() + ?: DEFAULT_TEST_TIMEOUT_SECONDS + + val process = processBuilder.start() + val completed = process.waitFor(timeoutSeconds, TimeUnit.SECONDS) + + if (!completed) { + logger.error { + "$capitalizedType conformance test [$transportType] '$scenario' timed out after $timeoutSeconds seconds" + } + process.destroyForcibly() + throw AssertionError( + "❌ $capitalizedType conformance test [$transportType] '$scenario' timed out after $timeoutSeconds seconds", + ) + } + + when (val exitCode = process.exitValue()) { + 0 -> logger.info { "✅ $capitalizedType conformance test [$transportType] '$scenario' passed!" } + + else -> { + logger.error { + "$capitalizedType conformance test [$transportType] '$scenario' failed with exit code: $exitCode" + } + throw AssertionError( + "❌ $capitalizedType conformance test [$transportType] '$scenario' failed (exit code: $exitCode). Check test output above for details.", + ) + } + } + } +} diff --git a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/WebSocketConformanceClient.kt b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/WebSocketConformanceClient.kt new file mode 100644 index 00000000..f1123192 --- /dev/null +++ b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/WebSocketConformanceClient.kt @@ -0,0 +1,104 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.websocket.WebSockets +import io.ktor.client.plugins.websocket.webSocket +import io.ktor.websocket.WebSocketSession +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL +import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject + +private val logger = KotlinLogging.logger {} + +class WebSocketClientTransport(override val session: WebSocketSession) : WebSocketMcpTransport() { + override suspend fun initializeSession() { + logger.debug { "WebSocket client session initialized" } + } +} + +fun main(args: Array) { + require(args.isNotEmpty()) { + "Server WebSocket URL must be provided as an argument" + } + + val serverUrl = args.last() + logger.info { "Connecting to WebSocket test server at: $serverUrl" } + + val httpClient = HttpClient(CIO) { + install(WebSockets) + } + + var exitCode = 0 + + runBlocking { + try { + httpClient.webSocket(serverUrl, request = { + headers.append("Sec-WebSocket-Protocol", MCP_SUBPROTOCOL) + }) { + val transport = WebSocketClientTransport(this) + + val client = Client( + clientInfo = Implementation( + name = "kotlin-conformance-client-websocket", + version = "1.0.0", + ), + ) + + try { + client.connect(transport) + logger.info { "✅ Connected to server successfully" } + + try { + val tools = client.listTools() + logger.info { "Available tools: ${tools.tools.map { it.name }}" } + + if (tools.tools.isNotEmpty()) { + val toolName = tools.tools.first().name + logger.info { "Calling tool: $toolName" } + + val result = client.callTool( + CallToolRequest( + params = CallToolRequestParams( + name = toolName, + arguments = buildJsonObject { + put("input", JsonPrimitive("test")) + }, + ), + ), + ) + logger.info { "Tool result: ${result.content}" } + } + } catch (e: Exception) { + logger.debug(e) { "Error during tool operations (may be expected for some scenarios)" } + } + + logger.info { "✅ Client operations completed successfully" } + } catch (e: Exception) { + logger.error(e) { "❌ Client failed" } + exitCode = 1 + } finally { + try { + transport.close() + } catch (e: Exception) { + logger.warn(e) { "Error closing transport" } + } + } + } + } catch (e: Exception) { + logger.error(e) { "❌ WebSocket connection failed" } + exitCode = 1 + } finally { + httpClient.close() + } + } + + kotlin.system.exitProcess(exitCode) +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 368bff38..d7ec54f8 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -23,4 +23,5 @@ include( ":kotlin-sdk-server", ":kotlin-sdk", ":kotlin-sdk-test", + ":conformance-test", )