Skip to content

Commit

Permalink
Merge pull request #275 from devoxx/issue-274
Browse files Browse the repository at this point in the history
Fix #274 + Load history conversation into chat memory
  • Loading branch information
stephanj authored Sep 5, 2024
2 parents df43af0 + 9a3342b commit 44fe023
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 43 deletions.
55 changes: 36 additions & 19 deletions src/main/java/com/devoxx/genie/service/ChatMemoryService.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package com.devoxx.genie.service;
import com.devoxx.genie.model.conversation.Conversation;
import com.devoxx.genie.model.request.ChatMessageContext;
import com.devoxx.genie.ui.listener.ChatMemorySizeListener;
import com.devoxx.genie.ui.topic.AppTopics;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.project.Project;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import org.jetbrains.annotations.NotNull;
Expand All @@ -15,15 +19,15 @@

public class ChatMemoryService implements ChatMemorySizeListener {

private final Map<Project, MessageWindowChatMemory> projectMemories = new ConcurrentHashMap<>();
private final Map<String, MessageWindowChatMemory> projectMemories = new ConcurrentHashMap<>();
private final InMemoryChatMemoryStore inMemoryChatMemoryStore = new InMemoryChatMemoryStore();

public static ChatMemoryService getInstance() {
return ApplicationManager.getApplication().getService(ChatMemoryService.class);
}

public void init(Project project) {
createChatMemory(project, DevoxxGenieSettingsServiceProvider.getInstance().getChatMemorySize());
public void init(@NotNull Project project) {
createChatMemory(project.getLocationHash(), DevoxxGenieSettingsServiceProvider.getInstance().getChatMemorySize());
createChangeListener();
}

Expand All @@ -33,51 +37,64 @@ private void createChangeListener() {
.subscribe(AppTopics.CHAT_MEMORY_SIZE_TOPIC, this);
}

public void clear(Project project) {
projectMemories.get(project).clear();
public void clear(@NotNull Project project) {
projectMemories.get(project.getLocationHash()).clear();
}

public void add(Project project, ChatMessage chatMessage) {
projectMemories.get(project).add(chatMessage);
public void add(@NotNull Project project, ChatMessage chatMessage) {
projectMemories.get(project.getLocationHash()).add(chatMessage);
}

public void remove(@NotNull ChatMessageContext chatMessageContext) {
Project project = chatMessageContext.getProject();
List<ChatMessage> messages = projectMemories.get(project).messages();
List<ChatMessage> messages = projectMemories.get(project.getLocationHash()).messages();
messages.remove(chatMessageContext.getAiMessage());
messages.remove(chatMessageContext.getUserMessage());
projectMemories.get(project).clear();

// Remove the conversation from the storage service
projectMemories.get(project.getLocationHash()).clear();
messages.forEach(message -> add(project, message));
}

public void removeLast(Project project) {
List<ChatMessage> messages = projectMemories.get(project).messages();
public void removeLast(@NotNull Project project) {
List<ChatMessage> messages = projectMemories.get(project.getLocationHash()).messages();
if (!messages.isEmpty()) {
messages.remove(messages.size() - 1);
projectMemories.get(project).clear();
projectMemories.get(project.getLocationHash()).clear();
messages.forEach(message -> add(project, message));
}
}

public List<ChatMessage> messages(Project project) {
return projectMemories.get(project).messages();
public List<ChatMessage> messages(@NotNull Project project) {
return projectMemories.get(project.getLocationHash()).messages();
}

public boolean isEmpty(Project project) {
return projectMemories.get(project).messages().isEmpty();
public boolean isEmpty(@NotNull Project project) {
return projectMemories.get(project.getLocationHash()).messages().isEmpty();
}

@Override
public void onChatMemorySizeChanged(int chatMemorySize) {
projectMemories.forEach((project, memory) -> createChatMemory(project, chatMemorySize));
}

private void createChatMemory(Project project, int chatMemorySize) {
private void createChatMemory(@NotNull String projectHash, int chatMemorySize) {
MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
.id("devoxxgenie-" + project.getLocationHash())
.id("devoxxgenie-" + projectHash)
.chatMemoryStore(inMemoryChatMemoryStore)
.maxMessages(chatMemorySize)
.build();
projectMemories.put(project, chatMemory);
projectMemories.put(projectHash, chatMemory);
}

public void restoreConversation(@NotNull Project project, @NotNull Conversation conversation) {
clear(project);
for (com.devoxx.genie.model.conversation.ChatMessage message : conversation.getMessages()) {
if (message.isUser()) {
add(project, new UserMessage(message.getContent()));
} else {
add(project, new AiMessage(message.getContent()));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,22 @@ public void loadState(@NotNull State state) {
public void addConversation(@NotNull Project project, Conversation conversation) {
if (conversation != null &&
(!conversation.getMessages().isEmpty() || (conversation.getTitle() != null && !conversation.getTitle().trim().isEmpty()))) {
String projectId = project.getLocationHash();
myState.conversations.computeIfAbsent(projectId, k -> new ArrayList<>()).add(conversation);
myState.conversations.computeIfAbsent(project.getLocationHash(), k -> new ArrayList<>()).add(conversation);
saveState();
}
}

public @NotNull List<Conversation> getConversations(@NotNull Project project) {
String projectId = project.getLocationHash();
return new ArrayList<>(myState.conversations.getOrDefault(projectId, new ArrayList<>()));
return new ArrayList<>(myState.conversations.getOrDefault(project.getLocationHash(), new ArrayList<>()));
}

public void removeConversation(@NotNull Project project, Conversation conversation) {
String projectId = project.getLocationHash();
myState.conversations.getOrDefault(projectId, new ArrayList<>()).remove(conversation);
myState.conversations.getOrDefault(project.getLocationHash(), new ArrayList<>()).remove(conversation);
saveState();
}

public void clearAllConversations(@NotNull Project project) {
String projectId = project.getLocationHash();
myState.conversations.remove(projectId);
myState.conversations.remove(project.getLocationHash());
saveState();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ public void startNewConversation() {
FileListManager.getInstance().clear();
ChatMemoryService.getInstance().clear(project);

// TODO Set title based on first question
chatService.startNewConversation("");

SwingUtilities.invokeLater(() -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.devoxx.genie.ui.panel;

import com.devoxx.genie.model.conversation.Conversation;
import com.devoxx.genie.service.ChatMemoryService;
import com.devoxx.genie.service.ConversationStorageService;
import com.devoxx.genie.ui.component.JHoverButton;
import com.devoxx.genie.ui.listener.ConversationSelectionListener;
Expand Down Expand Up @@ -72,6 +73,7 @@ public void loadConversations() {
titleLabel.addMouseListener( new java.awt.event.MouseAdapter() {
public void mouseClicked(java.awt.event.MouseEvent evt) {
conversationSelectionListener.onConversationSelected(conversation);
updateChatMemory(conversation);
}
});
titleLabel.setBorder(JBUI.Borders.empty(5, 8));
Expand All @@ -98,6 +100,10 @@ public void mouseClicked(java.awt.event.MouseEvent evt) {
return rowPanel;
}

private void updateChatMemory(Conversation conversation) {
ChatMemoryService.getInstance().restoreConversation(project, conversation);
}

private @NotNull String formatTimeSince(String timestamp) {
try {
LocalDateTime messageTime = LocalDateTime.parse(timestamp);
Expand Down
15 changes: 9 additions & 6 deletions src/main/java/com/devoxx/genie/ui/panel/PromptOutputPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import javax.swing.*;
import java.awt.*;
import java.util.ResourceBundle;
import java.util.UUID;

import static javax.swing.ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER;
import static javax.swing.ScrollPaneConstants.VERTICAL_SCROLLBAR_AS_NEEDED;
Expand Down Expand Up @@ -133,23 +134,25 @@ public void updateHelpText() {

public void displayConversation(Project project, Conversation conversation) {
SwingUtilities.invokeLater(() -> {
container.removeAll();
String conversationId = UUID.randomUUID().toString();
for (ChatMessage message : conversation.getMessages()) {
conversation.setId(conversationId);
ChatMessageContext chatMessageContext = createChatMessageContext(project, conversation, message);
if (message.isUser()) {
addUserPrompt(createChatMessageContext(project, message, conversation));
addUserPrompt(chatMessageContext);
} else {
addChatResponse(createChatMessageContext(project, message, conversation));
addChatResponse(chatMessageContext);
}
}
scrollToBottom();
});
}

private ChatMessageContext createChatMessageContext(Project project,
@NotNull ChatMessage message,
@NotNull Conversation conversation) {
@NotNull Conversation conversation,
@NotNull ChatMessage message) {
return ChatMessageContext.builder()
.name(String.valueOf(System.currentTimeMillis()))
.name(conversation.getId())
.project(project)
.userPrompt(message.isUser() ? message.getContent() : "")
.aiMessage(message.isUser() ? null : AiMessage.aiMessage(message.getContent()))
Expand Down
26 changes: 19 additions & 7 deletions src/main/java/com/devoxx/genie/ui/panel/UserPromptPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;

import static com.devoxx.genie.ui.util.DevoxxGenieIconsUtil.DevoxxIcon;
import static com.devoxx.genie.ui.util.DevoxxGenieIconsUtil.TrashIcon;
Expand Down Expand Up @@ -38,7 +39,11 @@ public UserPromptPanel(JPanel container,

// User prompt setup
JEditorPane htmlJEditorPane =
JEditorPaneUtils.createHtmlJEditorPane(chatMessageContext.getUserPrompt(), null, StyleSheetsFactory.createParagraphStyleSheet());
JEditorPaneUtils.createHtmlJEditorPane(
chatMessageContext.getUserPrompt(),
null,
StyleSheetsFactory.createParagraphStyleSheet()
);

add(headerPanel, BorderLayout.NORTH);
add(htmlJEditorPane, BorderLayout.CENTER);
Expand Down Expand Up @@ -72,12 +77,19 @@ public UserPromptPanel(JPanel container,
*
* @param chatMessageContext the chat message context
*/
private void removeChat(ChatMessageContext chatMessageContext) {

// Get all container components and delete by name
stream(container.getComponents())
.filter(c -> c.getName() != null && c.getName().equals(chatMessageContext.getName()))
.forEach(container::remove);
private void removeChat(@NotNull ChatMessageContext chatMessageContext) {
String nameToRemove = chatMessageContext.getName();
java.util.List<Component> componentsToRemove = new ArrayList<>();

for (Component c : container.getComponents()) {
if (c.getName() != null && c.getName().equals(nameToRemove)) {
componentsToRemove.add(c);
}
}

for (Component c : componentsToRemove) {
container.remove(c);
}

// Repaint the container
container.revalidate();
Expand Down
4 changes: 2 additions & 2 deletions src/main/resources/application.properties
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#Wed Sep 04 13:52:11 CEST 2024
version=0.2.17
#Thu Sep 05 09:00:29 CEST 2024
version=0.2.17

0 comments on commit 44fe023

Please sign in to comment.