Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,18 @@ private static List<Message> messagesForMetaStruct(List<Message> messages) {
final List<Message> result = new ArrayList<>(size);
final int maxContent = config.getAiGuardMaxContentSize();
boolean contentTruncated = false;
for (int i = 0; i < size; i++) {
Message source = messages.get(i);
final String content = source.getContent();
for (int i = messages.size() - size; i < messages.size(); i++) {
final Message source = messages.get(i);
String content = source.getContent();
if (content != null && content.length() > maxContent) {
contentTruncated = true;
source =
new Message(
source.getRole(),
content.substring(0, maxContent),
source.getToolCalls(),
source.getToolCallId());
content = content.substring(0, maxContent);
}
result.add(source);
List<ToolCall> toolCalls = source.getToolCalls();
if (toolCalls != null) {
toolCalls = new ArrayList<>(toolCalls);
}
result.add(new Message(source.getRole(), content, toolCalls, source.getToolCallId()));
}
if (contentTruncated) {
WafMetricCollector.get().aiGuardTruncated(CONTENT);
Expand Down Expand Up @@ -240,7 +239,7 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
span.setTag(BLOCKED_TAG, true);
throw new AIGuardAbortError(action, reason, tags);
}
return new Evaluation(action, reason);
return new Evaluation(action, reason, tags);
}
} catch (AIGuardAbortError e) {
span.addThrowable(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,18 @@ class AIGuardInternalTests extends DDSpecification {
1 * span.addThrowable(_ as AIGuard.AIGuardAbortError)
}

receivedMeta.messages == suite.messages
if (suite.tags) {
receivedMeta.attack_categories == suite.tags
}
assertMeta(receivedMeta, suite)
assertRequest(request, suite.messages)
if (throwAbortError) {
error instanceof AIGuard.AIGuardAbortError
error.action == suite.action
error.reason == suite.reason
error.tags == suite.tags
} else {
error == null
eval.action == suite.action
eval.reason == suite.reason
eval.tags == suite.tags
}
assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false')

Expand All @@ -221,19 +220,7 @@ class AIGuardInternalTests extends DDSpecification {
void 'test evaluate with API errors'() {
given:
final errors = [[status: 400, title: 'Bad request']]
Request request = null
final call = Mock(Call) {
execute() >> {
return mockResponse(request, 404, [errors: errors])
}
}
final client = Mock(OkHttpClient) {
newCall(_ as Request) >> {
request = (Request) it[0]
return call
}
}
final aiguard = new AIGuardInternal(URL, HEADERS, client)
final aiguard = mockClient(404, [errors: errors])

when:
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
Expand All @@ -247,19 +234,7 @@ class AIGuardInternalTests extends DDSpecification {

void 'test evaluate with invalid JSON'() {
given:
Request request = null
final call = Mock(Call) {
execute() >> {
return mockResponse(request, 200, [bad: 'This is an invalid response'])
}
}
final client = Mock(OkHttpClient) {
newCall(_ as Request) >> {
request = (Request) it[0]
return call
}
}
final aiguard = new AIGuardInternal(URL, HEADERS, client)
final aiguard = mockClient(200, [bad: 'This is an invalid response'])

when:
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
Expand All @@ -272,19 +247,7 @@ class AIGuardInternalTests extends DDSpecification {

void 'test evaluate with missing action'() {
given:
Request request = null
final call = Mock(Call) {
execute() >> {
return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]])
}
}
final client = Mock(OkHttpClient) {
newCall(_ as Request) >> {
request = (Request) it[0]
return call
}
}
final aiguard = new AIGuardInternal(URL, HEADERS, client)
final aiguard = mockClient(200, [data: [attributes: [reason: 'I miss something']]])

when:
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
Expand All @@ -297,19 +260,7 @@ class AIGuardInternalTests extends DDSpecification {

void 'test evaluate with non JSON response'() {
given:
Request request = null
final call = Mock(Call) {
execute() >> {
return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]])
}
}
final client = Mock(OkHttpClient) {
newCall(_ as Request) >> {
request = (Request) it[0]
return call
}
}
final aiguard = new AIGuardInternal(URL, HEADERS, client)
final aiguard = mockClient(200, 'I am no JSON')

when:
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
Expand All @@ -322,19 +273,7 @@ class AIGuardInternalTests extends DDSpecification {

void 'test evaluate with empty response'() {
given:
Request request = null
final call = Mock(Call) {
execute() >> {
return mockResponse(request, 200, null)
}
}
final client = Mock(OkHttpClient) {
newCall(_ as Request) >> {
request = (Request) it[0]
return call
}
}
final aiguard = new AIGuardInternal(URL, HEADERS, client)
final aiguard = mockClient(200, null)

when:
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
Expand All @@ -348,19 +287,7 @@ class AIGuardInternalTests extends DDSpecification {
void 'test message length truncation'() {
given:
final maxMessages = Config.get().getAiGuardMaxMessagesLength()
Request request = null
final call = Mock(Call) {
execute() >> {
return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]])
}
}
final client = Mock(OkHttpClient) {
newCall(_ as Request) >> {
request = (Request) it[0]
return call
}
}
final aiguard = new AIGuardInternal(URL, HEADERS, client)
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
final messages = (0..maxMessages)
.collect { AIGuard.Message.message('user', "This is a prompt: ${it}") }
.toList()
Expand All @@ -380,19 +307,7 @@ class AIGuardInternalTests extends DDSpecification {
void 'test message content truncation'() {
given:
final maxContent = Config.get().getAiGuardMaxContentSize()
Request request = null
final call = Mock(Call) {
execute() >> {
return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]])
}
}
final client = Mock(OkHttpClient) {
newCall(_ as Request) >> {
request = (Request) it[0]
return call
}
}
final aiguard = new AIGuardInternal(URL, HEADERS, client)
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
final message = AIGuard.Message.message("user", (0..maxContent).collect { 'A' }.join())

when:
Expand Down Expand Up @@ -426,23 +341,7 @@ class AIGuardInternalTests extends DDSpecification {

void 'test missing tool name'() {
given:
def request
final call = Mock(Call) {
execute() >> {
return mockResponse(
request,
200,
[data: [attributes: [action: 'ALLOW', reason: 'Just do it']]]
)
}
}
final client = Mock(OkHttpClient) {
newCall(_ as Request) >> {
request = (Request) it[0]
return call
}
}
final aiguard = new AIGuardInternal(URL, HEADERS, client)
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]])

when:
aiguard.evaluate([AIGuard.Message.tool('call_1', 'Content')], AIGuard.Options.DEFAULT)
Expand All @@ -460,6 +359,57 @@ class AIGuardInternalTests extends DDSpecification {
thrown(IllegalArgumentException)
}

void 'test message immutability'() {
given:
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]])
final messages = [
new AIGuard.Message(
"assistant",
null,
[AIGuard.ToolCall.toolCall('call_1', 'execute_shell', '{"cmd": "ls -lah"}')],
null
)
]
Map<String, Object> receivedMeta

when:
aiguard.evaluate(messages, AIGuard.Options.DEFAULT)

then:
1 * span.finish() >> {
// modify the messages before serialization
messages.first().toolCalls.add(
AIGuard.ToolCall.toolCall('call_2', 'execute_shell', '{"cmd": "rm -rf"}')
)
messages.add(AIGuard.Message.tool('call_1', 'dir1, dir2, dir3'))
}
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> {
receivedMeta = it[1] as Map<String, Object>
return span
}
final metaStructMessages = receivedMeta.messages as List<AIGuard.Message>
metaStructMessages.size() != messages.size()
metaStructMessages.size() == 1
metaStructMessages.first().toolCalls.size() != messages.first().toolCalls.size()
metaStructMessages.first().toolCalls.size() == 1
}

