Skip to content

Commit edd5b1d

Browse files
authored
feat: enforce SEP-1577 MUST requirements for sampling with tools (#646)
1 parent 8bd3fcb commit edd5b1d

File tree

4 files changed

+310
-78
lines changed

4 files changed

+310
-78
lines changed

crates/rmcp/src/model.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,85 @@ impl TaskAugmentedRequestParamsMeta for CreateMessageRequestParams {
15801580
}
15811581
}
15821582

1583+
impl CreateMessageRequestParams {
1584+
/// Validate the sampling request parameters per SEP-1577 spec requirements.
1585+
///
1586+
/// Checks:
1587+
/// - ToolUse content is only allowed in assistant messages
1588+
/// - ToolResult content is only allowed in user messages
1589+
/// - Messages with tool result content MUST NOT contain other content types
1590+
/// - Every assistant ToolUse must be balanced with a corresponding user ToolResult
1591+
pub fn validate(&self) -> Result<(), String> {
1592+
for msg in &self.messages {
1593+
for content in msg.content.iter() {
1594+
// ToolUse only in assistant messages, ToolResult only in user messages
1595+
match content {
1596+
SamplingMessageContent::ToolUse(_) if msg.role != Role::Assistant => {
1597+
return Err("ToolUse content is only allowed in assistant messages".into());
1598+
}
1599+
SamplingMessageContent::ToolResult(_) if msg.role != Role::User => {
1600+
return Err("ToolResult content is only allowed in user messages".into());
1601+
}
1602+
_ => {}
1603+
}
1604+
}
1605+
1606+
// Tool result messages MUST NOT contain other content types
1607+
let contents: Vec<_> = msg.content.iter().collect();
1608+
let has_tool_result = contents
1609+
.iter()
1610+
.any(|c| matches!(c, SamplingMessageContent::ToolResult(_)));
1611+
if has_tool_result
1612+
&& contents
1613+
.iter()
1614+
.any(|c| !matches!(c, SamplingMessageContent::ToolResult(_)))
1615+
{
1616+
return Err(
1617+
"SamplingMessage with tool result content MUST NOT contain other content types"
1618+
.into(),
1619+
);
1620+
}
1621+
}
1622+
1623+
// Every assistant ToolUse must be balanced with a user ToolResult
1624+
self.validate_tool_use_result_balance()?;
1625+
1626+
Ok(())
1627+
}
1628+
1629+
fn validate_tool_use_result_balance(&self) -> Result<(), String> {
1630+
let mut pending_tool_use_ids: Vec<String> = Vec::new();
1631+
for msg in &self.messages {
1632+
if msg.role == Role::Assistant {
1633+
for content in msg.content.iter() {
1634+
if let SamplingMessageContent::ToolUse(tu) = content {
1635+
pending_tool_use_ids.push(tu.id.clone());
1636+
}
1637+
}
1638+
} else if msg.role == Role::User {
1639+
for content in msg.content.iter() {
1640+
if let SamplingMessageContent::ToolResult(tr) = content {
1641+
if !pending_tool_use_ids.contains(&tr.tool_use_id) {
1642+
return Err(format!(
1643+
"ToolResult with toolUseId '{}' has no matching ToolUse",
1644+
tr.tool_use_id
1645+
));
1646+
}
1647+
pending_tool_use_ids.retain(|id| id != &tr.tool_use_id);
1648+
}
1649+
}
1650+
}
1651+
}
1652+
if !pending_tool_use_ids.is_empty() {
1653+
return Err(format!(
1654+
"ToolUse with id(s) {:?} not balanced with ToolResult",
1655+
pending_tool_use_ids
1656+
));
1657+
}
1658+
Ok(())
1659+
}
1660+
}
1661+
15831662
/// Deprecated: Use [`CreateMessageRequestParams`] instead (SEP-1319 compliance).
15841663
#[deprecated(since = "0.13.0", note = "Use CreateMessageRequestParams instead")]
15851664
pub type CreateMessageRequestParam = CreateMessageRequestParams;
@@ -2229,6 +2308,14 @@ impl CreateMessageResult {
22292308
pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence";
22302309
pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens";
22312310
pub const STOP_REASON_TOOL_USE: &str = "toolUse";
2311+
2312+
/// Validate the result per SEP-1577: role must be "assistant".
2313+
pub fn validate(&self) -> Result<(), String> {
2314+
if self.message.role != Role::Assistant {
2315+
return Err("CreateMessageResult role must be 'assistant'".into());
2316+
}
2317+
Ok(())
2318+
}
22322319
}
22332320

