Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
}

SetRequestHandler<CreateMessageRequestParams, CreateMessageResult>(
SetRequestHandler(
RequestMethods.SamplingCreateMessage,
(request, cancellationToken) => samplingHandler(
request,
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
cancellationToken));
cancellationToken),
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
}

if (options.Capabilities?.Roots is { } rootsCapability)
Expand All @@ -55,9 +57,11 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
}

SetRequestHandler<ListRootsRequestParams, ListRootsResult>(
SetRequestHandler(
RequestMethods.RootsList,
(request, cancellationToken) => rootsHandler(request, cancellationToken));
rootsHandler,
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
McpJsonUtilities.JsonContext.Default.ListRootsResult);
}
}

Expand Down Expand Up @@ -88,21 +92,20 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
initializationCts.CancelAfter(_options.InitializationTimeout);

try
{
// Send initialize request
var initializeResponse = await SendRequestAsync<InitializeResult>(
new JsonRpcRequest
{
Method = RequestMethods.Initialize,
Params = new InitializeRequestParams()
{
ProtocolVersion = _options.ProtocolVersion,
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
ClientInfo = _options.ClientInfo
}
},
initializationCts.Token).ConfigureAwait(false);
try
{
// Send initialize request
var initializeResponse = await this.SendRequestAsync(
RequestMethods.Initialize,
new InitializeRequestParams
{
ProtocolVersion = _options.ProtocolVersion,
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
ClientInfo = _options.ClientInfo
},
McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
McpJsonUtilities.JsonContext.Default.InitializeResult,
cancellationToken: initializationCts.Token).ConfigureAwait(false);

// Store server information
_logger.ServerCapabilitiesReceived(EndpointName,
Expand Down
250 changes: 136 additions & 114 deletions src/ModelContextProtocol/Client/McpClientExtensions.cs

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion src/ModelContextProtocol/Client/McpClientPrompt.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Protocol.Types;
using System.Text.Json;

namespace ModelContextProtocol.Client;

Expand All @@ -20,17 +21,19 @@ internal McpClientPrompt(IMcpClient client, Prompt prompt)
/// Retrieves a specific prompt with optional arguments.
/// </summary>
/// <param name="arguments">Optional arguments for the prompt</param>
/// <param name="serializerOptions">The serialization options governing argument serialization.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task containing the prompt's content and messages.</returns>
public async ValueTask<GetPromptResult> GetAsync(
IEnumerable<KeyValuePair<string, object?>>? arguments = null,
JsonSerializerOptions? serializerOptions = null,
CancellationToken cancellationToken = default)
{
IReadOnlyDictionary<string, object?>? argDict =
arguments as IReadOnlyDictionary<string, object?> ??
arguments?.ToDictionary();

return await _client.GetPromptAsync(ProtocolPrompt.Name, argDict, cancellationToken).ConfigureAwait(false);
return await _client.GetPromptAsync(ProtocolPrompt.Name, argDict, serializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
}

/// <summary>Gets the name of the prompt.</summary>
Expand Down
7 changes: 4 additions & 3 deletions src/ModelContextProtocol/Client/McpClientTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ public sealed class McpClientTool : AIFunction
{
private readonly IMcpClient _client;

internal McpClientTool(IMcpClient client, Tool tool)
internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions serializerOptions)
{
_client = client;
ProtocolTool = tool;
JsonSerializerOptions = serializerOptions;
}

/// <summary>Gets the protocol <see cref="Tool"/> type for this instance.</summary>
Expand All @@ -29,7 +30,7 @@ internal McpClientTool(IMcpClient client, Tool tool)
public override JsonElement JsonSchema => ProtocolTool.InputSchema;

/// <inheritdoc/>
public override JsonSerializerOptions JsonSerializerOptions => McpJsonUtilities.DefaultOptions;
public override JsonSerializerOptions JsonSerializerOptions { get; }

/// <inheritdoc/>
protected async override Task<object?> InvokeCoreAsync(
Expand All @@ -39,7 +40,7 @@ internal McpClientTool(IMcpClient client, Tool tool)
arguments as IReadOnlyDictionary<string, object?> ??
arguments.ToDictionary();

CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, argDict, cancellationToken).ConfigureAwait(false);
CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, argDict, JsonSerializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResponse);
}
}
5 changes: 2 additions & 3 deletions src/ModelContextProtocol/IMcpEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ namespace ModelContextProtocol;
/// <summary>Represents a client or server MCP endpoint.</summary>
public interface IMcpEndpoint : IAsyncDisposable
{
/// <summary>Sends a generic JSON-RPC request to the connected endpoint.</summary>
/// <typeparam name="TResult">The expected response type.</typeparam>
/// <summary>Sends a JSON-RPC request to the connected endpoint.</summary>
/// <param name="request">The JSON-RPC request to send.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task containing the client's response.</returns>
Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class;
Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default);

/// <summary>Sends a message to the connected endpoint.</summary>
/// <param name="message">The message.</param>
Expand Down
3 changes: 0 additions & 3 deletions src/ModelContextProtocol/Logging/Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ internal static partial class Log
[LoggerMessage(Level = LogLevel.Information, Message = "Request response received for {endpointName} with method {method}")]
internal static partial void RequestResponseReceived(this ILogger logger, string endpointName, string method);

[LoggerMessage(Level = LogLevel.Error, Message = "Request response type conversion error for {endpointName} with method {method}: expected {expectedType}")]
internal static partial void RequestResponseTypeConversionError(this ILogger logger, string endpointName, string method, Type expectedType);

[LoggerMessage(Level = LogLevel.Error, Message = "Request invalid response type for {endpointName} with method {method}")]
internal static partial void RequestInvalidResponseType(this ILogger logger, string endpointName, string method);

Expand Down
140 changes: 137 additions & 3 deletions src/ModelContextProtocol/McpEndpointExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,147 @@
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization.Metadata;

namespace ModelContextProtocol;

/// <summary>Provides extension methods for interacting with an <see cref="IMcpEndpoint"/>.</summary>
public static class McpEndpointExtensions
{
/// <summary>
/// Sends a JSON-RPC request and attempts to deserialize the result to <typeparamref name="TResult"/>.
/// </summary>
/// <typeparam name="TParameters">The type of the request parameters to serialize from.</typeparam>
/// <typeparam name="TResult">The type of the result to deserialize to.</typeparam>
/// <param name="endpoint">The MCP client or server instance.</param>
/// <param name="method">The JSON-RPC method name to invoke.</param>
/// <param name="parameters">Object representing the request parameters.</param>
/// <param name="requestId">The request id for the request.</param>
/// <param name="serializerOptions">The options governing request serialization.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task that represents the asynchronous operation. The task result contains the deserialized result.</returns>
public static Task<TResult> SendRequestAsync<TParameters, TResult>(
this IMcpEndpoint endpoint,
string method,
TParameters parameters,
JsonSerializerOptions? serializerOptions = null,
RequestId? requestId = null,
CancellationToken cancellationToken = default)
where TResult : notnull
{
serializerOptions ??= McpJsonUtilities.DefaultOptions;
serializerOptions.MakeReadOnly();

JsonTypeInfo<TParameters> paramsTypeInfo = serializerOptions.GetTypeInfo<TParameters>();
JsonTypeInfo<TResult> resultTypeInfo = serializerOptions.GetTypeInfo<TResult>();
return SendRequestAsync(endpoint, method, parameters, paramsTypeInfo, resultTypeInfo, requestId, cancellationToken);
}

/// <summary>
/// Sends a JSON-RPC request and attempts to deserialize the result to <typeparamref name="TResult"/>.
/// </summary>
/// <typeparam name="TParameters">The type of the request parameters to serialize from.</typeparam>
/// <typeparam name="TResult">The type of the result to deserialize to.</typeparam>
/// <param name="endpoint">The MCP client or server instance.</param>
/// <param name="method">The JSON-RPC method name to invoke.</param>
/// <param name="parameters">Object representing the request parameters.</param>
/// <param name="parametersTypeInfo">The type information for request parameter serialization.</param>
/// <param name="resultTypeInfo">The type information for request parameter deserialization.</param>
/// <param name="requestId">The request id for the request.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task that represents the asynchronous operation. The task result contains the deserialized result.</returns>
internal static async Task<TResult> SendRequestAsync<TParameters, TResult>(
this IMcpEndpoint endpoint,
string method,
TParameters parameters,
JsonTypeInfo<TParameters> parametersTypeInfo,
JsonTypeInfo<TResult> resultTypeInfo,
RequestId? requestId = null,
CancellationToken cancellationToken = default)
where TResult : notnull
{
Throw.IfNull(endpoint);
Throw.IfNullOrWhiteSpace(method);
Throw.IfNull(parametersTypeInfo);
Throw.IfNull(resultTypeInfo);

JsonRpcRequest jsonRpcRequest = new()
{
Method = method,
Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo),
};

if (requestId is { } id)
{
jsonRpcRequest.Id = id;
}

JsonRpcResponse response = await endpoint.SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false);
return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response.");
}

