diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index 35c7b9a607e..b5efff833b6 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -8,6 +8,7 @@ import { splitWhen } from "remeda" import { Storage } from "../storage/storage" import { Bus } from "../bus" import { SessionPrompt } from "./prompt" +import { Todo } from "./todo" export namespace SessionRevert { const log = Log.create({ service: "session.revert" }) @@ -57,6 +58,8 @@ export namespace SessionRevert { revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track()) await Snapshot.revert(patches) if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot) + const todos = getTodos(all, revert.messageID) + await Todo.update({ sessionID: input.sessionID, todos }) return Session.update(input.sessionID, (draft) => { draft.revert = revert }) @@ -70,6 +73,9 @@ export namespace SessionRevert { const session = await Session.get(input.sessionID) if (!session.revert) return session if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot) + const all = await Session.messages({ sessionID: input.sessionID }) + const todos = getTodos(all) + await Todo.update({ sessionID: input.sessionID, todos }) const next = await Session.update(input.sessionID, (draft) => { draft.revert = undefined }) @@ -101,8 +107,20 @@ export namespace SessionRevert { }) } } + const todos = getTodos(preserve) + await Todo.update({ sessionID, todos }) await Session.update(sessionID, (draft) => { draft.revert = undefined }) } + + function getTodos(messages: MessageV2.WithParts[], until?: string): Todo.Info[] { + const idx = until ? messages.findIndex((m) => m.info.id === until) : messages.length + const parts = messages.slice(0, idx === -1 ? undefined : idx).flatMap((m) => m.parts) + const last = parts.findLast( + (p): p is MessageV2.ToolPart & { state: MessageV2.ToolStateCompleted } => + p.type === "tool" && p.tool === "todowrite" && p.state.status === "completed", + ) + return (last?.state.metadata?.todos as Todo.Info[]) ?? [] + } }