diff --git a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java index 1acfc8ef34b..d2484de3772 100644 --- a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java +++ b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java @@ -134,19 +134,18 @@ private static List messagesForMetaStruct(List messages) { final List result = new ArrayList<>(size); final int maxContent = config.getAiGuardMaxContentSize(); boolean contentTruncated = false; - for (int i = 0; i < size; i++) { - Message source = messages.get(i); - final String content = source.getContent(); + for (int i = messages.size() - size; i < messages.size(); i++) { + final Message source = messages.get(i); + String content = source.getContent(); if (content != null && content.length() > maxContent) { contentTruncated = true; - source = - new Message( - source.getRole(), - content.substring(0, maxContent), - source.getToolCalls(), - source.getToolCallId()); + content = content.substring(0, maxContent); } - result.add(source); + List toolCalls = source.getToolCalls(); + if (toolCalls != null) { + toolCalls = new ArrayList<>(toolCalls); + } + result.add(new Message(source.getRole(), content, toolCalls, source.getToolCallId())); } if (contentTruncated) { WafMetricCollector.get().aiGuardTruncated(CONTENT); @@ -240,7 +239,7 @@ public Evaluation evaluate(final List messages, final Options options) span.setTag(BLOCKED_TAG, true); throw new AIGuardAbortError(action, reason, tags); } - return new Evaluation(action, reason); + return new Evaluation(action, reason, tags); } } catch (AIGuardAbortError e) { span.addThrowable(e); diff --git a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy index 12bc6b196c8..f7c9de745ca 100644 --- a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy +++ b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy @@ -198,19 +198,18 @@ class AIGuardInternalTests extends DDSpecification { 1 * span.addThrowable(_ as AIGuard.AIGuardAbortError) } - receivedMeta.messages == suite.messages - if (suite.tags) { - receivedMeta.attack_categories == suite.tags - } + assertMeta(receivedMeta, suite) assertRequest(request, suite.messages) if (throwAbortError) { error instanceof AIGuard.AIGuardAbortError error.action == suite.action error.reason == suite.reason + error.tags == suite.tags } else { error == null eval.action == suite.action eval.reason == suite.reason + eval.tags == suite.tags } assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false') @@ -221,19 +220,7 @@ class AIGuardInternalTests extends DDSpecification { void 'test evaluate with API errors'() { given: final errors = [[status: 400, title: 'Bad request']] - Request request = null - final call = Mock(Call) { - execute() >> { - return mockResponse(request, 404, [errors: errors]) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) + final aiguard = mockClient(404, [errors: errors]) when: aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) @@ -247,19 +234,7 @@ class AIGuardInternalTests extends DDSpecification { void 'test evaluate with invalid JSON'() { given: - Request request = null - final call = Mock(Call) { - execute() >> { - return mockResponse(request, 200, [bad: 'This is an invalid response']) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) + final aiguard = mockClient(200, [bad: 'This is an invalid response']) when: aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) @@ -272,19 +247,7 @@ class AIGuardInternalTests extends DDSpecification { void 'test evaluate with missing action'() { given: - Request request = null - final call = Mock(Call) { - execute() >> { - return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]]) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) + final aiguard = mockClient(200, [data: [attributes: [reason: 'I miss something']]]) when: aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) @@ -297,19 +260,7 @@ class AIGuardInternalTests extends DDSpecification { void 'test evaluate with non JSON response'() { given: - Request request = null - final call = Mock(Call) { - execute() >> { - return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]]) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) + final aiguard = mockClient(200, 'I am no JSON') when: aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) @@ -322,19 +273,7 @@ class AIGuardInternalTests extends DDSpecification { void 'test evaluate with empty response'() { given: - Request request = null - final call = Mock(Call) { - execute() >> { - return mockResponse(request, 200, null) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) + final aiguard = mockClient(200, null) when: aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) @@ -348,19 +287,7 @@ class AIGuardInternalTests extends DDSpecification { void 'test message length truncation'() { given: final maxMessages = Config.get().getAiGuardMaxMessagesLength() - Request request = null - final call = Mock(Call) { - execute() >> { - return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]]) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) + final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) final messages = (0..maxMessages) .collect { AIGuard.Message.message('user', "This is a prompt: ${it}") } .toList() @@ -380,19 +307,7 @@ class AIGuardInternalTests extends DDSpecification { void 'test message content truncation'() { given: final maxContent = Config.get().getAiGuardMaxContentSize() - Request request = null - final call = Mock(Call) { - execute() >> { - return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]]) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) + final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) final message = AIGuard.Message.message("user", (0..maxContent).collect { 'A' }.join()) when: @@ -426,23 +341,7 @@ class AIGuardInternalTests extends DDSpecification { void 'test missing tool name'() { given: - def request - final call = Mock(Call) { - execute() >> { - return mockResponse( - request, - 200, - [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]] - ) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) + final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]]) when: aiguard.evaluate([AIGuard.Message.tool('call_1', 'Content')], AIGuard.Options.DEFAULT) @@ -460,6 +359,57 @@ class AIGuardInternalTests extends DDSpecification { thrown(IllegalArgumentException) } + void 'test message immutability'() { + given: + final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]]) + final messages = [ + new AIGuard.Message( + "assistant", + null, + [AIGuard.ToolCall.toolCall('call_1', 'execute_shell', '{"cmd": "ls -lah"}')], + null + ) + ] + Map receivedMeta + + when: + aiguard.evaluate(messages, AIGuard.Options.DEFAULT) + + then: + 1 * span.finish() >> { + // modify the messages before serialization + messages.first().toolCalls.add( + AIGuard.ToolCall.toolCall('call_2', 'execute_shell', '{"cmd": "rm -rf"}') + ) + messages.add(AIGuard.Message.tool('call_1', 'dir1, dir2, dir3')) + } + 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> { + receivedMeta = it[1] as Map + return span + } + final metaStructMessages = receivedMeta.messages as List + metaStructMessages.size() != messages.size() + metaStructMessages.size() == 1 + metaStructMessages.first().toolCalls.size() != messages.first().toolCalls.size() + metaStructMessages.first().toolCalls.size() == 1 + } + + private AIGuardInternal mockClient(final int status, final Object response) { + def request + final call = Stub(Call) { + execute() >> { + return mockResponse(request, status, response) + } + } + final client = Stub(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + return new AIGuardInternal(URL, HEADERS, client) + } + private static assertTelemetry(final String metric, final String...tags) { final metrics = WafMetricCollector.get().with { prepareMetrics() @@ -475,6 +425,16 @@ class AIGuardInternalTests extends DDSpecification { return true } + private static assertMeta(final Map meta, final TestSuite suite) { + if (suite.tags) { + assert meta.attack_categories == suite.tags + } + final receivedMessages = snakeCaseJson(meta.messages) + final expectedMessages = snakeCaseJson(suite.messages) + JSONAssert.assertEquals(expectedMessages, receivedMessages, JSONCompareMode.NON_EXTENSIBLE) + return true + } + private static assertRequest(final Request request, final List messages) { assert request.url() == URL assert request.method() == 'POST' @@ -556,7 +516,8 @@ class AIGuardInternalTests extends DDSpecification { ", reason='" + reason + '\'' + ", blocking=" + blocking + ", target='" + target + '\'' + - ", messages=" + messages + + ", messages=" + messages + '\'' + + ", tags=" + tags + '}' } } diff --git a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java index 83a63789dc8..1a8c0a89dd7 100644 --- a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java +++ b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java @@ -143,16 +143,19 @@ public static class Evaluation { final Action action; final String reason; + final List tags; /** * Creates a new evaluation result. * * @param action the recommended action for the evaluated content * @param reason human-readable explanation for the decision + * @param tags list of tags associated with the evaluation (e.g. indirect-prompt-injection) */ - public Evaluation(final Action action, final String reason) { + public Evaluation(final Action action, final String reason, final List tags) { this.action = action; this.reason = reason; + this.tags = tags; } /** @@ -172,6 +175,15 @@ public Action getAction() { public String getReason() { return reason; } + + /** + * Returns the list of tags associated with the evaluation (e.g. indirect-prompt-injection) + * + * @return list of tags. + */ + public List getTags() { + return tags; + } } /** diff --git a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java index bdb5a1869c4..44d88c68878 100644 --- a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java +++ b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java @@ -1,6 +1,7 @@ package datadog.trace.api.aiguard.noop; import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW; +import static java.util.Collections.emptyList; import datadog.trace.api.aiguard.AIGuard.Evaluation; import datadog.trace.api.aiguard.AIGuard.Message; @@ -12,6 +13,6 @@ public final class NoOpEvaluator implements Evaluator { @Override public Evaluation evaluate(final List messages, final Options options) { - return new Evaluation(ALLOW, "AI Guard is not enabled"); + return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList()); } }