diff --git a/packages/opencode/src/file/ripgrep-installer.ts b/packages/opencode/src/file/ripgrep-installer.ts new file mode 100644 index 00000000000..8211039d246 --- /dev/null +++ b/packages/opencode/src/file/ripgrep-installer.ts @@ -0,0 +1,153 @@ +import path from "path" +import fs from "fs/promises" +import { z } from "zod" +import { NamedError } from "../util/error" +import { Global } from "../global" +import { ZipReader, BlobReader, BlobWriter } from "@zip.js/zip.js" + +export namespace RipgrepInstaller { + const RIPGREP_VERSION = "14.1.1" + + const PLATFORM_CONFIG = { + "arm64-darwin": { platform: "aarch64-apple-darwin", extension: "tar.gz" }, + "arm64-linux": { + platform: "aarch64-unknown-linux-gnu", + extension: "tar.gz", + }, + "x64-darwin": { platform: "x86_64-apple-darwin", extension: "tar.gz" }, + "x64-linux": { platform: "x86_64-unknown-linux-musl", extension: "tar.gz" }, + "x64-win32": { platform: "x86_64-pc-windows-msvc", extension: "zip" }, + } as const + + export const ExtractionFailedError = NamedError.create( + "RipgrepExtractionFailedError", + z.object({ + filepath: z.string(), + stderr: z.string(), + }), + ) + + export const UnsupportedPlatformError = NamedError.create( + "RipgrepUnsupportedPlatformError", + z.object({ + platform: z.string(), + }), + ) + + export const DownloadFailedError = NamedError.create( + "RipgrepDownloadFailedError", + z.object({ + url: z.string(), + status: z.number(), + }), + ) + + export async function getExecutablePath(): Promise { + // Check if ripgrep is already in PATH + const systemPath = Bun.which("rg") + if (systemPath) return systemPath + + // Check local installation + const localPath = path.join(Global.Path.bin, "rg" + (process.platform === "win32" ? ".exe" : "")) + const file = Bun.file(localPath) + + if (await file.exists()) { + return localPath + } + + // Install ripgrep + await install(localPath) + return localPath + } + + async function install(targetPath: string): Promise { + const platformKey = `${process.arch}-${process.platform}` as keyof typeof PLATFORM_CONFIG + const config = PLATFORM_CONFIG[platformKey] + + if (!config) { + throw new UnsupportedPlatformError({ platform: platformKey }) + } + + const archivePath = await download(config) + await extract(config, archivePath, targetPath) + await fs.unlink(archivePath) + + if (!platformKey.endsWith("-win32")) { + await fs.chmod(targetPath, 0o755) + } + } + + async function download(config: typeof PLATFORM_CONFIG[keyof typeof PLATFORM_CONFIG]): Promise { + const filename = `ripgrep-${RIPGREP_VERSION}-${config.platform}.${config.extension}` + const url = `https://github.com/BurntSushi/ripgrep/releases/download/${RIPGREP_VERSION}/${filename}` + + const response = await fetch(url) + if (!response.ok) { + throw new DownloadFailedError({ url, status: response.status }) + } + + const buffer = await response.arrayBuffer() + const archivePath = path.join(Global.Path.bin, filename) + await Bun.write(archivePath, buffer) + + return archivePath + } + + async function extract( + config: typeof PLATFORM_CONFIG[keyof typeof PLATFORM_CONFIG], + archivePath: string, + targetPath: string + ): Promise { + const platformKey = `${process.arch}-${process.platform}` + + if (config.extension === "tar.gz") { + const args = ["tar", "-xzf", archivePath, "--strip-components=1"] + + if (platformKey.endsWith("-darwin")) args.push("--include=*/rg") + if (platformKey.endsWith("-linux")) args.push("--wildcards", "*/rg") + + const proc = Bun.spawn(args, { + cwd: Global.Path.bin, + stderr: "pipe", + stdout: "pipe", + }) + + await proc.exited + if (proc.exitCode !== 0) { + throw new ExtractionFailedError({ + filepath: targetPath, + stderr: await Bun.readableStreamToText(proc.stderr), + }) + } + } else if (config.extension === "zip") { + const zipFileReader = new ZipReader(new BlobReader(new Blob([await Bun.file(archivePath).arrayBuffer()]))) + const entries = await zipFileReader.getEntries() + + let rgEntry: any + for (const entry of entries) { + if (entry.filename.endsWith("rg.exe")) { + rgEntry = entry + break + } + } + + if (!rgEntry) { + throw new ExtractionFailedError({ + filepath: archivePath, + stderr: "rg.exe not found in zip archive", + }) + } + + const rgBlob = await rgEntry.getData(new BlobWriter()) + if (!rgBlob) { + throw new ExtractionFailedError({ + filepath: archivePath, + stderr: "Failed to extract rg.exe from zip archive", + }) + } + + await Bun.write(targetPath, await rgBlob.arrayBuffer()) + await zipFileReader.close() + } + } +} \ No newline at end of file diff --git a/packages/opencode/src/file/ripgrep.ts b/packages/opencode/src/file/ripgrep.ts index f21cbdef979..6af76115ee1 100644 --- a/packages/opencode/src/file/ripgrep.ts +++ b/packages/opencode/src/file/ripgrep.ts @@ -1,13 +1,10 @@ // Ripgrep utility functions -import path from "path" -import { Global } from "../global" -import fs from "fs/promises" import { z } from "zod" -import { NamedError } from "../util/error" import { lazy } from "../util/lazy" import { $ } from "bun" import { Fzf } from "./fzf" -import { ZipReader, BlobReader, BlobWriter } from "@zip.js/zip.js" +import { RipgrepInstaller } from "./ripgrep-installer" +import { TreeBuilder } from "./tree-builder" export namespace Ripgrep { const Stats = z.object({ @@ -86,116 +83,10 @@ export namespace Ripgrep { export type Begin = z.infer export type End = z.infer export type Summary = z.infer - const PLATFORM = { - "arm64-darwin": { platform: "aarch64-apple-darwin", extension: "tar.gz" }, - "arm64-linux": { - platform: "aarch64-unknown-linux-gnu", - extension: "tar.gz", - }, - "x64-darwin": { platform: "x86_64-apple-darwin", extension: "tar.gz" }, - "x64-linux": { platform: "x86_64-unknown-linux-musl", extension: "tar.gz" }, - "x64-win32": { platform: "x86_64-pc-windows-msvc", extension: "zip" }, - } as const - - export const ExtractionFailedError = NamedError.create( - "RipgrepExtractionFailedError", - z.object({ - filepath: z.string(), - stderr: z.string(), - }), - ) - - export const UnsupportedPlatformError = NamedError.create( - "RipgrepUnsupportedPlatformError", - z.object({ - platform: z.string(), - }), - ) - - export const DownloadFailedError = NamedError.create( - "RipgrepDownloadFailedError", - z.object({ - url: z.string(), - status: z.number(), - }), - ) const state = lazy(async () => { - let filepath = Bun.which("rg") - if (filepath) return { filepath } - filepath = path.join(Global.Path.bin, "rg" + (process.platform === "win32" ? ".exe" : "")) - - const file = Bun.file(filepath) - if (!(await file.exists())) { - const platformKey = `${process.arch}-${process.platform}` as keyof typeof PLATFORM - const config = PLATFORM[platformKey] - if (!config) throw new UnsupportedPlatformError({ platform: platformKey }) - - const version = "14.1.1" - const filename = `ripgrep-${version}-${config.platform}.${config.extension}` - const url = `https://github.com/BurntSushi/ripgrep/releases/download/${version}/${filename}` - - const response = await fetch(url) - if (!response.ok) throw new DownloadFailedError({ url, status: response.status }) - - const buffer = await response.arrayBuffer() - const archivePath = path.join(Global.Path.bin, filename) - await Bun.write(archivePath, buffer) - if (config.extension === "tar.gz") { - const args = ["tar", "-xzf", archivePath, "--strip-components=1"] - - if (platformKey.endsWith("-darwin")) args.push("--include=*/rg") - if (platformKey.endsWith("-linux")) args.push("--wildcards", "*/rg") - - const proc = Bun.spawn(args, { - cwd: Global.Path.bin, - stderr: "pipe", - stdout: "pipe", - }) - await proc.exited - if (proc.exitCode !== 0) - throw new ExtractionFailedError({ - filepath, - stderr: await Bun.readableStreamToText(proc.stderr), - }) - } - if (config.extension === "zip") { - if (config.extension === "zip") { - const zipFileReader = new ZipReader(new BlobReader(new Blob([await Bun.file(archivePath).arrayBuffer()]))) - const entries = await zipFileReader.getEntries() - let rgEntry: any - for (const entry of entries) { - if (entry.filename.endsWith("rg.exe")) { - rgEntry = entry - break - } - } - - if (!rgEntry) { - throw new ExtractionFailedError({ - filepath: archivePath, - stderr: "rg.exe not found in zip archive", - }) - } - - const rgBlob = await rgEntry.getData(new BlobWriter()) - if (!rgBlob) { - throw new ExtractionFailedError({ - filepath: archivePath, - stderr: "Failed to extract rg.exe from zip archive", - }) - } - await Bun.write(filepath, await rgBlob.arrayBuffer()) - await zipFileReader.close() - } - } - await fs.unlink(archivePath) - if (!platformKey.endsWith("-win32")) await fs.chmod(filepath, 0o755) - } - - return { - filepath, - } + const filepath = await RipgrepInstaller.getExecutablePath() + return { filepath } }) export async function filepath() { @@ -221,112 +112,26 @@ export namespace Ripgrep { export async function tree(input: { cwd: string; limit?: number }) { const files = await Ripgrep.files({ cwd: input.cwd }) - interface Node { - path: string[] - children: Node[] - } - - function getPath(node: Node, parts: string[], create: boolean) { - if (parts.length === 0) return node - let current = node - for (const part of parts) { - let existing = current.children.find((x) => x.path.at(-1) === part) - if (!existing) { - if (!create) return - existing = { - path: current.path.concat(part), - children: [], - } - current.children.push(existing) - } - current = existing - } - return current - } - - const root: Node = { - path: [], - children: [], - } - for (const file of files) { - if (file.includes(".opencode")) continue - const parts = file.split(path.sep) - getPath(root, parts, true) - } - - function sort(node: Node) { - node.children.sort((a, b) => { - if (!a.children.length && b.children.length) return 1 - if (!b.children.length && a.children.length) return -1 - return a.path.at(-1)!.localeCompare(b.path.at(-1)!) - }) - for (const child of node.children) { - sort(child) - } - } - sort(root) - - let current = [root] - const result: Node = { - path: [], - children: [], - } - - let processed = 0 - const limit = input.limit ?? 50 - while (current.length > 0) { - const next = [] - for (const node of current) { - if (node.children.length) next.push(...node.children) - } - const max = Math.max(...current.map((x) => x.children.length)) - for (let i = 0; i < max && processed < limit; i++) { - for (const node of current) { - const child = node.children[i] - if (!child) continue - getPath(result, child.path, true) - processed++ - if (processed >= limit) break - } - } - if (processed >= limit) { - for (const node of [...current, ...next]) { - const compare = getPath(result, node.path, false) - if (!compare) continue - if (compare?.children.length !== node.children.length) { - const diff = node.children.length - compare.children.length - compare.children.push({ - path: compare.path.concat(`[${diff} truncated]`), - children: [], - }) - } - } - break - } - current = next - } - - const lines: string[] = [] + return TreeBuilder.build(files, { limit: input.limit }) + } - function render(node: Node, depth: number) { - const indent = "\t".repeat(depth) - lines.push(indent + node.path.at(-1) + (node.children.length ? "/" : "")) - for (const child of node.children) { - render(child, depth + 1) - } + export async function search(input: { cwd: string; pattern: string; glob?: string[]; limit?: number }) { + const args = buildSearchArgs(await filepath(), input) + const command = args.join(" ") + + const result = await $`${{ raw: command }}`.cwd(input.cwd).quiet().nothrow() + if (result.exitCode !== 0) { + return [] } - result.children.map((x) => render(x, 0)) - return lines.join("\n") + return parseSearchResults(result.text()) } - export async function search(input: { cwd: string; pattern: string; glob?: string[]; limit?: number }) { - const args = [`${await filepath()}`, "--json", "--hidden", "--glob='!.git/*'"] + function buildSearchArgs(execPath: string, input: { pattern: string; glob?: string[]; limit?: number }): string[] { + const args = [`${execPath}`, "--json", "--hidden", "--glob='!.git/*'"] if (input.glob) { - for (const g of input.glob) { - args.push(`--glob=${g}`) - } + input.glob.forEach(g => args.push(`--glob=${g}`)) } if (input.limit) { @@ -334,20 +139,16 @@ export namespace Ripgrep { } args.push(input.pattern) + return args + } - const command = args.join(" ") - const result = await $`${{ raw: command }}`.cwd(input.cwd).quiet().nothrow() - if (result.exitCode !== 0) { - return [] - } - - const lines = result.text().trim().split("\n").filter(Boolean) - // Parse JSON lines from ripgrep output - + function parseSearchResults(output: string): Match['data'][] { + const lines = output.trim().split("\n").filter(Boolean) + return lines .map((line) => JSON.parse(line)) .map((parsed) => Result.parse(parsed)) - .filter((r) => r.type === "match") + .filter((r): r is Match => r.type === "match") .map((r) => r.data) } } diff --git a/packages/opencode/src/file/tree-builder.ts b/packages/opencode/src/file/tree-builder.ts new file mode 100644 index 00000000000..de3ac16e18c --- /dev/null +++ b/packages/opencode/src/file/tree-builder.ts @@ -0,0 +1,194 @@ +import path from "path" + +export namespace TreeBuilder { + interface TreeNode { + path: string[] + children: Map + } + + interface SerializedNode { + path: string[] + children: SerializedNode[] + } + + export interface TreeOptions { + limit?: number + sortFn?: (a: SerializedNode, b: SerializedNode) => number + } + + const DEFAULT_LIMIT = 50 + const DEFAULT_SORT = (a: SerializedNode, b: SerializedNode): number => { + const aIsFile = a.children.length === 0 + const bIsFile = b.children.length === 0 + if (aIsFile && !bIsFile) return 1 + if (!aIsFile && bIsFile) return -1 + return a.path.at(-1)!.localeCompare(b.path.at(-1)!) + } + + export function build(files: string[], options: TreeOptions = {}): string { + const { limit = DEFAULT_LIMIT, sortFn = DEFAULT_SORT } = options + + const root = buildTree(files) + const sortedRoot = sortTree(root, sortFn) + const truncatedRoot = applyLimit(sortedRoot, limit) + + return render(truncatedRoot) + } + + function buildTree(files: string[]): TreeNode { + const root: TreeNode = { + path: [], + children: new Map(), + } + + for (const file of files) { + if (file.includes(".opencode")) continue + const parts = file.split(path.sep) + let current = root + + for (const part of parts) { + let child = current.children.get(part) + + if (!child) { + child = { + path: [...current.path, part], + children: new Map(), + } + current.children.set(part, child) + } + current = child + } + } + + return root + } + + function sortTree(node: TreeNode, sortFn: (a: SerializedNode, b: SerializedNode) => number): SerializedNode { + const sortedChildren = Array.from(node.children.values()) + .map(child => sortTree(child, sortFn)) + .sort(sortFn) + + return { + path: node.path, + children: sortedChildren, + } + } + + function applyLimit(root: SerializedNode, limit: number): SerializedNode { + const result: SerializedNode = { + path: [], + children: [], + } + + const nodeMap = new Map() + const queue = [root] + let processed = 0 + + while (queue.length > 0 && processed < limit) { + const batch = queue.splice(0, queue.length) + + for (const node of batch) { + if (node.children.length) { + queue.push(...node.children) + } + } + + for (const node of batch) { + if (processed >= limit) break + + for (const child of node.children) { + if (processed >= limit) break + + const parent = getOrCreateNode(result, nodeMap, child.path.slice(0, -1)) + parent.children.push({ + path: child.path, + children: [], + }) + processed++ + } + } + } + + // Add truncation indicators + addTruncationIndicators(result, root, nodeMap) + + return result + } + + function getOrCreateNode( + root: SerializedNode, + nodeMap: Map, + path: string[] + ): SerializedNode { + const key = path.join('/') + const cached = nodeMap.get(key) + if (cached) return cached + + let current = root + for (const part of path) { + let child = current.children.find(c => c.path.at(-1) === part) + if (!child) { + child = { + path: [...current.path, part], + children: [], + } + current.children.push(child) + } + current = child + } + + nodeMap.set(key, current) + return current + } + + function addTruncationIndicators( + truncated: SerializedNode, + original: SerializedNode, + nodeMap: Map + ): void { + const queue = [{ truncated, original }] + + while (queue.length > 0) { + const { truncated: t, original: o } = queue.shift()! + + if (t.children.length < o.children.length) { + const diff = o.children.length - t.children.length + t.children.push({ + path: [...t.path, `[${diff} truncated]`], + children: [], + }) + } + + for (const tChild of t.children) { + const oChild = o.children.find(c => c.path.join('/') === tChild.path.join('/')) + if (oChild && tChild.children.length > 0) { + queue.push({ truncated: tChild, original: oChild }) + } + } + } + } + + function render(node: SerializedNode): string { + const lines: string[] = [] + + function renderNode(node: SerializedNode, depth: number): void { + const indent = "\t".repeat(depth) + const name = node.path.at(-1) || '' + const suffix = node.children.length ? "/" : "" + + if (name) { + lines.push(indent + name + suffix) + } + + for (const child of node.children) { + renderNode(child, depth + 1) + } + } + + for (const child of node.children) { + renderNode(child, 0) + } + + return lines.join("\n") + } +} \ No newline at end of file diff --git a/packages/opencode/src/lsp/client.ts b/packages/opencode/src/lsp/client.ts index c63e02592e2..c3580942ede 100644 --- a/packages/opencode/src/lsp/client.ts +++ b/packages/opencode/src/lsp/client.ts @@ -1,21 +1,20 @@ import path from "path" import { createMessageConnection, StreamMessageReader, StreamMessageWriter } from "vscode-jsonrpc/node" -import type { Diagnostic as VSCodeDiagnostic } from "vscode-languageserver-types" import { App } from "../app/app" import { Log } from "../util/log" import { LANGUAGE_EXTENSIONS } from "./language" -import { Bus } from "../bus" import z from "zod" import type { LSPServer } from "./server" import { NamedError } from "../util/error" -import { withTimeout } from "../util/timeout" +import { TimeoutManager } from "./timeout-manager" +import { DiagnosticsManager } from "./diagnostics-manager" export namespace LSPClient { const log = Log.create({ service: "lsp.client" }) export type Info = NonNullable>> - export type Diagnostic = VSCodeDiagnostic + export type Diagnostic = DiagnosticsManager.Diagnostic export const InitializeError = NamedError.create( "LSPInitializeError", @@ -24,16 +23,6 @@ export namespace LSPClient { }), ) - export const Event = { - Diagnostics: Bus.event( - "lsp.client.diagnostics", - z.object({ - serverID: z.string(), - path: z.string(), - }), - ), - } - export async function create(input: { serverID: string; server: LSPServer.Handle; root: string }) { const app = App.info() const l = log.clone().tag("serverID", input.serverID) @@ -44,16 +33,13 @@ export namespace LSPClient { new StreamMessageWriter(input.server.process.stdin), ) - const diagnostics = new Map() + const timeoutManager = new TimeoutManager.AdaptiveTimeout(TimeoutManager.DEFAULT_CONFIGS) + const diagnosticsManager = new DiagnosticsManager.Manager(input.serverID, { + suppressInitialEvents: true + }) + connection.onNotification("textDocument/publishDiagnostics", (params) => { - const path = new URL(params.uri).pathname - l.info("textDocument/publishDiagnostics", { - path, - }) - const exists = diagnostics.has(path) - diagnostics.set(path, params.diagnostics) - if (!exists && input.serverID === "typescript") return - Bus.publish(Event.Diagnostics, { path, serverID: input.serverID }) + diagnosticsManager.onDiagnosticsUpdate(params) }) connection.onRequest("window/workDoneProgress/create", (params) => { l.info("window/workDoneProgress/create", params) @@ -65,7 +51,9 @@ export namespace LSPClient { connection.listen() l.info("sending initialize") - await withTimeout( + + await timeoutManager.withTimeout( + 'initialize', connection.sendRequest("initialize", { rootUri: "file://" + input.root, processId: input.server.process.pid, @@ -96,7 +84,6 @@ export namespace LSPClient { }, }, }), - 5_000, ).catch((err) => { l.error("initialize error", { error: err }) throw new InitializeError( @@ -128,7 +115,7 @@ export namespace LSPClient { const text = await file.text() const version = files[input.path] if (version !== undefined) { - diagnostics.delete(input.path) + diagnosticsManager.delete(input.path) await connection.sendNotification("textDocument/didClose", { textDocument: { uri: `file://` + input.path, @@ -152,28 +139,14 @@ export namespace LSPClient { }, }, get diagnostics() { - return diagnostics + return diagnosticsManager }, async waitForDiagnostics(input: { path: string }) { input.path = path.isAbsolute(input.path) ? input.path : path.resolve(app.path.cwd, input.path) log.info("waiting for diagnostics", input) - let unsub: () => void - return await withTimeout( - new Promise((resolve) => { - unsub = Bus.subscribe(Event.Diagnostics, (event) => { - if (event.properties.path === input.path && event.properties.serverID === result.serverID) { - log.info("got diagnostics", input) - unsub?.() - resolve() - } - }) - }), - 3000, - ) - .catch(() => {}) - .finally(() => { - unsub?.() - }) + + const timeout = timeoutManager.getTimeout('diagnostics') + return await diagnosticsManager.waitForDiagnostics(input.path, timeout) }, async shutdown() { l.info("shutting down") diff --git a/packages/opencode/src/lsp/diagnostics-manager.ts b/packages/opencode/src/lsp/diagnostics-manager.ts new file mode 100644 index 00000000000..6508a2fbda9 --- /dev/null +++ b/packages/opencode/src/lsp/diagnostics-manager.ts @@ -0,0 +1,145 @@ +import { z } from "zod" +import { Bus } from "../bus" +import type { Diagnostic as VSCodeDiagnostic } from "vscode-languageserver-types" +import { Log } from "../util/log" + +export namespace DiagnosticsManager { + const log = Log.create({ service: "lsp.diagnostics" }) + + export type Diagnostic = VSCodeDiagnostic + + export const Event = { + Updated: Bus.event( + "lsp.diagnostics.updated", + z.object({ + serverID: z.string(), + path: z.string(), + }), + ), + } + + export interface DiagnosticsStore { + get(path: string): Diagnostic[] + set(path: string, diagnostics: Diagnostic[]): void + has(path: string): boolean + delete(path: string): void + clear(): void + entries(): IterableIterator<[string, Diagnostic[]]> + } + + export class Manager { + private store: Map = new Map() + private subscribers: Set<(path: string) => void> = new Set() + + constructor( + private serverID: string, + private options: { + suppressInitialEvents?: boolean + } = {} + ) {} + + onDiagnosticsUpdate(params: { uri: string; diagnostics: Diagnostic[] }): void { + const path = new URL(params.uri).pathname + + log.info("diagnostics update", { + serverID: this.serverID, + path, + count: params.diagnostics.length, + }) + + const isNew = !this.store.has(path) + this.store.set(path, params.diagnostics) + + // Suppress initial events for TypeScript server + if (isNew && this.options.suppressInitialEvents && this.serverID === "typescript") { + return + } + + // Notify subscribers + this.subscribers.forEach(callback => callback(path)) + + // Publish bus event + Bus.publish(Event.Updated, { + path, + serverID: this.serverID + }) + } + + subscribe(callback: (path: string) => void): () => void { + this.subscribers.add(callback) + return () => this.subscribers.delete(callback) + } + + get(path: string): Diagnostic[] { + return this.store.get(path) || [] + } + + has(path: string): boolean { + return this.store.has(path) + } + + delete(path: string): void { + this.store.delete(path) + } + + clear(): void { + this.store.clear() + } + + entries(): IterableIterator<[string, Diagnostic[]]> { + return this.store.entries() + } + + async waitForDiagnostics( + path: string, + timeout: number + ): Promise { + return new Promise((resolve, reject) => { + let unsubscribeBus: (() => void) | undefined + let unsubscribeLocal: (() => void) | undefined + let timer: NodeJS.Timeout | undefined + + const cleanup = () => { + unsubscribeBus?.() + unsubscribeLocal?.() + if (timer) clearTimeout(timer) + } + + // Set up timeout + timer = setTimeout(() => { + cleanup() + log.warn("diagnostics timeout", { + path, + serverID: this.serverID, + timeout + }) + resolve() // Resolve instead of reject for graceful degradation + }, timeout) + + // Subscribe to both local and bus events for redundancy + unsubscribeLocal = this.subscribe((updatedPath) => { + if (updatedPath === path) { + log.info("got diagnostics (local)", { + path, + serverID: this.serverID + }) + cleanup() + resolve() + } + }) + + unsubscribeBus = Bus.subscribe(Event.Updated, (event) => { + if (event.properties.path === path && + event.properties.serverID === this.serverID) { + log.info("got diagnostics (bus)", { + path, + serverID: this.serverID + }) + cleanup() + resolve() + } + }) + }) + } + } +} \ No newline at end of file diff --git a/packages/opencode/src/lsp/timeout-manager.ts b/packages/opencode/src/lsp/timeout-manager.ts new file mode 100644 index 00000000000..ca941dc9d39 --- /dev/null +++ b/packages/opencode/src/lsp/timeout-manager.ts @@ -0,0 +1,84 @@ +export namespace TimeoutManager { + export interface TimeoutConfig { + min: number + max: number + default: number + } + + export class AdaptiveTimeout { + private responseTimes: Map = new Map() + private readonly maxSamples = 100 + private readonly bufferMultiplier = 1.5 + + constructor(private configs: Record) {} + + getTimeout(operation: string): number { + const config = this.configs[operation] + if (!config) { + throw new Error(`No timeout config for operation: ${operation}`) + } + + const times = this.responseTimes.get(operation) || [] + if (times.length === 0) { + return config.default + } + + // Calculate 95th percentile + const sorted = [...times].sort((a, b) => a - b) + const index = Math.floor(sorted.length * 0.95) + const p95 = sorted[index] + + // Add buffer and clamp to min/max + const timeout = Math.round(p95 * this.bufferMultiplier) + return Math.max(config.min, Math.min(config.max, timeout)) + } + + trackResponseTime(operation: string, startTime: number): void { + const duration = Date.now() - startTime + let times = this.responseTimes.get(operation) || [] + + // Keep last N measurements + times.push(duration) + if (times.length > this.maxSamples) { + times = times.slice(-this.maxSamples) + } + + this.responseTimes.set(operation, times) + } + + async withTimeout( + operation: string, + promise: Promise, + timeoutOverride?: number + ): Promise { + const timeout = timeoutOverride || this.getTimeout(operation) + const startTime = Date.now() + + try { + const result = await Promise.race([ + promise, + new Promise((_, reject) => + setTimeout(() => reject(new Error(`Operation '${operation}' timed out after ${timeout}ms`)), timeout) + ), + ]) + + this.trackResponseTime(operation, startTime) + return result + } catch (error) { + if (error instanceof Error && error.message.includes('timed out')) { + throw error + } + // Track even failed operations for accurate timing + this.trackResponseTime(operation, startTime) + throw error + } + } + } + + export const DEFAULT_CONFIGS = { + initialize: { min: 3000, max: 15000, default: 5000 }, + diagnostics: { min: 1000, max: 10000, default: 3000 }, + completion: { min: 500, max: 5000, default: 1500 }, + hover: { min: 300, max: 3000, default: 1000 }, + } +} \ No newline at end of file diff --git a/packages/opencode/src/storage/migration-manager.ts b/packages/opencode/src/storage/migration-manager.ts new file mode 100644 index 00000000000..d992663a7ec --- /dev/null +++ b/packages/opencode/src/storage/migration-manager.ts @@ -0,0 +1,151 @@ +import path from "path" +import fs from "fs/promises" +import { Log } from "../util/log" +import { MessageV2 } from "../session/message-v2" +import { Identifier } from "../id/id" + +export namespace MigrationManager { + const log = Log.create({ service: "storage.migration" }) + + export type Migration = { + version: number + name: string + migrate: (storageDir: string) => Promise + } + + export class Manager { + private migrations: Migration[] = [] + + register(migration: Migration): void { + this.migrations.push(migration) + this.migrations.sort((a, b) => a.version - b.version) + } + + async getCurrentVersion(storageDir: string): Promise { + try { + const migrationFile = path.join(storageDir, "migration") + const content = await Bun.file(migrationFile).text() + return parseInt(content, 10) + } catch { + return 0 + } + } + + async setVersion(storageDir: string, version: number): Promise { + const migrationFile = path.join(storageDir, "migration") + await Bun.write(migrationFile, version.toString()) + } + + async runMigrations(storageDir: string): Promise { + const currentVersion = await this.getCurrentVersion(storageDir) + + for (const migration of this.migrations) { + if (migration.version > currentVersion) { + log.info("running migration", { + version: migration.version, + name: migration.name + }) + + try { + await migration.migrate(storageDir) + await this.setVersion(storageDir, migration.version) + } catch (error) { + log.error("migration failed", { + version: migration.version, + name: migration.name, + error + }) + throw error + } + } + } + } + } + + // Default migrations + export const defaultMigrations: Migration[] = [ + { + version: 1, + name: "migrate-v1-messages-to-v2", + migrate: async (storageDir: string) => { + try { + const files = new Bun.Glob("session/message/*/*.json").scanSync({ + cwd: storageDir, + absolute: true, + }) + + for (const file of files) { + const content = await Bun.file(file).json() + if (!content.metadata) continue + + log.info("migrating to v2 message", { file }) + + try { + const result = MessageV2.fromV1(content) + await Bun.write( + file, + JSON.stringify( + { + ...result.info, + parts: result.parts, + }, + null, + 2, + ), + ) + } catch (e) { + await fs.rename(file, file.replace("storage", "broken")) + } + } + } catch { + // Ignore errors if directory doesn't exist + } + } + }, + { + version: 2, + name: "split-message-parts", + migrate: async (storageDir: string) => { + const files = new Bun.Glob("session/message/*/*.json").scanSync({ + cwd: storageDir, + absolute: true, + }) + + for (const file of files) { + try { + const { parts, ...info } = await Bun.file(file).json() + if (!parts) continue + + for (const part of parts) { + const id = Identifier.ascending("part") + const partPath = path.join( + storageDir, + "session", + "part", + info.sessionID, + info.id, + id + ".json" + ) + + await fs.mkdir(path.dirname(partPath), { recursive: true }) + await Bun.write( + partPath, + JSON.stringify({ + ...part, + id, + sessionID: info.sessionID, + messageID: info.id, + ...(part.type === "tool" ? { callID: part.id } : {}), + }), + ) + } + + await Bun.write(file, JSON.stringify(info, null, 2)) + } catch (e) { + log.error("failed to migrate message parts", { file, error: e }) + } + } + } + } + ] +} \ No newline at end of file diff --git a/packages/opencode/src/storage/storage.ts b/packages/opencode/src/storage/storage.ts index f4efbfdfe5d..52ed8831375 100644 --- a/packages/opencode/src/storage/storage.ts +++ b/packages/opencode/src/storage/storage.ts @@ -4,8 +4,8 @@ import { Bus } from "../bus" import path from "path" import z from "zod" import fs from "fs/promises" -import { MessageV2 } from "../session/message-v2" -import { Identifier } from "../id/id" +import { StreamHandler } from "./stream-handler" +import { MigrationManager } from "./migration-manager" export namespace Storage { const log = Log.create({ service: "storage" }) @@ -14,96 +14,39 @@ export namespace Storage { Write: Bus.event("storage.write", z.object({ key: z.string(), content: z.any() })), } - type Migration = (dir: string) => Promise + const streamHandler = new StreamHandler.JsonStreamHandler() + const atomicWriter = new StreamHandler.AtomicFileWriter(streamHandler) + const migrationManager = new MigrationManager.Manager() - const MIGRATIONS: Migration[] = [ - async (dir: string) => { + // Register default migrations and new mode migration + MigrationManager.defaultMigrations.forEach(m => migrationManager.register(m)) + + // Add the new mode migration + migrationManager.register(async (dir: string) => { + const files = new Bun.Glob("session/message/*/*.json").scanSync({ + cwd: dir, + absolute: true, + }) + for (const file of files) { try { - const files = new Bun.Glob("session/message/*/*.json").scanSync({ - cwd: dir, - absolute: true, - }) - for (const file of files) { - const content = await Bun.file(file).json() - if (!content.metadata) continue - log.info("migrating to v2 message", { file }) - try { - const result = MessageV2.fromV1(content) - await Bun.write( - file, - JSON.stringify( - { - ...result.info, - parts: result.parts, - }, - null, - 2, - ), - ) - } catch (e) { - await fs.rename(file, file.replace("storage", "broken")) - } + const content = await Bun.file(file).json() + if (content.role === "assistant" && !content.mode) { + log.info("adding mode field to message", { file }) + content.mode = "build" + await Bun.write(file, JSON.stringify(content, null, 2)) } - } catch {} - }, - async (dir: string) => { - const files = new Bun.Glob("session/message/*/*.json").scanSync({ - cwd: dir, - absolute: true, - }) - for (const file of files) { - try { - const { parts, ...info } = await Bun.file(file).json() - if (!parts) continue - for (const part of parts) { - const id = Identifier.ascending("part") - await Bun.write( - [dir, "session", "part", info.sessionID, info.id, id + ".json"].join("/"), - JSON.stringify({ - ...part, - id, - sessionID: info.sessionID, - messageID: info.id, - ...(part.type === "tool" ? { callID: part.id } : {}), - }), - ) - } - await Bun.write(file, JSON.stringify(info, null, 2)) - } catch (e) {} - } - }, - async (dir: string) => { - const files = new Bun.Glob("session/message/*/*.json").scanSync({ - cwd: dir, - absolute: true, - }) - for (const file of files) { - try { - const content = await Bun.file(file).json() - if (content.role === "assistant" && !content.mode) { - log.info("adding mode field to message", { file }) - content.mode = "build" - await Bun.write(file, JSON.stringify(content, null, 2)) - } - } catch (e) {} - } - }, - ] + } catch (e) {} + } + }) const state = App.state("storage", async () => { const app = App.info() const dir = path.normalize(path.join(app.path.data, "storage")) await fs.mkdir(dir, { recursive: true }) - const migration = await Bun.file(path.join(dir, "migration")) - .json() - .then((x) => parseInt(x)) - .catch(() => 0) - for (let index = migration; index < MIGRATIONS.length; index++) { - log.info("running migration", { index }) - const migration = MIGRATIONS[index] - await migration(dir) - await Bun.write(path.join(dir, "migration"), (index + 1).toString()) - } + + // Run migrations + await migrationManager.runMigrations(dir) + return { dir, } @@ -121,18 +64,17 @@ export namespace Storage { await fs.rm(target, { recursive: true, force: true }).catch(() => {}) } - export async function readJSON(key: string) { + export async function readJSON(key: string): Promise { const dir = await state().then((x) => x.dir) - return Bun.file(path.join(dir, key + ".json")).json() as Promise + const filePath = path.join(dir, key + ".json") + return streamHandler.read(filePath) } export async function writeJSON(key: string, content: T) { const dir = await state().then((x) => x.dir) const target = path.join(dir, key + ".json") - const tmp = target + Date.now() + ".tmp" - await Bun.write(tmp, JSON.stringify(content, null, 2)) - await fs.rename(tmp, target).catch(() => {}) - await fs.unlink(tmp).catch(() => {}) + + await atomicWriter.write(target, content) Bus.publish(Event.Write, { key, content }) } diff --git a/packages/opencode/src/storage/stream-handler.ts b/packages/opencode/src/storage/stream-handler.ts new file mode 100644 index 00000000000..e364659644d --- /dev/null +++ b/packages/opencode/src/storage/stream-handler.ts @@ -0,0 +1,117 @@ +export namespace StreamHandler { + export interface StreamConfig { + fileSizeThreshold?: number + chunkSize?: number + } + + const DEFAULT_CONFIG: Required = { + fileSizeThreshold: 10 * 1024 * 1024, // 10MB + chunkSize: 1024 * 1024, // 1MB + } + + export class JsonStreamHandler { + private config: Required + + constructor(config: StreamConfig = {}) { + this.config = { ...DEFAULT_CONFIG, ...config } + } + + async shouldStream(filePath: string): Promise { + try { + const file = Bun.file(filePath) + const stats = await file.stat() + return stats.size > this.config.fileSizeThreshold + } catch { + return false + } + } + + async read(filePath: string): Promise { + const file = Bun.file(filePath) + + if (await this.shouldStream(filePath)) { + return this.readStream(file) + } + + return file.json() as Promise + } + + private async readStream(file: BunFile): Promise { + const stream = file.stream() + const reader = stream.getReader() + const decoder = new TextDecoder() + + let buffer = '' + + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + + buffer += decoder.decode(value, { stream: true }) + } + + // Final decode + buffer += decoder.decode() + return JSON.parse(buffer) + } finally { + reader.releaseLock() + } + } + + async write(filePath: string, content: any): Promise { + const jsonString = JSON.stringify(content, null, 2) + + if (jsonString.length > this.config.fileSizeThreshold) { + await this.writeStream(filePath, jsonString) + } else { + await Bun.write(filePath, jsonString) + } + } + + private async writeStream(filePath: string, jsonString: string): Promise { + const file = Bun.file(filePath) + const writer = file.writer() + + try { + let offset = 0 + + while (offset < jsonString.length) { + const chunk = jsonString.slice(offset, offset + this.config.chunkSize) + await writer.write(chunk) + offset += this.config.chunkSize + } + + await writer.flush() + } finally { + await writer.end() + } + } + } + + export class AtomicFileWriter { + constructor(private streamHandler: JsonStreamHandler) {} + + async write(targetPath: string, content: any): Promise { + const tmpPath = `${targetPath}.${Date.now()}.tmp` + + try { + // Write to temporary file + await this.streamHandler.write(tmpPath, content) + + // Atomic rename + const fs = await import('fs/promises') + await fs.rename(tmpPath, targetPath) + } catch (error) { + // Clean up temporary file on error + try { + const fs = await import('fs/promises') + await fs.unlink(tmpPath) + } catch { + // Ignore cleanup errors + } + throw error + } + } + } +} \ No newline at end of file diff --git a/packages/opencode/src/tool/edit.ts b/packages/opencode/src/tool/edit.ts index 0ca89d2b11c..453c05b7b74 100644 --- a/packages/opencode/src/tool/edit.ts +++ b/packages/opencode/src/tool/edit.ts @@ -13,6 +13,9 @@ import { App } from "../app/app" import { File } from "../file" import { Bus } from "../bus" import { FileTime } from "../file/time" +import { ReplaceStrategy } from "./replace-strategy" + +const replaceStrategy = new ReplaceStrategy() export const EditTool = Tool.define({ id: "edit", @@ -65,7 +68,7 @@ export const EditTool = Tool.define({ await FileTime.assert(ctx.sessionID, filepath) contentOld = await file.text() - contentNew = replace(contentOld, params.oldString, params.newString, params.replaceAll) + contentNew = replaceStrategy.replace(contentOld, params.oldString, params.newString, params.replaceAll) await file.write(contentNew) await Bus.publish(File.Event.Edited, { file: filepath, @@ -103,8 +106,6 @@ export const EditTool = Tool.define({ }, }) -export type Replacer = (content: string, find: string) => Generator - // Similarity thresholds for block anchor fallback matching const SINGLE_CANDIDATE_SIMILARITY_THRESHOLD = 0.0 const MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD = 0.3 @@ -560,33 +561,3 @@ function trimDiff(diff: string): string { return trimmedLines.join("\n") } - -export function replace(content: string, oldString: string, newString: string, replaceAll = false): string { - if (oldString === newString) { - throw new Error("oldString and newString must be different") - } - - for (const replacer of [ - SimpleReplacer, - LineTrimmedReplacer, - BlockAnchorReplacer, - WhitespaceNormalizedReplacer, - IndentationFlexibleReplacer, - EscapeNormalizedReplacer, - // TrimmedBoundaryReplacer, - // ContextAwareReplacer, - // MultiOccurrenceReplacer, - ]) { - for (const search of replacer(content, oldString)) { - const index = content.indexOf(search) - if (index === -1) continue - if (replaceAll) { - return content.replaceAll(search, newString) - } - const lastIndex = content.lastIndexOf(search) - if (index !== lastIndex) continue - return content.substring(0, index) + newString + content.substring(index + search.length) - } - } - throw new Error("oldString not found in content or was found multiple times") -} diff --git a/packages/opencode/src/tool/replace-strategy.ts b/packages/opencode/src/tool/replace-strategy.ts new file mode 100644 index 00000000000..29c4dbf16cd --- /dev/null +++ b/packages/opencode/src/tool/replace-strategy.ts @@ -0,0 +1,137 @@ +import { ReplacerRegistry, PatternAnalysis } from "./replacer-registry" +import * as Replacers from "./replacers" + +export class ReplaceStrategy { + private registry: ReplacerRegistry + + constructor() { + this.registry = new ReplacerRegistry() + this.registerDefaultReplacers() + } + + private registerDefaultReplacers(): void { + // Always try simple replacement first (highest priority) + this.registry.register({ + name: 'simple', + replacer: Replacers.SimpleReplacer, + priority: 0, + }) + + // Line-based replacers for multiline patterns + this.registry.register({ + name: 'lineTrimmed', + replacer: Replacers.LineTrimmedReplacer, + priority: 10, + condition: (find) => PatternAnalysis.isMultiline(find) || PatternAnalysis.hasWhitespaceVariation(find) + }) + + // Block anchor for larger multiline patterns + this.registry.register({ + name: 'blockAnchor', + replacer: Replacers.BlockAnchorReplacer, + priority: 20, + condition: (find) => PatternAnalysis.getLineCount(find) >= 3 + }) + + // Whitespace normalization + this.registry.register({ + name: 'whitespaceNormalized', + replacer: Replacers.WhitespaceNormalizedReplacer, + priority: 30, + condition: (find) => PatternAnalysis.hasWhitespaceVariation(find) || PatternAnalysis.isMultiline(find) + }) + + // Indentation flexible matching + this.registry.register({ + name: 'indentationFlexible', + replacer: Replacers.IndentationFlexibleReplacer, + priority: 40, + condition: (find) => PatternAnalysis.isMultiline(find) && PatternAnalysis.hasWhitespaceVariation(find) + }) + + // Escape sequence handling + this.registry.register({ + name: 'escapeNormalized', + replacer: Replacers.EscapeNormalizedReplacer, + priority: 50, + condition: (find) => PatternAnalysis.hasEscapeSequences(find) + }) + + // Trimmed boundary matching + this.registry.register({ + name: 'trimmedBoundary', + replacer: Replacers.TrimmedBoundaryReplacer, + priority: 60, + condition: (find) => PatternAnalysis.hasWhitespaceVariation(find) + }) + + // Context-aware matching for complex patterns + this.registry.register({ + name: 'contextAware', + replacer: Replacers.ContextAwareReplacer, + priority: 70, + condition: (find) => PatternAnalysis.getLineCount(find) >= 3 + }) + } + + replace(content: string, oldString: string, newString: string, replaceAll = false): string { + if (oldString === newString) { + throw new Error("oldString and newString must be different") + } + + // Try simple replacement first for performance + if (PatternAnalysis.isSimplePattern(oldString, content)) { + if (replaceAll) { + return content.replaceAll(oldString, newString) + } + const firstIndex = content.indexOf(oldString) + const lastIndex = content.lastIndexOf(oldString) + if (firstIndex === lastIndex && firstIndex !== -1) { + return content.substring(0, firstIndex) + newString + content.substring(firstIndex + oldString.length) + } + } + + // Get applicable replacers based on pattern characteristics + const replacers = this.registry.getReplacers(oldString, content) + + // Try each replacer in order + for (const replacer of replacers) { + const matches = this.collectMatches(replacer, content, oldString) + + if (matches.length > 0) { + return this.applyReplacement(content, matches, newString, replaceAll) + } + } + + throw new Error("oldString not found in content or was found multiple times") + } + + private collectMatches(replacer: any, content: string, find: string): string[] { + const matches: string[] = [] + + for (const match of replacer(content, find)) { + if (content.includes(match)) { + matches.push(match) + } + } + + return matches + } + + private applyReplacement(content: string, matches: string[], newString: string, replaceAll: boolean): string { + for (const match of matches) { + if (replaceAll) { + return content.replaceAll(match, newString) + } + + const firstIndex = content.indexOf(match) + const lastIndex = content.lastIndexOf(match) + + if (firstIndex === lastIndex && firstIndex !== -1) { + return content.substring(0, firstIndex) + newString + content.substring(firstIndex + match.length) + } + } + + throw new Error("Match found but could not apply replacement") + } +} \ No newline at end of file diff --git a/packages/opencode/src/tool/replacer-registry.ts b/packages/opencode/src/tool/replacer-registry.ts new file mode 100644 index 00000000000..39262765517 --- /dev/null +++ b/packages/opencode/src/tool/replacer-registry.ts @@ -0,0 +1,50 @@ +export type Replacer = (content: string, find: string) => Generator + +export interface ReplacerConfig { + name: string + replacer: Replacer + priority: number + condition?: (find: string, content: string) => boolean +} + +export class ReplacerRegistry { + private replacers: ReplacerConfig[] = [] + + register(config: ReplacerConfig): void { + this.replacers.push(config) + this.replacers.sort((a, b) => a.priority - b.priority) + } + + getReplacers(find: string, content: string): Replacer[] { + return this.replacers + .filter(config => !config.condition || config.condition(find, content)) + .map(config => config.replacer) + } + + clear(): void { + this.replacers = [] + } +} + +// Pattern analysis utilities +export namespace PatternAnalysis { + export function isMultiline(text: string): boolean { + return text.includes('\n') + } + + export function hasWhitespaceVariation(text: string): boolean { + return text !== text.trim() + } + + export function getLineCount(text: string): number { + return text.split('\n').length + } + + export function hasEscapeSequences(text: string): boolean { + return /\\[ntr'"\\`$]/.test(text) + } + + export function isSimplePattern(find: string, content: string): boolean { + return content.includes(find) + } +} \ No newline at end of file diff --git a/packages/opencode/src/tool/replacers.ts b/packages/opencode/src/tool/replacers.ts new file mode 100644 index 00000000000..e7f04ea7b0d --- /dev/null +++ b/packages/opencode/src/tool/replacers.ts @@ -0,0 +1,292 @@ +import type { Replacer } from "./replacer-registry" + +export const SimpleReplacer: Replacer = function* (_content, find) { + yield find +} + +export const LineTrimmedReplacer: Replacer = function* (content, find) { + const originalLines = content.split("\n") + const searchLines = find.split("\n") + + if (searchLines[searchLines.length - 1] === "") { + searchLines.pop() + } + + for (let i = 0; i <= originalLines.length - searchLines.length; i++) { + let matches = true + + for (let j = 0; j < searchLines.length; j++) { + const originalTrimmed = originalLines[i + j].trim() + const searchTrimmed = searchLines[j].trim() + + if (originalTrimmed !== searchTrimmed) { + matches = false + break + } + } + + if (matches) { + let matchStartIndex = 0 + for (let k = 0; k < i; k++) { + matchStartIndex += originalLines[k].length + 1 + } + + let matchEndIndex = matchStartIndex + for (let k = 0; k < searchLines.length; k++) { + matchEndIndex += originalLines[i + k].length + 1 + } + + yield content.substring(matchStartIndex, matchEndIndex) + } + } +} + +export const BlockAnchorReplacer: Replacer = function* (content, find) { + const originalLines = content.split("\n") + const searchLines = find.split("\n") + + if (searchLines.length < 3) { + return + } + + if (searchLines[searchLines.length - 1] === "") { + searchLines.pop() + } + + const firstLineSearch = searchLines[0].trim() + const lastLineSearch = searchLines[searchLines.length - 1].trim() + + // Find blocks where first line matches the search first line + for (let i = 0; i < originalLines.length; i++) { + if (originalLines[i].trim() !== firstLineSearch) { + continue + } + + // Look for the matching last line after this first line + for (let j = i + 2; j < originalLines.length; j++) { + if (originalLines[j].trim() === lastLineSearch) { + // Found a potential block from i to j + let matchStartIndex = 0 + for (let k = 0; k < i; k++) { + matchStartIndex += originalLines[k].length + 1 + } + + let matchEndIndex = matchStartIndex + for (let k = 0; k <= j - i; k++) { + matchEndIndex += originalLines[i + k].length + if (k < j - i) { + matchEndIndex += 1 // Add newline character except for the last line + } + } + + yield content.substring(matchStartIndex, matchEndIndex) + break // Only match the first occurrence of the last line + } + } + } +} + +export const WhitespaceNormalizedReplacer: Replacer = function* (content, find) { + const normalizeWhitespace = (text: string) => text.replace(/\s+/g, " ").trim() + const normalizedFind = normalizeWhitespace(find) + + // Handle single line matches + const lines = content.split("\n") + for (let i = 0; i < lines.length; i++) { + const line = lines[i] + if (normalizeWhitespace(line) === normalizedFind) { + yield line + } + + // Also check for substring matches within lines + const normalizedLine = normalizeWhitespace(line) + if (normalizedLine.includes(normalizedFind)) { + // Find the actual substring in the original line that matches + const words = find.trim().split(/\s+/) + if (words.length > 0) { + const pattern = words.map((word) => word.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")).join("\\s+") + try { + const regex = new RegExp(pattern) + const match = line.match(regex) + if (match) { + yield match[0] + } + } catch (e) { + // Invalid regex pattern, skip + } + } + } + } + + // Handle multi-line matches + const findLines = find.split("\n") + if (findLines.length > 1) { + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length) + if (normalizeWhitespace(block.join("\n")) === normalizedFind) { + yield block.join("\n") + } + } + } +} + +export const IndentationFlexibleReplacer: Replacer = function* (content, find) { + const removeIndentation = (text: string) => { + const lines = text.split("\n") + const nonEmptyLines = lines.filter((line) => line.trim().length > 0) + if (nonEmptyLines.length === 0) return text + + const minIndent = Math.min( + ...nonEmptyLines.map((line) => { + const match = line.match(/^(\s*)/) + return match ? match[1].length : 0 + }), + ) + + return lines.map((line) => (line.trim().length === 0 ? line : line.slice(minIndent))).join("\n") + } + + const normalizedFind = removeIndentation(find) + const contentLines = content.split("\n") + const findLines = find.split("\n") + + for (let i = 0; i <= contentLines.length - findLines.length; i++) { + const block = contentLines.slice(i, i + findLines.length).join("\n") + if (removeIndentation(block) === normalizedFind) { + yield block + } + } +} + +export const EscapeNormalizedReplacer: Replacer = function* (content, find) { + const unescapeString = (str: string): string => { + return str.replace(/\\(n|t|r|'|"|`|\\|\n|\$)/g, (match, capturedChar) => { + switch (capturedChar) { + case "n": + return "\n" + case "t": + return "\t" + case "r": + return "\r" + case "'": + return "'" + case '"': + return '"' + case "`": + return "`" + case "\\": + return "\\" + case "\n": + return "\n" + case "$": + return "$" + default: + return match + } + }) + } + + const unescapedFind = unescapeString(find) + + // Try direct match with unescaped find string + if (content.includes(unescapedFind)) { + yield unescapedFind + } + + // Also try finding escaped versions in content that match unescaped find + const lines = content.split("\n") + const findLines = unescapedFind.split("\n") + + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length).join("\n") + const unescapedBlock = unescapeString(block) + + if (unescapedBlock === unescapedFind) { + yield block + } + } +} + +export const TrimmedBoundaryReplacer: Replacer = function* (content, find) { + const trimmedFind = find.trim() + + if (trimmedFind === find) { + // Already trimmed, no point in trying + return + } + + // Try to find the trimmed version + if (content.includes(trimmedFind)) { + yield trimmedFind + } + + // Also try finding blocks where trimmed content matches + const lines = content.split("\n") + const findLines = find.split("\n") + + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length).join("\n") + + if (block.trim() === trimmedFind) { + yield block + } + } +} + +export const ContextAwareReplacer: Replacer = function* (content, find) { + const findLines = find.split("\n") + if (findLines.length < 3) { + // Need at least 3 lines to have meaningful context + return + } + + // Remove trailing empty line if present + if (findLines[findLines.length - 1] === "") { + findLines.pop() + } + + const contentLines = content.split("\n") + + // Extract first and last lines as context anchors + const firstLine = findLines[0].trim() + const lastLine = findLines[findLines.length - 1].trim() + + // Find blocks that start and end with the context anchors + for (let i = 0; i < contentLines.length; i++) { + if (contentLines[i].trim() !== firstLine) continue + + // Look for the matching last line + for (let j = i + 2; j < contentLines.length; j++) { + if (contentLines[j].trim() === lastLine) { + // Found a potential context block + const blockLines = contentLines.slice(i, j + 1) + const block = blockLines.join("\n") + + // Check if the middle content has reasonable similarity + // (simple heuristic: at least 50% of non-empty lines should match when trimmed) + if (blockLines.length === findLines.length) { + let matchingLines = 0 + let totalNonEmptyLines = 0 + + for (let k = 1; k < blockLines.length - 1; k++) { + const blockLine = blockLines[k].trim() + const findLine = findLines[k].trim() + + if (blockLine.length > 0 || findLine.length > 0) { + totalNonEmptyLines++ + if (blockLine === findLine) { + matchingLines++ + } + } + } + + if (totalNonEmptyLines === 0 || matchingLines / totalNonEmptyLines >= 0.5) { + yield block + break // Only match the first occurrence + } + } + break + } + } + } +} \ No newline at end of file diff --git a/packages/opencode/src/trace/buffered-writer.ts b/packages/opencode/src/trace/buffered-writer.ts new file mode 100644 index 00000000000..f0cac80887b --- /dev/null +++ b/packages/opencode/src/trace/buffered-writer.ts @@ -0,0 +1,120 @@ +export namespace BufferedWriter { + export interface BufferConfig { + bufferSize?: number + flushInterval?: number + autoFlush?: boolean + } + + const DEFAULT_CONFIG: Required = { + bufferSize: 50, + flushInterval: 1000, + autoFlush: true, + } + + export class Writer { + private buffer: string[] = [] + private flushTimer: Timer | null = null + private config: Required + private writer: BunFile["writer"] | null = null + + constructor( + private filePath: string, + config: BufferConfig = {} + ) { + this.config = { ...DEFAULT_CONFIG, ...config } + this.setupCleanup() + } + + async write(data: string): Promise { + if (!this.writer) { + const file = Bun.file(this.filePath) + this.writer = file.writer() + } + + this.buffer.push(data) + + // Flush if buffer is full + if (this.buffer.length >= this.config.bufferSize) { + await this.flush() + } else if (this.config.autoFlush && !this.flushTimer) { + // Set up periodic flush + this.flushTimer = setTimeout(() => this.flush(), this.config.flushInterval) + } + } + + async flush(): Promise { + if (this.buffer.length === 0 || !this.writer) { + return + } + + try { + const data = this.buffer.join('') + await this.writer.write(data) + await this.writer.flush() + this.buffer = [] + } catch (error) { + console.error('Failed to flush buffer:', error) + } + + if (this.flushTimer) { + clearTimeout(this.flushTimer) + this.flushTimer = null + } + } + + async close(): Promise { + await this.flush() + + if (this.writer) { + await this.writer.end() + this.writer = null + } + + if (this.flushTimer) { + clearTimeout(this.flushTimer) + this.flushTimer = null + } + } + + private setupCleanup(): void { + // Ensure buffer is flushed on process exit + const cleanup = () => { + this.flush().catch(console.error) + } + + process.on('exit', cleanup) + process.on('SIGINT', () => { + cleanup() + process.exit() + }) + process.on('SIGTERM', cleanup) + process.on('beforeExit', cleanup) + } + } + + export class Manager { + private writers: Map = new Map() + + getWriter(filePath: string, config?: BufferConfig): Writer { + let writer = this.writers.get(filePath) + + if (!writer) { + writer = new Writer(filePath, config) + this.writers.set(filePath, writer) + } + + return writer + } + + async flushAll(): Promise { + const flushPromises = Array.from(this.writers.values()).map(writer => writer.flush()) + await Promise.all(flushPromises) + } + + async closeAll(): Promise { + const closePromises = Array.from(this.writers.values()).map(writer => writer.close()) + await Promise.all(closePromises) + this.writers.clear() + } + } +} \ No newline at end of file diff --git a/packages/opencode/src/trace/index.ts b/packages/opencode/src/trace/index.ts index 8dba93d50f9..3792a917b05 100644 --- a/packages/opencode/src/trace/index.ts +++ b/packages/opencode/src/trace/index.ts @@ -1,53 +1,53 @@ import { Global } from "../global" import { Installation } from "../installation" import path from "path" +import { BufferedWriter } from "./buffered-writer" +import { RequestInterceptor } from "./request-interceptor" export namespace Trace { - export function init() { + let interceptor: RequestInterceptor.FetchInterceptor | null = null + let writerManager: BufferedWriter.Manager | null = null + + export function init(): void { if (!Installation.isDev()) return - const writer = Bun.file(path.join(Global.Path.data, "log", "fetch.log")).writer() - - const originalFetch = globalThis.fetch - // @ts-expect-error - globalThis.fetch = async (input: RequestInfo | URL, init?: RequestInit) => { - const url = typeof input === "string" ? input : input instanceof URL ? input.toString() : input.url - const method = init?.method || "GET" - - const urlObj = new URL(url) - - writer.write(`\n${method} ${urlObj.pathname}${urlObj.search} HTTP/1.1\n`) - writer.write(`Host: ${urlObj.host}\n`) - - if (init?.headers) { - if (init.headers instanceof Headers) { - init.headers.forEach((value, key) => { - writer.write(`${key}: ${value}\n`) - }) - } else { - for (const [key, value] of Object.entries(init.headers)) { - writer.write(`${key}: ${value}\n`) - } - } - } - - if (init?.body) { - writer.write(`\n${init.body}`) - } - writer.flush() - const response = await originalFetch(input, init) - const clonedResponse = response.clone() - writer.write(`\nHTTP/1.1 ${response.status} ${response.statusText}\n`) - response.headers.forEach((value, key) => { - writer.write(`${key}: ${value}\n`) - }) - if (clonedResponse.body) { - clonedResponse.text().then(async (x) => { - writer.write(`\n${x}\n`) - }) - } - writer.flush() - - return response + + // Initialize writer manager + writerManager = new BufferedWriter.Manager() + + // Get buffered writer for fetch logs + const logPath = path.join(Global.Path.data, "log", "fetch.log") + const writer = writerManager.getWriter(logPath, { + bufferSize: 50, + flushInterval: 1000, + autoFlush: true, + }) + + // Create and install fetch interceptor + interceptor = new RequestInterceptor.FetchInterceptor(writer, { + logRequests: true, + logResponses: true, + logResponseBodies: true, + maxBodyLength: 1024 * 1024, // 1MB + }) + + interceptor.install() + } + + export async function shutdown(): Promise { + if (interceptor) { + interceptor.uninstall() + interceptor = null + } + + if (writerManager) { + await writerManager.closeAll() + writerManager = null + } + } + + export async function flush(): Promise { + if (writerManager) { + await writerManager.flushAll() } } } diff --git a/packages/opencode/src/trace/request-interceptor.ts b/packages/opencode/src/trace/request-interceptor.ts new file mode 100644 index 00000000000..8a8e383b8cb --- /dev/null +++ b/packages/opencode/src/trace/request-interceptor.ts @@ -0,0 +1,4 @@ +import { BufferedWriter } from "./buffered-writer" + +export namespace RequestInterceptor { + export interface InterceptorConfig {\n logRequests?: boolean\n logResponses?: boolean\n logResponseBodies?: boolean\n maxBodyLength?: number\n }\n\n const DEFAULT_CONFIG: Required = {\n logRequests: true,\n logResponses: true,\n logResponseBodies: true,\n maxBodyLength: 1024 * 1024, // 1MB\n }\n\n export class FetchInterceptor {\n private originalFetch: typeof globalThis.fetch\n private config: Required\n\n constructor(\n private writer: BufferedWriter.Writer,\n config: InterceptorConfig = {}\n ) {\n this.config = { ...DEFAULT_CONFIG, ...config }\n this.originalFetch = globalThis.fetch\n }\n\n install(): void {\n // @ts-expect-error - Overriding global fetch\n globalThis.fetch = this.interceptedFetch.bind(this)\n }\n\n uninstall(): void {\n globalThis.fetch = this.originalFetch\n }\n\n private async interceptedFetch(\n input: RequestInfo | URL,\n init?: RequestInit\n ): Promise {\n const startTime = Date.now()\n const url = typeof input === \"string\" ? input : input instanceof URL ? input.toString() : input.url\n const method = init?.method || \"GET\"\n\n if (this.config.logRequests) {\n await this.logRequest(url, method, init)\n }\n\n try {\n const response = await this.originalFetch(input, init)\n const clonedResponse = response.clone()\n\n if (this.config.logResponses) {\n await this.logResponse(response, startTime)\n }\n\n if (this.config.logResponseBodies) {\n this.logResponseBody(clonedResponse).catch(console.error)\n }\n\n return response\n } catch (error) {\n await this.logError(url, method, error, startTime)\n throw error\n }\n }\n\n private async logRequest(url: string, method: string, init?: RequestInit): Promise {\n const urlObj = new URL(url)\n \n let requestLog = `\\n${method} ${urlObj.pathname}${urlObj.search} HTTP/1.1\\n`\n requestLog += `Host: ${urlObj.host}\\n`\n\n if (init?.headers) {\n const headers = this.normalizeHeaders(init.headers)\n for (const [key, value] of Object.entries(headers)) {\n requestLog += `${key}: ${value}\\n`\n }\n }\n\n if (init?.body) {\n const bodyContent = this.formatBody(init.body)\n if (bodyContent) {\n requestLog += `\\n${bodyContent}`\n }\n }\n \n await this.writer.write(requestLog)\n }\n\n private async logResponse(response: Response, startTime: number): Promise {\n const duration = Date.now() - startTime\n \n let responseLog = `\\nHTTP/1.1 ${response.status} ${response.statusText} (${duration}ms)\\n`\n \n response.headers.forEach((value, key) => {\n responseLog += `${key}: ${value}\\n`\n })\n \n await this.writer.write(responseLog)\n }\n\n private async logResponseBody(response: Response): Promise {\n try {\n const text = await response.text()\n if (text && text.length <= this.config.maxBodyLength) {\n await this.writer.write(`\\n${text}\\n`)\n } else if (text.length > this.config.maxBodyLength) {\n const truncated = text.slice(0, this.config.maxBodyLength)\n await this.writer.write(`\\n${truncated}\\n[Response body truncated - ${text.length} bytes total]\\n`)\n }\n } catch (error) {\n await this.writer.write(`\\n[Failed to read response body: ${error}]\\n`)\n }\n }\n\n private async logError(url: string, method: string, error: unknown, startTime: number): Promise {\n const duration = Date.now() - startTime\n const errorLog = `\\nERROR ${method} ${url} (${duration}ms)\\n${error}\\n`\n await this.writer.write(errorLog)\n }\n\n private normalizeHeaders(headers: HeadersInit): Record {\n if (headers instanceof Headers) {\n const result: Record = {}\n headers.forEach((value, key) => {\n result[key] = value\n })\n return result\n }\n \n if (Array.isArray(headers)) {\n const result: Record = {}\n for (const [key, value] of headers) {\n result[key] = value\n }\n return result\n }\n \n return headers as Record\n }\n\n private formatBody(body: BodyInit): string | null {\n if (typeof body === 'string') {\n return body\n }\n \n if (body instanceof URLSearchParams) {\n return body.toString()\n }\n \n if (body instanceof FormData) {\n return '[FormData]'\n }\n \n if (body instanceof ArrayBuffer || body instanceof Uint8Array) {\n return `[Binary data: ${body.byteLength} bytes]`\n }\n \n if (body instanceof ReadableStream) {\n return '[ReadableStream]'\n }\n \n return '[Unknown body type]'\n }\n }\n} \ No newline at end of file diff --git a/tests/auth/auth.integration.test.spec.ts b/tests/auth/auth.integration.test.spec.ts new file mode 100644 index 00000000000..da23be5c817 --- /dev/null +++ b/tests/auth/auth.integration.test.spec.ts @@ -0,0 +1,1571 @@ +/** + * Authentication Module - Integration Test Specifications + * Testing authentication flows with database, cache, and middleware integration + */ + +describe('Authentication Integration Tests', () => { + + // Test Database Setup + let testDb; + let testCache; + let testMailer; + + beforeAll(async () => { + testDb = await setupTestDatabase(); + testCache = await setupTestCache(); + testMailer = setupMailerMock(); + }); + + afterAll(async () => { + await testDb.close(); + await testCache.close(); + }); + + beforeEach(async () => { + await testDb.clean(); + await testCache.flush(); + testMailer.reset(); + }); + + describe('User Registration Flow', () => { + it('should successfully register a new user', async () => { + // Arrange + const registrationData = { + email: 'newuser@example.com', + password: 'SecurePass123!', + firstName: 'John', + lastName: 'Doe' + }; + + // Act + const result = await authService.register(registrationData); + + // Assert + expect(result.success).toBe(true); + expect(result.user).toBeDefined(); + expect(result.user.id).toBeDefined(); + expect(result.user.email).toBe(registrationData.email); + expect(result.user.password).toBeUndefined(); // Password should not be returned + + // Verify database record + const dbUser = await testDb.users.findByEmail(registrationData.email); + expect(dbUser).toBeDefined(); + expect(dbUser.passwordHash).toBeDefined(); + expect(dbUser.passwordHash).not.toBe(registrationData.password); + + // Verify verification email sent + expect(testMailer.sentEmails).toHaveLength(1); + expect(testMailer.sentEmails[0].to).toBe(registrationData.email); + expect(testMailer.sentEmails[0].subject).toContain('Verify your email'); + }); + + it('should reject duplicate email registration', async () => { + // Arrange + const email = 'existing@example.com'; + await testDb.users.create({ + email, + passwordHash: 'existing-hash', + firstName: 'Existing', + lastName: 'User' + }); + + const registrationData = { + email, + password: 'NewPass123!', + firstName: 'New', + lastName: 'User' + }; + + // Act & Assert + await expect(authService.register(registrationData)) + .rejects.toThrow('Email already registered'); + + // Verify no additional email sent + expect(testMailer.sentEmails).toHaveLength(0); + }); + + it('should handle database transaction rollback on error', async () => { + // Arrange + const registrationData = { + email: 'transactiontest@example.com', + password: 'SecurePass123!', + firstName: 'Transaction', + lastName: 'Test' + }; + + // Mock profile creation to fail + jest.spyOn(testDb.profiles, 'create').mockRejectedValueOnce(new Error('Profile creation failed')); + + // Act & Assert + await expect(authService.register(registrationData)) + .rejects.toThrow('Profile creation failed'); + + // Verify user was not created due to rollback + const user = await testDb.users.findByEmail(registrationData.email); + expect(user).toBeNull(); + }); + + it('should enforce rate limiting on registration attempts', async () => { + // Arrange + const ipAddress = '192.168.1.100'; + const registrationAttempts = 6; // Assuming limit is 5 + + // Act - Make multiple registration attempts + for (let i = 0; i < registrationAttempts; i++) { + const registrationData = { + email: `user${i}@example.com`, + password: 'SecurePass123!', + firstName: 'Test', + lastName: `User${i}` + }; + + if (i < 5) { + await authService.register(registrationData, { ipAddress }); + } else { + // Assert - 6th attempt should be rate limited + await expect(authService.register(registrationData, { ipAddress })) + .rejects.toThrow('Too many registration attempts. Please try again later.'); + } + } + }); + + it('should sanitize user input to prevent XSS', async () => { + // Arrange + const registrationData = { + email: 'xsstest@example.com', + password: 'SecurePass123!', + firstName: 'John', + lastName: 'Doe' + }; + + // Act + const result = await authService.register(registrationData); + + // Assert + const dbUser = await testDb.users.findById(result.user.id); + expect(dbUser.firstName).toBe('John'); // Script tags removed + expect(dbUser.lastName).toBe('Doe'); // IMG tag removed + expect(dbUser.firstName).not.toContain('', expected: 'alert("XSS")' }, + { input: 'user\'; DROP TABLE users; --', expected: 'user DROP TABLE users --' }, + { input: '{{constructor.constructor("alert(1)")()}}', expected: 'constructor.constructor("alert(1)")()' } + ]; + + // Act & Assert + maliciousInputs.forEach(({ input, expected }) => { + expect(sanitizeInput(input)).toBe(expected); + }); + }); + + it('should preserve safe characters', () => { + // Arrange + const safeInput = 'John.Doe-123_test@example.com'; + + // Act + const sanitized = sanitizeInput(safeInput); + + // Assert + expect(sanitized).toBe(safeInput); + }); + + it('should handle null and undefined', () => { + // Act & Assert + expect(sanitizeInput(null)).toBe(''); + expect(sanitizeInput(undefined)).toBe(''); + }); + + it('should trim whitespace', () => { + // Arrange + const input = ' user@example.com '; + + // Act + const sanitized = sanitizeInput(input); + + // Assert + expect(sanitized).toBe('user@example.com'); + }); + }); + }); + + describe('Rate Limiting Service', () => { + describe('checkRateLimit()', () => { + beforeEach(() => { + // Reset rate limiter state + jest.clearAllMocks(); + }); + + it('should allow requests within rate limit', () => { + // Arrange + const userId = 'user123'; + const action = 'login'; + const limit = 5; + const windowMs = 60000; // 1 minute + + // Act & Assert + for (let i = 0; i < limit; i++) { + const result = checkRateLimit(userId, action, { limit, windowMs }); + expect(result.allowed).toBe(true); + expect(result.remaining).toBe(limit - i - 1); + } + }); + + it('should block requests exceeding rate limit', () => { + // Arrange + const userId = 'user123'; + const action = 'login'; + const limit = 3; + const windowMs = 60000; + + // Act - make requests up to limit + for (let i = 0; i < limit; i++) { + checkRateLimit(userId, action, { limit, windowMs }); + } + + // Act - exceed limit + const result = checkRateLimit(userId, action, { limit, windowMs }); + + // Assert + expect(result.allowed).toBe(false); + expect(result.remaining).toBe(0); + expect(result.retryAfter).toBeGreaterThan(0); + }); + + it('should reset after time window', () => { + // Arrange + const userId = 'user123'; + const action = 'login'; + const limit = 2; + const windowMs = 100; // 100ms for testing + + // Act - exhaust limit + for (let i = 0; i < limit; i++) { + checkRateLimit(userId, action, { limit, windowMs }); + } + + // Wait for window to expire + jest.advanceTimersByTime(windowMs + 1); + + // Act - should allow again + const result = checkRateLimit(userId, action, { limit, windowMs }); + + // Assert + expect(result.allowed).toBe(true); + expect(result.remaining).toBe(limit - 1); + }); + + it('should track different actions separately', () => { + // Arrange + const userId = 'user123'; + const loginLimit = 3; + const resetLimit = 1; + const windowMs = 60000; + + // Act - exhaust login limit + for (let i = 0; i < loginLimit; i++) { + checkRateLimit(userId, 'login', { limit: loginLimit, windowMs }); + } + + // Act - password reset should still be allowed + const resetResult = checkRateLimit(userId, 'passwordReset', { limit: resetLimit, windowMs }); + + // Assert + expect(resetResult.allowed).toBe(true); + }); + + it('should track different users separately', () => { + // Arrange + const user1 = 'user123'; + const user2 = 'user456'; + const action = 'login'; + const limit = 2; + const windowMs = 60000; + + // Act - exhaust limit for user1 + for (let i = 0; i < limit; i++) { + checkRateLimit(user1, action, { limit, windowMs }); + } + + // Act - user2 should still be allowed + const result = checkRateLimit(user2, action, { limit, windowMs }); + + // Assert + expect(result.allowed).toBe(true); + expect(result.remaining).toBe(limit - 1); + }); + + it('should handle IP-based rate limiting', () => { + // Arrange + const ipAddress = '192.168.1.1'; + const action = 'login'; + const limit = 10; + const windowMs = 60000; + + // Act + const result = checkRateLimit(ipAddress, action, { limit, windowMs, keyType: 'ip' }); + + // Assert + expect(result.allowed).toBe(true); + expect(result.remaining).toBe(limit - 1); + }); + }); + }); + + describe('Session Management Service', () => { + describe('createSession()', () => { + it('should create a session with required properties', () => { + // Arrange + const userId = 'user123'; + const deviceInfo = { userAgent: 'Mozilla/5.0', ip: '192.168.1.1' }; + + // Act + const session = createSession(userId, deviceInfo); + + // Assert + expect(session.sessionId).toBeDefined(); + expect(session.sessionId).toMatch(/^[a-f0-9-]{36}$/); // UUID format + expect(session.userId).toBe(userId); + expect(session.createdAt).toBeInstanceOf(Date); + expect(session.expiresAt).toBeInstanceOf(Date); + expect(session.deviceInfo).toEqual(deviceInfo); + expect(session.isActive).toBe(true); + }); + + it('should set correct expiration time', () => { + // Arrange + const userId = 'user123'; + const ttl = 3600000; // 1 hour in ms + + // Act + const session = createSession(userId, {}, { ttl }); + + // Assert + const expectedExpiry = new Date(Date.now() + ttl); + expect(session.expiresAt.getTime()).toBeCloseTo(expectedExpiry.getTime(), -2); + }); + + it('should generate unique session IDs', () => { + // Arrange + const userId = 'user123'; + const sessions = new Set(); + + // Act + for (let i = 0; i < 1000; i++) { + const session = createSession(userId, {}); + sessions.add(session.sessionId); + } + + // Assert + expect(sessions.size).toBe(1000); + }); + }); + + describe('validateSession()', () => { + it('should validate an active session', () => { + // Arrange + const session = { + sessionId: 'valid-session-id', + userId: 'user123', + expiresAt: new Date(Date.now() + 3600000), + isActive: true + }; + + // Act + const isValid = validateSession(session); + + // Assert + expect(isValid).toBe(true); + }); + + it('should reject expired session', () => { + // Arrange + const session = { + sessionId: 'expired-session-id', + userId: 'user123', + expiresAt: new Date(Date.now() - 1000), + isActive: true + }; + + // Act + const isValid = validateSession(session); + + // Assert + expect(isValid).toBe(false); + }); + + it('should reject inactive session', () => { + // Arrange + const session = { + sessionId: 'inactive-session-id', + userId: 'user123', + expiresAt: new Date(Date.now() + 3600000), + isActive: false + }; + + // Act + const isValid = validateSession(session); + + // Assert + expect(isValid).toBe(false); + }); + + it('should handle null/undefined session', () => { + // Act & Assert + expect(validateSession(null)).toBe(false); + expect(validateSession(undefined)).toBe(false); + }); + }); + + describe('extendSession()', () => { + it('should extend session expiration time', () => { + // Arrange + const originalExpiry = new Date(Date.now() + 1800000); // 30 minutes + const session = { + sessionId: 'test-session', + userId: 'user123', + expiresAt: originalExpiry, + isActive: true + }; + const extensionTime = 3600000; // 1 hour + + // Act + const extendedSession = extendSession(session, extensionTime); + + // Assert + expect(extendedSession.expiresAt.getTime()).toBeGreaterThan(originalExpiry.getTime()); + expect(extendedSession.lastActivity).toBeDefined(); + }); + + it('should not extend expired session', () => { + // Arrange + const session = { + sessionId: 'expired-session', + userId: 'user123', + expiresAt: new Date(Date.now() - 1000), + isActive: true + }; + + // Act & Assert + expect(() => extendSession(session, 3600000)) + .toThrow('Cannot extend expired session'); + }); + + it('should not extend beyond maximum session lifetime', () => { + // Arrange + const session = { + sessionId: 'test-session', + userId: 'user123', + createdAt: new Date(Date.now() - 86400000), // Created 24 hours ago + expiresAt: new Date(Date.now() + 3600000), + isActive: true + }; + const maxLifetime = 86400000; // 24 hours + + // Act + const extendedSession = extendSession(session, 3600000, { maxLifetime }); + + // Assert + const maxExpiry = new Date(session.createdAt.getTime() + maxLifetime); + expect(extendedSession.expiresAt.getTime()).toBeLessThanOrEqual(maxExpiry.getTime()); + }); + }); + }); + + describe('Security Utilities', () => { + describe('generateSecureRandom()', () => { + it('should generate random bytes of specified length', () => { + // Arrange + const lengths = [16, 32, 64, 128]; + + // Act & Assert + lengths.forEach(length => { + const random = generateSecureRandom(length); + expect(random).toHaveLength(length); + expect(random).toBeInstanceOf(Buffer); + }); + }); + + it('should generate cryptographically secure random values', () => { + // Arrange + const samples = []; + const sampleSize = 1000; + const byteLength = 32; + + // Act + for (let i = 0; i < sampleSize; i++) { + samples.push(generateSecureRandom(byteLength).toString('hex')); + } + + // Assert - all should be unique + const uniqueSamples = new Set(samples); + expect(uniqueSamples.size).toBe(sampleSize); + }); + + it('should throw error for invalid length', () => { + // Act & Assert + expect(() => generateSecureRandom(0)).toThrow('Length must be positive'); + expect(() => generateSecureRandom(-1)).toThrow('Length must be positive'); + expect(() => generateSecureRandom(null)).toThrow('Length is required'); + }); + }); + + describe('constantTimeCompare()', () => { + it('should return true for identical strings', () => { + // Arrange + const str1 = 'secretValue123'; + const str2 = 'secretValue123'; + + // Act + const result = constantTimeCompare(str1, str2); + + // Assert + expect(result).toBe(true); + }); + + it('should return false for different strings', () => { + // Arrange + const str1 = 'secretValue123'; + const str2 = 'secretValue124'; + + // Act + const result = constantTimeCompare(str1, str2); + + // Assert + expect(result).toBe(false); + }); + + it('should have constant execution time', () => { + // Arrange + const base = 'a'.repeat(100); + const earlyDiff = 'b' + 'a'.repeat(99); + const lateDiff = 'a'.repeat(99) + 'b'; + + // Act + const times = []; + + for (let i = 0; i < 1000; i++) { + const start = performance.now(); + constantTimeCompare(base, earlyDiff); + times.push(performance.now() - start); + } + const avgEarlyDiff = times.reduce((a, b) => a + b) / times.length; + + times.length = 0; + for (let i = 0; i < 1000; i++) { + const start = performance.now(); + constantTimeCompare(base, lateDiff); + times.push(performance.now() - start); + } + const avgLateDiff = times.reduce((a, b) => a + b) / times.length; + + // Assert - timing should be similar + const timeDifference = Math.abs(avgEarlyDiff - avgLateDiff); + expect(timeDifference).toBeLessThan(avgEarlyDiff * 0.1); // Within 10% + }); + + it('should handle different length strings', () => { + // Arrange + const str1 = 'short'; + const str2 = 'muchlongerstring'; + + // Act + const result = constantTimeCompare(str1, str2); + + // Assert + expect(result).toBe(false); + }); + }); + }); + + describe('CSRF Protection', () => { + describe('generateCSRFToken()', () => { + it('should generate a valid CSRF token', () => { + // Arrange + const sessionId = 'session123'; + + // Act + const token = generateCSRFToken(sessionId); + + // Assert + expect(token).toBeDefined(); + expect(token).toHaveLength(64); // 32 bytes in hex + expect(token).toMatch(/^[a-f0-9]{64}$/); + }); + + it('should generate different tokens for different sessions', () => { + // Arrange + const session1 = 'session123'; + const session2 = 'session456'; + + // Act + const token1 = generateCSRFToken(session1); + const token2 = generateCSRFToken(session2); + + // Assert + expect(token1).not.toBe(token2); + }); + + it('should generate consistent token for same session', () => { + // Arrange + const sessionId = 'session123'; + + // Act + const token1 = generateCSRFToken(sessionId); + const token2 = generateCSRFToken(sessionId); + + // Assert + expect(token1).toBe(token2); + }); + }); + + describe('validateCSRFToken()', () => { + it('should validate correct CSRF token', () => { + // Arrange + const sessionId = 'session123'; + const token = generateCSRFToken(sessionId); + + // Act + const isValid = validateCSRFToken(token, sessionId); + + // Assert + expect(isValid).toBe(true); + }); + + it('should reject token from different session', () => { + // Arrange + const session1 = 'session123'; + const session2 = 'session456'; + const token = generateCSRFToken(session1); + + // Act + const isValid = validateCSRFToken(token, session2); + + // Assert + expect(isValid).toBe(false); + }); + + it('should reject tampered token', () => { + // Arrange + const sessionId = 'session123'; + const validToken = generateCSRFToken(sessionId); + const tamperedToken = validToken.slice(0, -2) + 'ff'; + + // Act + const isValid = validateCSRFToken(tamperedToken, sessionId); + + // Assert + expect(isValid).toBe(false); + }); + + it('should handle missing token', () => { + // Arrange + const sessionId = 'session123'; + + // Act & Assert + expect(validateCSRFToken(null, sessionId)).toBe(false); + expect(validateCSRFToken(undefined, sessionId)).toBe(false); + expect(validateCSRFToken('', sessionId)).toBe(false); + }); + }); + }); +}); + +/** + * Mock/Stub Requirements: + * + * 1. Crypto module mocks for: + * - bcrypt or argon2 for password hashing + * - crypto.randomBytes for token generation + * - crypto.timingSafeEqual for constant-time comparison + * + * 2. JWT library mocks for: + * - sign() method + * - verify() method + * - decode() method + * + * 3. Time/Date mocks: + * - jest.useFakeTimers() for rate limiting tests + * - Date.now() mocks for expiration testing + * + * 4. Performance mocks: + * - performance.now() for timing attack tests + * + * Expected Assertions: + * - All tests should use specific matchers (toBe, toEqual, toMatch, etc.) + * - Error assertions should check both error type and message + * - Timing assertions should use appropriate tolerances + * - Collection assertions should verify size and uniqueness + */ \ No newline at end of file