diff --git a/packages/opencode/src/cli/cmd/tui/component/dialog-agent.tsx b/packages/opencode/src/cli/cmd/tui/component/dialog-agent.tsx index 365a22445b4..1675fa6962b 100644 --- a/packages/opencode/src/cli/cmd/tui/component/dialog-agent.tsx +++ b/packages/opencode/src/cli/cmd/tui/component/dialog-agent.tsx @@ -20,7 +20,7 @@ export function DialogAgent() { return ( { local.agent.set(option.value) diff --git a/packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx b/packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx index 4558914cb7e..2c1c8c5d91c 100644 --- a/packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx +++ b/packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx @@ -530,11 +530,20 @@ export function Prompt(props: PromptProps) { // Capture mode before it gets reset const currentMode = store.mode const variant = local.model.variant.current() + const agent = local.agent.current()?.name + if (!agent) { + toast.show({ + variant: "warning", + message: "No agent is available", + duration: 3000, + }) + return + } if (store.mode === "shell") { sdk.client.session.shell({ sessionID, - agent: local.agent.current().name, + agent, model: { providerID: selectedModel.providerID, modelID: selectedModel.modelID, @@ -555,7 +564,7 @@ export function Prompt(props: PromptProps) { sessionID, command: command.slice(1), arguments: args.join(" "), - agent: local.agent.current().name, + agent, model: `${selectedModel.providerID}/${selectedModel.modelID}`, messageID, variant, @@ -571,7 +580,7 @@ export function Prompt(props: PromptProps) { sessionID, ...selectedModel, messageID, - agent: local.agent.current().name, + agent, model: selectedModel, variant, parts: [ @@ -688,10 +697,12 @@ export function Prompt(props: PromptProps) { return } + const currentAgentName = createMemo(() => local.agent.current()?.name ?? "") + const highlight = createMemo(() => { if (keybind.leader) return theme.border if (store.mode === "shell") return theme.primary - return local.agent.color(local.agent.current().name) + return local.agent.color(currentAgentName()) }) const showVariant = createMemo(() => { @@ -702,7 +713,7 @@ export function Prompt(props: PromptProps) { }) const spinnerDef = createMemo(() => { - const color = local.agent.color(local.agent.current().name) + const color = local.agent.color(currentAgentName()) return { frames: createFrames({ color, @@ -933,7 +944,11 @@ export function Prompt(props: PromptProps) { /> - {store.mode === "shell" ? "Shell" : Locale.titlecase(local.agent.current().name)}{" "} + {store.mode === "shell" + ? "Shell" + : currentAgentName() + ? Locale.titlecase(currentAgentName()) + : "Agent"}{" "} diff --git a/packages/opencode/src/cli/cmd/tui/context/local.tsx b/packages/opencode/src/cli/cmd/tui/context/local.tsx index 63f1d9743bf..d1c8b6fba67 100644 --- a/packages/opencode/src/cli/cmd/tui/context/local.tsx +++ b/packages/opencode/src/cli/cmd/tui/context/local.tsx @@ -35,11 +35,21 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ const agent = iife(() => { const agents = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent" && !x.hidden)) - const [agentStore, setAgentStore] = createStore<{ - current: string - }>({ - current: agents()[0].name, + const [agentStore, setAgentStore] = createStore<{ current: string | undefined }>({ + current: undefined, + }) + + createEffect(() => { + const list = agents() + if (list.length === 0) { + if (agentStore.current !== undefined) setAgentStore("current", undefined) + return + } + if (!agentStore.current || !list.some((x) => x.name === agentStore.current)) { + setAgentStore("current", list[0].name) + } }) + const { theme } = useTheme() const colors = createMemo(() => [ theme.secondary, @@ -54,7 +64,10 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ return agents() }, current() { - return agents().find((x) => x.name === agentStore.current)! + const list = agents() + if (list.length === 0) return undefined + if (!agentStore.current) return list[0] + return list.find((x) => x.name === agentStore.current) ?? list[0] }, set(name: string) { if (!agents().some((x) => x.name === name)) @@ -66,11 +79,15 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ setAgentStore("current", name) }, move(direction: 1 | -1) { + const list = agents() + if (list.length === 0) return batch(() => { - let next = agents().findIndex((x) => x.name === agentStore.current) + direction - if (next < 0) next = agents().length - 1 - if (next >= agents().length) next = 0 - const value = agents()[next] + const current = agentStore.current + const index = current ? list.findIndex((x) => x.name === current) : -1 + let next = (index === -1 ? 0 : index) + direction + if (next < 0) next = list.length - 1 + if (next >= list.length) next = 0 + const value = list[next] setAgentStore("current", value.name) }) }, @@ -181,8 +198,8 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ const a = agent.current() return ( getFirstValidModel( - () => modelStore.model[a.name], - () => a.model, + () => (a ? modelStore.model[a.name] : undefined), + () => a?.model, fallbackModel, ) ?? undefined ) @@ -227,7 +244,9 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ if (next >= recent.length) next = 0 const val = recent[next] if (!val) return - setModelStore("model", agent.current().name, { ...val }) + const a = agent.current() + if (!a) return + setModelStore("model", a.name, { ...val }) }, cycleFavorite(direction: 1 | -1) { const favorites = modelStore.favorite.filter((item) => isModelValid(item)) @@ -253,7 +272,10 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ } const next = favorites[index] if (!next) return - setModelStore("model", agent.current().name, { ...next }) + const a = agent.current() + if (!a) return + setModelStore("model", a.name, { ...next }) + const uniq = uniqueBy([next, ...modelStore.recent], (x) => `${x.providerID}/${x.modelID}`) if (uniq.length > 10) uniq.pop() setModelStore( @@ -272,7 +294,10 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ }) return } - setModelStore("model", agent.current().name, model) + const a = agent.current() + if (!a) return + setModelStore("model", a.name, model) + if (options?.recent) { const uniq = uniqueBy([model, ...modelStore.recent], (x) => `${x.providerID}/${x.modelID}`) if (uniq.length > 10) uniq.pop() @@ -368,6 +393,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ // Automatically update model when agent changes createEffect(() => { const value = agent.current() + if (!value) return if (value.model) { if (isModelValid(value.model)) model.set({ diff --git a/packages/opencode/src/provider/auth.ts b/packages/opencode/src/provider/auth.ts index e6681ff0891..20509301a41 100644 --- a/packages/opencode/src/provider/auth.ts +++ b/packages/opencode/src/provider/auth.ts @@ -30,7 +30,7 @@ export namespace ProviderAuth { export async function methods() { const s = await state().then((x) => x.methods) - return mapValues(s, (x) => + return mapValues(s, (x: NonNullable) => x.methods.map( (y): Method => ({ type: y.type, @@ -78,42 +78,48 @@ export namespace ProviderAuth { code: z.string().optional(), }), async (input) => { + const clearPending = () => state().then((s) => delete s.pending[input.providerID]) const match = await state().then((s) => s.pending[input.providerID]) if (!match) throw new OauthMissing({ providerID: input.providerID }) - let result + if (match.method === "code" && !input.code) throw new OauthCodeMissing({ providerID: input.providerID }) - if (match.method === "code") { - if (!input.code) throw new OauthCodeMissing({ providerID: input.providerID }) - result = await match.callback(input.code) - } + return (async () => { + const result = await (match.method === "code" ? match.callback(input.code!) : match.callback()) - if (match.method === "auto") { - result = await match.callback() - } + if (!result || result.type !== "success") { + throw new OauthCallbackFailed({}) + } + + const providerID = + "provider" in result && typeof result.provider === "string" && result.provider + ? result.provider + : input.providerID - if (result?.type === "success") { if ("key" in result) { - await Auth.set(input.providerID, { + await Auth.set(providerID, { type: "api", key: result.key, }) } + if ("refresh" in result) { + const accountId = + "accountId" in result && typeof result.accountId === "string" ? result.accountId : undefined + const enterpriseUrl = + "enterpriseUrl" in result && typeof result.enterpriseUrl === "string" ? result.enterpriseUrl : undefined + const info: Auth.Info = { type: "oauth", access: result.access, refresh: result.refresh, expires: result.expires, + ...(accountId ? { accountId } : {}), + ...(enterpriseUrl ? { enterpriseUrl } : {}), } - if (result.accountId) { - info.accountId = result.accountId - } - await Auth.set(input.providerID, info) - } - return - } - throw new OauthCallbackFailed({}) + await Auth.set(providerID, info) + } + })().finally(clearPending) }, ) diff --git a/packages/opencode/test/provider/auth-extra-fields.test.ts b/packages/opencode/test/provider/auth-extra-fields.test.ts new file mode 100644 index 00000000000..de790b05a80 --- /dev/null +++ b/packages/opencode/test/provider/auth-extra-fields.test.ts @@ -0,0 +1,66 @@ +import { expect, mock, test } from "bun:test" + +mock.module("../../src/plugin", () => ({ + Plugin: { + list: async () => [ + { + auth: { + provider: "openai-test", + methods: [ + { + label: "Mock OAuth", + type: "oauth", + authorize: async () => { + return { + url: "https://example.com/oauth", + method: "auto", + instructions: "Complete auth in your browser", + callback: async () => { + return { + type: "success", + refresh: "refresh-token", + access: "access-token", + expires: 123, + accountId: "acct_123", + enterpriseUrl: "https://ghe.example.com", + } + }, + } + }, + }, + ], + }, + }, + ], + }, +})) + +const { tmpdir } = await import("../fixture/fixture") +const { Instance } = await import("../../src/project/instance") +const { ProviderAuth } = await import("../../src/provider/auth") +const { Auth } = await import("../../src/auth") + +test("ProviderAuth oauth callback persists accountId and enterpriseUrl", async () => { + await using tmp = await tmpdir() + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const auth = await ProviderAuth.authorize({ + providerID: "openai-test", + method: 0, + }) + expect(auth).toBeDefined() + + await ProviderAuth.callback({ + providerID: "openai-test", + method: 0, + }) + + const saved = await Auth.get("openai-test") + expect(saved?.type).toBe("oauth") + expect(saved && saved.type === "oauth" ? saved.accountId : undefined).toBe("acct_123") + expect(saved && saved.type === "oauth" ? saved.enterpriseUrl : undefined).toBe("https://ghe.example.com") + }, + }) +}) diff --git a/packages/plugin/src/index.ts b/packages/plugin/src/index.ts index e57eff579e6..712193bd8e4 100644 --- a/packages/plugin/src/index.ts +++ b/packages/plugin/src/index.ts @@ -115,6 +115,7 @@ export type AuthOuathResult = { url: string; instructions: string } & ( access: string expires: number accountId?: string + enterpriseUrl?: string } | { key: string } )) @@ -135,6 +136,7 @@ export type AuthOuathResult = { url: string; instructions: string } & ( access: string expires: number accountId?: string + enterpriseUrl?: string } | { key: string } ))