Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/opencode/src/session/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { SessionPrompt } from "./prompt"
import { fn } from "@/util/fn"
import { Command } from "../command"
import { Snapshot } from "@/snapshot"
import { Todo } from "./todo"

export namespace Session {
const log = Log.create({ service: "session" })
Expand Down Expand Up @@ -66,6 +67,7 @@ export namespace Session {
partID: z.string().optional(),
snapshot: z.string().optional(),
diff: z.string().optional(),
todos: Todo.Info.array().optional(),
})
.optional(),
})
Expand Down
44 changes: 34 additions & 10 deletions packages/opencode/src/session/revert.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@ import { Snapshot } from "../snapshot"
import { MessageV2 } from "./message-v2"
import { Session } from "."
import { Log } from "../util/log"
import { splitWhen } from "remeda"

import { Storage } from "../storage/storage"
import { Bus } from "../bus"
import { SessionLock } from "./lock"
import { Todo } from "./todo"
import { splitWhen } from "remeda"

export namespace SessionRevert {
const log = Log.create({ service: "session.revert" })

function extractTodos(part: MessageV2.Part): Todo.Info[] | undefined {
if (part.type !== "tool") return undefined
if (part.tool !== "todowrite") return undefined
if (part.state.status !== "completed") return undefined
const metadata = part.state.metadata as { todos?: Todo.Info[] } | undefined
return metadata?.todos
}

export const RevertInput = z.object({
sessionID: Identifier.schema("session"),
messageID: Identifier.schema("message"),
Expand All @@ -30,6 +40,7 @@ export namespace SessionRevert {
const session = await Session.get(input.sessionID)

let revert: Session.Info["revert"]
let todosBefore: Todo.Info[] | undefined
const patches: Snapshot.Patch[] = []
for (const msg of all) {
if (msg.info.role === "user") lastUser = msg.info
Expand All @@ -42,23 +53,33 @@ export namespace SessionRevert {
continue
}

if (!revert) {
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
// if no useful parts left in message, same as reverting whole message
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
revert = {
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
partID,
}
const matchesMessage = msg.info.id === input.messageID && !input.partID
const matchesPart = part.id === input.partID
const isTarget = matchesMessage || matchesPart

if (isTarget) {
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
revert = {
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
partID,
}
}
if (!isTarget) {
const todos = extractTodos(part)
if (todos) {
todosBefore = todos
}
remaining.push(part)
}

remaining.push(part)
}
}

if (revert) {
const session = await Session.get(input.sessionID)
revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
revert.todos = todosBefore ?? []
await Todo.update({ sessionID: input.sessionID, todos: revert.todos })
await Snapshot.revert(patches)
if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
return Session.update(input.sessionID, (draft) => {
Expand Down Expand Up @@ -108,6 +129,9 @@ export namespace SessionRevert {
})
}
}
const todos = session.revert.todos ?? []
await Todo.update({ sessionID, todos })

await Session.update(sessionID, (draft) => {
draft.revert = undefined
})
Expand Down
Loading