diff --git a/src/main/kotlin/com/sourcegraph/cody/agent/CodyAgentServer.kt b/src/main/kotlin/com/sourcegraph/cody/agent/CodyAgentServer.kt index a42052ce62..c86a03db45 100644 --- a/src/main/kotlin/com/sourcegraph/cody/agent/CodyAgentServer.kt +++ b/src/main/kotlin/com/sourcegraph/cody/agent/CodyAgentServer.kt @@ -103,6 +103,9 @@ interface CodyAgentServer { @JsonRequest("chat/models") fun chatModels(params: ChatModelsParams): CompletableFuture + @JsonRequest("chat/setModel") + fun chatSetModel(params: ChatSetModelParams): CompletableFuture + @JsonRequest("chat/restore") fun chatRestore(params: ChatRestoreParams): CompletableFuture @JsonRequest("attribution/search") diff --git a/src/main/kotlin/com/sourcegraph/cody/agent/protocol/ChatSetModelParams.kt b/src/main/kotlin/com/sourcegraph/cody/agent/protocol/ChatSetModelParams.kt new file mode 100644 index 0000000000..48a4e29008 --- /dev/null +++ b/src/main/kotlin/com/sourcegraph/cody/agent/protocol/ChatSetModelParams.kt @@ -0,0 +1,3 @@ +package com.sourcegraph.cody.agent.protocol + +data class ChatSetModelParams(val id: String, val model: String) diff --git a/src/main/kotlin/com/sourcegraph/cody/chat/AgentChatSession.kt b/src/main/kotlin/com/sourcegraph/cody/chat/AgentChatSession.kt index e7eabeaed5..d7f16d7452 100644 --- a/src/main/kotlin/com/sourcegraph/cody/chat/AgentChatSession.kt +++ b/src/main/kotlin/com/sourcegraph/cody/chat/AgentChatSession.kt @@ -28,6 +28,7 @@ private constructor( private val project: Project, newSessionId: CompletableFuture, private val internalId: String = UUID.randomUUID().toString(), + private var selectedModel: String = "" ) : ChatSession { /** @@ -43,11 +44,13 @@ private constructor( init { cancellationToken.get().dispose() + sessionId.get().thenAccept { sessionId -> + chatPanel.addModelDropdown(project, sessionId, selectedModel) + } } fun restoreAgentSession(agent: CodyAgent) { // todo serialize model - val model = "anthropic/claude-2.0" val messagesToReload = messages .toList() @@ -55,7 +58,8 @@ private constructor( .fold(emptyList()) { acc, msg -> if (acc.lastOrNull()?.speaker == msg.speaker) acc else acc.plus(msg) } - val restoreParams = ChatRestoreParams(model, messagesToReload, UUID.randomUUID().toString()) + val restoreParams = + ChatRestoreParams(selectedModel, messagesToReload, UUID.randomUUID().toString()) val newSessionId = agent.server.chatRestore(restoreParams) sessionId.getAndSet(newSessionId) } @@ -177,9 +181,21 @@ private constructor( if (messages.lastOrNull()?.id == message.id) { messages.removeLast() } - messages.add(message) - chatPanel.addOrUpdateMessage(message) - HistoryService.getInstance(project).updateChatMessages(internalId, messages) + if (messages.size == 0) { + selectedModel = chatPanel.modelDropdown.selectedItem?.toString() ?: "" + sessionId.get().thenAccept { sessionId -> + CodyAgentService.applyAgentOnBackgroundThread(project) { agent -> + agent.server.chatSetModel(ChatSetModelParams(sessionId, selectedModel)).get() + messages.add(message) + chatPanel.addOrUpdateMessage(message) + HistoryService.getInstance(project).updateChatMessages(internalId, messages, selectedModel) + } + } + } else { + messages.add(message) + chatPanel.addOrUpdateMessage(message) + HistoryService.getInstance(project).updateChatMessages(internalId, messages) + } } @RequiresEdt @@ -193,6 +209,7 @@ private constructor( companion object { private val logger = LoggerFactory.getLogger(AgentChatSession::class.java) + private const val defaultModel = "anthropic/claude-2.0" @RequiresEdt fun createNew(project: Project): AgentChatSession { @@ -233,7 +250,7 @@ private constructor( @RequiresEdt fun createFromState(project: Project, state: ChatState): AgentChatSession { val sessionId = createNewPanel(project) { it.server.chatNew() } - val chatSession = AgentChatSession(project, sessionId, state.internalId!!) + val chatSession = AgentChatSession(project, sessionId, state.internalId!!, state.model ?: defaultModel) for (message in state.messages) { val parsed = when (val speaker = message.speaker) { diff --git a/src/main/kotlin/com/sourcegraph/cody/chat/ui/ChatPanel.kt b/src/main/kotlin/com/sourcegraph/cody/chat/ui/ChatPanel.kt index 8edb7e633c..0c117dcc87 100644 --- a/src/main/kotlin/com/sourcegraph/cody/chat/ui/ChatPanel.kt +++ b/src/main/kotlin/com/sourcegraph/cody/chat/ui/ChatPanel.kt @@ -2,11 +2,14 @@ package com.sourcegraph.cody.chat.ui import com.intellij.icons.AllIcons import com.intellij.openapi.project.Project +import com.intellij.openapi.ui.ComboBox import com.intellij.openapi.ui.VerticalFlowLayout import com.intellij.util.IconUtil import com.intellij.util.concurrency.annotations.RequiresEdt import com.sourcegraph.cody.PromptPanel +import com.sourcegraph.cody.agent.CodyAgentService import com.sourcegraph.cody.agent.protocol.ChatMessage +import com.sourcegraph.cody.agent.protocol.ChatModelsParams import com.sourcegraph.cody.chat.ChatSession import com.sourcegraph.cody.context.ui.EnhancedContextPanel import com.sourcegraph.cody.ui.ChatScrollPane @@ -22,6 +25,7 @@ class ChatPanel(project: Project, chatSession: ChatSession) : JPanel(VerticalFlowLayout(VerticalFlowLayout.CENTER, 0, 0, true, false)) { val promptPanel: PromptPanel = PromptPanel(project, chatSession) + val modelDropdown = ComboBox() private val messagesPanel = MessagesPanel(project, chatSession) private val chatPanel = ChatScrollPane(messagesPanel) @@ -49,10 +53,31 @@ class ChatPanel(project: Project, chatSession: ChatSession) : add(lowerPanel, BorderLayout.SOUTH) } + fun addModelDropdown(project: Project, sessionId: String, selectedModel: String) { + CodyAgentService.applyAgentOnBackgroundThread(project) { agent -> + agent.server.isCurrentUserPro().thenApply { isUserPro -> + if (isUserPro == true) { + add(modelDropdown, BorderLayout.NORTH) + if (selectedModel == "") { + val chatModels = agent.server.chatModels(ChatModelsParams(sessionId)) + chatModels.thenApply { response -> + response.models.forEach { model -> modelDropdown.addItem(model.model) } + } + } else { + modelDropdown.addItem(selectedModel) + } + } + } + } + } + fun isEnhancedContextEnabled(): Boolean = contextView.isEnhancedContextEnabled.get() @RequiresEdt fun addOrUpdateMessage(message: ChatMessage, shouldAddBlinkingCursor: Boolean = true) { + if (messagesPanel.componentCount == 1) { + modelDropdown.isEnabled = false + } messagesPanel.addOrUpdateMessage(message, shouldAddBlinkingCursor) } diff --git a/src/main/kotlin/com/sourcegraph/cody/history/HistoryService.kt b/src/main/kotlin/com/sourcegraph/cody/history/HistoryService.kt index 95d2ea41e9..1f653f5ade 100644 --- a/src/main/kotlin/com/sourcegraph/cody/history/HistoryService.kt +++ b/src/main/kotlin/com/sourcegraph/cody/history/HistoryService.kt @@ -23,6 +23,13 @@ class HistoryService(private val project: Project) : synchronized(listeners) { listeners += listener } } + @Synchronized + fun updateChatMessages(internalId: String, chatMessages: List, selectedModel: String) { + val found = getOrCreateChat(internalId) + found.model = selectedModel + updateChatMessages(internalId, chatMessages) + } + @Synchronized fun updateChatMessages(internalId: String, chatMessages: List) { val found = getOrCreateChat(internalId) diff --git a/src/main/kotlin/com/sourcegraph/cody/history/state/ChatState.kt b/src/main/kotlin/com/sourcegraph/cody/history/state/ChatState.kt index 376d096469..7c04d21d46 100644 --- a/src/main/kotlin/com/sourcegraph/cody/history/state/ChatState.kt +++ b/src/main/kotlin/com/sourcegraph/cody/history/state/ChatState.kt @@ -16,6 +16,8 @@ class ChatState : BaseState() { @get:OptionTag(tag = "updatedAt", nameAttribute = "") var updatedAt: String? by string() + @get:OptionTag(tag = "model", nameAttribute = "") var model: String? by string() + @get:OptionTag(tag = "accountId", nameAttribute = "") var accountId: String? by string() @get:OptionTag(tag = "enhancedContext", nameAttribute = "")