/// <summary>
/// Sends a notification to the server with parameters.
/// </summary>
/// <param name="client">The client.</param>
/// <param name="method">The notification method name.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
public static Task SendNotificationAsync(this IMcpEndpoint client, string method, CancellationToken cancellationToken = default)
{
Throw.IfNull(client);
Throw.IfNullOrWhiteSpace(method);
return client.SendMessageAsync(new JsonRpcNotification { Method = method }, cancellationToken);
}

/// <summary>
/// Sends a notification to the server with parameters.
/// </summary>
/// <param name="endpoint">The MCP client or server instance.</param>
/// <param name="method">The JSON-RPC method name to invoke.</param>
/// <param name="parameters">Object representing the request parameters.</param>
/// <param name="serializerOptions">The options governing request serialization.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
public static Task SendNotificationAsync<TParameters>(
this IMcpEndpoint endpoint,
string method,
TParameters parameters,
JsonSerializerOptions? serializerOptions = null,
CancellationToken cancellationToken = default)
{
serializerOptions ??= McpJsonUtilities.DefaultOptions;
serializerOptions.MakeReadOnly();

JsonTypeInfo<TParameters> parametersTypeInfo = serializerOptions.GetTypeInfo<TParameters>();
return SendNotificationAsync(endpoint, method, parameters, parametersTypeInfo, cancellationToken);
}

