Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(chore): Adding tags and usage for LLM Models #1894

Merged
merged 12 commits into from
Jul 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import javax.swing.Icon

data class ChatModelsResponse(val models: List<ChatModelProvider>) {
data class ChatModelProvider(
val default: Boolean,
val codyProOnly: Boolean,
val provider: String?,
val title: String?,
val model: String,
val deprecated: Boolean = false
val tags: MutableList<String>? = mutableListOf(),
val usage: MutableList<String>? = mutableListOf(),
@Deprecated("No longer provided by agent") val default: Boolean = false,
@Deprecated("No longer provided by agent") val codyProOnly: Boolean = false,
@Deprecated("No longer provided by agent") val deprecated: Boolean = false
) {
fun getIcon(): Icon? =
when (provider) {
Expand All @@ -34,5 +36,9 @@ data class ChatModelsResponse(val models: List<ChatModelProvider>) {
provider?.let { append(" by $provider") }
}
}

public fun isCodyProOnly(): Boolean = tags?.contains("pro") ?: codyProOnly

public fun isDeprecated(): Boolean = tags?.contains("deprecated") ?: deprecated
}
}
6 changes: 3 additions & 3 deletions src/main/kotlin/com/sourcegraph/cody/chat/AgentChatSession.kt
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,11 @@ private constructor(
val chatModelProvider =
state.llm?.let {
ChatModelsResponse.ChatModelProvider(
default = it.model == null,
codyProOnly = false,
provider = it.provider,
title = it.title,
model = it.model ?: "")
model = it.model ?: "",
usage = it.usage.toMutableList(),
tags = it.tags.toMutableList())
}

