Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 96 additions & 14 deletions crates/goose/src/conversation/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rmcp::model::{
AnnotateAble, Content, ImageContent, PromptMessage, PromptMessageContent, PromptMessageRole,
RawContent, RawImageContent, RawTextContent, ResourceContents, Role, TextContent,
};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::Value;
use std::collections::HashSet;
use std::fmt;
Expand All @@ -13,6 +13,42 @@ use utoipa::ToSchema;

use crate::conversation::tool_result_serde;

/// Sanitize Unicode Tags Block characters from text
fn sanitize_unicode_tags(text: &str) -> String {
let normalized: String = text.nfc().collect();

normalized
.chars()
.filter(|&c| !matches!(c, '\u{E0000}'..='\u{E007F}'))
.collect()
}

/// Custom deserializer for MessageContent that sanitizes Unicode Tags in text content
fn deserialize_sanitized_content<'de, D>(deserializer: D) -> Result<Vec<MessageContent>, D::Error>
where
D: Deserializer<'de>,
{
let mut content: Vec<MessageContent> = Vec::deserialize(deserializer)?;

for message_content in &mut content {
if let MessageContent::Text(text_content) = message_content {
let original = &text_content.text;
let sanitized = sanitize_unicode_tags(original);
if *original != sanitized {
tracing::info!(
original = %original,
sanitized = %sanitized,
removed_count = original.len() - sanitized.len(),
"Unicode Tags sanitized during Message deserialization"
);
text_content.text = sanitized;
}
}
}

Ok(content)
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[derive(ToSchema)]
Expand Down Expand Up @@ -346,6 +382,7 @@ pub struct Message {
pub role: Role,
#[serde(default = "default_created")]
pub created: i64,
#[serde(deserialize_with = "deserialize_sanitized_content")]
pub content: Vec<MessageContent>,
}

Expand Down Expand Up @@ -410,20 +447,10 @@ impl Message {
self
}

fn sanitize_unicode_tags(text: &str) -> String {
let normalized: String = text.nfc().collect();

// Remove Unicode Tags Block characters only
normalized
.chars()
.filter(|&c| !matches!(c, '\u{E0000}'..='\u{E007F}'))
.collect()
}

/// Add text content to the message
pub fn with_text<S: Into<String>>(self, text: S) -> Self {
let raw_text = text.into();
let sanitized_text = Self::sanitize_unicode_tags(&raw_text);
let sanitized_text = sanitize_unicode_tags(&raw_text);

self.with_content(MessageContent::Text(
RawTextContent {
Expand Down Expand Up @@ -587,14 +614,14 @@ mod tests {
#[test]
fn test_sanitize_unicode_tags() {
let malicious = "Hello\u{E0041}\u{E0042}\u{E0043}world"; // Invisible "ABC"
let cleaned = Message::sanitize_unicode_tags(malicious);
let cleaned = super::sanitize_unicode_tags(malicious);
assert_eq!(cleaned, "Helloworld");
}

#[test]
fn test_no_sanitize_unicode_tags() {
let clean_text = "Hello world 世界 🌍";
let cleaned = Message::sanitize_unicode_tags(clean_text);
let cleaned = super::sanitize_unicode_tags(clean_text);
assert_eq!(cleaned, clean_text);
}

Expand Down Expand Up @@ -872,4 +899,59 @@ mod tests {
assert_eq!(ids.len(), 1);
assert!(ids.contains("req1"));
}

#[test]
fn test_message_deserialization_sanitizes_text_content() {
// Create a test string with Unicode Tags characters
let malicious_text = "Hello\u{E0041}\u{E0042}\u{E0043}world";
let malicious_json = format!(
r#"{{
"id": "test-id",
"role": "user",
"created": 1640995200,
"content": [
{{
"type": "text",
"text": "{}"
}},
{{
"type": "image",
"data": "base64data",
"mimeType": "image/png"
}}
]
}}"#,
malicious_text
);

let message: Message = serde_json::from_str(&malicious_json).unwrap();

// Text content should be sanitized
assert_eq!(message.as_concat_text(), "Helloworld");

// Image content should be unchanged
if let MessageContent::Image(img) = &message.content[1] {
assert_eq!(img.data, "base64data");
assert_eq!(img.mime_type, "image/png");
} else {
panic!("Expected ImageContent");
}
}

#[test]
fn test_legitimate_unicode_preserved_during_message_deserialization() {
let clean_json = r#"{
"id": "test-id",
"role": "user",
"created": 1640995200,
"content": [{
"type": "text",
"text": "Hello world 世界 🌍"
}]
}"#;

let message: Message = serde_json::from_str(clean_json).unwrap();

assert_eq!(message.as_concat_text(), "Hello world 世界 🌍");
}
}
Loading