22342321
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]

crates/rmcp/src/model/content.rs

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -129,67 +129,6 @@ impl ToolResultContent {
129129
}
130130
}
131131

132-
/// Assistant message content types (SEP-1577).
133-
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134-
#[serde(tag = "type", rename_all = "snake_case")]
135-
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
136-
pub enum AssistantMessageContent {
137-
Text(RawTextContent),
138-
Image(RawImageContent),
139-
Audio(RawAudioContent),
140-
ToolUse(ToolUseContent),
141-
}
142-
143-
/// User message content types (SEP-1577).
144-
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
145-
#[serde(tag = "type", rename_all = "snake_case")]
146-
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
147-
pub enum UserMessageContent {
148-
Text(RawTextContent),
149-
Image(RawImageContent),
150-
Audio(RawAudioContent),
151-
ToolResult(ToolResultContent),
152-
}
153-
154-
impl AssistantMessageContent {
155-
/// Create a text content
156-
pub fn text(text: impl Into<String>) -> Self {
157-
Self::Text(RawTextContent {
158-
text: text.into(),
159-
meta: None,
160-
})
161-
}
162-
163-
/// Create a tool use content
164-
pub fn tool_use(
165-
id: impl Into<String>,
166-
name: impl Into<String>,
167-
input: super::JsonObject,
168-
) -> Self {
169-
Self::ToolUse(ToolUseContent::new(id, name, input))
170-
}
171-
}
172-
173-
impl UserMessageContent {
174-
/// Create a text content
175-
pub fn text(text: impl Into<String>) -> Self {
176-
Self::Text(RawTextContent {
177-
text: text.into(),
178-
meta: None,
179-
})
180-
}
181-
182-
/// Create a tool result content
183-
pub fn tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
184-
Self::ToolResult(ToolResultContent::new(tool_use_id, content))
185-
}
186-
187-
/// Create an error tool result content
188-
pub fn tool_result_error(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
189-
Self::ToolResult(ToolResultContent::error(tool_use_id, content))
190-
}
191-
}
192-
193132
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
194133
#[serde(tag = "type", rename_all = "snake_case")]
195134
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]

crates/rmcp/src/service/server.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,10 +384,37 @@ macro_rules! method {
384384
}
385385

386386
impl Peer<RoleServer> {
387+
/// Check if the client supports sampling tools capability.
388+
pub fn supports_sampling_tools(&self) -> bool {
389+
if let Some(client_info) = self.peer_info() {
390+
client_info
391+
.capabilities
392+
.sampling
393+
.as_ref()
394+
.and_then(|s| s.tools.as_ref())
395+
.is_some()
396+
} else {
397+
false
398+
}
399+
}
400+
387401
pub async fn create_message(
388402
&self,
389403
params: CreateMessageRequestParams,
390404
) -> Result<CreateMessageResult, ServiceError> {
405+
// MUST throw error when tools/toolChoice provided without capability
406+
if (params.tools.is_some() || params.tool_choice.is_some())
407+
&& !self.supports_sampling_tools()
408+
{
409+
return Err(ServiceError::McpError(ErrorData::invalid_params(
410+
"tools or toolChoice provided but client does not support sampling tools capability",
411+
None,
412+
)));
413+
}
414+
// Validate message structure
415+
params
416+
.validate()
417+
.map_err(|e| ServiceError::McpError(ErrorData::invalid_params(e, None)))?;
391418
let result = self
392419
.send_request(ServerRequest::CreateMessageRequest(CreateMessageRequest {
393420
method: Default::default(),

0 commit comments

Comments
 (0)