private AIGuardInternal mockClient(final int status, final Object response) {
def request
final call = Stub(Call) {
execute() >> {
return mockResponse(request, status, response)
}
}
final client = Stub(OkHttpClient) {
newCall(_ as Request) >> {
request = (Request) it[0]
return call
}
}
return new AIGuardInternal(URL, HEADERS, client)
}

private static assertTelemetry(final String metric, final String...tags) {
final metrics = WafMetricCollector.get().with {
prepareMetrics()
Expand All @@ -475,6 +425,16 @@ class AIGuardInternalTests extends DDSpecification {
return true
}

private static assertMeta(final Map<String, Object> meta, final TestSuite suite) {
if (suite.tags) {
assert meta.attack_categories == suite.tags
}
final receivedMessages = snakeCaseJson(meta.messages)
final expectedMessages = snakeCaseJson(suite.messages)
JSONAssert.assertEquals(expectedMessages, receivedMessages, JSONCompareMode.NON_EXTENSIBLE)
return true
}

private static assertRequest(final Request request, final List<AIGuard.Message> messages) {
assert request.url() == URL
assert request.method() == 'POST'
Expand Down Expand Up @@ -556,7 +516,8 @@ class AIGuardInternalTests extends DDSpecification {
", reason='" + reason + '\'' +
", blocking=" + blocking +
", target='" + target + '\'' +
", messages=" + messages +
", messages=" + messages + '\'' +
", tags=" + tags +
'}'
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,19 @@ public static class Evaluation {

final Action action;
final String reason;
final List<String> tags;

/**
* Creates a new evaluation result.
*
* @param action the recommended action for the evaluated content
* @param reason human-readable explanation for the decision
* @param tags list of tags associated with the evaluation (e.g. indirect-prompt-injection)
*/
public Evaluation(final Action action, final String reason) {
public Evaluation(final Action action, final String reason, final List<String> tags) {
this.action = action;
this.reason = reason;
this.tags = tags;
}

/**
Expand All @@ -172,6 +175,15 @@ public Action getAction() {
public String getReason() {
return reason;
}

/**
* Returns the list of tags associated with the evaluation (e.g. indirect-prompt-injection)
*
* @return list of tags.
*/
public List<String> getTags() {
return tags;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package datadog.trace.api.aiguard.noop;

import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW;
import static java.util.Collections.emptyList;

import datadog.trace.api.aiguard.AIGuard.Evaluation;
import datadog.trace.api.aiguard.AIGuard.Message;
Expand All @@ -12,6 +13,6 @@ public final class NoOpEvaluator implements Evaluator {

@Override
public Evaluation evaluate(final List<Message> messages, final Options options) {
return new Evaluation(ALLOW, "AI Guard is not enabled");
return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList());
}
}