val connectionId = createNewPanel(project) { it.server.chatNew() }
Expand Down
10 changes: 5 additions & 5 deletions src/main/kotlin/com/sourcegraph/cody/chat/ui/LlmDropdown.kt
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class LlmDropdown(
private fun updateModelsInUI(models: List<ChatModelsResponse.ChatModelProvider>) {
if (project.isDisposed) return

val availableModels = models.filterNot { it.deprecated }
availableModels.sortedBy { it.codyProOnly }.forEach(::addItem)
val availableModels = models.filterNot { it.isDeprecated() }
availableModels.sortedBy { it.isCodyProOnly() }.forEach(::addItem)

val selectedFromState = chatModelProviderFromState
val selectedFromHistory = HistoryService.getInstance(project).getDefaultLlm()
Expand All @@ -59,8 +59,8 @@ class LlmDropdown(
?: availableModels.find { it.model == selectedFromHistory?.model }

selectedItem =
if (selectedModel?.codyProOnly == true && isCurrentUserFree())
availableModels.find { it.default }
if (selectedModel?.isCodyProOnly() == true && isCurrentUserFree())
availableModels.getOrNull(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, if we do not have default anymore isn't there any tag which would replace it?
Like recommended LLM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just replaced it with the fallback convention of the head of the list. The model structs are immutable having a mutable field tracking the user's selection was a bit tricky and also easy to get wrong (multiple default's was possible). And we always fell back to just grabbing the head of the list if we couldn't find one.

I could see adding a default tag, but I would think to only use it for what the server originally sent as the default, not as the user's current selection.

else selectedModel

val isEnterpriseAccount =
Copy link
Contributor

@pkukielka pkukielka Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesmcnamara

  1. We need to remove this line and just set isVisible = true always.
  2. Also, I noticed that there is now enterprise tag present, shouldn't we filter out LLMs displayed for enterprise users based on if that tag is present? Or maybe for an enterprise account we are guaranteed to get a correct list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed first change to your branch already.

I'm not sure if we need the second.

Copy link
Contributor Author

@jamesmcnamara jamesmcnamara Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this because we don't want it to be visible for enterprise users that don't have multiple models enabled. However I'm still struggling with the integration tests failing and I'm not sure how to debug them. Do you know any good tricks for figuring out why they keep timing out? I tried opening the IDE to the test file and trying the document code action but it seemed to work fine.

To your second point, I think the current idea is that an enterprise user can use any model that the server provided (but most likely they will all be enterprise models).

Also, unrelated, but do you know how to access the generated kotlin code from the Cody repo? I don't see it referenced anywhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Problem with integration tests was that updated recordings were missing (requires run of the ./gradlew :recordingIntegrationTest). I updated them.
Interestingly it seems we were not reporting that fact properly, I will make sure we fixed that.

Also, unrelated, but do you know how to access the generated kotlin code from the Cody repo? I don't see it referenced anywhere.

Right now they are just working as a manual reference, BUT it will change very soon as there is @RXminuS PR in flight which adress that.

Expand All @@ -84,7 +84,7 @@ class LlmDropdown(
if (project.isDisposed) return
val modelProvider = anObject as? ChatModelsResponse.ChatModelProvider
if (modelProvider != null) {
if (modelProvider.codyProOnly && isCurrentUserFree()) {
if (modelProvider.isCodyProOnly() && isCurrentUserFree()) {
BrowserOpener.openInBrowser(project, "https://sourcegraph.com/cody/subscription")
return
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.sourcegraph.cody.config.migration

import com.intellij.openapi.project.Project
import com.sourcegraph.cody.agent.CodyAgentService
import com.sourcegraph.cody.agent.protocol.ChatModelsParams
import com.sourcegraph.cody.agent.protocol.ChatModelsResponse
import com.sourcegraph.cody.agent.protocol.ModelUsage
import com.sourcegraph.cody.history.HistoryService
import com.sourcegraph.cody.history.state.AccountData
import com.sourcegraph.cody.history.state.LLMState
import java.util.concurrent.TimeUnit

object ChatTagsLlmMigration {

fun migrate(project: Project) {
CodyAgentService.withAgent(project) { agent ->
val chatModels = agent.server.chatModels(ChatModelsParams(ModelUsage.CHAT.value))
val models =
chatModels.completeOnTimeout(null, 10, TimeUnit.SECONDS).get()?.models ?: return@withAgent
migrateHistory(HistoryService.getInstance(project).state.accountData, models)
}
}

fun migrateHistory(
accountData: List<AccountData>,
models: List<ChatModelsResponse.ChatModelProvider>,
) {
accountData.forEach { accData ->
accData.chats
.mapNotNull { it.llm }
.forEach { llm ->
val model = models.find { it.model == llm.model }
llm.usage = model?.usage ?: mutableListOf("chat", "edit")
llm.tags = model?.tags ?: mutableListOf()

addTagIf(llm, "deprecated", model?.deprecated)
addTagIf(llm, "pro", model?.codyProOnly)
}
}
}

fun addTagIf(llm: LLMState, tag: String, condition: Boolean?) {
if (condition ?: false && !llm.tags.contains(tag)) {
llm.tags.add(tag)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class SettingsMigration : Activity {
}

DeprecatedChatLlmMigration.migrate(project)
ChatTagsLlmMigration.migrate(project)
}

private fun migrateOrphanedChatsToActiveAccount(project: Project) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@ class LLMState : BaseState() {

@get:OptionTag(tag = "provider", nameAttribute = "") var provider: String? by string()

@get:OptionTag(tag = "tags", nameAttribute = "") var tags: MutableList<String> by list()

@get:OptionTag(tag = "usage", nameAttribute = "") var usage: MutableList<String> by list()

Comment on lines +16 to +19
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the implications of this. Where is it getting this state from? Is this going to break immediately?

Copy link
Contributor

@pkukielka pkukielka Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is state saved in a IntelliJ xml settings files.
You will need to add migration which will convert old format to the new one .
For the reference you can look at the DeprecatedChatLlmMigration.

In practice you will most likely have to:

  1. Keep ChatModelProvider::deprecated field but mark it as @Deprecated.
  2. In the migration function (see DeprecatedChatLlmMigration::migrate for reference):
  • read all the old chats
  • for each chat read the content from deprecated field
  • set tags appropriately
  • reset deprecated to be null (as it won't be used anymore, we will remove it in future)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. I created ChatTagsLllmMigration that I believe accomplishes this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! It looks good to me now, I just added one question about the default but otherwise I think it is mergeable.

companion object {
fun fromChatModel(chatModelProvider: ChatModelsResponse.ChatModelProvider): LLMState {
return LLMState().also {
it.model = chatModelProvider.model
it.title = chatModelProvider.title
it.provider = chatModelProvider.provider
it.tags = chatModelProvider.tags ?: mutableListOf()
it.usage = chatModelProvider.usage ?: mutableListOf()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class LlmComboBoxRenderer(private val llmDropdown: LlmDropdown) : DefaultListCel
val displayNameLabel = JLabel(chatModelProvider.displayName())
textBadgePanel.add(displayNameLabel, BorderLayout.CENTER)
textBadgePanel.border = BorderFactory.createEmptyBorder(0, 5, 0, 0)
if (chatModelProvider.codyProOnly && llmDropdown.isCurrentUserFree()) {
if (chatModelProvider.isCodyProOnly() && llmDropdown.isCurrentUserFree()) {
textBadgePanel.add(JLabel(Icons.LLM.ProSticker), BorderLayout.EAST)
}
val isInline = llmDropdown.parentDialog != null
Expand Down
112 changes: 102 additions & 10 deletions src/test/kotlin/com/sourcegraph/cody/config/SettingsMigrationTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.sourcegraph.cody.config
import com.intellij.testFramework.fixtures.BasePlatformTestCase
import com.intellij.testFramework.registerServiceInstance
import com.sourcegraph.cody.agent.protocol.ChatModelsResponse
import com.sourcegraph.cody.config.migration.ChatTagsLlmMigration
import com.sourcegraph.cody.config.migration.DeprecatedChatLlmMigration
import com.sourcegraph.cody.config.migration.SettingsMigration
import com.sourcegraph.cody.history.HistoryService
Expand Down Expand Up @@ -99,7 +100,11 @@ class SettingsMigrationTest : BasePlatformTestCase() {
it.defaultLlm =
LLMState.fromChatModel(
ChatModelsResponse.ChatModelProvider(
true, false, "Cyberdyne", "Terminator", "T-800"))
provider = "Cyberdyne",
title = "Terminator",
model = "T-800",
default = true,
codyProOnly = false))
it.defaultEnhancedContext =
EnhancedContextState().also {
it.isEnabled = true
Expand Down Expand Up @@ -144,7 +149,7 @@ class SettingsMigrationTest : BasePlatformTestCase() {
it.llm =
LLMState.fromChatModel(
ChatModelsResponse.ChatModelProvider(
false, true, "Uni of IL", "HAL", "HAL 9000"))
"Uni of IL", "HAL", "HAL 9000", codyProOnly = true))
it.messages =
mutableListOf(
MessageState().also {
Expand Down Expand Up @@ -190,21 +195,20 @@ class SettingsMigrationTest : BasePlatformTestCase() {
fun `test DeprecatedChatLlmMigration`() {
Copy link
Contributor

@pkukielka pkukielka Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is incorrect I think.
That is migration function, so it specifically migrates data form one old format to another.
It needs to work after your changes.
Then your migration function can pick up state after all existing migrations, and migrate it further (so you will have to add new test for it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. I reverted those changes and added a new test for my migration.

fun createLlmModel(
version: String,
isDefault: Boolean,
isDeprecated: Boolean
isDefault: Boolean = false,
isDeprecated: Boolean = false,
): ChatModelsResponse.ChatModelProvider {
return ChatModelsResponse.ChatModelProvider(
isDefault,
false,
"Anthropic",
"Claude $version",
"anthropic/claude-$version",
isDeprecated)
default = isDefault,
deprecated = isDeprecated)
}

val claude20 = createLlmModel("2.0", isDefault = false, isDeprecated = true)
val claude21 = createLlmModel("2.1", isDefault = false, isDeprecated = true)
val claude30 = createLlmModel("3.0", isDefault = true, isDeprecated = false)
val claude20 = createLlmModel("2.0", isDeprecated = true)
val claude21 = createLlmModel("2.1", isDeprecated = true)
val claude30 = createLlmModel("3.0", isDefault = true)
val models = listOf(claude20, claude21, claude30)

val accountData =
Expand Down Expand Up @@ -251,4 +255,92 @@ class SettingsMigrationTest : BasePlatformTestCase() {
}
}
}

fun `test ChatTagsLlmMigration`() {
fun createLlmModel(
version: String,
isDeprecated: Boolean = false,
isCodyPro: Boolean = false,
usage: List<String> = listOf("chat", "edit"),
tags: List<String> = listOf()
): ChatModelsResponse.ChatModelProvider {
return ChatModelsResponse.ChatModelProvider(
"Anthropic",
"Claude $version",
"anthropic/claude-$version",
usage = usage.toMutableList(),
tags = tags.toMutableList(),
deprecated = isDeprecated,
codyProOnly = isCodyPro)
}

val claude20Old = createLlmModel("2.0", isDeprecated = true)
val claude20New = createLlmModel("2.0", tags = listOf("deprecated", "free"))

// This will be included as an old style model in the agent response to simulate
// an upgrade that runs before the agent upgrades
val claude21Old = createLlmModel("2.1", isDeprecated = true, isCodyPro = true)

val claude30Old = createLlmModel("3.0")
val claude30New = createLlmModel("3.0", tags = listOf("pro", "other"), usage = listOf("edit"))
val models = listOf(claude20New, claude21Old, claude30New)

val accountData =
mutableListOf(
AccountData().also {
it.accountId = "first"
it.chats =
mutableListOf(
ChatState("chat1").also {
it.messages = mutableListOf()
it.llm = LLMState.fromChatModel(claude20Old)
},
ChatState("chat2").also {
it.messages = mutableListOf()
it.llm = LLMState.fromChatModel(claude21Old)
})
},
AccountData().also {
it.accountId = "second"
it.chats =
mutableListOf(
ChatState("chat1").also {
it.messages = mutableListOf()
it.llm = LLMState.fromChatModel(claude20Old)
},
ChatState("chat2").also {
it.messages = mutableListOf()
it.llm = LLMState.fromChatModel(claude30Old)
})
})

fun getTagsAndUsage(chat: ChatState): Pair<List<String>, List<String>> {
val llm = chat.llm ?: return Pair(listOf<String>(), listOf<String>())
return Pair(llm.tags.toList(), llm.usage.toList())
}

ChatTagsLlmMigration.migrateHistory(accountData, models)
assertEquals(2, accountData.size)
accountData.forEach { ad ->
ad.chats.forEach { chat ->
when (chat.llm?.model) {
claude20Old.model -> {
val (tags, usage) = getTagsAndUsage(chat)
assertEquals(listOf("deprecated", "free"), tags)
assertEquals(listOf("chat", "edit"), usage)
}
claude21Old.model -> {
val (tags, usage) = getTagsAndUsage(chat)
assertEquals(listOf("deprecated", "pro"), tags)
assertEquals(listOf("chat", "edit"), usage)
}
claude30Old.model -> {
val (tags, usage) = getTagsAndUsage(chat)
assertEquals(listOf("pro", "other"), tags)
assertEquals(listOf("edit"), usage)
}
}
}
}
}
}
Loading