Skip to content

Commit

Permalink
#5211 - AI assistant prototype
Browse files Browse the repository at this point in the history
- UI style changes
- Some classes renameed
- Added time info in assistant messages
  • Loading branch information
reckart committed Jan 3, 2025
1 parent 2c20c88 commit aba1495
Show file tree
Hide file tree
Showing 24 changed files with 449 additions and 157 deletions.
14 changes: 14 additions & 0 deletions inception/inception-assistant/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
</dependency>

<!-- DEPENDENCIES FOR TESTING -->
<dependency>
Expand Down Expand Up @@ -267,6 +271,16 @@
<artifactId>inception-documents</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-log</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-search-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
import java.util.List;

import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.inception.assistant.model.MAssistantTextMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MTextMessage;

public interface AssistantService
{
List<MAssistantTextMessage> getAllChatMessages(String aSessionOwner, Project aProject);
List<MTextMessage> getAllChatMessages(String aSessionOwner, Project aProject);

List<MAssistantTextMessage> getChatMessages(String aSessionOwner, Project aProject);
List<MTextMessage> getChatMessages(String aSessionOwner, Project aProject);

void processUserMessage(String aSessionOwner, Project aProject,
MAssistantTextMessage aMessage);
MTextMessage aMessage);

void clearConversation(String aSessionOwner, Project aProject);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
*/
package de.tudarmstadt.ukp.inception.assistant;

import static de.tudarmstadt.ukp.inception.assistant.model.MAssistantChatRoles.ASSISTANT;
import static de.tudarmstadt.ukp.inception.assistant.model.MAssistantChatRoles.SYSTEM;
import static de.tudarmstadt.ukp.inception.assistant.model.MChatRoles.ASSISTANT;
import static de.tudarmstadt.ukp.inception.assistant.model.MChatRoles.SYSTEM;
import static java.lang.Math.floorDiv;
import static java.lang.String.join;
import static java.util.Arrays.asList;
Expand Down Expand Up @@ -48,17 +48,14 @@
import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.security.model.User;
import de.tudarmstadt.ukp.inception.assistant.config.AssistantProperties;
import de.tudarmstadt.ukp.inception.assistant.model.MAssistantClearCommand;
import de.tudarmstadt.ukp.inception.assistant.model.MAssistantMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MAssistantTextMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MRemoveConversationCommand;
import de.tudarmstadt.ukp.inception.assistant.model.MTextMessage;
import de.tudarmstadt.ukp.inception.assistant.retriever.RetrieverExtensionPoint;
import de.tudarmstadt.ukp.inception.project.api.event.AfterProjectRemovedEvent;
import de.tudarmstadt.ukp.inception.project.api.event.BeforeProjectRemovedEvent;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaChatMessage;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaChatRequest;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaChatResponse;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaClient;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaOptions;

public class AssistantServiceImpl
implements AssistantService
Expand Down Expand Up @@ -125,36 +122,36 @@ public void onAfterProjectRemoved(AfterProjectRemovedEvent aEvent)
}

@Override
public List<MAssistantTextMessage> getAllChatMessages(String aSessionOwner, Project aProject)
public List<MTextMessage> getAllChatMessages(String aSessionOwner, Project aProject)
{
var state = getState(aSessionOwner, aProject);

return state.getMessages().stream() //
.filter(MAssistantTextMessage.class::isInstance) //
.map(MAssistantTextMessage.class::cast) //
.filter(MTextMessage.class::isInstance) //
.map(MTextMessage.class::cast) //
.toList();
}

@Override
public List<MAssistantTextMessage> getChatMessages(String aSessionOwner, Project aProject)
public List<MTextMessage> getChatMessages(String aSessionOwner, Project aProject)
{
var state = getState(aSessionOwner, aProject);

// In dev mode, we also record internal messages, so we need to filter them out again here
return state.getMessages().stream() //
.filter(MAssistantTextMessage.class::isInstance) //
.map(MAssistantTextMessage.class::cast) //
.filter(MTextMessage.class::isInstance) //
.map(MTextMessage.class::cast) //
.filter(msg -> !msg.internal()) //
.toList();
}

void recordMessage(String aSessionOwner, Project aProject, MAssistantMessage aMessage)
void recordMessage(String aSessionOwner, Project aProject, MMessage aMessage)
{
var state = getState(aSessionOwner, aProject);
state.addMessage(aMessage);
}

void dispatchMessage(String aSessionOwner, Project aProject, MAssistantMessage aMessage)
void dispatchMessage(String aSessionOwner, Project aProject, MMessage aMessage)
{
// LOG.trace("Dispatching assistant message: {}", aMessage);
var topic = AssistantWebsocketController.getChannel(aProject);
Expand All @@ -169,20 +166,21 @@ public void clearConversation(String aSessionOwner, Project aProject)
&& Objects.equals(aProject.getId(), key.projectId));
}

dispatchMessage(aSessionOwner, aProject, new MAssistantClearCommand());
dispatchMessage(aSessionOwner, aProject, new MRemoveConversationCommand());
}

