Skip to content

Commit

Permalink
add multi model selector for openai backends
Browse files Browse the repository at this point in the history
  • Loading branch information
juzeon committed Sep 14, 2024
1 parent 7077959 commit c3b3da9
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 5 deletions.
12 changes: 11 additions & 1 deletion app_chatbot.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type AskOptions struct {
Prompt string `json:"prompt"`
ImageURL string `json:"image_url"`
UploadFilePath string `json:"upload_file_path"`
Model string `json:"model"` // for openai models
}

const (
Expand Down Expand Up @@ -266,8 +267,17 @@ func (a *App) askOpenAI(options AskOptions) {
}},
})
}
models := strings.Split(backend.OpenaiShortModel, ",")
if options.Model == "" {
options.Model = models[0]
}
if !slices.Contains(models, options.Model) {
handleErr(errors.New("model " + options.Model + " is not in the model list of the corresponding backend"))
return
}
slog.Info("Ask OpenAI with model: " + options.Model)
stream, err := client.CreateChatCompletionStream(stopCtx, openai.ChatCompletionRequest{
Model: backend.OpenaiShortModel,
Model: options.Model,
Messages: messages,
Temperature: backend.OpenaiTemperature,
FrequencyPenalty: backend.FrequencyPenalty,
Expand Down
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Workspace struct {
PersistentInput bool `json:"persistent_input"`
Plugins []string `json:"plugins"`
DataReferences []DataReference `json:"data_references"`
Model string `json:"model"`
}
type DataReference struct {
UUID string `json:"uuid"`
Expand Down
11 changes: 10 additions & 1 deletion frontend/src/components/index/WorkspaceNav.vue
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@
import Conversation from "./Conversation.vue"
import SearchWorkspaceButton from "./SearchWorkspaceButton.vue"
import {main, sydney} from "../../../wailsjs/go/models"
import {main} from "../../../wailsjs/go/models"
import {generateRandomName, swal} from "../../helper"
import dayjs from "dayjs"
import {computed, ref} from "vue"
import {ExportWorkspace, ShareWorkspace} from "../../../wailsjs/go/main/App"
import Workspace = main.Workspace
import Preset = main.Preset
import DataReference = main.DataReference
import OpenAIBackend = main.OpenAIBackend
let props = defineProps<{
modelValue: boolean,
workspaces: Workspace[],
currentWorkspace: Workspace,
isAsking: boolean,
presets: Preset[],
open_ai_backends: OpenAIBackend[],
}>()
let sortedWorkspaces = computed(() => {
return props.workspaces.sort((a, b) => {
Expand Down Expand Up @@ -85,6 +87,7 @@ function addWorkspace() {
gpt_4_turbo: props.currentWorkspace.gpt_4_turbo,
persistent_input: props.currentWorkspace.persistent_input,
plugins: props.currentWorkspace.plugins,
model: props.currentWorkspace.model
}
props.workspaces.push(workspace)
switchWorkspace(workspace)
Expand All @@ -97,6 +100,12 @@ function switchWorkspace(workspace: Workspace) {
if (!workspace.data_references) {
workspace.data_references = []
}
if (!workspace.model && workspace.backend != 'Sydney') {
let backend = props.open_ai_backends.find(v => v.name === workspace.backend)
if (backend) {
workspace.model = backend.openai_short_model.split(',')[0]
}
}
emit('update:currentWorkspace', workspace)
emit('update:suggestedResponses', [])
emit('scrollChatContextToBottom')
Expand Down
18 changes: 16 additions & 2 deletions frontend/src/components/settings/OpenAIBackendCard.vue
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function deleteOpenaiBackend(backend: OpenAIBackend) {
return
}
activeOpenaiBackendName.value = props.open_ai_backends[0].name
emit('update:open_ai_backends',props.open_ai_backends.filter(v => v !== backend))
emit('update:open_ai_backends', props.open_ai_backends.filter(v => v !== backend))
}
function checkOpenaiEndpoint(val: string) {
Expand All @@ -74,6 +74,18 @@ function onChangeOpenAIMaxTokens(val: string) {
}
backend.max_tokens = i
}
function onChangeOpenAIShortModel(val: string) {
let backend = props.open_ai_backends.find(v => v.name === activeOpenaiBackendName.value)
if (!backend) {
return
}
let v = val.replaceAll(' ', '')
if (!v) {
v = 'gpt-4o-mini'
}
backend.openai_short_model = v
}
</script>

<template>
Expand Down Expand Up @@ -115,7 +127,9 @@ function onChangeOpenAIMaxTokens(val: string) {
<v-text-field label="Endpoint" :rules="[checkOpenaiEndpoint]" v-model="backend.openai_endpoint"
color="primary"></v-text-field>
<v-text-field label="Key" v-model="backend.openai_key" color="primary"></v-text-field>
<v-text-field label="Model" v-model="backend.openai_short_model" color="primary"></v-text-field>
<v-text-field label="Model List" @update:model-value="onChangeOpenAIShortModel"
:model-value="backend.openai_short_model" color="primary"
hint="Separated by comma(,)"></v-text-field>
<v-slider label="Temperature" v-model="backend.openai_temperature" min="0" max="2"
step="0.1" color="primary" thumb-label="always"></v-slider>
<v-slider label="Frequency Penalty" v-model="backend.frequency_penalty"
Expand Down
15 changes: 14 additions & 1 deletion frontend/src/pages/IndexPage.vue
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ let currentWorkspace = ref(<Workspace>{
use_classic: false,
gpt_4_turbo: false,
persistent_input: false,
model: '',
})
let chatContextTokenCount = ref(0)
Expand All @@ -68,6 +69,12 @@ watch(currentWorkspace, async () => {
chatContextTokenCount.value = await CountToken(currentWorkspace.value.context)
userInputTokenCount.value = await CountToken(currentWorkspace.value.input)
config.value.current_workspace_id = currentWorkspace.value.id
if (!currentWorkspace.value.model && currentWorkspace.value.backend != 'Sydney') {
let backend = config.value.open_ai_backends.find(v => v.name === currentWorkspace.value.backend)
if (backend) {
currentWorkspace.value.model = backend.openai_short_model.split(',')[0]
}
}
}, {deep: true})
let statusTokenCountText = computed(() => {
return 'Chat Context: ' + chatContextTokenCount.value + ' tokens; User Input: ' + userInputTokenCount.value + ' tokens'
Expand Down Expand Up @@ -272,6 +279,7 @@ async function startAsking(args: StartAskingArgs = {}) {
askOptions.openai_backend = currentWorkspace.value.backend
askOptions.image_url = uploadedImage.value?.bing_url ?? ''
askOptions.upload_file_path = selectedUploadFile.value ?? ''
askOptions.model = currentWorkspace.value.model ?? ''
await AskAI(askOptions)
}
Expand Down Expand Up @@ -491,6 +499,7 @@ function generateTitle() {
<workspace-nav v-if="!loading" :is-asking="isAsking" v-model="navDrawer"
v-model:current-workspace="currentWorkspace"
v-model:workspaces="config.workspaces" :presets="config.presets" @on-reset="onReset"
:open_ai_backends="config.open_ai_backends"
@update:suggested-responses="arr => suggestedResponses=arr"
@scroll-chat-context-to-bottom="scrollChatContextToBottom"></workspace-nav>
<div class="d-flex flex-column fill-height" v-if="!loading">
Expand All @@ -501,10 +510,14 @@ function generateTitle() {
<v-select v-model="currentWorkspace.backend" :items="backendList" color="primary" label="Backend"
density="compact"
class="mx-2"></v-select>
<v-select v-model="currentWorkspace.conversation_style" :disabled="currentWorkspace.backend!=='Sydney'"
<v-select v-model="currentWorkspace.conversation_style" v-if="currentWorkspace.backend==='Sydney'"
:items="modeList" color="primary" label="Mode"
density="compact"
class="mx-2"></v-select>
<v-select v-model="currentWorkspace.model" v-else
:items="config.open_ai_backends.find(v=>v.name===currentWorkspace.backend)?.openai_short_model.split(',') ?? []"
color="primary" label="Model" density="compact"
class="mx-2"></v-select>
<v-select :model-value="currentWorkspace.preset" @update:model-value="onPresetChange"
:items="config.presets.map(v=>v.name)" color="primary"
label="Preset"
Expand Down
4 changes: 4 additions & 0 deletions frontend/wailsjs/go/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export namespace main {
prompt: string;
image_url: string;
upload_file_path: string;
model: string;

static createFrom(source: any = {}) {
return new AskOptions(source);
Expand All @@ -20,6 +21,7 @@ export namespace main {
this.prompt = source["prompt"];
this.image_url = source["image_url"];
this.upload_file_path = source["upload_file_path"];
this.model = source["model"];
}
}
export class ChatFinishResult {
Expand Down Expand Up @@ -155,6 +157,7 @@ export namespace main {
persistent_input: boolean;
plugins: string[];
data_references: DataReference[];
model: string;

static createFrom(source: any = {}) {
return new Workspace(source);
Expand All @@ -177,6 +180,7 @@ export namespace main {
this.persistent_input = source["persistent_input"];
this.plugins = source["plugins"];
this.data_references = this.convertValues(source["data_references"], DataReference);
this.model = source["model"];
}

convertValues(a: any, classs: any, asMap: boolean = false): any {
Expand Down

0 comments on commit c3b3da9

Please sign in to comment.