diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index 9b9a89ee1..0e677692e 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -13,10 +13,6 @@ true - - - - diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index e8ca56bc4..849ad2ffe 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -1,12 +1,13 @@ -using ModelContextProtocol.Client; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; -using Microsoft.Extensions.AI; -using Microsoft.Extensions.Logging; using Moq; +using System.Reflection; namespace ModelContextProtocol.Tests.Server; @@ -43,7 +44,7 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = public async Task Constructor_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); // Assert Assert.NotNull(server); @@ -53,21 +54,21 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() public void Constructor_Throws_For_Null_Transport() { // Arrange, Act & Assert - Assert.Throws(() => new McpServer(null!, _options, _loggerFactory.Object, _serviceProvider)); + Assert.Throws(() => McpServerFactory.Create(null!, _options, _loggerFactory.Object, _serviceProvider)); } [Fact] public void Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws(() => new McpServer(_serverTransport.Object, null!, _loggerFactory.Object, _serviceProvider)); + Assert.Throws(() => McpServerFactory.Create(_serverTransport.Object, null!, _loggerFactory.Object, _serviceProvider)); } [Fact] public async Task Constructor_Does_Not_Throw_For_Null_Logger() { // Arrange & Act - await using var server = new McpServer(_serverTransport.Object, _options, null, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, null, _serviceProvider); // Assert Assert.NotNull(server); @@ -77,25 +78,17 @@ public async Task Constructor_Does_Not_Throw_For_Null_Logger() public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() { // Arrange & Act - await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, null); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, null); // Assert Assert.NotNull(server); } - [Fact] - public async Task Property_EndpointName_Return_Infos() - { - await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); - server.ClientInfo = new Implementation { Name = "TestClient", Version = "1.1" }; - Assert.Equal("Server (TestServer 1.0), Client (TestClient 1.1)", server.EndpointName); - } - [Fact] public async Task StartAsync_Should_Throw_InvalidOperationException_If_Already_Initializing() { // Arrange - await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); server.GetType().GetField("_isInitializing", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?.SetValue(server, true); // Act & Assert @@ -106,8 +99,8 @@ public async Task StartAsync_Should_Throw_InvalidOperationException_If_Already_I public async Task StartAsync_Should_Do_Nothing_If_Already_Initialized() { // Arrange - await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); - server.IsInitialized = true; + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + SetInitialized(server, true); await server.StartAsync(TestContext.Current.CancellationToken); @@ -119,7 +112,7 @@ public async Task StartAsync_Should_Do_Nothing_If_Already_Initialized() public async Task StartAsync_ShouldStartListening() { // Arrange - await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); // Act await server.StartAsync(TestContext.Current.CancellationToken); @@ -132,7 +125,7 @@ public async Task StartAsync_ShouldStartListening() public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initialized_Notification() { await using var transport = new TestServerTransport(); - await using var server = new McpServer(transport, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(transport, _options, _loggerFactory.Object, _serviceProvider); await server.StartAsync(TestContext.Current.CancellationToken); @@ -152,8 +145,8 @@ await transport.SendMessageAsync(new JsonRpcNotification public async Task RequestSamplingAsync_Should_Throw_McpServerException_If_Client_Does_Not_Support_Sampling() { // Arrange - await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); - server.ClientCapabilities = new ClientCapabilities(); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + SetClientCapabilities(server, new ClientCapabilities()); var action = () => server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -166,8 +159,8 @@ public async Task RequestSamplingAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = new McpServer(transport, _options, _loggerFactory.Object, _serviceProvider); - server.ClientCapabilities = new ClientCapabilities { Sampling = new SamplingCapability() }; + await using var server = McpServerFactory.Create(transport, _options, _loggerFactory.Object, _serviceProvider); + SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); await server.StartAsync(TestContext.Current.CancellationToken); @@ -184,8 +177,8 @@ public async Task RequestSamplingAsync_Should_SendRequest() public async Task RequestRootsAsync_Should_Throw_McpServerException_If_Client_Does_Not_Support_Roots() { // Arrange - await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); - server.ClientCapabilities = new ClientCapabilities(); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert await Assert.ThrowsAsync("server", () => server.RequestRootsAsync(new ListRootsRequestParams(), CancellationToken.None)); @@ -196,8 +189,8 @@ public async Task RequestRootsAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = new McpServer(transport, _options, _loggerFactory.Object, _serviceProvider); - server.ClientCapabilities = new ClientCapabilities { Roots = new RootsCapability() }; + await using var server = McpServerFactory.Create(transport, _options, _loggerFactory.Object, _serviceProvider); + SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); await server.StartAsync(TestContext.Current.CancellationToken); // Act @@ -213,8 +206,8 @@ public async Task RequestRootsAsync_Should_SendRequest() [Fact] public async Task Throws_Exception_If_Not_Connected() { - await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); - server.ClientCapabilities = new ClientCapabilities { Roots = new RootsCapability() }; + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); _serverTransport.SetupGet(t => t.IsConnected).Returns(false); var action = async () => await server.RequestRootsAsync(new ListRootsRequestParams(), CancellationToken.None); @@ -522,7 +515,7 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s var options = CreateOptions(serverCapabilities); configureOptions?.Invoke(options); - await using var server = new McpServer(transport, options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(transport, options, _loggerFactory.Object, _serviceProvider); await server.StartAsync(); @@ -595,6 +588,20 @@ public async Task AsSamplingChatClient_HandlesRequestResponse() Assert.Equal(ChatRole.Assistant, response.Messages[0].Role); } + private static void SetClientCapabilities(IMcpServer server, ClientCapabilities capabilities) + { + PropertyInfo? property = server.GetType().GetProperty("ClientCapabilities", BindingFlags.Public | BindingFlags.Instance); + Assert.NotNull(property); + property.SetValue(server, capabilities); + } + + private static void SetInitialized(IMcpServer server, bool isInitialized) + { + PropertyInfo? property = server.GetType().GetProperty("IsInitialized", BindingFlags.Public | BindingFlags.Instance); + Assert.NotNull(property); + property.SetValue(server, isInitialized); + } + private sealed class TestServerForIChatClient(bool supportsSampling) : IMcpServer { public ClientCapabilities? ClientCapabilities => diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 48f91dcc5..50188cfa1 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -1,10 +1,11 @@ -using System.Text.Json; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Tests.Utils; -using Microsoft.Extensions.Logging; +using System.Reflection; +using System.Text.Json; namespace ModelContextProtocol.Tests; @@ -299,8 +300,10 @@ public async Task ConnectTwice_Throws() defaultOptions, loggerFactory: loggerFactory, cancellationToken: TestContext.Current.CancellationToken); - var mcpClient = (McpClient)client; - var transport = (SseClientTransport)mcpClient.Transport; + + PropertyInfo? transportProperty = client.GetType().GetProperty("Transport", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(transportProperty); + var transport = (SseClientTransport)transportProperty.GetValue(client)!; // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 26aaecd07..0749eed71 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -1,9 +1,10 @@ -using System.Net; +using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Tests.Utils; -using Microsoft.Extensions.Logging.Abstractions; +using System.Net; +using System.Reflection; namespace ModelContextProtocol.Tests.Transport; @@ -42,11 +43,16 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() // Assert Assert.NotNull(transport); - Assert.Equal(TimeSpan.FromSeconds(2), transport.Options.ConnectionTimeout); - Assert.Equal(3, transport.Options.MaxReconnectAttempts); - Assert.Equal(TimeSpan.FromMilliseconds(50), transport.Options.ReconnectDelay); - Assert.NotNull(transport.Options.AdditionalHeaders); - Assert.Equal("header", transport.Options.AdditionalHeaders["test"]); + + PropertyInfo? getOptions = transport.GetType().GetProperty("Options", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(getOptions); + var options = (SseClientTransportOptions)getOptions.GetValue(transport)!; + + Assert.Equal(TimeSpan.FromSeconds(2), options.ConnectionTimeout); + Assert.Equal(3, options.MaxReconnectAttempts); + Assert.Equal(TimeSpan.FromMilliseconds(50), options.ReconnectDelay); + Assert.NotNull(options.AdditionalHeaders); + Assert.Equal("header", options.AdditionalHeaders["test"]); } [Fact] @@ -137,7 +143,6 @@ public async Task SendMessageAsync_Throws_Exception_If_MessageEndpoint_Not_Set() await using var transport = new SseClientTransport(_transportOptions, _serverConfig, NullLoggerFactory.Instance); // Assert - Assert.True(string.IsNullOrEmpty(transport.MessageEndpoint?.ToString())); await Assert.ThrowsAsync(() => transport.SendMessageAsync(new JsonRpcRequest() { Method = "test" }, CancellationToken.None)); }