diff --git a/.gitignore b/.gitignore index 753a64b2..8cb40b8c 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,11 @@ Thumbs.db # OpenCode .opencode/ +# Python cache +__pycache__/ +*.py[cod] +*$py.class + # Generated prompt files (from scripts/generate-prompts.ts) lib/prompts/**/*.generated.ts @@ -40,4 +45,4 @@ test-update.ts docs/ SCHEMA_NOTES.md -repomix-output.xml \ No newline at end of file +repomix-output.xml diff --git a/README.md b/README.md index 5b9bb57b..52940ec5 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ For model-facing behavior (prompts and tool calls), this capability is always ad ### Tool -**Compress** — Exposes a single `compress` tool with one method: match a conversation range using `startString` and `endString`, then replace it with a technical summary. +**Compress** — Exposes a single `compress` tool with one method: select a conversation range using injected `startId` and `endId` (`mNNNN` or `bN`), then replace it with a technical summary. The model can use that same method at different scales: tiny ranges for noise cleanup, focused ranges for preserving key findings, and full chapters for completed work. diff --git a/lib/hooks.ts b/lib/hooks.ts index 9bb91106..02169867 100644 --- a/lib/hooks.ts +++ b/lib/hooks.ts @@ -1,9 +1,10 @@ import type { SessionState, WithParts } from "./state" import type { Logger } from "./logger" import type { PluginConfig } from "./config" +import { assignMessageRefs } from "./message-ids" import { syncToolCache } from "./state/tool-cache" import { deduplicate, supersedeWrites, purgeErrors } from "./strategies" -import { prune, insertCompressToolContext } from "./messages" +import { prune, insertCompressToolContext, insertMessageIdContext } from "./messages" import { buildToolIdList, isIgnoredUserMessage } from "./messages/utils" import { checkSession } from "./state" import { renderSystemPrompt } from "./prompts" @@ -104,6 +105,8 @@ export function createChatMessageTransformHandler( cacheSystemPromptTokens(state, output.messages) + assignMessageRefs(state, output.messages) + syncToolCache(state, config, logger, output.messages) buildToolIdList(state, output.messages) @@ -113,6 +116,8 @@ export function createChatMessageTransformHandler( prune(state, logger, config, output.messages) + insertMessageIdContext(state, output.messages) + insertCompressToolContext(state, config, logger, output.messages) applyPendingManualTriggerPrompt(state, output.messages, logger) diff --git a/lib/message-ids.ts b/lib/message-ids.ts new file mode 100644 index 00000000..07247daa --- /dev/null +++ b/lib/message-ids.ts @@ -0,0 +1,132 @@ +import type { SessionState, WithParts } from "./state" + +const MESSAGE_REF_REGEX = /^m(\d{4})$/ +const BLOCK_REF_REGEX = /^b([1-9]\d*)$/ + +const MESSAGE_REF_WIDTH = 4 +const MESSAGE_REF_MIN_INDEX = 0 +export const MESSAGE_REF_MAX_INDEX = 9999 + +export type ParsedBoundaryId = + | { + kind: "message" + ref: string + index: number + } + | { + kind: "compressed-block" + ref: string + blockId: number + } + +export function formatMessageRef(index: number): string { + if ( + !Number.isInteger(index) || + index < MESSAGE_REF_MIN_INDEX || + index > MESSAGE_REF_MAX_INDEX + ) { + throw new Error( + `Message ID index out of bounds: ${index}. Supported range is 0-${MESSAGE_REF_MAX_INDEX}.`, + ) + } + return `m${index.toString().padStart(MESSAGE_REF_WIDTH, "0")}` +} + +export function formatBlockRef(blockId: number): string { + if (!Number.isInteger(blockId) || blockId < 1) { + throw new Error(`Invalid block ID: ${blockId}`) + } + return `b${blockId}` +} + +export function parseMessageRef(ref: string): number | null { + const normalized = ref.trim().toLowerCase() + const match = normalized.match(MESSAGE_REF_REGEX) + if (!match) { + return null + } + const index = Number.parseInt(match[1], 10) + return Number.isInteger(index) ? index : null +} + +export function parseBlockRef(ref: string): number | null { + const normalized = ref.trim().toLowerCase() + const match = normalized.match(BLOCK_REF_REGEX) + if (!match) { + return null + } + const id = Number.parseInt(match[1], 10) + return Number.isInteger(id) ? id : null +} + +export function parseBoundaryId(id: string): ParsedBoundaryId | null { + const normalized = id.trim().toLowerCase() + const messageIndex = parseMessageRef(normalized) + if (messageIndex !== null) { + return { + kind: "message", + ref: formatMessageRef(messageIndex), + index: messageIndex, + } + } + + const blockId = parseBlockRef(normalized) + if (blockId !== null) { + return { + kind: "compressed-block", + ref: formatBlockRef(blockId), + blockId, + } + } + + return null +} + +export function formatMessageIdMarker(ref: string): string { + return `Message ID: ${ref}` +} + +export function assignMessageRefs(state: SessionState, messages: WithParts[]): number { + let assigned = 0 + + for (const message of messages) { + const rawMessageId = message.info.id + if (typeof rawMessageId !== "string" || rawMessageId.length === 0) { + continue + } + + const existingRef = state.messageIds.byRawId.get(rawMessageId) + if (existingRef) { + if (state.messageIds.byRef.get(existingRef) !== rawMessageId) { + state.messageIds.byRef.set(existingRef, rawMessageId) + } + continue + } + + const ref = allocateNextMessageRef(state) + state.messageIds.byRawId.set(rawMessageId, ref) + state.messageIds.byRef.set(ref, rawMessageId) + assigned++ + } + + return assigned +} + +function allocateNextMessageRef(state: SessionState): string { + let candidate = Number.isInteger(state.messageIds.nextRef) + ? Math.max(MESSAGE_REF_MIN_INDEX, state.messageIds.nextRef) + : MESSAGE_REF_MIN_INDEX + + while (candidate <= MESSAGE_REF_MAX_INDEX) { + const ref = formatMessageRef(candidate) + if (!state.messageIds.byRef.has(ref)) { + state.messageIds.nextRef = candidate + 1 + return ref + } + candidate++ + } + + throw new Error( + `Message ID alias capacity exceeded. Cannot allocate more than ${formatMessageRef(MESSAGE_REF_MAX_INDEX)} aliases in this session.`, + ) +} diff --git a/lib/messages/index.ts b/lib/messages/index.ts index f2011dd6..e0ec4ef0 100644 --- a/lib/messages/index.ts +++ b/lib/messages/index.ts @@ -1,2 +1,3 @@ export { prune } from "./prune" export { insertCompressToolContext } from "./inject/inject" +export { insertMessageIdContext } from "./inject/inject" diff --git a/lib/messages/inject/inject.ts b/lib/messages/inject/inject.ts index 72543954..ae78efe0 100644 --- a/lib/messages/inject/inject.ts +++ b/lib/messages/inject/inject.ts @@ -1,6 +1,8 @@ import type { SessionState, WithParts } from "../../state" import type { Logger } from "../../logger" import type { PluginConfig } from "../../config" +import { formatMessageIdMarker } from "../../message-ids" +import { createSyntheticTextPart, createSyntheticToolPart, isIgnoredUserMessage } from "../utils" import { addAnchor, applyAnchoredNudge, @@ -56,3 +58,52 @@ export const insertCompressToolContext = ( persistAnchors(state, logger) } } + +export const insertMessageIdContext = (state: SessionState, messages: WithParts[]): void => { + const { modelId } = getModelInfo(messages) + const toolModelId = modelId || "" + + for (const message of messages) { + if (message.info.role === "user" && isIgnoredUserMessage(message)) { + continue + } + + const messageRef = state.messageIds.byRawId.get(message.info.id) + if (!messageRef) { + continue + } + + const marker = formatMessageIdMarker(messageRef) + + if (message.info.role === "user") { + const hasMarker = message.parts.some( + (part) => part.type === "text" && part.text.trim() === marker, + ) + if (!hasMarker) { + message.parts.push(createSyntheticTextPart(message, marker)) + } + continue + } + + if (message.info.role !== "assistant") { + continue + } + + const hasMarker = message.parts.some((part) => { + if (part.type !== "tool") { + return false + } + if (part.tool !== "context_info") { + return false + } + return ( + part.state?.status === "completed" && + typeof part.state.output === "string" && + part.state.output.trim() === marker + ) + }) + if (!hasMarker) { + message.parts.push(createSyntheticToolPart(message, marker, toolModelId)) + } + } +} diff --git a/lib/prompts/compress.md b/lib/prompts/compress.md index dd1c2e8a..317aa2fa 100644 --- a/lib/prompts/compress.md +++ b/lib/prompts/compress.md @@ -31,6 +31,29 @@ USER INTENT FIDELITY When the compressed range includes user messages, preserve the user's intent with extra care. Do not change scope, constraints, priorities, acceptance criteria, or requested outcomes. Directly quote user messages when they are short enough to include safely. Direct quotes are preferred when they best preserve exact meaning. +COMPRESSED BLOCK PLACEHOLDERS +When the selected range includes previously compressed blocks, use this exact placeholder format when referencing one: + +- `{block_N}` + +Rules: + +- Include every required block placeholder exactly once. +- Do not invent placeholders for blocks outside the selected range. +- Treat `{block_N}` placeholders as RESERVED TOKENS. Do not emit `{block_N}` text anywhere except intentional placeholders. +- If you need to mention a block in prose, use plain text like `compressed bN` (without curly braces). +- Preflight check before finalizing: the set of `{block_N}` placeholders in your summary must exactly match the required set, with no duplicates. + +These placeholders are semantic references. They will be replaced with the full stored compressed block content when the tool processes your output. + +FLOW PRESERVATION WITH PLACEHOLDERS +When you use compressed block placeholders, write the surrounding summary text so it still reads correctly AFTER placeholder expansion. + +- Treat each placeholder as a stand-in for a full conversation segment, not as a short label. +- Ensure transitions before and after each placeholder preserve chronology and causality. +- Do not write text that depends on the placeholder staying literal (for example, "as noted in {block_2}"). +- Your final meaning must be coherent once each placeholder is replaced with its full compressed block content. + Yet be LEAN. Strip away the noise: failed attempts that led nowhere, verbose tool outputs, back-and-forth exploration. What remains should be pure signal - golden nuggets of detail that preserve full understanding with zero ambiguity. THE WAYS OF COMPRESS @@ -43,7 +66,7 @@ Exploration exhausted and patterns understood Compress smaller ranges when: You need to discard dead-end noise without waiting for a whole chapter to close You need to preserve key findings from a narrow slice while freeing context quickly -You can bound a stale range cleanly with unique boundaries +You can bound a stale range cleanly with injected IDs Do NOT compress when: You may need exact code, error messages, or file contents from the range in the immediate next steps @@ -52,40 +75,28 @@ You cannot identify reliable boundaries yet Before compressing, ask: _"Is this range closed enough to become summary-only right now?"_ Compression is irreversible. The summary replaces everything in the range. -BOUNDARY MATCHING -You specify boundaries by matching unique text strings in the conversation. CRITICAL: In code-centric conversations, strings repeat often. Provide sufficiently unique text to match exactly once. Be conservative and choose longer, highly specific boundaries when in doubt. If a match fails (not found or found multiple times), the tool will error - extend your boundary string with more surrounding context in order to make SURE the tool does NOT error. - -WHERE TO PICK STRINGS FROM (important for reliable matching): - -- Your own assistant text responses (MOST RELIABLE - always stored verbatim) -- The user's own words in their messages -- Tool result output text (distinctive substrings within the output) -- Previous compress summaries -- Tool input string values (LEAST RELIABLE - only single concrete field values, not keys or schema fields, may be transformed by AI SDK) - -NEVER USE GENERIC OR REPEATING STRINGS: +BOUNDARY IDS +You specify boundaries by ID -Tool status messages repeat identically across every invocation. These are ALWAYS ambiguous: +Use the injected IDs visible in the conversation: -- "Edit applied successfully." (appears in EVERY successful edit) -- "File written successfully" or any tool success/error boilerplate -- Common tool output patterns that are identical across calls +- `mNNNN` IDs identify raw messages +- `bN` IDs identify previously compressed blocks -Instead, combine the generic output with surrounding unique context (a file path, a specific code snippet, or your own unique assistant text). +Rules: -Each boundary string you choose MUST be unique to ONE specific message. Before using a string, ask: "Could this exact text appear in any other place in this conversation?" If yes, extend it or pick a different string. +- Pick `startId` and `endId` directly from injected IDs in context. +- IDs must exist in the current visible context. +- `startId` must appear before `endId`. +- Prefer boundaries that produce short, closed ranges. -WHERE TO NEVER PICK STRINGS FROM: +ID SOURCES -- `` tags or any XML wrapper/meta-commentary around messages -- Injected system instructions (plan mode text, max-steps warnings, mode-switch text, environment info) -- Reasoning parts or chain-of-thought text -- File/directory listing framing text (e.g. "Called the Read tool with the following input...") -- Strings that span across message or part boundaries -- Entire serialized JSON objects (key ordering may differ - pick a distinctive substring within instead) +- User messages include a text marker with their `mNNNN` ID. +- Assistant messages include a `context_info` tool marker with their `mNNNN` ID. +- Compressed blocks are addressable by `bN` IDs. -CRITICAL: AVOID USING TOOL INPUT VALUES -NEVER use tool input schema keys or field names as boundary strings (e.g., "startString", "endString", "filePath", "content"). These may be transformed by the AI SDK and are not reliable. The ONLY acceptable use of tool input strings is a SINGLE concrete field VALUE (not the key), and even then, prefer using assistant text, user messages, or tool result outputs instead. When in doubt, choose boundaries from your own assistant responses or distinctive user message content. +Do not invent IDs. Use only IDs that are present in context. PARALLEL COMPRESS EXECUTION When multiple independent ranges are ready and their boundaries do not overlap, launch MULTIPLE `compress` calls in parallel in a single response. This is the PREFERRED pattern over a single large-range compression when the work can be safely split. Run compression sequentially only when ranges overlap or when a later range depends on the result of an earlier compression. @@ -96,8 +107,8 @@ THE FORMAT OF COMPRESS { topic: string, // Short label (3-5 words) - e.g., "Auth System Exploration" content: { - startString: string, // Unique text string marking the beginning of the range - endString: string, // Unique text string marking the end of the range + startId: string, // Boundary ID at range start: mNNNN or bN + endId: string, // Boundary ID at range end: mNNNN or bN summary: string // Complete technical summary replacing all content in the range } } diff --git a/lib/prompts/nudge.md b/lib/prompts/nudge.md index bef47653..a7e2fe4d 100644 --- a/lib/prompts/nudge.md +++ b/lib/prompts/nudge.md @@ -31,10 +31,7 @@ Do not jump to a single broad range when the same cleanup can be done safely wit If you are performing a critical atomic operation, do not interrupt it, but make sure to perform context management rapidly -BE VERY MINDFUL of the startString and endString you use for compression for RELIABLE boundary matching. NEVER use generic tool outputs like "Edit applied successfully." or generic status message as boundaries. Use unique assistant text or distinctive content instead with enough surrounding context to ensure uniqueness. - -CRITICAL: AVOID USING TOOL INPUT VALUES AS BOUNDARIES -NEVER use tool input schema keys or field names. The ONLY acceptable use of tool input strings is a SINGLE concrete field VALUE (not the key), and even then, prefer assistant text, user messages, or tool result outputs instead. +Use injected boundary IDs for compression (`mNNNN` for messages, `bN` for compressed blocks). Pick IDs that are visible in context and ensure `startId` appears before `endId`. Ensure your summaries are inclusive of all parts of the range. If the compressed range includes user messages, preserve user intent exactly. Prefer direct quotes for short user messages to avoid semantic drift. diff --git a/lib/state/persistence.ts b/lib/state/persistence.ts index a80d3d5d..0b659cac 100644 --- a/lib/state/persistence.ts +++ b/lib/state/persistence.ts @@ -116,6 +116,7 @@ export async function loadSessionState( (s): s is CompressSummary => s !== null && typeof s === "object" && + typeof s.blockId === "number" && typeof s.anchorMessageId === "string" && typeof s.summary === "string", ) diff --git a/lib/state/state.ts b/lib/state/state.ts index f308a982..00848525 100644 --- a/lib/state/state.ts +++ b/lib/state/state.ts @@ -76,6 +76,11 @@ export function createSessionState(): SessionState { }, toolParameters: new Map(), toolIdList: [], + messageIds: { + byRawId: new Map(), + byRef: new Map(), + nextRef: 0, + }, lastCompaction: 0, currentTurn: 0, variant: undefined, @@ -101,6 +106,11 @@ export function resetSessionState(state: SessionState): void { } state.toolParameters.clear() state.toolIdList = [] + state.messageIds = { + byRawId: new Map(), + byRef: new Map(), + nextRef: 0, + } state.lastCompaction = 0 state.currentTurn = 0 state.variant = undefined diff --git a/lib/state/types.ts b/lib/state/types.ts index 99d90f53..95cc2513 100644 --- a/lib/state/types.ts +++ b/lib/state/types.ts @@ -22,6 +22,7 @@ export interface SessionStats { } export interface CompressSummary { + blockId: number anchorMessageId: string summary: string } @@ -36,6 +37,12 @@ export interface PendingManualTrigger { prompt: string } +export interface MessageIdState { + byRawId: Map + byRef: Map + nextRef: number +} + export interface SessionState { sessionId: string | null isSubAgent: boolean @@ -47,6 +54,7 @@ export interface SessionState { stats: SessionStats toolParameters: Map toolIdList: string[] + messageIds: MessageIdState lastCompaction: number currentTurn: number variant: string | undefined diff --git a/lib/tools/compress-utils.ts b/lib/tools/compress-utils.ts new file mode 100644 index 00000000..2542531f --- /dev/null +++ b/lib/tools/compress-utils.ts @@ -0,0 +1,733 @@ +import type { SessionState, WithParts, CompressSummary } from "../state" +import type { Logger } from "../logger" +import type { PluginConfig } from "../config" +import { formatBlockRef, parseBoundaryId } from "../message-ids" +import { prune } from "../messages" +import { isIgnoredUserMessage } from "../messages/utils" +import { countAllMessageTokens, countTokens } from "../strategies/utils" + +const BLOCK_PLACEHOLDER_REGEX = /\{block_(\d+)\}/gi +const COMPRESSED_BLOCK_HEADER_PREFIX_REGEX = /^\s*\[Compressed conversation b(\d+)\]/i + +export interface CompressToolArgs { + topic: string + content: { + startId: string + endId: string + summary: string + } +} + +export interface BoundaryReference { + kind: "message" | "compressed-block" + transformedIndex: number + messageId?: string + blockId?: number + anchorMessageId?: string +} + +export interface SearchContext { + transformedMessages: WithParts[] + rawMessagesById: Map + summaryByBlockId: Map +} + +export interface RangeResolution { + startReference: BoundaryReference + endReference: BoundaryReference + messageIds: string[] + messageTokenById: Map + toolIds: string[] + requiredBlockIds: number[] +} + +export interface ParsedBlockPlaceholder { + raw: string + blockId: number + startIndex: number + endIndex: number +} + +export interface InjectedSummaryResult { + expandedSummary: string + consumedBlockIds: number[] +} + +export interface AppliedCompressionResult { + compressedTokens: number + messageIds: string[] +} + +export function formatCompressedBlockHeader(blockId: number): string { + return `[Compressed conversation b${blockId}]` +} + +export function formatBlockPlaceholder(blockId: number): string { + return `{block_${blockId}}` +} + +export function validateCompressArgs(args: CompressToolArgs): void { + if (typeof args.topic !== "string" || args.topic.trim().length === 0) { + throw new Error("topic is required and must be a non-empty string") + } + + if (typeof args.content?.startId !== "string" || args.content.startId.trim().length === 0) { + throw new Error("content.startId is required and must be a non-empty string") + } + + if (parseBoundaryId(args.content.startId) === null) { + throw new Error("content.startId must be a valid message/block ID (mNNNN or bN)") + } + + if (typeof args.content?.endId !== "string" || args.content.endId.trim().length === 0) { + throw new Error("content.endId is required and must be a non-empty string") + } + + if (parseBoundaryId(args.content.endId) === null) { + throw new Error("content.endId must be a valid message/block ID (mNNNN or bN)") + } + + if (typeof args.content?.summary !== "string" || args.content.summary.trim().length === 0) { + throw new Error("content.summary is required and must be a non-empty string") + } +} + +export async function fetchSessionMessages(client: any, sessionId: string): Promise { + const response = await client.session.messages({ + path: { id: sessionId }, + }) + + const payload = (response?.data || response) as WithParts[] + return Array.isArray(payload) ? payload : [] +} + +export function buildSearchContext( + state: SessionState, + logger: Logger, + config: PluginConfig, + rawMessages: WithParts[], +): SearchContext { + const transformedMessages = structuredClone(rawMessages) as WithParts[] + prune(state, logger, config, transformedMessages) + + const rawMessagesById = new Map() + for (const msg of rawMessages) { + rawMessagesById.set(msg.info.id, msg) + } + + const summaryByBlockId = new Map() + for (const summary of state.compressSummaries || []) { + summaryByBlockId.set(summary.blockId, summary) + } + + return { + transformedMessages, + rawMessagesById, + summaryByBlockId, + } +} + +export function resolveBoundaryIds( + context: SearchContext, + state: SessionState, + startId: string, + endId: string, +): { startReference: BoundaryReference; endReference: BoundaryReference } { + const lookup = buildBoundaryReferenceLookup(context, state) + const issues: string[] = [] + const parsedStartId = parseBoundaryId(startId) + const parsedEndId = parseBoundaryId(endId) + + if (parsedStartId === null) { + issues.push("startId is invalid. Use an injected message ID (mNNNN) or block ID (bN).") + } + + if (parsedEndId === null) { + issues.push("endId is invalid. Use an injected message ID (mNNNN) or block ID (bN).") + } + + if (issues.length > 0) { + throwCombinedIssues(issues) + } + + if (!parsedStartId || !parsedEndId) { + throw new Error("Invalid boundary ID(s)") + } + + const startReference = lookup.get(parsedStartId.ref) + const endReference = lookup.get(parsedEndId.ref) + + if (!startReference) { + issues.push( + `startId ${parsedStartId.ref} is not available in the current conversation context. Choose an injected ID visible in context.`, + ) + } + + if (!endReference) { + issues.push( + `endId ${parsedEndId.ref} is not available in the current conversation context. Choose an injected ID visible in context.`, + ) + } + + if (issues.length > 0) { + throwCombinedIssues(issues) + } + + if (!startReference || !endReference) { + throw new Error("Failed to resolve boundary IDs") + } + + if (startReference.transformedIndex > endReference.transformedIndex) { + throw new Error( + `startId ${parsedStartId.ref} appears after endId ${parsedEndId.ref} in the conversation. Start must come before end.`, + ) + } + + return { startReference, endReference } +} + +function buildBoundaryReferenceLookup( + context: SearchContext, + state: SessionState, +): Map { + const lookup = new Map() + + for (let index = 0; index < context.transformedMessages.length; index++) { + const message = context.transformedMessages[index] + if (!message) { + continue + } + if (message.info.role === "user" && isIgnoredUserMessage(message)) { + continue + } + + const text = buildSearchableMessageText(message) + const reference = resolveBoundaryReference( + message, + index, + text, + context.summaryByBlockId, + context.rawMessagesById.has(message.info.id), + ) + + if (reference.kind === "compressed-block") { + if (reference.blockId === undefined) { + continue + } + const blockRef = formatBlockRef(reference.blockId) + if (!lookup.has(blockRef)) { + lookup.set(blockRef, reference) + } + continue + } + + if (!reference.messageId) { + continue + } + const messageRef = state.messageIds.byRawId.get(reference.messageId) + if (!messageRef) { + continue + } + + if (!lookup.has(messageRef)) { + lookup.set(messageRef, reference) + } + } + + return lookup +} + +export function resolveRange( + context: SearchContext, + startReference: BoundaryReference, + endReference: BoundaryReference, +): RangeResolution { + const messageIds: string[] = [] + const messageSeen = new Set() + const toolIds: string[] = [] + const toolSeen = new Set() + const requiredBlockIds: number[] = [] + const requiredBlockSeen = new Set() + const messageTokenById = new Map() + + for ( + let index = startReference.transformedIndex; + index <= endReference.transformedIndex; + index++ + ) { + const message = context.transformedMessages[index] + if (!message) { + continue + } + if (message.info.role === "user" && isIgnoredUserMessage(message)) { + continue + } + + const text = buildSearchableMessageText(message) + const reference = resolveBoundaryReference( + message, + index, + text, + context.summaryByBlockId, + context.rawMessagesById.has(message.info.id), + ) + + if (reference.kind === "compressed-block") { + if (reference.blockId !== undefined && !requiredBlockSeen.has(reference.blockId)) { + requiredBlockSeen.add(reference.blockId) + requiredBlockIds.push(reference.blockId) + } + continue + } + + if (!context.rawMessagesById.has(message.info.id)) { + continue + } + + const messageId = message.info.id + if (!messageSeen.has(messageId)) { + messageSeen.add(messageId) + messageIds.push(messageId) + } + + const rawMessage = context.rawMessagesById.get(messageId) + if (!rawMessage) { + continue + } + + if (!messageTokenById.has(messageId)) { + messageTokenById.set(messageId, countAllMessageTokens(rawMessage)) + } + + const parts = Array.isArray(rawMessage.parts) ? rawMessage.parts : [] + for (const part of parts) { + if (part.type !== "tool" || !part.callID) { + continue + } + if (toolSeen.has(part.callID)) { + continue + } + toolSeen.add(part.callID) + toolIds.push(part.callID) + } + } + + if (messageIds.length === 0) { + throw new Error( + "Failed to map boundary matches back to raw messages. Choose boundaries that include original conversation messages.", + ) + } + + return { + startReference, + endReference, + messageIds, + messageTokenById, + toolIds, + requiredBlockIds, + } +} + +export function resolveAnchorMessageId(startReference: BoundaryReference): string { + if (startReference.kind === "compressed-block") { + if (!startReference.anchorMessageId) { + throw new Error("Failed to map boundary matches back to raw messages") + } + return startReference.anchorMessageId + } + + if (!startReference.messageId) { + throw new Error("Failed to map boundary matches back to raw messages") + } + return startReference.messageId +} + +export function parseBlockPlaceholders(summary: string): ParsedBlockPlaceholder[] { + const placeholders: ParsedBlockPlaceholder[] = [] + const regex = new RegExp(BLOCK_PLACEHOLDER_REGEX) + + let match: RegExpExecArray | null + while ((match = regex.exec(summary)) !== null) { + const full = match[0] + const parsed = Number.parseInt(match[1], 10) + if (!Number.isInteger(parsed)) { + continue + } + + placeholders.push({ + raw: full, + blockId: parsed, + startIndex: match.index, + endIndex: match.index + full.length, + }) + } + + return placeholders +} + +export function validateSummaryPlaceholders( + placeholders: ParsedBlockPlaceholder[], + requiredBlockIds: number[], + startReference: BoundaryReference, + endReference: BoundaryReference, + summaryByBlockId: Map, +): void { + const issues: string[] = [] + + const boundaryOptionalIds = new Set() + if (startReference.kind === "compressed-block") { + if (startReference.blockId === undefined) { + issues.push("Failed to map boundary matches back to raw messages") + } else { + boundaryOptionalIds.add(startReference.blockId) + } + } + if (endReference.kind === "compressed-block") { + if (endReference.blockId === undefined) { + issues.push("Failed to map boundary matches back to raw messages") + } else { + boundaryOptionalIds.add(endReference.blockId) + } + } + + const strictRequiredIds = requiredBlockIds.filter((id) => !boundaryOptionalIds.has(id)) + const requiredSet = new Set(requiredBlockIds) + const placeholderIds = placeholders.map((p) => p.blockId) + const placeholderSet = new Set() + const duplicateIds = new Set() + + for (const id of placeholderIds) { + if (placeholderSet.has(id)) { + duplicateIds.add(id) + continue + } + placeholderSet.add(id) + } + + const missing = strictRequiredIds.filter((id) => !placeholderSet.has(id)) + if (missing.length > 0) { + issues.push( + `Missing required block placeholders: ${missing.map(formatBlockPlaceholder).join(", ")}`, + ) + } + + const unknown = placeholderIds.filter((id) => !summaryByBlockId.has(id)) + if (unknown.length > 0) { + const uniqueUnknown = [...new Set(unknown)] + issues.push( + `Unknown block placeholders: ${uniqueUnknown.map(formatBlockPlaceholder).join(", ")}`, + ) + } + + const invalid = placeholderIds.filter((id) => !requiredSet.has(id)) + if (invalid.length > 0) { + const uniqueInvalid = [...new Set(invalid)] + issues.push( + `Invalid block placeholders for selected range: ${uniqueInvalid.map(formatBlockPlaceholder).join(", ")}`, + ) + } + + if (duplicateIds.size > 0) { + issues.push( + `Duplicate block placeholders are not allowed: ${[...duplicateIds].map(formatBlockPlaceholder).join(", ")}`, + ) + } + + if (issues.length > 0) { + throwCombinedIssues(issues) + } +} + +export function injectBlockPlaceholders( + summary: string, + placeholders: ParsedBlockPlaceholder[], + summaryByBlockId: Map, + startReference: BoundaryReference, + endReference: BoundaryReference, +): InjectedSummaryResult { + let cursor = 0 + let expanded = summary + const consumed: number[] = [] + const consumedSeen = new Set() + + if (placeholders.length > 0) { + expanded = "" + for (const placeholder of placeholders) { + const target = summaryByBlockId.get(placeholder.blockId) + if (!target) { + throw new Error( + `Compressed block not found: ${formatBlockPlaceholder(placeholder.blockId)}`, + ) + } + + expanded += summary.slice(cursor, placeholder.startIndex) + expanded += stripCompressedBlockHeader(target.summary) + cursor = placeholder.endIndex + + if (!consumedSeen.has(placeholder.blockId)) { + consumedSeen.add(placeholder.blockId) + consumed.push(placeholder.blockId) + } + } + + expanded += summary.slice(cursor) + } + + expanded = injectBoundarySummaryIfMissing( + expanded, + startReference, + "start", + summaryByBlockId, + consumed, + consumedSeen, + ) + expanded = injectBoundarySummaryIfMissing( + expanded, + endReference, + "end", + summaryByBlockId, + consumed, + consumedSeen, + ) + + return { + expandedSummary: expanded, + consumedBlockIds: consumed, + } +} + +export function allocateBlockId(summaries: CompressSummary[]): number { + if (summaries.length === 0) { + return 1 + } + + let max = 0 + for (const summary of summaries) { + if (summary.blockId > max) { + max = summary.blockId + } + } + return max + 1 +} + +export function addCompressedBlockHeader(blockId: number, summary: string): string { + const header = formatCompressedBlockHeader(blockId) + const body = summary.trim() + if (body.length === 0) { + return header + } + return `${header}\n${body}` +} + +export function applyCompressionState( + state: SessionState, + range: RangeResolution, + anchorMessageId: string, + blockId: number, + summary: string, + consumedBlockIds: number[], +): AppliedCompressionResult { + const consumed = new Set(consumedBlockIds) + state.compressSummaries = (state.compressSummaries || []).filter( + (s) => !consumed.has(s.blockId), + ) + state.compressSummaries.push({ + blockId, + anchorMessageId, + summary, + }) + + let compressedTokens = 0 + for (const messageId of range.messageIds) { + if (state.prune.messages.has(messageId)) { + continue + } + + const tokenCount = range.messageTokenById.get(messageId) || 0 + state.prune.messages.set(messageId, tokenCount) + compressedTokens += tokenCount + } + + state.stats.pruneTokenCounter += compressedTokens + state.stats.totalPruneTokens += state.stats.pruneTokenCounter + state.stats.pruneTokenCounter = 0 + + return { + compressedTokens, + messageIds: range.messageIds, + } +} + +export function countSummaryTokens(summary: string): number { + return countTokens(summary) +} + +function resolveBoundaryReference( + message: WithParts, + transformedIndex: number, + searchableText: string, + summaryByBlockId: Map, + isRawMessage: boolean, +): BoundaryReference { + const leadingBlockId = extractLeadingBlockId(searchableText) + if (!isRawMessage && leadingBlockId !== null) { + const blockSummary = summaryByBlockId.get(leadingBlockId) + if (blockSummary) { + return { + kind: "compressed-block", + transformedIndex, + blockId: leadingBlockId, + anchorMessageId: blockSummary.anchorMessageId, + } + } + } + + return { + kind: "message", + transformedIndex, + messageId: message.info.id, + } +} + +function buildSearchableMessageText(message: WithParts): string { + const parts = Array.isArray(message.parts) ? message.parts : [] + let content = "" + + for (const part of parts) { + const p = part as Record + if ((part as any).ignored) { + continue + } + + switch (part.type) { + case "text": + if (typeof p.text === "string") { + content += ` ${p.text}` + } + break + + case "tool": { + if ((part as any).tool === "compress") { + break + } + + const state = p.state as Record | undefined + if (!state) break + + if (state.status === "completed" && state.output !== undefined) { + content += + " " + + (typeof state.output === "string" + ? state.output + : JSON.stringify(state.output)) + } else if (state.status === "error" && state.error !== undefined) { + content += + " " + + (typeof state.error === "string" + ? state.error + : JSON.stringify(state.error)) + } + + if (state.input !== undefined) { + content += + " " + + (typeof state.input === "string" + ? state.input + : JSON.stringify(state.input)) + } + break + } + + case "compaction": + if (typeof p.summary === "string") { + content += ` ${p.summary}` + } + break + + case "subtask": + if (typeof p.summary === "string") { + content += ` ${p.summary}` + } + if (typeof p.result === "string") { + content += ` ${p.result}` + } + break + + default: + break + } + } + + return content +} + +function extractLeadingBlockId(text: string): number | null { + const match = text.match(COMPRESSED_BLOCK_HEADER_PREFIX_REGEX) + if (!match) { + return null + } + const id = Number.parseInt(match[1], 10) + return Number.isInteger(id) ? id : null +} + +function stripCompressedBlockHeader(summary: string): string { + const headerMatch = summary.match(/^\s*\[Compressed conversation b\d+\]/i) + if (!headerMatch) { + return summary + } + + const afterHeader = summary.slice(headerMatch[0].length) + return afterHeader.replace(/^(?:\r?\n)+/, "") +} + +function injectBoundarySummaryIfMissing( + summary: string, + reference: BoundaryReference, + position: "start" | "end", + summaryByBlockId: Map, + consumed: number[], + consumedSeen: Set, +): string { + if (reference.kind !== "compressed-block" || reference.blockId === undefined) { + return summary + } + if (consumedSeen.has(reference.blockId)) { + return summary + } + + const target = summaryByBlockId.get(reference.blockId) + if (!target) { + throw new Error(`Compressed block not found: ${formatBlockPlaceholder(reference.blockId)}`) + } + + const injectedBody = stripCompressedBlockHeader(target.summary) + const next = + position === "start" + ? mergeWithSpacing(injectedBody, summary) + : mergeWithSpacing(summary, injectedBody) + + consumedSeen.add(reference.blockId) + consumed.push(reference.blockId) + return next +} + +function mergeWithSpacing(left: string, right: string): string { + const l = left.trim() + const r = right.trim() + + if (!l) { + return right + } + if (!r) { + return left + } + return `${l}\n\n${r}` +} + +function throwCombinedIssues(issues: string[]): never { + if (issues.length === 1) { + throw new Error(issues[0]) + } + + throw new Error(issues.map((issue) => `- ${issue}`).join("\n")) +} diff --git a/lib/tools/compress.ts b/lib/tools/compress.ts index 71fa6243..c8f49c94 100644 --- a/lib/tools/compress.ts +++ b/lib/tools/compress.ts @@ -1,23 +1,28 @@ import { tool } from "@opencode-ai/plugin" -import type { WithParts, CompressSummary } from "../state" import type { ToolContext } from "./types" -import { ensureSessionInitialized } from "../state" -import { saveSessionState } from "../state/persistence" import { COMPRESS_TOOL_SPEC } from "../prompts" -import { getCurrentParams, countAllMessageTokens, countTokens } from "../strategies/utils" -import type { AssistantMessage } from "@opencode-ai/sdk/v2" +import { ensureSessionInitialized } from "../state" import { - findStringInMessages, - collectToolIdsInRange, - collectMessageIdsInRange, - findSummaryAnchorForBoundary, -} from "./utils" + addCompressedBlockHeader, + allocateBlockId, + applyCompressionState, + buildSearchContext, + countSummaryTokens, + fetchSessionMessages, + injectBlockPlaceholders, + parseBlockPlaceholders, + resolveAnchorMessageId, + resolveBoundaryIds, + resolveRange, + validateCompressArgs, + validateSummaryPlaceholders, + type CompressToolArgs, +} from "./compress-utils" +import { getCurrentParams, getCurrentTokenUsage } from "../strategies/utils" +import { saveSessionState } from "../state/persistence" import { sendCompressNotification } from "../ui/notification" -import { prune as applyPruneTransforms } from "../messages/prune" -import { clog, C } from "../compress-logger" const COMPRESS_TOOL_DESCRIPTION = COMPRESS_TOOL_SPEC -const COMPRESS_SUMMARY_PREFIX = "[Compressed conversation block]\n\n" export function createCompressTool(ctx: ToolContext): ReturnType { return tool({ @@ -28,26 +33,21 @@ export function createCompressTool(ctx: ToolContext): ReturnType { .describe("Short label (3-5 words) for display - e.g., 'Auth System Exploration'"), content: tool.schema .object({ - startString: tool.schema + startId: tool.schema .string() - .describe("Unique text from conversation marking the beginning of range"), - endString: tool.schema + .describe( + "Message or block ID marking the beginning of range (e.g. m0000, b2)", + ), + endId: tool.schema .string() - .describe("Unique text marking the end of range"), + .describe("Message or block ID marking the end of range (e.g. m0012, b5)"), summary: tool.schema .string() .describe("Complete technical summary replacing all content in range"), }) - .describe("Compression details: boundaries and replacement summary"), + .describe("Compression details: ID boundaries and replacement summary"), }, async execute(args, toolCtx) { - const invocationId = Date.now().toString(36) - const separator = "═".repeat(79) - clog.info( - C.COMPRESS, - `${separator}\nCOMPRESS INVOCATION START\nID: ${invocationId}\n${separator}`, - ) - await toolCtx.ask({ permission: "compress", patterns: ["*"], @@ -55,427 +55,94 @@ export function createCompressTool(ctx: ToolContext): ReturnType { metadata: {}, }) - const { topic, content } = args - const { startString, endString, summary } = content || {} + const compressArgs = args as CompressToolArgs + validateCompressArgs(compressArgs) - clog.info(C.COMPRESS, `Arguments`, { - topic, - startString: startString ? `"${startString.substring(0, 120)}"` : undefined, - startStringLength: startString?.length, - endString: endString ? `"${endString.substring(0, 120)}"` : undefined, - endStringLength: endString?.length, - summaryLength: summary?.length, + toolCtx.metadata({ + title: `Compress: ${compressArgs.topic}`, }) - if (!topic || typeof topic !== "string") { - clog.error(C.COMPRESS, `✗ Validation Failed\ntopic missing or not string`, { - topic, - }) - throw new Error("topic is required and must be a non-empty string") - } - if (!startString || typeof startString !== "string") { - clog.error(C.COMPRESS, `✗ Validation Failed\nstartString missing or not string`, { - startString: typeof startString, - }) - throw new Error("content.startString is required and must be a non-empty string") - } - if (!endString || typeof endString !== "string") { - clog.error(C.COMPRESS, `✗ Validation Failed\nendString missing or not string`, { - endString: typeof endString, - }) - throw new Error("content.endString is required and must be a non-empty string") - } - if (!summary || typeof summary !== "string") { - clog.error(C.COMPRESS, `✗ Validation Failed\nsummary missing or not string`, { - summary: typeof summary, - }) - throw new Error("content.summary is required and must be a non-empty string") - } - - const { client, state, logger } = ctx - const sessionId = toolCtx.sessionID - - clog.info(C.COMPRESS, `Session\nid: ${sessionId}`) - - try { - const messagesResponse = await client.session.messages({ - path: { id: sessionId }, - }) - const messages: WithParts[] = messagesResponse.data || messagesResponse - - clog.info(C.COMPRESS, `Messages\nfetched: ${messages.length} raw messages`) - - await ensureSessionInitialized( - client, - state, - sessionId, - logger, - messages, - ctx.config.manualMode.enabled, - ) - - clog.info(C.STATE, `State Snapshot (before boundary matching)`, { - sessionId: state.sessionId, - isSubAgent: state.isSubAgent, - summaries: state.compressSummaries.length, - pruned: { - tools: state.prune.tools.size, - messages: state.prune.messages.size, - }, - toolParameters: state.toolParameters.size, - turn: state.currentTurn, - }) - - const transformedMessages = structuredClone(messages) as WithParts[] - applyPruneTransforms(state, logger, ctx.config, transformedMessages) - - clog.info( - C.COMPRESS, - `Prune Transform\nraw: ${messages.length} messages\ntransformed: ${transformedMessages.length} messages`, - ) - - // Log message IDs for both raw and transformed to detect discrepancies - clog.debug(C.COMPRESS, `Message IDs`, { - raw: messages.map((m, i) => `${i}:${m.info.id}:${m.info.role}`), - transformed: transformedMessages.map( - (m, i) => `${i}:${m.info.id}:${m.info.role}`, - ), - }) - - clog.info(C.BOUNDARY, `Boundary Search: START STRING\nsearching...`) - const startResult = findStringInMessages( - transformedMessages, - startString, - logger, - "startString", - ) - clog.info(C.BOUNDARY, `✓ Start boundary found`, { - messageId: startResult.messageId, - messageIndex: startResult.messageIndex, - }) - - clog.info(C.BOUNDARY, `Boundary Search: END STRING\nsearching...`) - const endResult = findStringInMessages( - transformedMessages, - endString, - logger, - "endString", - ) - clog.info(C.BOUNDARY, `✓ End boundary found`, { - messageId: endResult.messageId, - messageIndex: endResult.messageIndex, - }) - - let rawStartIndex = messages.findIndex((m) => m.info.id === startResult.messageId) - let rawEndIndex = messages.findIndex((m) => m.info.id === endResult.messageId) - - clog.info(C.COMPRESS, `Raw Index Mapping (direct)`, { - start: { messageId: startResult.messageId, rawIndex: rawStartIndex }, - end: { messageId: endResult.messageId, rawIndex: rawEndIndex }, - }) - - // If a boundary matched inside a synthetic compress summary message, - // resolve it back to the summary's anchor message in the raw messages - if (rawStartIndex === -1) { - clog.warn( - C.COMPRESS, - `⚠ Start boundary not in raw messages\nTrying compressSummaries fallback...`, - { - messageId: startResult.messageId, - summaries: state.compressSummaries.length, - }, - ) - const s = findSummaryAnchorForBoundary( - state.compressSummaries, - startString, - "startString", - ) - if (s) { - rawStartIndex = messages.findIndex((m) => m.info.id === s.anchorMessageId) - clog.info(C.COMPRESS, `✓ Start resolved via summary anchor`, { - anchorMessageId: s.anchorMessageId, - rawStartIndex, - }) - } else { - clog.warn( - C.COMPRESS, - `✗ Start not found in any summary either\nCannot resolve boundary`, - ) - } - } - if (rawEndIndex === -1) { - clog.warn( - C.COMPRESS, - `⚠ End boundary not in raw messages\nTrying compressSummaries fallback...`, - { - messageId: endResult.messageId, - summaries: state.compressSummaries.length, - }, - ) - const s = findSummaryAnchorForBoundary( - state.compressSummaries, - endString, - "endString", - ) - if (s) { - rawEndIndex = messages.findIndex((m) => m.info.id === s.anchorMessageId) - clog.info(C.COMPRESS, `✓ End resolved via summary anchor`, { - anchorMessageId: s.anchorMessageId, - rawEndIndex, - }) - } else { - clog.warn( - C.COMPRESS, - `✗ End not found in any summary either\nCannot resolve boundary`, - ) - } - } - - if (rawStartIndex === -1 || rawEndIndex === -1) { - clog.error( - C.COMPRESS, - `✗ Boundary Mapping Failed\nCannot map boundaries to raw`, - { - indices: { rawStartIndex, rawEndIndex }, - boundaries: { - start: startResult.messageId, - end: endResult.messageId, - }, - context: { - rawMessageIds: messages.map((m) => m.info.id), - transformedMessageIds: transformedMessages.map((m) => m.info.id), - summaries: state.compressSummaries.map((s) => ({ - anchor: s.anchorMessageId, - preview: s.summary.substring(0, 80), - })), - }, - }, - ) - throw new Error(`Failed to map boundary matches back to raw messages`) - } - - if (rawStartIndex > rawEndIndex) { - clog.error(C.COMPRESS, `✗ Invalid Range\nStart appears after end`, { - rawStartIndex, - rawEndIndex, - start: startResult.messageId, - end: endResult.messageId, - }) - throw new Error( - "startString appears after endString in the conversation. Start must come before end.", - ) - } - - const rangeSize = rawEndIndex - rawStartIndex + 1 - clog.info( - C.COMPRESS, - `Final Range\n[${rawStartIndex}..${rawEndIndex}] → ${rangeSize} messages`, - ) - - const containedToolIds = collectToolIdsInRange(messages, rawStartIndex, rawEndIndex) - const containedMessageIds = collectMessageIdsInRange( - messages, - rawStartIndex, - rawEndIndex, - ) - - clog.info(C.COMPRESS, `Range Contents`, { - tools: containedToolIds.length, - messages: containedMessageIds.length, - samples: { - toolIds: containedToolIds.slice(0, 5), - messageIds: containedMessageIds.slice(0, 5), - }, - }) - - // Remove any existing summaries whose anchors are now inside this range - // This prevents duplicate injections when a larger compress subsumes a smaller one - const removedSummaries = state.compressSummaries.filter((s) => - containedMessageIds.includes(s.anchorMessageId), - ) - if (removedSummaries.length > 0) { - clog.info( - C.COMPRESS, - `Removing Subsumed Summaries\ncount: ${removedSummaries.length}`, - { - removed: removedSummaries.map((s) => ({ - anchor: s.anchorMessageId, - preview: s.summary.substring(0, 60), - })), - }, - ) - state.compressSummaries = state.compressSummaries.filter( - (s) => !containedMessageIds.includes(s.anchorMessageId), - ) - } - - const anchorMessageId = messages[rawStartIndex]?.info.id || startResult.messageId - const compressSummary: CompressSummary = { - anchorMessageId, - summary: COMPRESS_SUMMARY_PREFIX + summary, - } - state.compressSummaries.push(compressSummary) - - clog.info(C.COMPRESS, `Summary Creation`, { - anchor: anchorMessageId, - totalSummaries: state.compressSummaries.length, - }) - - const compressedMessageIds = containedMessageIds.filter( - (id) => !state.prune.messages.has(id), - ) - const compressedToolIds = containedToolIds.filter( - (id) => !state.prune.tools.has(id), - ) - - clog.info(C.COMPRESS, `Prune Accounting`, { - new: { - messages: compressedMessageIds.length, - tools: compressedToolIds.length, - }, - alreadyPruned: { - messages: containedMessageIds.length - compressedMessageIds.length, - tools: containedToolIds.length - compressedToolIds.length, - }, - }) - - let estimatedCompressedTokens = 0 - for (const msgId of compressedMessageIds) { - const msg = messages.find((m) => m.info.id === msgId) - if (msg) { - const tokens = countAllMessageTokens(msg) - estimatedCompressedTokens += tokens - state.prune.messages.set(msgId, tokens) - } - } - for (const id of compressedToolIds) { - const entry = state.toolParameters.get(id) - state.prune.tools.set(id, entry?.tokenCount ?? 0) - } - - // Use API-reported tokens from last assistant message (matches OpenCode UI) - let totalSessionTokens = 0 - let hasApiTokenMetadata = false - for (let i = messages.length - 1; i >= 0; i--) { - if (messages[i].info.role === "assistant") { - const info = messages[i].info as AssistantMessage - const input = info.tokens?.input || 0 - const output = info.tokens?.output || 0 - const reasoning = info.tokens?.reasoning || 0 - const cacheRead = info.tokens?.cache?.read || 0 - const cacheWrite = info.tokens?.cache?.write || 0 - const total = input + output + reasoning + cacheRead + cacheWrite - if (total > 0) { - totalSessionTokens = total - hasApiTokenMetadata = true - break - } - } - } - - if (!hasApiTokenMetadata) { - let estimatedContentTokens = 0 - for (const msg of messages) { - estimatedContentTokens += countAllMessageTokens(msg) - } - totalSessionTokens = estimatedContentTokens - clog.info(C.COMPRESS, `Token Accounting Fallback`, { - totalSessionTokens, - }) - } - - // Cap estimate — countAllMessageTokens can inflate beyond API count - if (totalSessionTokens > 0 && estimatedCompressedTokens > totalSessionTokens) { - estimatedCompressedTokens = Math.round(totalSessionTokens * 0.95) - } - - clog.info(C.COMPRESS, `Token Accounting`, { - totalSessionTokens, - estimatedCompressedTokens, - compressedMessages: compressedMessageIds.length, - compressedTools: compressedToolIds.length, - pruneState: { - tools: state.prune.tools.size, - messages: state.prune.messages.size, - }, - }) - - state.stats.pruneTokenCounter += estimatedCompressedTokens + const rawMessages = await fetchSessionMessages(ctx.client, toolCtx.sessionID) + await ensureSessionInitialized( + ctx.client, + ctx.state, + toolCtx.sessionID, + ctx.logger, + rawMessages, + ctx.config.manualMode.enabled, + ) - const currentParams = getCurrentParams(state, messages, logger) - const summaryTokens = countTokens(args.content.summary) + const searchContext = buildSearchContext(ctx.state, ctx.logger, ctx.config, rawMessages) - clog.info(C.COMPRESS, `Notification Values`, { - totalSessionTokens, - estimatedCompressedTokens, - summaryTokens, - reductionPercent: - totalSessionTokens > 0 - ? `-${Math.round((estimatedCompressedTokens / totalSessionTokens) * 100)}%` - : "N/A", - messageCount: messages.length, - compressedMessageIds: compressedMessageIds.length, - compressedToolIds: compressedToolIds.length, - }) + const { startReference, endReference } = resolveBoundaryIds( + searchContext, + ctx.state, + compressArgs.content.startId, + compressArgs.content.endId, + ) - await sendCompressNotification( - client, - logger, - ctx.config, - state, - sessionId, - compressedToolIds, - compressedMessageIds, - topic, - summary, - summaryTokens, - totalSessionTokens, - estimatedCompressedTokens, - messages.map((m) => m.info.id), - messages.length, - currentParams, - ) + const range = resolveRange(searchContext, startReference, endReference) + const anchorMessageId = resolveAnchorMessageId(range.startReference) - state.stats.totalPruneTokens += state.stats.pruneTokenCounter - state.stats.pruneTokenCounter = 0 - state.contextLimitAnchors = new Set() + const parsedPlaceholders = parseBlockPlaceholders(compressArgs.content.summary) + validateSummaryPlaceholders( + parsedPlaceholders, + range.requiredBlockIds, + range.startReference, + range.endReference, + searchContext.summaryByBlockId, + ) - clog.info(C.COMPRESS, `Final Stats`, { - totalPruneTokens: state.stats.totalPruneTokens, - }) + const injected = injectBlockPlaceholders( + compressArgs.content.summary, + parsedPlaceholders, + searchContext.summaryByBlockId, + range.startReference, + range.endReference, + ) - saveSessionState(state, logger).catch((err) => { - clog.error(C.STATE, `✗ State Persistence Failed`, { error: err.message }) - }) + const blockId = allocateBlockId(ctx.state.compressSummaries) + const storedSummary = addCompressedBlockHeader(blockId, injected.expandedSummary) + const summaryTokens = countSummaryTokens(storedSummary) + + const applied = applyCompressionState( + ctx.state, + range, + anchorMessageId, + blockId, + storedSummary, + injected.consumedBlockIds, + ) - const result = `Compressed ${compressedMessageIds.length} messages (${compressedToolIds.length} tool calls) into summary (${summaryTokens} tokens). The content will be replaced with your summary.` - clog.info( - C.COMPRESS, - `${separator}\n✓ COMPRESS INVOCATION SUCCESS\nID: ${invocationId}\n\n${result}\n${separator}`, - ) - void clog.flush() + await saveSessionState(ctx.state, ctx.logger) + + const params = getCurrentParams(ctx.state, rawMessages, ctx.logger) + const totalSessionTokens = getCurrentTokenUsage(rawMessages) + const sessionMessageIds = rawMessages.map((msg) => msg.info.id) + + await sendCompressNotification( + ctx.client, + ctx.logger, + ctx.config, + ctx.state, + toolCtx.sessionID, + range.toolIds, + applied.messageIds, + compressArgs.topic, + storedSummary, + summaryTokens, + totalSessionTokens, + applied.compressedTokens, + sessionMessageIds, + rawMessages.length, + params, + ) - return result - } catch (err: unknown) { - const msg = err instanceof Error ? err.message : String(err) - const stack = err instanceof Error ? err.stack : undefined - const separator = "═".repeat(79) - clog.error( - C.COMPRESS, - `${separator}\n✗ COMPRESS INVOCATION FAILED\nID: ${invocationId}\n${separator}`, - { - error: msg, - stack, - context: { - topic, - startString: startString.substring(0, 120), - endString: endString.substring(0, 120), - }, - }, - ) - void clog.flush() - throw err - } + return `Compressed ${applied.messageIds.length} messages into ${formatBlock(blockId)}.` }, }) } + +function formatBlock(blockId: number): string { + return `[Compressed conversation b${blockId}]` +} diff --git a/lib/tools/utils.ts b/lib/tools/utils.ts deleted file mode 100644 index f5edc376..00000000 --- a/lib/tools/utils.ts +++ /dev/null @@ -1,355 +0,0 @@ -import { partial_ratio } from "fuzzball" -import type { CompressSummary, WithParts } from "../state" -import type { Logger } from "../logger" -import { isIgnoredUserMessage } from "../messages/utils" -import { clog, C } from "../compress-logger" - -export interface FuzzyConfig { - minScore: number - minGap: number -} - -export const DEFAULT_FUZZY_CONFIG: FuzzyConfig = { - minScore: 95, - minGap: 15, -} - -interface MatchResult { - messageId: string - messageIndex: number - score: number - matchType: "exact" | "fuzzy" -} - -export function findSummaryAnchorForBoundary( - summaries: CompressSummary[], - searchString: string, - stringType: "startString" | "endString", -): CompressSummary | undefined { - const matches = summaries.filter((s) => s.summary.includes(searchString)) - - if (matches.length > 1) { - const sample = matches.slice(0, 8).map((s) => ({ - anchorMessageId: s.anchorMessageId, - preview: s.summary.substring(0, 120), - })) - - clog.error(C.BOUNDARY, `✗ Multiple Summary Matches (ambiguous)`, { - type: stringType, - count: matches.length, - matches: sample, - omitted: Math.max(0, matches.length - sample.length), - searchPreview: searchString.substring(0, 150), - }) - - throw new Error( - `Found multiple matches for ${stringType}. ` + - `Provide more surrounding context to uniquely identify the intended match.`, - ) - } - - return matches[0] -} - -function summarizeMatches( - matches: MatchResult[], - limit = 8, -): { - sample: Array<{ msgId: string; idx: number; score: number }> - total: number - omitted: number -} { - const sample = matches.slice(0, limit).map((m) => ({ - msgId: m.messageId, - idx: m.messageIndex, - score: m.score, - })) - return { sample, total: matches.length, omitted: Math.max(0, matches.length - sample.length) } -} - -function extractMessageContent(msg: WithParts): string { - const parts = Array.isArray(msg.parts) ? msg.parts : [] - let content = "" - - for (const part of parts) { - const p = part as Record - if ((part as any).ignored) { - continue - } - - switch (part.type) { - case "text": - if (typeof p.text === "string") { - content += " " + p.text - } - break - - case "tool": { - const state = p.state as Record | undefined - if (!state) break - - // Include tool output (completed or error) - if (state.status === "completed" && typeof state.output === "string") { - content += " " + state.output - } else if (state.status === "error" && typeof state.error === "string") { - content += " " + state.error - } - - // Include tool input - if (state.input) { - content += - " " + - (typeof state.input === "string" - ? state.input - : JSON.stringify(state.input)) - } - break - } - - case "compaction": - if (typeof p.summary === "string") { - content += " " + p.summary - } - break - - case "subtask": - if (typeof p.summary === "string") { - content += " " + p.summary - } - if (typeof p.result === "string") { - content += " " + p.result - } - break - } - } - - return content -} - -function findExactMatches(messages: WithParts[], searchString: string): MatchResult[] { - const matches: MatchResult[] = [] - - for (let i = 0; i < messages.length; i++) { - const msg = messages[i] - if (isIgnoredUserMessage(msg)) { - continue - } - const content = extractMessageContent(msg) - if (content.includes(searchString)) { - matches.push({ - messageId: msg.info.id, - messageIndex: i, - score: 100, - matchType: "exact", - }) - } - } - - return matches -} - -function findFuzzyMatches( - messages: WithParts[], - searchString: string, - minScore: number, -): MatchResult[] { - const matches: MatchResult[] = [] - - for (let i = 0; i < messages.length; i++) { - const msg = messages[i] - if (isIgnoredUserMessage(msg)) { - continue - } - const content = extractMessageContent(msg) - const score = partial_ratio(searchString, content) - if (score >= minScore) { - matches.push({ - messageId: msg.info.id, - messageIndex: i, - score, - matchType: "fuzzy", - }) - } - } - - return matches -} - -export function findStringInMessages( - messages: WithParts[], - searchString: string, - logger: Logger, - stringType: "startString" | "endString", - fuzzyConfig: FuzzyConfig = DEFAULT_FUZZY_CONFIG, -): { messageId: string; messageIndex: number } { - clog.info(C.BOUNDARY, `Search Configuration`, { - type: stringType, - targetText: searchString.substring(0, 150), - targetLength: searchString.length, - messages: messages.length, - fuzzyMinScore: fuzzyConfig.minScore, - fuzzyMinGap: fuzzyConfig.minGap, - }) - - const searchableMessages = messages.length > 1 ? messages.slice(0, -1) : messages - const lastMessage = messages.length > 0 ? messages[messages.length - 1] : undefined - - clog.debug( - C.BOUNDARY, - `Searching ${searchableMessages.length} messages\n(last message excluded: ${messages.length > 1})`, - ) - - const exactMatches = findExactMatches(searchableMessages, searchString) - const exactSummary = summarizeMatches(exactMatches) - - clog.info(C.BOUNDARY, `Exact Match Results`, { - count: exactSummary.total, - matches: exactSummary.sample, - omitted: exactSummary.omitted, - }) - - if (exactMatches.length === 1) { - clog.info(C.BOUNDARY, `✓ Single exact match`, { - messageId: exactMatches[0].messageId, - messageIndex: exactMatches[0].messageIndex, - }) - return { messageId: exactMatches[0].messageId, messageIndex: exactMatches[0].messageIndex } - } - - if (exactMatches.length > 1) { - clog.error(C.BOUNDARY, `✗ Multiple Exact Matches (ambiguous)`, { - count: exactMatches.length, - matches: exactMatches.map((m) => ({ msgId: m.messageId, idx: m.messageIndex })), - searchPreview: searchString.substring(0, 150), - }) - throw new Error( - `Found multiple matches for ${stringType}. ` + - `Provide more surrounding context to uniquely identify the intended match.`, - ) - } - - clog.info(C.BOUNDARY, `No exact match\nAttempting fuzzy search...`, { - minScore: fuzzyConfig.minScore, - minGap: fuzzyConfig.minGap, - }) - - const fuzzyMatches = findFuzzyMatches(searchableMessages, searchString, fuzzyConfig.minScore) - const fuzzySummary = summarizeMatches(fuzzyMatches) - - clog.info(C.BOUNDARY, `Fuzzy Match Results`, { - count: fuzzySummary.total, - matches: fuzzySummary.sample, - omitted: fuzzySummary.omitted, - }) - - if (fuzzyMatches.length === 0) { - clog.warn(C.BOUNDARY, `⚠ No fuzzy matches\nTrying last message as last resort...`) - - if (lastMessage && !isIgnoredUserMessage(lastMessage)) { - const lastMsgContent = extractMessageContent(lastMessage) - const lastMsgIndex = messages.length - 1 - clog.debug(C.BOUNDARY, `Last message check`, { - messageId: lastMessage.info.id, - contentLength: lastMsgContent.length, - }) - if (lastMsgContent.includes(searchString)) { - clog.info(C.BOUNDARY, `✓ Found in last message (last resort)`, { - messageId: lastMessage.info.id, - messageIndex: lastMsgIndex, - }) - return { - messageId: lastMessage.info.id, - messageIndex: lastMsgIndex, - } - } - clog.warn(C.BOUNDARY, `✗ Not found in last message either`) - } - - clog.error(C.BOUNDARY, `✗ NOT FOUND ANYWHERE`, { - searchString: searchString.substring(0, 200), - searchStringLen: searchString.length, - messageCount: messages.length, - messageRoles: messages.map((m, i) => `${i}:${m.info.role}`), - }) - throw new Error( - `${stringType} not found in conversation. ` + - `Make sure the string exists and is spelled exactly as it appears.`, - ) - } - - fuzzyMatches.sort((a, b) => b.score - a.score) - - const best = fuzzyMatches[0] - const secondBest = fuzzyMatches[1] - - clog.info(C.BOUNDARY, `Fuzzy Ranking`, { - best: { msgId: best.messageId, idx: best.messageIndex, score: best.score }, - secondBest: secondBest - ? { msgId: secondBest.messageId, idx: secondBest.messageIndex, score: secondBest.score } - : null, - gap: secondBest ? best.score - secondBest.score : "N/A", - requiredGap: fuzzyConfig.minGap, - }) - - // Check confidence gap - best must be significantly better than second best - if (secondBest && best.score - secondBest.score < fuzzyConfig.minGap) { - clog.error(C.BOUNDARY, `✗ Ambiguous Fuzzy Match (gap too small)`, { - best: best.score, - secondBest: secondBest.score, - gap: best.score - secondBest.score, - required: fuzzyConfig.minGap, - }) - throw new Error( - `Found multiple matches for ${stringType}. ` + - `Provide more unique surrounding context to disambiguate.`, - ) - } - - clog.info(C.BOUNDARY, `✓ Fuzzy match accepted`, { - messageId: best.messageId, - messageIndex: best.messageIndex, - score: best.score, - }) - - return { messageId: best.messageId, messageIndex: best.messageIndex } -} - -export function collectToolIdsInRange( - messages: WithParts[], - startIndex: number, - endIndex: number, -): string[] { - const toolIds: string[] = [] - - for (let i = startIndex; i <= endIndex; i++) { - const msg = messages[i] - const parts = Array.isArray(msg.parts) ? msg.parts : [] - - for (const part of parts) { - if (part.type === "tool" && part.callID) { - if (!toolIds.includes(part.callID)) { - toolIds.push(part.callID) - } - } - } - } - - return toolIds -} - -export function collectMessageIdsInRange( - messages: WithParts[], - startIndex: number, - endIndex: number, -): string[] { - const messageIds: string[] = [] - - for (let i = startIndex; i <= endIndex; i++) { - const msgId = messages[i].info.id - if (!messageIds.includes(msgId)) { - messageIds.push(msgId) - } - } - - return messageIds -} diff --git a/scripts/opencode-dcp-stats b/scripts/opencode-dcp-stats index 67a57f4f..44a4322f 100755 --- a/scripts/opencode-dcp-stats +++ b/scripts/opencode-dcp-stats @@ -8,11 +8,12 @@ Usage: opencode-dcp-stats [--sessions N] [--min-messages M] [--json] [--verbose] import json import argparse -from pathlib import Path from datetime import datetime from collections import defaultdict from typing import Optional +from opencode_api import APIError, add_api_arguments, create_client_from_args, list_sessions_across_projects + # DCP tool names across versions (compress is canonical; others are legacy aliases) DCP_TOOLS = { "compress", "prune", "distill", @@ -24,42 +25,9 @@ CACHE_READ_COST_PER_1K = 0.00030 # $0.30 per 1M tokens INPUT_COST_PER_1K = 0.003 # $3.00 per 1M tokens -def get_session_messages(storage: Path, session_id: str) -> list[dict]: - """Get all messages for a session, sorted by creation order.""" - message_dir = storage / "message" / session_id - if not message_dir.exists(): - return [] - - messages = [] - for msg_file in message_dir.glob("*.json"): - try: - msg = json.loads(msg_file.read_text()) - msg["_file"] = msg_file - msg["_id"] = msg_file.stem - messages.append(msg) - except (json.JSONDecodeError, IOError): - pass - - return sorted(messages, key=lambda m: m.get("_id", "")) - - -def get_message_parts(storage: Path, message_id: str) -> list[dict]: - """Get all parts for a message, sorted by creation order.""" - parts_dir = storage / "part" / message_id - if not parts_dir.exists(): - return [] - - parts = [] - for part_file in parts_dir.glob("*.json"): - try: - part = json.loads(part_file.read_text()) - part["_file"] = part_file - part["_id"] = part_file.stem - parts.append(part) - except (json.JSONDecodeError, IOError): - pass - - return sorted(parts, key=lambda p: p.get("_id", "")) +def get_session_messages(client, session: dict) -> list[dict]: + """Get all messages for a session.""" + return client.get_session_messages(session["id"], directory=session.get("directory")) def is_ignored_message(message: dict, parts: list[dict]) -> bool: @@ -83,22 +51,20 @@ def is_ignored_message(message: dict, parts: list[dict]) -> bool: return True -def count_real_user_messages(storage: Path, session_id: str) -> int: +def count_real_user_messages(messages: list[dict]) -> int: """Count user messages that are not ignored (real user interactions).""" - messages = get_session_messages(storage, session_id) count = 0 - + for msg in messages: + info = msg.get("info", {}) + parts = msg.get("parts", []) # Only count user role messages - if msg.get("role") != "user": + if info.get("role") != "user": continue - - msg_id = msg.get("_id", "") - parts = get_message_parts(storage, msg_id) - + if not is_ignored_message(msg, parts): count += 1 - + return count @@ -136,10 +102,9 @@ def calc_cache_hit_rate(tokens: dict) -> float: return (cache_read / total_context) * 100 -def analyze_session(storage: Path, session_id: str) -> dict: +def analyze_session(messages: list[dict], session_id: str) -> dict: """Analyze DCP impact for a single session.""" - messages = get_session_messages(storage, session_id) - + result = { "session_id": session_id, "dcp_events": [], @@ -164,14 +129,15 @@ def analyze_session(storage: Path, session_id: str) -> dict: prev_step = None prev_dcp_tools = [] steps_since_dcp = None # None = no DCP yet, 0 = just had DCP, 1+ = steps after - + for i, msg in enumerate(messages): - msg_id = msg.get("_id", "") - parts = get_message_parts(storage, msg_id) - + msg_info = msg.get("info", {}) + msg_id = msg_info.get("id", "") + parts = msg.get("parts", []) + step_finish = extract_step_finish(parts) dcp_tools = extract_dcp_tools(parts) - + if step_finish: result["total_steps"] += 1 tokens = step_finish.get("tokens", {}) @@ -239,27 +205,24 @@ def analyze_session(storage: Path, session_id: str) -> dict: return result -def analyze_sessions(num_sessions: int = 20, min_messages: int = 5, output_json: bool = False, verbose: bool = False, session_id: str = None): +def analyze_sessions( + client, + num_sessions: int = 20, + min_messages: int = 5, + output_json: bool = False, + verbose: bool = False, + session_id: str = None, + session_list_limit: int = 5000, +): """Analyze DCP impact across recent sessions.""" - storage = Path.home() / ".local/share/opencode/storage" - message_dir = storage / "message" - session_dir = storage / "session" - - if not message_dir.exists(): - print("Error: OpenCode storage not found at", storage) - return - + # Get sessions to analyze if session_id: # Analyze specific session - session_path = message_dir / session_id - if not session_path.exists(): - print(f"Error: Session {session_id} not found") - return - sessions = [session_path] + sessions = [client.get_session(session_id)] else: - sessions = sorted(message_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)[:num_sessions] - + sessions = list_sessions_across_projects(client, per_project_limit=session_list_limit)[:num_sessions] + all_results = [] grand_totals = { "sessions_analyzed": 0, @@ -283,44 +246,34 @@ def analyze_sessions(num_sessions: int = 20, min_messages: int = 5, output_json: "hit_rates_by_distance": defaultdict(list) } - for session_path in sessions: - session_id = session_path.name - + for session in sessions: + session_id = session.get("id", "") + messages = get_session_messages(client, session) + # Check minimum message count (excluding ignored messages) - real_user_messages = count_real_user_messages(storage, session_id) + real_user_messages = count_real_user_messages(messages) if real_user_messages < min_messages: grand_totals["sessions_skipped_short"] += 1 continue - - result = analyze_session(storage, session_id) + + result = analyze_session(messages, session_id) result["user_messages"] = real_user_messages - - # Get session metadata - title = "Unknown" - for s_dir in session_dir.iterdir(): - s_file = s_dir / f"{session_id}.json" - if s_file.exists(): - try: - sess = json.loads(s_file.read_text()) - title = sess.get("title", "Untitled")[:50] - except (json.JSONDecodeError, IOError): - pass - break - - result["title"] = title - + + # Session metadata from API + result["title"] = session.get("title", "Untitled")[:50] + if result["total_dcp_calls"] > 0: all_results.append(result) grand_totals["sessions_with_dcp"] += 1 - + grand_totals["sessions_analyzed"] += 1 grand_totals["total_dcp_calls"] += result["total_dcp_calls"] grand_totals["total_steps"] += result["total_steps"] - + for tool, stats in result["by_tool"].items(): for key in stats: grand_totals["by_tool"][tool][key] += stats[key] - + # Aggregate hit rates by distance for dist, rates in result["hit_rates_by_distance"].items(): grand_totals["hit_rates_by_distance"][dist].extend(rates) @@ -474,16 +427,26 @@ def main(): help="Output as JSON") parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed per-event breakdown") + add_api_arguments(parser) args = parser.parse_args() - - analyze_sessions( - num_sessions=args.sessions, - min_messages=args.min_messages, - output_json=args.json, - verbose=args.verbose, - session_id=args.session - ) + + try: + with create_client_from_args(args) as client: + analyze_sessions( + client, + num_sessions=args.sessions, + min_messages=args.min_messages, + output_json=args.json, + verbose=args.verbose, + session_id=args.session, + session_list_limit=args.session_list_limit, + ) + except APIError as err: + print(f"Error: {err}") + return 1 + + return 0 if __name__ == "__main__": - main() + raise SystemExit(main()) diff --git a/scripts/opencode-find-session b/scripts/opencode-find-session index 5b7e2087..dfa6e712 100755 --- a/scripts/opencode-find-session +++ b/scripts/opencode-find-session @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -Find OpenCode session IDs by title search. +Find OpenCode session IDs by title search using the OpenCode API. Returns matching session IDs ordered by last usage time. Usage: opencode-find-session [--exact] [--json] @@ -8,46 +8,28 @@ Usage: opencode-find-session [--exact] [--json] import json import argparse -from pathlib import Path from datetime import datetime +from opencode_api import APIError, add_api_arguments, create_client_from_args, list_sessions_across_projects -def get_all_sessions(storage: Path) -> list[dict]: - """Get all sessions with their metadata.""" - session_dir = storage / "session" - message_dir = storage / "message" - - if not session_dir.exists(): - return [] - + +def get_all_sessions(client, session_list_limit: int) -> list[dict]: + """Get all sessions with normalized metadata from API.""" + api_sessions = list_sessions_across_projects(client, per_project_limit=session_list_limit) sessions = [] - - for app_dir in session_dir.iterdir(): - if not app_dir.is_dir(): - continue - - for session_file in app_dir.glob("*.json"): - try: - session = json.loads(session_file.read_text()) - session_id = session_file.stem - - # Get last modified time from message directory - msg_path = message_dir / session_id - if msg_path.exists(): - mtime = msg_path.stat().st_mtime - else: - mtime = session_file.stat().st_mtime - - sessions.append({ - "id": session_id, - "title": session.get("title", "Untitled"), - "created_at": session.get("createdAt"), - "last_used": mtime, - "last_used_iso": datetime.fromtimestamp(mtime).isoformat() - }) - except (json.JSONDecodeError, IOError): - pass - + for session in api_sessions: + time_data = session.get("time", {}) + updated_ms = time_data.get("updated") or time_data.get("created") or 0 + last_used = updated_ms / 1000 if updated_ms else 0 + sessions.append( + { + "id": session.get("id", ""), + "title": session.get("title", "Untitled"), + "created_at": time_data.get("created"), + "last_used": last_used, + "last_used_iso": datetime.fromtimestamp(last_used).isoformat() if last_used else None, + } + ) return sessions @@ -101,6 +83,7 @@ def main(): ) parser.add_argument( "search_term", + nargs="?", type=str, help="Text to search for in session titles" ) @@ -119,16 +102,19 @@ def main(): action="store_true", help="Show all sessions (ignore search term)" ) + add_api_arguments(parser) args = parser.parse_args() - - storage = Path.home() / ".local/share/opencode/storage" - - if not storage.exists(): - print("Error: OpenCode storage not found at", storage) + + if not args.all and not args.search_term: + parser.error("search_term is required unless --all is used") + + try: + with create_client_from_args(args) as client: + sessions = get_all_sessions(client, args.session_list_limit) + except APIError as err: + print(f"Error: {err}") return 1 - - sessions = get_all_sessions(storage) - + if args.all: results = sorted(sessions, key=lambda s: s["last_used"], reverse=True) else: diff --git a/scripts/opencode-get-message b/scripts/opencode-get-message new file mode 100755 index 00000000..0eebab23 --- /dev/null +++ b/scripts/opencode-get-message @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +""" +Get full OpenCode message payload(s) by message ID from the OpenCode API. + +Usage: + opencode-get-message [message-id ...] + opencode-get-message --session [message-id ...] +""" + +import argparse +import json +import sys + +from opencode_api import APIError, add_api_arguments, create_client_from_args, list_sessions_across_projects + + +def normalize_message_payload(payload: dict) -> dict: + """Normalize API response into expected output shape.""" + info = payload.get("info", {}) + parts = payload.get("parts", []) + return {"info": info, "parts": parts} + + +def not_found_result(message_id: str) -> dict: + return {"id": message_id, "error": "message_not_found"} + + +def get_message_for_session(client, session: dict, message_id: str) -> dict: + """Get a message payload within a known session.""" + session_id = session.get("id", "") + directory = session.get("directory") + try: + payload = client.get_session_message(session_id, message_id, directory=directory) + return normalize_message_payload(payload) + except APIError as err: + if err.status_code == 404: + return not_found_result(message_id) + raise + + +def find_messages_without_session(client, message_ids: list[str], scan_sessions: int, session_list_limit: int) -> list[dict]: + """Search recent sessions for requested message IDs.""" + wanted = set(message_ids) + found: dict[str, dict] = {} + + sessions = list_sessions_across_projects(client, per_project_limit=session_list_limit) + if scan_sessions > 0: + sessions = sessions[:scan_sessions] + + for session in sessions: + if not wanted: + break + messages = client.get_session_messages(session.get("id", ""), directory=session.get("directory")) + for message in messages: + info = message.get("info", {}) + mid = info.get("id") + if mid in wanted: + found[mid] = normalize_message_payload(message) + wanted.remove(mid) + + return [found.get(message_id, not_found_result(message_id)) for message_id in message_ids] + + +def main() -> int: + parser = argparse.ArgumentParser(description="Get full OpenCode message payload by message ID") + parser.add_argument("--session", "-s", type=str, default=None, help="Session ID for direct lookup") + parser.add_argument( + "--scan-sessions", + type=int, + default=200, + help="When --session is omitted, scan this many recent sessions for message IDs", + ) + parser.add_argument("--db", default=None, help=argparse.SUPPRESS) + parser.add_argument("message_ids", nargs="+", help="One or more message IDs") + add_api_arguments(parser) + args = parser.parse_args() + + if args.db: + print("Warning: --db is deprecated and ignored; this script now uses the OpenCode API", file=sys.stderr) + + try: + with create_client_from_args(args) as client: + if args.session: + session = client.get_session(args.session) + results = [get_message_for_session(client, session, message_id) for message_id in args.message_ids] + else: + results = find_messages_without_session( + client, + args.message_ids, + scan_sessions=args.scan_sessions, + session_list_limit=args.session_list_limit, + ) + except APIError as err: + print(f"Error: {err}") + return 1 + + output = results[0] if len(results) == 1 else results + print(json.dumps(output, indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/opencode-session-timeline b/scripts/opencode-session-timeline index 45c44d8b..94c4f247 100755 --- a/scripts/opencode-session-timeline +++ b/scripts/opencode-session-timeline @@ -8,10 +8,11 @@ Usage: opencode-session-timeline [--session ID] [--json] [--no-color] import json import argparse -from pathlib import Path from typing import Optional from datetime import datetime +from opencode_api import APIError, add_api_arguments, create_client_from_args, list_sessions_across_projects + # DCP tool names across versions (compress is canonical; others are legacy aliases) DCP_TOOLS = { "compress", "prune", "distill", @@ -54,46 +55,22 @@ def format_duration(ms: Optional[int], colors: Colors = None) -> str: return f"{hours}h{minutes}m" -def get_session_messages(storage: Path, session_id: str) -> list[dict]: +def get_session_messages(client, session: dict) -> list[dict]: """Get all messages for a session, sorted by creation order.""" - message_dir = storage / "message" / session_id - if not message_dir.exists(): - return [] - - messages = [] - for msg_file in message_dir.glob("*.json"): - try: - msg = json.loads(msg_file.read_text()) - msg["_file"] = str(msg_file) - msg["_id"] = msg_file.stem - # Extract timing info - time_info = msg.get("time", {}) - msg["_created"] = time_info.get("created") - msg["_completed"] = time_info.get("completed") - messages.append(msg) - except (json.JSONDecodeError, IOError): - pass - - return sorted(messages, key=lambda m: m.get("_id", "")) - - -def get_message_parts(storage: Path, message_id: str) -> list[dict]: - """Get all parts for a message, sorted by creation order.""" - parts_dir = storage / "part" / message_id - if not parts_dir.exists(): - return [] - - parts = [] - for part_file in parts_dir.glob("*.json"): - try: - part = json.loads(part_file.read_text()) - part["_file"] = str(part_file) - part["_id"] = part_file.stem - parts.append(part) - except (json.JSONDecodeError, IOError): - pass - - return sorted(parts, key=lambda p: p.get("_id", "")) + messages = client.get_session_messages(session["id"], directory=session.get("directory")) + normalized = [] + for message in messages: + info = message.get("info", {}) + time_info = info.get("time", {}) + normalized.append( + { + "_id": info.get("id", ""), + "_created": time_info.get("created"), + "_completed": time_info.get("completed"), + "_parts": message.get("parts", []), + } + ) + return normalized def extract_step_data(parts: list[dict]) -> Optional[dict]: @@ -131,42 +108,22 @@ def extract_step_data(parts: list[dict]) -> Optional[dict]: } -def get_most_recent_session(storage: Path) -> Optional[str]: - """Get the most recent session ID.""" - message_dir = storage / "message" - if not message_dir.exists(): - return None - - sessions = sorted(message_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True) - return sessions[0].name if sessions else None +def get_most_recent_session(client, session_list_limit: int) -> Optional[dict]: + """Get the most recent session across all projects.""" + sessions = list_sessions_across_projects(client, per_project_limit=session_list_limit) + return sessions[0] if sessions else None -def get_session_title(storage: Path, session_id: str) -> str: - """Get session title from metadata.""" - session_dir = storage / "session" - if not session_dir.exists(): - return "Unknown" - - for s_dir in session_dir.iterdir(): - s_file = s_dir / f"{session_id}.json" - if s_file.exists(): - try: - sess = json.loads(s_file.read_text()) - return sess.get("title", "Untitled") - except (json.JSONDecodeError, IOError): - pass - return "Unknown" - - -def analyze_session(storage: Path, session_id: str) -> dict: +def analyze_session(client, session: dict) -> dict: """Analyze a single session step by step.""" - messages = get_session_messages(storage, session_id) - title = get_session_title(storage, session_id) + session_id = session["id"] + messages = get_session_messages(client, session) + title = session.get("title", "Unknown") steps = [] for msg in messages: msg_id = msg.get("_id", "") - parts = get_message_parts(storage, msg_id) + parts = msg.get("_parts", []) step_data = extract_step_data(parts) if step_data: @@ -381,23 +338,24 @@ def main(): "--no-color", action="store_true", help="Disable colored output" ) + add_api_arguments(parser) args = parser.parse_args() - - storage = Path.home() / ".local/share/opencode/storage" - - if not storage.exists(): - print("Error: OpenCode storage not found at", storage) + + try: + with create_client_from_args(args) as client: + if args.session is None: + session = get_most_recent_session(client, args.session_list_limit) + if session is None: + print("Error: No sessions found") + return 1 + else: + session = client.get_session(args.session) + + result = analyze_session(client, session) + except APIError as err: + print(f"Error: {err}") return 1 - session_id = args.session - if session_id is None: - session_id = get_most_recent_session(storage) - if session_id is None: - print("Error: No sessions found") - return 1 - - result = analyze_session(storage, session_id) - if args.json: # Remove non-serializable fields print(json.dumps(result, indent=2, default=str)) diff --git a/scripts/opencode-token-stats b/scripts/opencode-token-stats index 3a7d6dba..aaf2c443 100755 --- a/scripts/opencode-token-stats +++ b/scripts/opencode-token-stats @@ -6,30 +6,18 @@ Usage: opencode-token-stats [--sessions N] [--json] import json import argparse -from pathlib import Path from datetime import datetime -def analyze_sessions(num_sessions=10, output_json=False, session_id=None): - storage = Path.home() / ".local/share/opencode/storage" - message_dir = storage / "message" - part_dir = storage / "part" - session_dir = storage / "session" - - if not message_dir.exists(): - print("Error: OpenCode storage not found at", storage) - return +from opencode_api import APIError, add_api_arguments, create_client_from_args, list_sessions_across_projects +def analyze_sessions(client, num_sessions=10, output_json=False, session_id=None, session_list_limit=5000): # Get sessions to analyze if session_id: # Analyze specific session - session_path = message_dir / session_id - if not session_path.exists(): - print(f"Error: Session {session_id} not found") - return - sessions = [session_path] + sessions = [client.get_session(session_id)] else: - # Get recent sessions sorted by modification time - sessions = sorted(message_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)[:num_sessions] + # Get recent sessions sorted by API updated time across projects + sessions = list_sessions_across_projects(client, per_project_limit=session_list_limit)[:num_sessions] results = [] grand_totals = { @@ -39,8 +27,9 @@ def analyze_sessions(num_sessions=10, output_json=False, session_id=None): "reasons": {"tool-calls": 0, "stop": 0, "other": 0} } - for session_path in sessions: - session_id = session_path.name + for session in sessions: + session_id = session.get("id", "") + directory = session.get("directory") totals = { "input": 0, "output": 0, "reasoning": 0, "cache_read": 0, "cache_write": 0, @@ -49,47 +38,31 @@ def analyze_sessions(num_sessions=10, output_json=False, session_id=None): } # Get messages for this session - msg_files = list(session_path.glob("*.json")) - - for msg_file in msg_files: - msg_id = msg_file.stem - parts_path = part_dir / msg_id - if parts_path.exists(): - for part_file in parts_path.glob("*.json"): - try: - part = json.loads(part_file.read_text()) - if part.get("type") == "step-finish" and "tokens" in part: - t = part["tokens"] - totals["input"] += t.get("input", 0) - totals["output"] += t.get("output", 0) - totals["reasoning"] += t.get("reasoning", 0) - cache = t.get("cache", {}) - totals["cache_read"] += cache.get("read", 0) - totals["cache_write"] += cache.get("write", 0) - totals["cost"] += part.get("cost", 0) - totals["steps"] += 1 - - reason = part.get("reason", "other") - if reason in totals["reasons"]: - totals["reasons"][reason] += 1 - else: - totals["reasons"]["other"] += 1 - except (json.JSONDecodeError, KeyError): - pass + messages = client.get_session_messages(session_id, directory=directory) + + for message in messages: + for part in message.get("parts", []): + if part.get("type") != "step-finish" or "tokens" not in part: + continue + t = part["tokens"] + totals["input"] += t.get("input", 0) + totals["output"] += t.get("output", 0) + totals["reasoning"] += t.get("reasoning", 0) + cache = t.get("cache", {}) + totals["cache_read"] += cache.get("read", 0) + totals["cache_write"] += cache.get("write", 0) + totals["cost"] += part.get("cost", 0) + totals["steps"] += 1 + + reason = part.get("reason", "other") + if reason in totals["reasons"]: + totals["reasons"][reason] += 1 + else: + totals["reasons"]["other"] += 1 # Get session metadata (title, timestamps) - title = "Unknown" - created = None - for s_dir in session_dir.iterdir(): - s_file = s_dir / f"{session_id}.json" - if s_file.exists(): - try: - sess = json.loads(s_file.read_text()) - title = sess.get("title", "Untitled")[:60] - created = sess.get("createdAt") - except (json.JSONDecodeError, KeyError): - pass - break + title = session.get("title", "Untitled")[:60] + created = session.get("time", {}).get("created") # Calculate derived metrics total_tokens = totals["input"] + totals["output"] + totals["cache_read"] @@ -187,9 +160,23 @@ def main(): parser.add_argument("--sessions", "-n", type=int, default=10, help="Number of recent sessions to analyze (default: 10)") parser.add_argument("--session", "-s", type=str, default=None, help="Analyze specific session ID") parser.add_argument("--json", "-j", action="store_true", help="Output as JSON instead of formatted text") + add_api_arguments(parser) args = parser.parse_args() - analyze_sessions(num_sessions=args.sessions, output_json=args.json, session_id=args.session) + try: + with create_client_from_args(args) as client: + analyze_sessions( + client, + num_sessions=args.sessions, + output_json=args.json, + session_id=args.session, + session_list_limit=args.session_list_limit, + ) + except APIError as err: + print(f"Error: {err}") + return 1 + + return 0 if __name__ == "__main__": - main() + raise SystemExit(main()) diff --git a/scripts/opencode_api.py b/scripts/opencode_api.py new file mode 100644 index 00000000..1923d93b --- /dev/null +++ b/scripts/opencode_api.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +"""Shared helpers for querying the OpenCode HTTP API from scripts.""" + +from __future__ import annotations + +import base64 +import json +import os +import re +import selectors +import subprocess +import time +from dataclasses import dataclass +from typing import Any, TextIO, cast +from urllib.error import HTTPError, URLError +from urllib.parse import urlencode +from urllib.request import Request, urlopen + + +DEFAULT_HOSTNAME = "127.0.0.1" +DEFAULT_PORT = 0 +DEFAULT_SERVER_TIMEOUT = 8.0 +DEFAULT_REQUEST_TIMEOUT = 30.0 +DEFAULT_SESSION_LIST_LIMIT = 5000 + + +class APIError(RuntimeError): + """OpenCode API request error.""" + + def __init__(self, message: str, *, status_code: int | None = None): + super().__init__(message) + self.status_code = status_code + + +@dataclass +class ManagedServer: + process: subprocess.Popen[str] + url: str + + +def _auth_header(username: str, password: str) -> str: + token = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode("ascii") + return f"Basic {token}" + + +def _parse_server_url(line: str) -> str | None: + if not line.startswith("opencode server listening"): + return None + match = re.search(r"on\s+(https?://\S+)", line) + if not match: + return None + return match.group(1) + + +def _start_server(hostname: str, port: int, timeout_seconds: float) -> ManagedServer: + process = subprocess.Popen( + ["opencode", "serve", f"--hostname={hostname}", f"--port={port}"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + env=os.environ.copy(), + ) + if process.stdout is None: + process.kill() + raise APIError("Failed to read opencode server output") + + selector = selectors.DefaultSelector() + selector.register(process.stdout, selectors.EVENT_READ) + + deadline = time.monotonic() + timeout_seconds + output: list[str] = [] + url: str | None = None + + while time.monotonic() < deadline: + if process.poll() is not None: + break + for key, _ in selector.select(timeout=0.2): + stream = cast(TextIO, key.fileobj) + line = stream.readline() + if not line: + continue + line = line.rstrip("\n") + output.append(line) + parsed = _parse_server_url(line) + if parsed: + url = parsed + break + if url: + break + + selector.close() + + if url: + return ManagedServer(process=process, url=url) + + if process.poll() is None: + process.kill() + process.wait(timeout=2) + details = "\n".join(output[-20:]).strip() + if details: + raise APIError(f"Timed out waiting for opencode server startup. Last output:\n{details}") + raise APIError("Timed out waiting for opencode server startup") + + +class OpencodeAPI: + def __init__( + self, + *, + url: str | None, + username: str, + password: str | None, + request_timeout: float, + server_hostname: str, + server_port: int, + server_timeout: float, + ): + self._managed_server: ManagedServer | None = None + if url: + self.base_url = url.rstrip("/") + else: + self._managed_server = _start_server(server_hostname, server_port, server_timeout) + self.base_url = self._managed_server.url.rstrip("/") + + self.request_timeout = request_timeout + self.headers = {"Accept": "application/json"} + if password: + self.headers["Authorization"] = _auth_header(username, password) + + def close(self): + if self._managed_server is None: + return + process = self._managed_server.process + self._managed_server = None + if process.poll() is None: + process.terminate() + try: + process.wait(timeout=2) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=2) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + def get_json(self, path: str, query: dict[str, Any] | None = None) -> Any: + params = {k: v for k, v in (query or {}).items() if v is not None} + url = f"{self.base_url}{path}" + if params: + url = f"{url}?{urlencode(params)}" + + request = Request(url, headers=self.headers, method="GET") + try: + with urlopen(request, timeout=self.request_timeout) as response: + body = response.read().decode("utf-8") + if not body: + return None + return json.loads(body) + except HTTPError as err: + body = err.read().decode("utf-8", errors="replace") + message = f"GET {path} failed with HTTP {err.code}" + if body: + message = f"{message}: {body}" + raise APIError(message, status_code=err.code) from err + except URLError as err: + raise APIError(f"GET {path} failed: {err}") from err + + def health(self) -> dict[str, Any]: + return self.get_json("/global/health") + + def list_projects(self) -> list[dict[str, Any]]: + return self.get_json("/project") + + def list_sessions( + self, + *, + directory: str | None = None, + roots: bool | None = None, + start: int | None = None, + search: str | None = None, + limit: int | None = None, + ) -> list[dict[str, Any]]: + return self.get_json( + "/session", + { + "directory": directory, + "roots": str(roots).lower() if roots is not None else None, + "start": start, + "search": search, + "limit": limit, + }, + ) + + def get_session(self, session_id: str, *, directory: str | None = None) -> dict[str, Any]: + return self.get_json(f"/session/{session_id}", {"directory": directory}) + + def get_session_messages( + self, + session_id: str, + *, + directory: str | None = None, + limit: int | None = None, + ) -> list[dict[str, Any]]: + return self.get_json( + f"/session/{session_id}/message", + { + "directory": directory, + "limit": limit, + }, + ) + + def get_session_message( + self, + session_id: str, + message_id: str, + *, + directory: str | None = None, + ) -> dict[str, Any]: + return self.get_json( + f"/session/{session_id}/message/{message_id}", + {"directory": directory}, + ) + + +def add_api_arguments(parser): + parser.add_argument("--url", type=str, default=None, help="OpenCode server URL (default: start local server)") + parser.add_argument("--username", type=str, default=os.environ.get("OPENCODE_SERVER_USERNAME", "opencode")) + parser.add_argument("--password", type=str, default=os.environ.get("OPENCODE_SERVER_PASSWORD")) + parser.add_argument("--hostname", type=str, default=DEFAULT_HOSTNAME, help="Hostname for spawned local server") + parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Port for spawned local server (0 = auto)") + parser.add_argument( + "--server-timeout", + type=float, + default=DEFAULT_SERVER_TIMEOUT, + help="Seconds to wait for spawned server startup", + ) + parser.add_argument( + "--request-timeout", + type=float, + default=DEFAULT_REQUEST_TIMEOUT, + help="HTTP request timeout in seconds", + ) + parser.add_argument( + "--session-list-limit", + type=int, + default=DEFAULT_SESSION_LIST_LIMIT, + help="Max sessions fetched per project from /session", + ) + + +def create_client_from_args(args) -> OpencodeAPI: + client = OpencodeAPI( + url=getattr(args, "url", None), + username=getattr(args, "username", "opencode"), + password=getattr(args, "password", None), + request_timeout=getattr(args, "request_timeout", DEFAULT_REQUEST_TIMEOUT), + server_hostname=getattr(args, "hostname", DEFAULT_HOSTNAME), + server_port=getattr(args, "port", DEFAULT_PORT), + server_timeout=getattr(args, "server_timeout", DEFAULT_SERVER_TIMEOUT), + ) + client.health() + return client + + +def list_sessions_across_projects( + client: OpencodeAPI, + *, + search: str | None = None, + roots: bool | None = None, + per_project_limit: int = DEFAULT_SESSION_LIST_LIMIT, +) -> list[dict[str, Any]]: + sessions_by_id: dict[str, dict[str, Any]] = {} + projects = client.list_projects() + + for project in projects: + directory = project.get("worktree") + if not directory: + continue + try: + sessions = client.list_sessions( + directory=directory, + roots=roots, + search=search, + limit=per_project_limit, + ) + except APIError: + continue + for session in sessions: + session_id = session.get("id") + if not session_id: + continue + sessions_by_id[session_id] = session + + results = list(sessions_by_id.values()) + results.sort(key=lambda item: item.get("time", {}).get("updated", 0), reverse=True) + return results