Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.model.tool;

import java.util.ArrayList;
import java.util.List;

import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -190,4 +191,34 @@ void whenEqualsAndHashCodeAreConsistent() {
assertThat(result1.hashCode()).isEqualTo(result2.hashCode());
}

@Test
void whenConversationHistoryIsImmutableList() {
List<Message> conversationHistory = List.of(new org.springframework.ai.chat.messages.UserMessage("Hello"),
new org.springframework.ai.chat.messages.UserMessage("Hi!"));

var result = DefaultToolExecutionResult.builder()
.conversationHistory(conversationHistory)
.returnDirect(false)
.build();

assertThat(result.conversationHistory()).hasSize(2);
assertThat(result.conversationHistory()).isEqualTo(conversationHistory);
}

@Test
void whenReturnDirectIsChangedMultipleTimes() {
var conversationHistory = new ArrayList<Message>();
conversationHistory.add(new org.springframework.ai.chat.messages.UserMessage("Test"));

var builder = DefaultToolExecutionResult.builder()
.conversationHistory(conversationHistory)
.returnDirect(true)
.returnDirect(false)
.returnDirect(true);

var result = builder.build();

assertThat(result.returnDirect()).isTrue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.model.tool;

import java.util.Collections;
import java.util.List;

import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -75,6 +76,32 @@ void whenTestMethodCalledDirectly() {
assertThat(result).isTrue();
}

@Test
void whenChatResponseHasEmptyGenerations() {
ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate();
ChatOptions promptOptions = ChatOptions.builder().build();
ChatResponse emptyResponse = new ChatResponse(Collections.emptyList());

boolean result = predicate.isToolExecutionRequired(promptOptions, emptyResponse);
assertThat(result).isTrue();
}

@Test
void whenChatOptionsHasModel() {
ModelCheckingPredicate predicate = new ModelCheckingPredicate();

ChatOptions optionsWithModel = ChatOptions.builder().model("gpt-4").build();

ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("test"))));

boolean result = predicate.isToolExecutionRequired(optionsWithModel, chatResponse);
assertThat(result).isTrue();

ChatOptions optionsWithoutModel = ChatOptions.builder().build();
result = predicate.isToolExecutionRequired(optionsWithoutModel, chatResponse);
assertThat(result).isFalse();
}

/**
* Test implementation of {@link ToolExecutionEligibilityPredicate} that always
* returns true.
Expand All @@ -88,4 +115,13 @@ public boolean test(ChatOptions promptOptions, ChatResponse chatResponse) {

}

private static class ModelCheckingPredicate implements ToolExecutionEligibilityPredicate {

@Override
public boolean test(ChatOptions promptOptions, ChatResponse chatResponse) {
return promptOptions.getModel() != null && !promptOptions.getModel().isEmpty();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.ai.chat.messages.UserMessage;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;

/**
* Unit tests for {@link ToolExecutionResult}.
Expand Down Expand Up @@ -80,4 +81,63 @@ void whenMultipleToolCallsThenMultipleGenerations() {
assertThat(generations.get(1).getMetadata().getFinishReason()).isEqualTo(ToolExecutionResult.FINISH_REASON);
}

@Test
void whenEmptyConversationHistoryThenThrowsException() {
var toolExecutionResult = ToolExecutionResult.builder().conversationHistory(List.of()).build();

assertThatThrownBy(() -> ToolExecutionResult.buildGenerations(toolExecutionResult))
.isInstanceOf(ArrayIndexOutOfBoundsException.class);
}

@Test
void whenToolResponseWithEmptyResponseListThenEmptyGenerations() {
var toolExecutionResult = ToolExecutionResult.builder()
.conversationHistory(
List.of(new AssistantMessage("Processing request"), new ToolResponseMessage(List.of())))
.build();

var generations = ToolExecutionResult.buildGenerations(toolExecutionResult);

assertThat(generations).isEmpty();
}

@Test
void whenToolResponseWithNullContentThenGenerationWithNullText() {
var toolExecutionResult = ToolExecutionResult.builder()
.conversationHistory(
List.of(new ToolResponseMessage(List.of(new ToolResponseMessage.ToolResponse("1", "tool", null)))))
.build();

var generations = ToolExecutionResult.buildGenerations(toolExecutionResult);

assertThat(generations).hasSize(1);
assertThat(generations.get(0).getOutput().getText()).isNull();
}

@Test
void whenToolResponseWithEmptyStringContentThenGenerationWithEmptyText() {
var toolExecutionResult = ToolExecutionResult.builder()
.conversationHistory(
List.of(new ToolResponseMessage(List.of(new ToolResponseMessage.ToolResponse("1", "tool", "")))))
.build();

var generations = ToolExecutionResult.buildGenerations(toolExecutionResult);

assertThat(generations).hasSize(1);
assertThat(generations.get(0).getOutput().getText()).isEmpty();
assertThat((String) generations.get(0).getMetadata().get(ToolExecutionResult.METADATA_TOOL_NAME))
.isEqualTo("tool");
}

@Test
void whenBuilderCalledWithoutConversationHistoryThenThrowsException() {
var toolExecutionResult = ToolExecutionResult.builder().build();

assertThatThrownBy(() -> ToolExecutionResult.buildGenerations(toolExecutionResult))
.isInstanceOf(ArrayIndexOutOfBoundsException.class);

assertThat(toolExecutionResult.conversationHistory()).isNotNull();
assertThat(toolExecutionResult.conversationHistory()).isEmpty();
}

}