@Override
public void processUserMessage(String aSessionOwner, Project aProject,
MAssistantTextMessage aMessage)
MTextMessage aMessage)
{
var assistant = new ChatContext(properties, ollamaClient, aSessionOwner, aProject);

// Dispatch message early so the front-end can enter waiting state
dispatchMessage(aSessionOwner, aProject, aMessage);

var responseId = UUID.randomUUID();
try {
var systemMessages = generateSystemMessages(aSessionOwner, aProject, aMessage);
var transientMessages = generateTransientMessages(aSessionOwner, aProject, aMessage);
var systemMessages = generateSystemMessages();
var transientMessages = generateTransientMessages(assistant, aMessage);
var conversationMessages = getChatMessages(aSessionOwner, aProject);

// We record the message only now to ensure it is not included in the listMessages above
Expand All @@ -199,34 +197,15 @@ public void processUserMessage(String aSessionOwner, Project aProject,
transientMessages, conversationMessages, aMessage,
properties.getChat().getContextLength());

var request = OllamaChatRequest.builder() //
.withModel(properties.getChat().getModel()) //
.withStream(true) //
.withMessages(recentConversation.stream() //
.map(msg -> new OllamaChatMessage(msg.role(), msg.message())) //
.toList()) //
.withOption(OllamaOptions.NUM_CTX, properties.getChat().getContextLength()) //
.withOption(OllamaOptions.TOP_P, properties.getChat().getTopP()) //
.withOption(OllamaOptions.TOP_K, properties.getChat().getTopK()) //
.withOption(OllamaOptions.REPEAT_PENALTY,
properties.getChat().getRepeatPenalty()) //
.withOption(OllamaOptions.TEMPERATURE, properties.getChat().getTemperature()) //
.build();

var response = ollamaClient.generate(properties.getUrl(), request,
r -> handleStreamedMessageFragment(aSessionOwner, aProject, responseId, r));

var responseMessage = MAssistantTextMessage.builder() //
.withId(responseId) //
.withRole(ASSISTANT) //
.withMessage(response) //
.build();
var responseMessage = assistant.generate(recentConversation,
(id, r) -> handleStreamedMessageFragment(aSessionOwner, aProject, id, r));

recordMessage(aSessionOwner, aProject, responseMessage);
dispatchMessage(aSessionOwner, aProject, responseMessage);
}
catch (IOException e) {
var errorMessage = MAssistantTextMessage.builder() //
.withId(responseId) //
var errorMessage = MTextMessage.builder() //
.withActor("Error")
.withRole(SYSTEM) //
.withMessage("Error: " + e.getMessage()) //
.build();
Expand All @@ -238,8 +217,9 @@ public void processUserMessage(String aSessionOwner, Project aProject,
private void handleStreamedMessageFragment(String aSessionOwner, Project aProject,
UUID responseId, OllamaChatResponse r)
{
var responseMessage = MAssistantTextMessage.builder() //
var responseMessage = MTextMessage.builder() //
.withId(responseId) //
.withActor(properties.getNickname()) //
.withRole(ASSISTANT) //
.withMessage(r.getMessage().content()) //
.notDone() //
Expand All @@ -248,23 +228,21 @@ private void handleStreamedMessageFragment(String aSessionOwner, Project aProjec
dispatchMessage(aSessionOwner, aProject, responseMessage);
}

private List<MAssistantTextMessage> generateTransientMessages(String aSessionOwner,
Project aProject, MAssistantTextMessage aMessage)
private List<MTextMessage> generateTransientMessages(ChatContext aAssistant, MTextMessage aMessage)
{
var transientMessages = new ArrayList<MAssistantTextMessage>();
var transientMessages = new ArrayList<MTextMessage>();

for (var retriever : retrieverExtensionPoint.getExtensions(aProject)) {
transientMessages.addAll(retriever.retrieve(aSessionOwner, aProject, aMessage));
for (var retriever : retrieverExtensionPoint.getExtensions(aAssistant.getProject())) {
transientMessages.addAll(retriever.retrieve(aAssistant, aMessage));
}

return transientMessages;
}

private List<MAssistantTextMessage> generateSystemMessages(String aSessionOwner,
Project aProject, MAssistantTextMessage aMessage)
private List<MTextMessage> generateSystemMessages()
{
var primeDirectives = asList(
"You are Dominick, a helpful assistant within the annotation tool INCEpTION.",
"You are " + properties.getNickname() + ", a helpful assistant within the annotation tool INCEpTION.",
"INCEpTION always refers to the annotation tool, never anything else such as the movie.",
"Do not include references to INCEpTION unless the user explicitly asks about the environment itself.",
"If the source of an information is known, provide it in your response."
Expand All @@ -277,16 +255,16 @@ private List<MAssistantTextMessage> generateSystemMessages(String aSessionOwner,
// """
);

return asList(MAssistantTextMessage.builder() //
return asList(MTextMessage.builder() //
.withRole(SYSTEM).internal() //
.withMessage(join("\n\n", primeDirectives)) //
.build());
}

private List<MAssistantTextMessage> limitConversationToContextLength(
List<MAssistantTextMessage> aSystemMessages,
List<MAssistantTextMessage> aTransientMessages,
List<MAssistantTextMessage> aRecentMessages, MAssistantTextMessage aLatestUserMessage,
private List<MTextMessage> limitConversationToContextLength(
List<MTextMessage> aSystemMessages,
List<MTextMessage> aTransientMessages,
List<MTextMessage> aRecentMessages, MTextMessage aLatestUserMessage,
int aContextLength)
{
// We don't really know which tokenizer the LLM uses. In case
Expand All @@ -298,8 +276,8 @@ private List<MAssistantTextMessage> limitConversationToContextLength(
"Unknown encoding: " + properties.getChat().getEncoding()));
var limit = floorDiv(aContextLength * 90, 100);

var headMessages = new ArrayList<MAssistantTextMessage>();
var tailMessages = new LinkedList<MAssistantTextMessage>();
var headMessages = new ArrayList<MTextMessage>();
var tailMessages = new LinkedList<MTextMessage>();

var totalMessages = aSystemMessages.size() + aTransientMessages.size()
+ aRecentMessages.size() + 1;
Expand Down Expand Up @@ -402,24 +380,24 @@ private void clearState(String aSessionOwner)

private static class AssistentState
{
private LinkedList<MAssistantMessage> messages = new LinkedList<>();
private LinkedList<MMessage> messages = new LinkedList<>();

public List<MAssistantMessage> getMessages()
public List<MMessage> getMessages()
{
return new ArrayList<>(messages);
}

public void addMessage(MAssistantMessage aMessage)
public void addMessage(MMessage aMessage)
{
synchronized (messages) {
var found = false;
if (aMessage instanceof MAssistantTextMessage textMsg) {
if (aMessage instanceof MTextMessage textMsg) {
var i = messages.listIterator(messages.size());

// If a message with the same ID already exists, update it
while (i.hasPrevious() && !found) {
var m = i.previous();
if (m instanceof MAssistantTextMessage existingTextMsg) {
if (m instanceof MTextMessage existingTextMsg) {
if (Objects.equals(existingTextMsg.id(), textMsg.id())) {
if (textMsg.done()) {
i.set(textMsg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
package de.tudarmstadt.ukp.inception.assistant;

import static de.tudarmstadt.ukp.inception.assistant.model.MAssistantChatRoles.USER;
import static de.tudarmstadt.ukp.inception.assistant.model.MChatRoles.USER;
import static de.tudarmstadt.ukp.inception.websocket.config.WebSocketConstants.PARAM_PROJECT;

import java.io.IOException;
Expand All @@ -39,7 +39,8 @@
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;

import de.tudarmstadt.ukp.inception.assistant.model.MAssistantTextMessage;
import de.tudarmstadt.ukp.clarin.webanno.security.UserDao;
import de.tudarmstadt.ukp.inception.assistant.model.MTextMessage;
import de.tudarmstadt.ukp.inception.project.api.ProjectService;
import jakarta.servlet.ServletContext;

Expand All @@ -52,17 +53,20 @@ public class AssistantWebsocketControllerImpl
{
private final AssistantService assistantService;
private final ProjectService projectService;
private final UserDao userService;

@Autowired
public AssistantWebsocketControllerImpl(ServletContext aServletContext,
SimpMessagingTemplate aMsgTemplate, AssistantService aAssistantService, ProjectService aProjectService)
SimpMessagingTemplate aMsgTemplate, AssistantService aAssistantService, ProjectService aProjectService
, UserDao aUserService)
{
assistantService = aAssistantService;
projectService = aProjectService;
userService = aUserService;
}

@SubscribeMapping(PROJECT_ASSISTANT_TOPIC_TEMPLATE)
public List<MAssistantTextMessage> onSubscribeToAssistantMessages(SimpMessageHeaderAccessor aHeaderAccessor,
public List<MTextMessage> onSubscribeToAssistantMessages(SimpMessageHeaderAccessor aHeaderAccessor,
Principal aPrincipal, //
@DestinationVariable(PARAM_PROJECT) long aProjectId)
throws IOException
Expand All @@ -79,12 +83,17 @@ public void onUserMessage(SimpMessageHeaderAccessor aHeaderAccessor,
throws IOException
{
var project = projectService.getProject(aProjectId);
var message = MAssistantTextMessage.builder().withRole(USER).withMessage(aMessage).build();
var user = userService.get(aPrincipal.getName());
var message = MTextMessage.builder() //
.withActor(user.getUiName()) //
.withRole(USER) //
.withMessage(aMessage) //
.build();
assistantService.processUserMessage(aPrincipal.getName(), project, message);
}

@SendTo(PROJECT_ASSISTANT_TOPIC_TEMPLATE)
public MAssistantTextMessage send(MAssistantTextMessage aUpdate)
public MTextMessage send(MTextMessage aUpdate)
{
return aUpdate;
}
Expand Down
Loading

0 comments on commit aba1495

Please sign in to comment.