/// <summary>
/// Sends a notification to the server with parameters.
/// </summary>
/// <param name="endpoint">The MCP client or server instance.</param>
/// <param name="method">The JSON-RPC method name to invoke.</param>
/// <param name="parameters">Object representing the request parameters.</param>
/// <param name="parametersTypeInfo">The type information for request parameter serialization.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
internal static Task SendNotificationAsync<TParameters>(
this IMcpEndpoint endpoint,
string method,
TParameters parameters,
JsonTypeInfo<TParameters> parametersTypeInfo,
CancellationToken cancellationToken = default)
{
Throw.IfNull(endpoint);
Throw.IfNullOrWhiteSpace(method);
Throw.IfNull(parametersTypeInfo);

JsonNode? parametersJson = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo);
return endpoint.SendMessageAsync(new JsonRpcNotification { Method = method, Params = parametersJson }, cancellationToken);
}

/// <summary>Notifies the connected endpoint of progress.</summary>
/// <param name="endpoint">The endpoint issueing the notification.</param>
/// <param name="endpoint">The endpoint issuing the notification.</param>
/// <param name="progressToken">The <see cref="ProgressToken"/> identifying the operation.</param>
/// <param name="progress">The progress update to send.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
Expand All @@ -24,11 +158,11 @@ public static Task NotifyProgressAsync(
return endpoint.SendMessageAsync(new JsonRpcNotification()
{
Method = NotificationMethods.ProgressNotification,
Params = new ProgressNotification()
Params = JsonSerializer.SerializeToNode(new ProgressNotification
{
ProgressToken = progressToken,
Progress = progress,
},
}, McpJsonUtilities.JsonContext.Default.ProgressNotification),
}, cancellationToken);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Text.Json.Serialization;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Protocol.Messages;

Expand All @@ -23,5 +24,5 @@ public record JsonRpcNotification : IJsonRpcMessage
/// Optional parameters for the notification.
/// </summary>
[JsonPropertyName("params")]
public object? Params { get; init; }
public JsonNode? Params { get; init; }
}
5 changes: 3 additions & 2 deletions src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Text.Json.Serialization;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Protocol.Messages;

Expand Down Expand Up @@ -29,5 +30,5 @@ public record JsonRpcRequest : IJsonRpcMessageWithId
/// Optional parameters for the method.
/// </summary>
[JsonPropertyName("params")]
public object? Params { get; init; }
public JsonNode? Params { get; init; }
}
4 changes: 2 additions & 2 deletions src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

using System.Text.Json.Nodes;
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Protocol.Messages;
Expand All @@ -23,5 +23,5 @@ public record JsonRpcResponse : IJsonRpcMessageWithId
/// The result of the method invocation.
/// </summary>
[JsonPropertyName("result")]
public required object? Result { get; init; }
public required JsonNode? Result { get; init; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ internal sealed class SseClientSessionTransport : TransportBase
private Task? _receiveTask;
private readonly ILogger _logger;
private readonly McpServerConfig _serverConfig;
private readonly JsonSerializerOptions _jsonOptions;
private readonly TaskCompletionSource<bool> _connectionEstablished;

private string EndpointName => $"Client (SSE) for ({_serverConfig.Id}: {_serverConfig.Name})";
Expand All @@ -50,7 +49,6 @@ public SseClientSessionTransport(SseClientTransportOptions transportOptions, Mcp
_httpClient = httpClient;
_connectionCts = new CancellationTokenSource();
_logger = (ILogger?)loggerFactory?.CreateLogger<SseClientTransport>() ?? NullLogger.Instance;
_jsonOptions = McpJsonUtilities.DefaultOptions;
_connectionEstablished = new TaskCompletionSource<bool>();
}

Expand Down Expand Up @@ -94,7 +92,7 @@ public override async Task SendMessageAsync(
throw new InvalidOperationException("Transport not connected");

using var content = new StringContent(
JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>()),
JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage),
Encoding.UTF8,
"application/json"
);
Expand Down Expand Up @@ -127,7 +125,7 @@ public override async Task SendMessageAsync(
}
else
{
JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, _jsonOptions.GetTypeInfo<JsonRpcResponse>()) ??
JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ??
throw new McpTransportException("Failed to initialize client");

_logger.TransportReceivedMessageParsed(EndpointName, messageId);
Expand Down Expand Up @@ -259,7 +257,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation

try
{
var message = JsonSerializer.Deserialize(data, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage);
if (message == null)
{
_logger.TransportMessageParseUnexpectedType(EndpointName, data);
Expand Down
Loading