Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions packages/opencode/src/session/revert.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { splitWhen } from "remeda"
import { Storage } from "../storage/storage"
import { Bus } from "../bus"
import { SessionPrompt } from "./prompt"
import { Todo } from "./todo"

export namespace SessionRevert {
const log = Log.create({ service: "session.revert" })
Expand Down Expand Up @@ -57,6 +58,8 @@ export namespace SessionRevert {
revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
await Snapshot.revert(patches)
if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
const todos = getTodos(all, revert.messageID)
await Todo.update({ sessionID: input.sessionID, todos })
return Session.update(input.sessionID, (draft) => {
draft.revert = revert
})
Expand All @@ -70,6 +73,9 @@ export namespace SessionRevert {
const session = await Session.get(input.sessionID)
if (!session.revert) return session
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
const all = await Session.messages({ sessionID: input.sessionID })
const todos = getTodos(all)
await Todo.update({ sessionID: input.sessionID, todos })
const next = await Session.update(input.sessionID, (draft) => {
draft.revert = undefined
})
Expand Down Expand Up @@ -101,8 +107,20 @@ export namespace SessionRevert {
})
}
}
const todos = getTodos(preserve)
await Todo.update({ sessionID, todos })
await Session.update(sessionID, (draft) => {
draft.revert = undefined
})
}

function getTodos(messages: MessageV2.WithParts[], until?: string): Todo.Info[] {
const idx = until ? messages.findIndex((m) => m.info.id === until) : messages.length
const parts = messages.slice(0, idx === -1 ? undefined : idx).flatMap((m) => m.parts)
const last = parts.findLast(
(p): p is MessageV2.ToolPart & { state: MessageV2.ToolStateCompleted } =>
p.type === "tool" && p.tool === "todowrite" && p.state.status === "completed",
)
return (last?.state.metadata?.todos as Todo.Info[]) ?? []
}
}
Loading