Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,13 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult,
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err))
}

if err := resolved.Validate(res.Content); err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err))
}
err = resolved.ApplyDefaults(&res.Content)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err))
}
}

return res, nil
Expand All @@ -341,6 +344,9 @@ func validateElicitSchema(wireSchema any) (*jsonschema.Schema, error) {
if err := remarshal(wireSchema, &schema); err != nil {
return nil, err
}
if schema == nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is possible? wireShema is non-nil, so wouldn't a nil schema be an error here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR: I think we either need a reflect.IsNil() call on wireSchema to determine if it is really nil, or we need to check if the result of remarshall is nil when err is != nil (which feels gross).

In more detail:

In the server, as I understand, the wireschema is **jsonschema. we pass in a nil value for this which gets wrapped in another pointer. so, in the server, wireschema is non-nil but *wireschema is nil. This means we end up passing a non-nil wireschema here, but it is unmarshalling a nil jsonschema.

This means that when we get back here since we passed &schema into remarshall now schema is nil.

It might make remarshall more robust to have a reflect isNil call and return an err, but then we'd need to add in the reflect call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the server the schema should be a *jsonschema.Schema, not a **jsonschema.Schema. Where are you seeing this?

Let's merge this as-is, and I will take a look.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I looked into it: the test uses a typed nil (one of the downsides of having any for the schema).

I think the code, as written, is fine.

return nil, nil
}

// The root schema must be of type "object" if specified
if schema.Type != "" && schema.Type != "object" {
Expand Down Expand Up @@ -369,7 +375,6 @@ func validateElicitProperty(propName string, propSchema *jsonschema.Schema) erro
if len(propSchema.Properties) > 0 {
return fmt.Errorf("elicit schema property %q contains nested properties, only primitive properties are allowed", propName)
}

// Validate based on the property type - only primitives are supported
switch propSchema.Type {
case "string":
Expand Down Expand Up @@ -439,7 +444,7 @@ func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema
}
}

return nil
return validateDefaultProperty[string](propName, propSchema)
}

// validateElicitNumberProperty validates number and integer-type properties.
Expand All @@ -450,19 +455,28 @@ func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema
}
}

intDefaultError := validateDefaultProperty[int](propName, propSchema)
floatDefaultError := validateDefaultProperty[float64](propName, propSchema)
if intDefaultError != nil && floatDefaultError != nil {
return fmt.Errorf("elicit schema property %q has default value that cannot be interpreted as an int or float", propName)
}

return nil
}

// validateElicitBooleanProperty validates boolean-type properties.
func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error {
// Validate default value if specified - must be a valid boolean
return validateDefaultProperty[bool](propName, propSchema)
}

func validateDefaultProperty[T any](propName string, propSchema *jsonschema.Schema) error {
// Validate default value if specified - must be a valid T
if propSchema.Default != nil {
var defaultValue bool
var defaultValue T
if err := json.Unmarshal(propSchema.Default, &defaultValue); err != nil {
return fmt.Errorf("elicit schema property %q has invalid default value, must be a boolean: %v", propName, err)
return fmt.Errorf("elicit schema property %q has invalid default value, must be a %T: %v", propName, defaultValue, err)
}
}

return nil
}

Expand Down
126 changes: 125 additions & 1 deletion mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,37 @@ func TestElicitationSchemaValidation(t *testing.T) {
"enabled": {Type: "boolean", Default: json.RawMessage(`"not-a-boolean"`)},
},
},
expectedError: "elicit schema property \"enabled\" has invalid default value, must be a boolean",
expectedError: "elicit schema property \"enabled\" has invalid default value, must be a bool",
},
{
name: "string with invalid default",
schema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"enabled": {Type: "string", Default: json.RawMessage("true")},
},
},
expectedError: "elicit schema property \"enabled\" has invalid default value, must be a string",
},
{
name: "integer with invalid default",
schema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"enabled": {Type: "integer", Default: json.RawMessage("true")},
},
},
expectedError: "elicit schema property \"enabled\" has default value that cannot be interpreted as an int or float",
},
{
name: "number with invalid default",
schema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"enabled": {Type: "number", Default: json.RawMessage("true")},
},
},
expectedError: "elicit schema property \"enabled\" has default value that cannot be interpreted as an int or float",
},
{
name: "enum with mismatched enumNames length",
Expand Down Expand Up @@ -1459,6 +1489,100 @@ func TestElicitationCapabilityDeclaration(t *testing.T) {
})
}

func TestElicitationDefaultValues(t *testing.T) {
ctx := context.Background()
ct, st := NewInMemoryTransports()

s := NewServer(testImpl, nil)
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
defer ss.Close()

c := NewClient(testImpl, &ClientOptions{
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
return &ElicitResult{Action: "accept", Content: map[string]any{"default": "response"}}, nil
},
})
cs, err := c.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
defer cs.Close()

testcases := []struct {
name string
schema *jsonschema.Schema
expected map[string]any
}{
{
name: "boolean with default",
schema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"key": {Type: "boolean", Default: json.RawMessage("true")},
},
},
expected: map[string]any{"key": true, "default": "response"},
},
{
name: "string with default",
schema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"key": {Type: "string", Default: json.RawMessage("\"potato\"")},
},
},
expected: map[string]any{"key": "potato", "default": "response"},
},
{
name: "integer with default",
schema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"key": {Type: "integer", Default: json.RawMessage("123")},
},
},
expected: map[string]any{"key": float64(123), "default": "response"},
},
{
name: "number with default",
schema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"key": {Type: "number", Default: json.RawMessage("89.7")},
},
},
expected: map[string]any{"key": float64(89.7), "default": "response"},
},
{
name: "enum with default",
schema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"key": {Type: "string", Enum: []any{"one", "two"}, Default: json.RawMessage("\"one\"")},
},
},
expected: map[string]any{"key": "one", "default": "response"},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
res, err := ss.Elicit(ctx, &ElicitParams{
Message: "Test schema with defaults: " + tc.name,
RequestedSchema: tc.schema,
})
if err != nil {
t.Fatalf("expected no error for default schema %q, got: %v", tc.name, err)
}
if diff := cmp.Diff(tc.expected, res.Content); diff != "" {
t.Errorf("%s: did not get expected value, -want +got:\n%s", tc.name, diff)
}
})
}
}

func TestKeepAliveFailure(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down
32 changes: 31 additions & 1 deletion mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,37 @@ func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*Eli
if err := ss.checkInitialized(methodElicit); err != nil {
return nil, err
}
return handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params)))

res, err := handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params)))
if err != nil {
return nil, err
}

if params.RequestedSchema == nil {
return res, nil
}
schema, err := validateElicitSchema(params.RequestedSchema)
if err != nil {
return nil, err
}
if schema == nil {
return res, nil
}

resolved, err := schema.Resolve(nil)
if err != nil {
fmt.Printf(" resolve err: %s", err)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this okay?

return nil, err
}
if err := resolved.Validate(res.Content); err != nil {
return nil, fmt.Errorf("elicitation result content does not match requested schema: %v", err)
}
err = resolved.ApplyDefaults(&res.Content)
if err != nil {
return nil, fmt.Errorf("failed to apply schema defalts to elicitation result: %v", err)
}

return res, nil
}

// Log sends a log message to the client.
Expand Down
Loading