diff --git a/.githooks/pre-commit b/.githooks/pre-commit index d6ea4dd2..f4c1929f 100755 --- a/.githooks/pre-commit +++ b/.githooks/pre-commit @@ -7,4 +7,4 @@ if ! command -v golangci-lint >/dev/null 2>&1; then exit 1 fi -golangci-lint run ./... +output=$(golangci-lint run ./... 2>&1) || { echo "$output" >&2; exit 1; } diff --git a/.roborev.toml b/.roborev.toml index 809ab342..c9770664 100644 --- a/.roborev.toml +++ b/.roborev.toml @@ -26,11 +26,23 @@ security findings: embedded in prompts (all locally generated; a compromised local agent already has full shell access) - "Secret exposure" from hook stderr logged or forwarded to the model +- "Untrusted remap data" or "history tampering" from the remap endpoint (the + client is the user's own post-rewrite hook, data comes from the user's own + git repo, and the daemon is localhost-only) - .githooks/ tracked as a supply-chain risk (source templates only; installed hooks are frozen copies in .git/hooks/, unaffected by branch switches) +- "Argument injection" on internal helpers (e.g., git.GetPatchID, + git.GetCommitInfo) that receive SHAs already validated or resolved at the + call site — validation belongs at trust boundaries, not every internal call - Race conditions in metadata handoff between CLI and daemon (correctness concern, not exploitable by external attacker) +- "Markerless" hook blocks or interpreter mismatch — hook install/upgrade uses + marker-based detection, every generated hook includes a marker comment, and + append logic refuses to add shell snippets to non-shell hooks +- Symlink-following in hook read/write under .git/hooks/ (controlled by the + local user; a compromised local filesystem is out of scope) + ## Config loading and filesystem fallback Config loading (loadGuidelines, loadCIRepoConfig) reads .roborev.toml from diff --git a/cmd/roborev/hook_test.go b/cmd/roborev/hook_test.go index 6f8acd54..2731c68d 100644 --- a/cmd/roborev/hook_test.go +++ b/cmd/roborev/hook_test.go @@ -3,6 +3,7 @@ package main import ( "errors" "fmt" + "io/fs" "net" "net/http" "net/http/httptest" @@ -56,9 +57,8 @@ func TestUninstallHookCmd(t *testing.T) { }) t.Run("hook with roborev only - removes file", func(t *testing.T) { - hookContent := "#!/bin/bash\n# roborev auto-commit hook\nroborev enqueue\n" repo := testutil.NewTestRepo(t) - repo.WriteHook(hookContent) + repo.WriteHook(generateHookContent()) defer repo.Chdir()() cmd := uninstallHookCmd() @@ -74,9 +74,10 @@ func TestUninstallHookCmd(t *testing.T) { }) t.Run("hook with roborev and other commands - preserves others", func(t *testing.T) { - hookContent := "#!/bin/bash\necho 'before'\nroborev enqueue\necho 'after'\n" repo := testutil.NewTestRepo(t) - repo.WriteHook(hookContent) + mixed := "#!/bin/sh\necho 'before'\necho 'after'\n" + + generateHookContent() + repo.WriteHook(mixed) defer repo.Chdir()() cmd := uninstallHookCmd() @@ -85,15 +86,15 @@ func TestUninstallHookCmd(t *testing.T) { t.Fatalf("uninstall-hook failed: %v", err) } - // Hook should exist with roborev line removed + // Hook should exist with roborev snippet removed content, err := os.ReadFile(repo.HookPath) if err != nil { t.Fatalf("Failed to read hook: %v", err) } contentStr := string(content) - if strings.Contains(strings.ToLower(contentStr), "roborev") { - t.Error("Hook should not contain roborev") + if strings.Contains(contentStr, "enqueue --quiet") { + t.Error("Hook should not contain generated roborev snippet") } if !strings.Contains(contentStr, "echo 'before'") { t.Error("Hook should still contain 'echo before'") @@ -103,10 +104,16 @@ func TestUninstallHookCmd(t *testing.T) { } }) - t.Run("hook with capitalized RoboRev", func(t *testing.T) { - hookContent := "#!/bin/bash\n# RoboRev hook\nRoboRev enqueue\n" + t.Run("also removes post-rewrite hook", func(t *testing.T) { repo := testutil.NewTestRepo(t) - repo.WriteHook(hookContent) + repo.WriteHook(generateHookContent()) + // Also install post-rewrite hook + prPath := filepath.Join(repo.HooksDir, "post-rewrite") + os.WriteFile( + prPath, + []byte(generatePostRewriteHookContent()), + 0755, + ) defer repo.Chdir()() cmd := uninstallHookCmd() @@ -115,9 +122,34 @@ func TestUninstallHookCmd(t *testing.T) { t.Fatalf("uninstall-hook failed: %v", err) } - // Hook should be removed (only had RoboRev content) if _, err := os.Stat(repo.HookPath); !os.IsNotExist(err) { - t.Error("Hook file should have been removed") + t.Error("post-commit hook should have been removed") + } + if _, err := os.Stat(prPath); !os.IsNotExist(err) { + t.Error("post-rewrite hook should have been removed") + } + }) + + t.Run("removes post-rewrite even without post-commit", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + // Only install post-rewrite, no post-commit + prPath := filepath.Join(repo.HooksDir, "post-rewrite") + os.MkdirAll(repo.HooksDir, 0755) + os.WriteFile( + prPath, + []byte(generatePostRewriteHookContent()), + 0755, + ) + defer repo.Chdir()() + + cmd := uninstallHookCmd() + err := cmd.Execute() + if err != nil { + t.Fatalf("uninstall-hook failed: %v", err) + } + + if _, err := os.Stat(prPath); !os.IsNotExist(err) { + t.Error("post-rewrite hook should have been removed") } }) } @@ -253,7 +285,7 @@ func TestHookNeedsUpgrade(t *testing.T) { t.Run("outdated hook", func(t *testing.T) { repo := testutil.NewTestRepo(t) repo.WriteHook("#!/bin/sh\n# roborev post-commit hook\nroborev enqueue\n") - if !hookNeedsUpgrade(repo.Root) { + if !hookNeedsUpgrade(repo.Root, "post-commit", hookVersionMarker) { t.Error("should detect outdated hook") } }) @@ -261,14 +293,14 @@ func TestHookNeedsUpgrade(t *testing.T) { t.Run("current hook", func(t *testing.T) { repo := testutil.NewTestRepo(t) repo.WriteHook("#!/bin/sh\n# roborev post-commit hook v2 - auto-reviews every commit\nroborev enqueue\n") - if hookNeedsUpgrade(repo.Root) { + if hookNeedsUpgrade(repo.Root, "post-commit", hookVersionMarker) { t.Error("should not flag current hook as outdated") } }) t.Run("no hook", func(t *testing.T) { repo := testutil.NewTestRepo(t) - if hookNeedsUpgrade(repo.Root) { + if hookNeedsUpgrade(repo.Root, "post-commit", hookVersionMarker) { t.Error("should not flag missing hook as outdated") } }) @@ -276,10 +308,68 @@ func TestHookNeedsUpgrade(t *testing.T) { t.Run("non-roborev hook", func(t *testing.T) { repo := testutil.NewTestRepo(t) repo.WriteHook("#!/bin/sh\necho hello\n") - if hookNeedsUpgrade(repo.Root) { + if hookNeedsUpgrade(repo.Root, "post-commit", hookVersionMarker) { t.Error("should not flag non-roborev hook as outdated") } }) + + t.Run("post-rewrite outdated", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + hooksDir := filepath.Join(repo.Root, ".git", "hooks") + os.MkdirAll(hooksDir, 0755) + os.WriteFile(filepath.Join(hooksDir, "post-rewrite"), + []byte("#!/bin/sh\n# roborev hook\nroborev remap\n"), 0755) + if !hookNeedsUpgrade(repo.Root, "post-rewrite", postRewriteHookVersionMarker) { + t.Error("should detect outdated post-rewrite hook") + } + }) + + t.Run("post-rewrite current", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + hooksDir := filepath.Join(repo.Root, ".git", "hooks") + os.MkdirAll(hooksDir, 0755) + os.WriteFile(filepath.Join(hooksDir, "post-rewrite"), + []byte("#!/bin/sh\n# roborev post-rewrite hook v1\nroborev remap\n"), 0755) + if hookNeedsUpgrade(repo.Root, "post-rewrite", postRewriteHookVersionMarker) { + t.Error("should not flag current post-rewrite hook") + } + }) +} + +func TestHookMissing(t *testing.T) { + t.Run("missing post-rewrite with roborev post-commit", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + repo.WriteHook("#!/bin/sh\n# roborev post-commit hook v2\nroborev enqueue\n") + if !hookMissing(repo.Root, "post-rewrite") { + t.Error("should detect missing post-rewrite when post-commit has roborev") + } + }) + + t.Run("no post-commit hook at all", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if hookMissing(repo.Root, "post-rewrite") { + t.Error("should not warn when post-commit is not installed") + } + }) + + t.Run("post-rewrite exists with roborev", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + repo.WriteHook("#!/bin/sh\n# roborev post-commit hook v2\nroborev enqueue\n") + hooksDir := filepath.Join(repo.Root, ".git", "hooks") + os.WriteFile(filepath.Join(hooksDir, "post-rewrite"), + []byte("#!/bin/sh\n# roborev post-rewrite hook v1\nroborev remap\n"), 0755) + if hookMissing(repo.Root, "post-rewrite") { + t.Error("should not warn when post-rewrite has roborev content") + } + }) + + t.Run("non-roborev post-commit", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + repo.WriteHook("#!/bin/sh\necho hello\n") + if hookMissing(repo.Root, "post-rewrite") { + t.Error("should not warn when post-commit is not roborev") + } + }) } func TestIsTransportError(t *testing.T) { @@ -456,3 +546,823 @@ func TestInitNoDaemon_Success(t *testing.T) { t.Errorf("should not show 'Setup incomplete' on success, got:\n%s", output) } } + +func TestInstallHookCmdCreatesPostRewriteHook(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("test checks Unix exec bits, skipping on Windows") + } + + repo := testutil.NewTestRepo(t) + repo.RemoveHooksDir() + defer repo.Chdir()() + + installCmd := installHookCmd() + installCmd.SetArgs([]string{}) + if err := installCmd.Execute(); err != nil { + t.Fatalf("install-hook failed: %v", err) + } + + prHookPath := filepath.Join(repo.HooksDir, "post-rewrite") + content, err := os.ReadFile(prHookPath) + if err != nil { + t.Fatalf("post-rewrite hook not created: %v", err) + } + + if !strings.Contains(string(content), "remap --quiet") { + t.Error("post-rewrite hook should contain 'remap --quiet'") + } + if !strings.Contains(string(content), postRewriteHookVersionMarker) { + t.Error("post-rewrite hook should contain version marker") + } +} + +func TestInstallOrUpgradeHook(t *testing.T) { + t.Run("appends to existing non-roborev hook", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + existing := "#!/bin/sh\necho 'custom logic'\n" + if err := os.WriteFile(hookPath, []byte(existing), 0755); err != nil { + t.Fatal(err) + } + + err := installOrUpgradeHook( + repo.HooksDir, "post-commit", + hookVersionMarker, generateHookContent, false, + ) + if err != nil { + t.Fatalf("installOrUpgradeHook: %v", err) + } + + content, _ := os.ReadFile(hookPath) + contentStr := string(content) + if !strings.Contains(contentStr, "echo 'custom logic'") { + t.Error("original content should be preserved") + } + if !strings.Contains(contentStr, hookVersionMarker) { + t.Error("roborev snippet should be appended") + } + }) + + t.Run("skips current version", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + current := generatePostRewriteHookContent() + if err := os.WriteFile(hookPath, []byte(current), 0755); err != nil { + t.Fatal(err) + } + + err := installOrUpgradeHook( + repo.HooksDir, "post-rewrite", + postRewriteHookVersionMarker, + generatePostRewriteHookContent, false, + ) + if err != nil { + t.Fatalf("installOrUpgradeHook: %v", err) + } + + content, _ := os.ReadFile(hookPath) + if string(content) != current { + t.Error("current hook should not be modified") + } + }) + + t.Run("upgrades outdated roborev hook", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + // Simulates an older generated hook (has marker but no version) + outdated := "#!/bin/sh\n" + + "# roborev post-commit hook\n" + + "ROBOREV=\"/usr/local/bin/roborev\"\n" + + "\"$ROBOREV\" enqueue --quiet 2>/dev/null\n" + if err := os.WriteFile(hookPath, []byte(outdated), 0755); err != nil { + t.Fatal(err) + } + + err := installOrUpgradeHook( + repo.HooksDir, "post-commit", + hookVersionMarker, generateHookContent, false, + ) + if err != nil { + t.Fatalf("installOrUpgradeHook: %v", err) + } + + content, _ := os.ReadFile(hookPath) + contentStr := string(content) + if !strings.Contains(contentStr, hookVersionMarker) { + t.Error("should have new version marker") + } + // Old marker should be gone + if strings.Contains(contentStr, "# roborev post-commit hook\n") { + t.Error("old marker should be removed") + } + }) + + t.Run("upgrades mixed hook preserving user content", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + mixed := "#!/bin/sh\necho 'user code'\n# roborev post-rewrite hook\nROBOREV=\"/usr/bin/roborev\"\n\"$ROBOREV\" remap --quiet 2>/dev/null\n" + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + err := installOrUpgradeHook( + repo.HooksDir, "post-rewrite", + postRewriteHookVersionMarker, + generatePostRewriteHookContent, false, + ) + if err != nil { + t.Fatalf("installOrUpgradeHook: %v", err) + } + + content, _ := os.ReadFile(hookPath) + contentStr := string(content) + if !strings.Contains(contentStr, "echo 'user code'") { + t.Error("user content should be preserved") + } + if !strings.Contains(contentStr, postRewriteHookVersionMarker) { + t.Error("should have new version marker") + } + }) + + t.Run("appends to hooks with various shell shebangs", func(t *testing.T) { + shebangs := []string{ + "#!/bin/sh", "#!/usr/bin/env sh", + "#!/bin/bash", "#!/usr/bin/env bash", + "#!/bin/zsh", "#!/usr/bin/env zsh", + "#!/bin/ksh", "#!/usr/bin/env ksh", + "#!/bin/dash", "#!/usr/bin/env dash", + } + for _, shebang := range shebangs { + t.Run(shebang, func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + existing := shebang + "\necho 'custom'\n" + if err := os.WriteFile(hookPath, []byte(existing), 0755); err != nil { + t.Fatal(err) + } + + err := installOrUpgradeHook( + repo.HooksDir, "post-commit", + hookVersionMarker, generateHookContent, false, + ) + if err != nil { + t.Fatalf("should append to %s hook: %v", shebang, err) + } + + content, _ := os.ReadFile(hookPath) + if !strings.Contains(string(content), "echo 'custom'") { + t.Error("original content should be preserved") + } + if !strings.Contains(string(content), hookVersionMarker) { + t.Error("roborev content should be appended") + } + }) + } + }) + + t.Run("refuses to append to non-shell hook", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + pythonHook := "#!/usr/bin/env python3\nprint('hello')\n" + if err := os.WriteFile(hookPath, []byte(pythonHook), 0755); err != nil { + t.Fatal(err) + } + + err := installOrUpgradeHook( + repo.HooksDir, "post-commit", + hookVersionMarker, generateHookContent, false, + ) + if err == nil { + t.Fatal("expected error for non-shell hook") + } + if !strings.Contains(err.Error(), "non-shell interpreter") { + t.Errorf("unexpected error: %v", err) + } + + // Verify hook was not modified + content, _ := os.ReadFile(hookPath) + if string(content) != pythonHook { + t.Errorf("hook should be unchanged, got:\n%s", content) + } + }) + + t.Run("upgrade returns error on re-read failure", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + // Outdated hook triggers the upgrade path. + outdated := "#!/bin/sh\n" + + "# roborev post-commit hook\n" + + "ROBOREV=\"/usr/local/bin/roborev\"\n" + + "\"$ROBOREV\" enqueue --quiet 2>/dev/null\n" + if err := os.WriteFile( + hookPath, []byte(outdated), 0755, + ); err != nil { + t.Fatal(err) + } + + // Inject a non-ENOENT error on re-read after cleanup. + origReadFile := hookReadFile + hookReadFile = func(string) ([]byte, error) { + return nil, fs.ErrPermission + } + t.Cleanup(func() { hookReadFile = origReadFile }) + + err := installOrUpgradeHook( + repo.HooksDir, "post-commit", + hookVersionMarker, generateHookContent, false, + ) + if err == nil { + t.Fatal("expected error from re-read failure") + } + if !strings.Contains(err.Error(), "re-read") { + t.Errorf("error should mention re-read, got: %v", err) + } + if !errors.Is(err, fs.ErrPermission) { + t.Errorf( + "error should wrap ErrPermission, got: %v", err, + ) + } + }) + + t.Run("refuses upgrade of non-shell hook mentioning roborev", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + // Non-shell hook that mentions roborev but has no version marker. + pythonHook := "#!/usr/bin/env python3\n# reviewed by roborev\nprint('hello')\n" + if err := os.WriteFile(hookPath, []byte(pythonHook), 0755); err != nil { + t.Fatal(err) + } + + err := installOrUpgradeHook( + repo.HooksDir, "post-commit", + hookVersionMarker, generateHookContent, false, + ) + if err == nil { + t.Fatal("expected error for non-shell hook in upgrade path") + } + if !strings.Contains(err.Error(), "non-shell interpreter") { + t.Errorf("unexpected error: %v", err) + } + + // Hook must not be modified. + content, _ := os.ReadFile(hookPath) + if string(content) != pythonHook { + t.Errorf("hook should be unchanged, got:\n%s", content) + } + }) + + t.Run("force overwrites existing hook", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + existing := "#!/bin/sh\necho 'custom'\n" + if err := os.WriteFile(hookPath, []byte(existing), 0755); err != nil { + t.Fatal(err) + } + + err := installOrUpgradeHook( + repo.HooksDir, "post-commit", + hookVersionMarker, generateHookContent, true, + ) + if err != nil { + t.Fatalf("installOrUpgradeHook: %v", err) + } + + content, _ := os.ReadFile(hookPath) + contentStr := string(content) + if strings.Contains(contentStr, "echo 'custom'") { + t.Error("force should overwrite, not append") + } + if !strings.Contains(contentStr, hookVersionMarker) { + t.Error("should have roborev content") + } + }) +} + +func TestRemoveRoborevFromHook(t *testing.T) { + t.Run("generated hook is deleted entirely", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + hookContent := generatePostRewriteHookContent() + if err := os.WriteFile(hookPath, []byte(hookContent), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + if _, err := os.Stat(hookPath); !os.IsNotExist(err) { + t.Error("generated hook should have been deleted entirely") + } + }) + + t.Run("mixed hook preserves non-roborev content", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + mixed := "#!/bin/sh\necho 'custom logic'\n" + generatePostRewriteHookContent() + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + contentStr := string(content) + if strings.Contains(strings.ToLower(contentStr), "roborev") { + t.Errorf("roborev content should be removed, got:\n%s", contentStr) + } + if !strings.Contains(contentStr, "echo 'custom logic'") { + t.Error("custom content should be preserved") + } + // Verify no orphaned fi + if strings.Contains(contentStr, "\nfi\n") || strings.HasSuffix(strings.TrimSpace(contentStr), "fi") { + t.Errorf("should not have orphaned fi, got:\n%s", contentStr) + } + }) + + t.Run("custom line mentioning roborev before snippet", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + // Custom line mentions roborev but isn't the generated marker + mixed := "#!/bin/sh\necho 'notify roborev team'\n" + + generatePostRewriteHookContent() + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + contentStr := string(content) + // The custom line should be preserved + if !strings.Contains(contentStr, "notify roborev team") { + t.Error("custom line mentioning roborev should be preserved") + } + // The generated snippet should be removed + if strings.Contains(contentStr, "remap --quiet") { + t.Errorf("generated snippet should be removed, got:\n%s", contentStr) + } + }) + + t.Run("custom if-block after snippet is preserved", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + // Snippet followed by user's own if-block + mixed := "#!/bin/sh\n" + + generatePostRewriteHookContent() + + "if [ -f .notify ]; then\n" + + " echo 'send notification'\n" + + "fi\n" + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + contentStr := string(content) + // User's if-block must survive + if !strings.Contains(contentStr, "send notification") { + t.Errorf("user if-block should be preserved, got:\n%s", contentStr) + } + if !strings.Contains(contentStr, "if [ -f .notify ]") { + t.Errorf("user if-statement should be preserved, got:\n%s", contentStr) + } + // No orphaned fi + lines := strings.Split(contentStr, "\n") + ifCount, fiCount := 0, 0 + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "if ") { + ifCount++ + } + if trimmed == "fi" { + fiCount++ + } + } + if ifCount != fiCount { + t.Errorf("if/fi mismatch: %d if vs %d fi in:\n%s", + ifCount, fiCount, contentStr) + } + // Generated snippet should be gone + if strings.Contains(contentStr, "remap --quiet") { + t.Errorf("generated snippet should be removed, got:\n%s", contentStr) + } + }) + + t.Run("custom comment starting with # roborev is preserved", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + // User comment starts with "# roborev" but isn't a generated marker + mixed := "#!/bin/sh\n# roborev notes: this hook was customized\necho 'custom'\n" + + generatePostRewriteHookContent() + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "roborev notes") { + t.Error("custom comment should be preserved") + } + if !strings.Contains(contentStr, "echo 'custom'") { + t.Error("custom echo should be preserved") + } + if strings.Contains(contentStr, "remap --quiet") { + t.Errorf("generated snippet should be removed, got:\n%s", contentStr) + } + }) + + t.Run("post-snippet user line mentioning roborev is preserved", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + // Snippet followed by user logic that mentions roborev + mixed := "#!/bin/sh\n" + + generatePostRewriteHookContent() + + "echo 'roborev hook finished'\n" + + "logger 'roborev done'\n" + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "roborev hook finished") { + t.Errorf("user line mentioning roborev should be preserved, got:\n%s", contentStr) + } + if !strings.Contains(contentStr, "roborev done") { + t.Errorf("user logger line should be preserved, got:\n%s", contentStr) + } + if strings.Contains(contentStr, "remap --quiet") { + t.Errorf("generated snippet should be removed, got:\n%s", contentStr) + } + }) + + t.Run("user $ROBOREV line after snippet is preserved", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + // User has their own "$ROBOREV" line (e.g. version check) after snippet + mixed := "#!/bin/sh\n" + + generatePostRewriteHookContent() + + "\"$ROBOREV\" --version\n" + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "\"$ROBOREV\" --version") { + t.Errorf("user $ROBOREV line should be preserved, got:\n%s", contentStr) + } + if strings.Contains(contentStr, "remap --quiet") { + t.Errorf("generated snippet should be removed, got:\n%s", contentStr) + } + }) + + t.Run("user $ROBOREV enqueue/remap lines after snippet are preserved", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + // User added their own enqueue/remap invocations with different flags + mixed := "#!/bin/sh\n" + + generatePostRewriteHookContent() + + "\"$ROBOREV\" enqueue --dry-run\n" + + "\"$ROBOREV\" remap --verbose\n" + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "enqueue --dry-run") { + t.Errorf("user enqueue line should be preserved, got:\n%s", contentStr) + } + if !strings.Contains(contentStr, "remap --verbose") { + t.Errorf("user remap line should be preserved, got:\n%s", contentStr) + } + if strings.Contains(contentStr, "remap --quiet") { + t.Errorf("generated snippet should be removed, got:\n%s", contentStr) + } + }) + + t.Run("user --quietly/--quiet-mode lines are preserved", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + // User lines that start with --quiet but aren't the exact generated form + mixed := "#!/bin/sh\n" + + generatePostRewriteHookContent() + + "\"$ROBOREV\" enqueue --quietly\n" + + "\"$ROBOREV\" remap --quiet-mode\n" + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "enqueue --quietly") { + t.Errorf("user --quietly line should be preserved, got:\n%s", contentStr) + } + if !strings.Contains(contentStr, "remap --quiet-mode") { + t.Errorf("user --quiet-mode line should be preserved, got:\n%s", contentStr) + } + }) + + t.Run("v0 hook (plain roborev invocation) is removed", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + v0Hook := "#!/bin/sh\n" + + "# RoboRev post-commit hook - auto-reviews every commit\n" + + "roborev enqueue --sha HEAD 2>/dev/null &\n" + if err := os.WriteFile(hookPath, []byte(v0Hook), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + if _, err := os.Stat(hookPath); !os.IsNotExist(err) { + content, _ := os.ReadFile(hookPath) + t.Errorf("v0 hook should be deleted entirely, got:\n%s", content) + } + }) + + t.Run("v0.5 hook (early variable format) is removed", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + v05Hook := "#!/bin/sh\n" + + "# RoboRev post-commit hook - auto-reviews every commit\n" + + "ROBOREV=\"/usr/local/bin/roborev\"\n" + + "if [ ! -x \"$ROBOREV\" ]; then\n" + + " ROBOREV=$(command -v roborev) || exit 0\n" + + "fi\n" + + "\"$ROBOREV\" enqueue --quiet &\n" + if err := os.WriteFile(hookPath, []byte(v05Hook), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + if _, err := os.Stat(hookPath); !os.IsNotExist(err) { + content, _ := os.ReadFile(hookPath) + t.Errorf("v0.5 hook should be deleted entirely, got:\n%s", content) + } + }) + + t.Run("v1 hook (PATH-first format) is removed", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + v1Hook := "#!/bin/sh\n" + + "# RoboRev post-commit hook - auto-reviews every commit\n" + + "ROBOREV=$(command -v roborev 2>/dev/null)\n" + + "if [ -z \"$ROBOREV\" ] || [ ! -x \"$ROBOREV\" ]; then\n" + + " ROBOREV=\"/usr/local/bin/roborev\"\n" + + " [ ! -x \"$ROBOREV\" ] && exit 0\n" + + "fi\n" + + "\"$ROBOREV\" enqueue --quiet 2>/dev/null &\n" + if err := os.WriteFile(hookPath, []byte(v1Hook), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + if _, err := os.Stat(hookPath); !os.IsNotExist(err) { + content, _ := os.ReadFile(hookPath) + t.Errorf("v1 hook should be deleted entirely, got:\n%s", content) + } + }) + + t.Run("v1 mixed hook removes only roborev block", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-commit") + v1Block := "# RoboRev post-commit hook - auto-reviews every commit\n" + + "ROBOREV=$(command -v roborev 2>/dev/null)\n" + + "if [ -z \"$ROBOREV\" ] || [ ! -x \"$ROBOREV\" ]; then\n" + + " ROBOREV=\"/usr/local/bin/roborev\"\n" + + " [ ! -x \"$ROBOREV\" ] && exit 0\n" + + "fi\n" + + "\"$ROBOREV\" enqueue --quiet 2>/dev/null &\n" + mixed := "#!/bin/sh\necho 'custom'\n" + v1Block + if err := os.WriteFile(hookPath, []byte(mixed), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + contentStr := string(content) + if strings.Contains(strings.ToLower(contentStr), "roborev") { + t.Errorf("roborev content should be removed, got:\n%s", contentStr) + } + if !strings.Contains(contentStr, "echo 'custom'") { + t.Error("custom content should be preserved") + } + }) + + t.Run("no-op if hook has no roborev content", func(t *testing.T) { + repo := testutil.NewTestRepo(t) + if err := os.MkdirAll(repo.HooksDir, 0755); err != nil { + t.Fatal(err) + } + hookPath := filepath.Join(repo.HooksDir, "post-rewrite") + hookContent := "#!/bin/sh\necho 'unrelated'\n" + if err := os.WriteFile(hookPath, []byte(hookContent), 0755); err != nil { + t.Fatal(err) + } + + if err := removeRoborevFromHook(hookPath); err != nil { + t.Fatalf("removeRoborevFromHook: %v", err) + } + + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("hook should still exist: %v", err) + } + if string(content) != hookContent { + t.Errorf("hook should be unchanged, got:\n%s", content) + } + }) +} + +func TestInitInstallsPostRewriteHookOnUpgrade(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("test uses shell script stub, skipping on Windows") + } + root := initNoDaemonSetup(t) + + // Pre-install a current post-commit hook so init takes the + // "already installed" goto path + hooksDir := filepath.Join(root, ".git", "hooks") + if err := os.MkdirAll(hooksDir, 0755); err != nil { + t.Fatal(err) + } + hookContent := generateHookContent() + if err := os.WriteFile(filepath.Join(hooksDir, "post-commit"), []byte(hookContent), 0755); err != nil { + t.Fatal(err) + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + oldAddr := serverAddr + serverAddr = ts.URL + defer func() { serverAddr = oldAddr }() + + captureStdout(t, func() { + cmd := initCmd() + cmd.SetArgs([]string{"--no-daemon"}) + _ = cmd.Execute() + }) + + // Verify post-rewrite hook was installed despite post-commit + // taking the "already installed" path + prHookPath := filepath.Join(hooksDir, "post-rewrite") + content, err := os.ReadFile(prHookPath) + if err != nil { + t.Fatalf("post-rewrite hook should be installed even when "+ + "post-commit is already current: %v", err) + } + if !strings.Contains(string(content), "remap --quiet") { + t.Error("post-rewrite hook should contain 'remap --quiet'") + } +} + +func TestGeneratePostRewriteHookContent(t *testing.T) { + content := generatePostRewriteHookContent() + + if !strings.HasPrefix(content, "#!/bin/sh\n") { + t.Error("hook should start with #!/bin/sh") + } + if !strings.Contains(content, postRewriteHookVersionMarker) { + t.Error("hook should contain version marker") + } + if !strings.Contains(content, "remap --quiet") { + t.Error("hook should call remap --quiet") + } +} diff --git a/cmd/roborev/main.go b/cmd/roborev/main.go index 89fa8f6b..198f0d26 100644 --- a/cmd/roborev/main.go +++ b/cmd/roborev/main.go @@ -78,6 +78,7 @@ func main() { rootCmd.AddCommand(repoCmd()) rootCmd.AddCommand(skillsCmd()) rootCmd.AddCommand(syncCmd()) + rootCmd.AddCommand(remapCmd()) rootCmd.AddCommand(checkAgentsCmd()) rootCmd.AddCommand(configCmd()) rootCmd.AddCommand(updateCmd()) @@ -467,6 +468,10 @@ func initCmd() *cobra.Command { if existing, err := os.ReadFile(hookPath); err == nil { existingStr := string(existing) if !strings.Contains(strings.ToLower(existingStr), "roborev") { + if !isShellHook(existingStr) { + fmt.Printf(" Warning: %s uses a non-shell interpreter, skipping roborev hook\n", hookPath) + goto startDaemon + } // Append to existing hook hookContent = existingStr + "\n" + hookContent } else if strings.Contains(existingStr, hookVersionMarker) { @@ -499,6 +504,11 @@ func initCmd() *cobra.Command { fmt.Printf(" Installed post-commit hook\n") startDaemon: + // 4b. Install post-rewrite hook (for rebase review preservation) + // Runs on all paths (fresh install, upgrade, already-installed) + // so existing users get the new hook on next `roborev init`. + installPostRewriteHook(hooksDir) + // 5. Start daemon (or just register if --no-daemon) var initIncomplete bool if noDaemon { @@ -1565,12 +1575,17 @@ func statusCmd() *cobra.Command { w.Flush() } - // Check for outdated hook in current repo + // Check for outdated hooks in current repo if root, err := git.GetRepoRoot("."); err == nil { - if hookNeedsUpgrade(root) { + if hookNeedsUpgrade(root, "post-commit", hookVersionMarker) { fmt.Println() fmt.Println("Warning: post-commit hook is outdated -- run 'roborev init' to upgrade") } + if hookNeedsUpgrade(root, "post-rewrite", postRewriteHookVersionMarker) || + hookMissing(root, "post-rewrite") { + fmt.Println() + fmt.Println("Warning: post-rewrite hook is missing or outdated -- run 'roborev init' to install") + } } return nil @@ -2335,25 +2350,26 @@ func installHookCmd() *cobra.Command { if err != nil { return fmt.Errorf("get hooks path: %w", err) } - hookPath := filepath.Join(hooksDir, "post-commit") - // Check if hook already exists - if _, err := os.Stat(hookPath); err == nil && !force { - return fmt.Errorf("hook already exists at %s (use --force to overwrite)", hookPath) - } - - // Ensure hooks directory exists if err := os.MkdirAll(hooksDir, 0755); err != nil { return fmt.Errorf("create hooks directory: %w", err) } - hookContent := generateHookContent() + if err := installOrUpgradeHook( + hooksDir, "post-commit", + hookVersionMarker, generateHookContent, force, + ); err != nil { + return err + } - if err := os.WriteFile(hookPath, []byte(hookContent), 0755); err != nil { - return fmt.Errorf("write hook: %w", err) + if err := installOrUpgradeHook( + hooksDir, "post-rewrite", + postRewriteHookVersionMarker, + generatePostRewriteHookContent, force, + ); err != nil { + return err } - fmt.Printf("Installed post-commit hook at %s\n", hookPath) return nil }, } @@ -2363,10 +2379,77 @@ func installHookCmd() *cobra.Command { return cmd } +// installOrUpgradeHook handles the append/upgrade/skip logic for a +// single hook file, following the design doc's per-path behavior: +// - No existing hook: write fresh +// - Existing without roborev: append +// - Existing with current version: skip +// - Existing with old version: upgrade (remove old, append new) +// - --force: overwrite unconditionally +// +// hookReadFile is used to re-read the hook file after cleanup during +// upgrade. Replaceable in tests to simulate read failures. +var hookReadFile = os.ReadFile + +func installOrUpgradeHook( + hooksDir, hookName, versionMarker string, + generate func() string, force bool, +) error { + hookPath := filepath.Join(hooksDir, hookName) + hookContent := generate() + + existing, err := os.ReadFile(hookPath) + if err == nil && !force { + existingStr := string(existing) + if !strings.Contains(strings.ToLower(existingStr), "roborev") { + if !isShellHook(existingStr) { + return fmt.Errorf( + "%s hook uses a non-shell interpreter; "+ + "add the roborev snippet manually or use --force to overwrite", + hookName) + } + // No roborev content — append + hookContent = existingStr + "\n" + hookContent + } else if strings.Contains(existingStr, versionMarker) { + fmt.Printf("%s hook already installed (current)\n", hookName) + return nil + } else { + // Upgrade: remove old snippet, append new one + if !isShellHook(existingStr) { + return fmt.Errorf( + "%s hook uses a non-shell interpreter; "+ + "add the roborev snippet manually "+ + "or use --force to overwrite", + hookName) + } + if rmErr := removeRoborevFromHook(hookPath); rmErr != nil { + return fmt.Errorf("upgrade %s: %w", hookName, rmErr) + } + updated, readErr := hookReadFile(hookPath) + if readErr != nil && !os.IsNotExist(readErr) { + return fmt.Errorf("re-read %s after cleanup: %w", hookName, readErr) + } + if readErr == nil { + remaining := string(updated) + if remaining != "" && !strings.HasSuffix(remaining, "\n") { + remaining += "\n" + } + hookContent = remaining + hookContent + } + } + } + + if err := os.WriteFile(hookPath, []byte(hookContent), 0755); err != nil { + return fmt.Errorf("write %s hook: %w", hookName, err) + } + fmt.Printf("Installed %s hook at %s\n", hookName, hookPath) + return nil +} + func uninstallHookCmd() *cobra.Command { return &cobra.Command{ Use: "uninstall-hook", - Short: "Remove post-commit hook from current repository", + Short: "Remove roborev hooks from current repository", RunE: func(cmd *cobra.Command, args []string) error { root, err := git.GetRepoRoot(".") if err != nil { @@ -2377,63 +2460,183 @@ func uninstallHookCmd() *cobra.Command { if err != nil { return fmt.Errorf("get hooks path: %w", err) } - hookPath := filepath.Join(hooksDir, "post-commit") - // Check if hook exists - content, err := os.ReadFile(hookPath) - if os.IsNotExist(err) { - fmt.Println("No post-commit hook found") - return nil - } else if err != nil { - return fmt.Errorf("read hook: %w", err) + for _, hookName := range []string{ + "post-commit", "post-rewrite", + } { + if err := removeRoborevFromHook( + filepath.Join(hooksDir, hookName), + ); err != nil { + return err + } } - // Check if it contains roborev (case-insensitive) - hookStr := string(content) - if !strings.Contains(strings.ToLower(hookStr), "roborev") { - fmt.Println("Post-commit hook does not contain roborev") - return nil - } + return nil + }, + } +} - // Remove roborev lines from the hook - lines := strings.Split(hookStr, "\n") - var newLines []string - for _, line := range lines { - // Skip roborev-related lines (case-insensitive) - if strings.Contains(strings.ToLower(line), "roborev") { - continue - } - newLines = append(newLines, line) - } +// removeRoborevFromHook removes the roborev block from a hook file, +// or deletes it entirely if nothing else remains. Uses block-based +// removal: drops all lines from the first roborev comment marker +// through the end of that contiguous block (since roborev appends +// its snippet as a self-contained block at the end). +// isShellHook returns true if the hook content starts with a +// POSIX-compatible shell shebang (sh, bash, zsh, ksh, dash). +// Used to avoid appending shell snippets to non-shell hooks. +func isShellHook(content string) bool { + first, _, _ := strings.Cut(content, "\n") + first = strings.TrimSpace(first) + for _, sh := range []string{"sh", "bash", "zsh", "ksh", "dash"} { + if strings.HasPrefix(first, "#!/bin/"+sh) || + strings.HasPrefix(first, "#!/usr/bin/env "+sh) { + return true + } + } + return false +} - // Check if anything remains (besides shebang and empty lines) - hasContent := false - for _, line := range newLines { - trimmed := strings.TrimSpace(line) - if trimmed != "" && !strings.HasPrefix(trimmed, "#!") { - hasContent = true - break - } - } +// isRoborevMarker returns true if the line is a generated roborev hook +// marker comment. Only matches the known generated forms: +// +// # roborev post-commit hook ... +// # roborev post-rewrite hook ... +func isRoborevMarker(line string) bool { + trimmed := strings.TrimSpace(strings.ToLower(line)) + return strings.HasPrefix(trimmed, "# roborev post-commit hook") || + strings.HasPrefix(trimmed, "# roborev post-rewrite hook") +} - if hasContent { - // Write back the hook without roborev lines - newContent := strings.Join(newLines, "\n") - if err := os.WriteFile(hookPath, []byte(newContent), 0755); err != nil { - return fmt.Errorf("write hook: %w", err) - } - fmt.Printf("Removed roborev from post-commit hook at %s\n", hookPath) - } else { - // Remove the hook entirely - if err := os.Remove(hookPath); err != nil { - return fmt.Errorf("remove hook: %w", err) - } - fmt.Printf("Removed post-commit hook at %s\n", hookPath) +// hasCommandPrefix checks if line starts with prefix and the prefix +// is followed by end-of-string, whitespace, or a shell operator +// (e.g. redirection). This prevents "enqueue --quiet" from matching +// "enqueue --quietly". +func hasCommandPrefix(line, prefix string) bool { + if !strings.HasPrefix(line, prefix) { + return false + } + if len(line) == len(prefix) { + return true + } + next := line[len(prefix)] + return next == ' ' || next == '\t' || next == '>' || + next == '|' || next == '&' || next == ';' +} + +// isRoborevSnippetLine returns true if the line is part of a +// generated roborev hook snippet (current or legacy versions). +func isRoborevSnippetLine(line string) bool { + trimmed := strings.TrimSpace(line) + if trimmed == "" { + return false + } + return strings.HasPrefix(trimmed, "ROBOREV=") || + strings.HasPrefix(trimmed, "ROBOREV=$(") || + hasCommandPrefix(trimmed, "\"$ROBOREV\" enqueue --quiet") || + hasCommandPrefix(trimmed, "\"$ROBOREV\" remap --quiet") || + hasCommandPrefix(trimmed, "roborev enqueue") || + hasCommandPrefix(trimmed, "roborev remap") || + strings.HasPrefix(trimmed, "if [ ! -x \"$ROBOREV\"") || + strings.HasPrefix(trimmed, "if [ -z \"$ROBOREV\"") || + strings.HasPrefix(trimmed, "[ -z \"$ROBOREV\"") || + strings.HasPrefix(trimmed, "[ ! -x \"$ROBOREV\"") +} + +func removeRoborevFromHook(hookPath string) error { + content, err := os.ReadFile(hookPath) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return fmt.Errorf("read %s: %w", filepath.Base(hookPath), err) + } + + hookStr := string(content) + if !strings.Contains(strings.ToLower(hookStr), "roborev") { + return nil + } + + lines := strings.Split(hookStr, "\n") + + // Find the start: anchor on the generated marker comment line + // (e.g. "# roborev post-commit hook v2 - ..."), not just any + // line that mentions roborev. + blockStart := -1 + for i, line := range lines { + if isRoborevMarker(line) { + blockStart = i + break + } + } + if blockStart < 0 { + return nil + } + + // Find the end: scan forward from the marker, consuming only + // lines that are part of the generated snippet. Stop at the + // first line that isn't a snippet line AND isn't "fi" closing + // a snippet's if-block. + blockEnd := blockStart + inIfBlock := false + for i := blockStart + 1; i < len(lines); i++ { + trimmed := strings.TrimSpace(lines[i]) + if trimmed == "" { + // Blank lines between snippet lines are consumed; a + // blank line after the snippet ends the block. + if i+1 < len(lines) && isRoborevSnippetLine(lines[i+1]) { + blockEnd = i + continue } + break + } + if isRoborevSnippetLine(trimmed) { + blockEnd = i + if strings.HasPrefix(trimmed, "if ") { + inIfBlock = true + } + continue + } + // "fi" only belongs to the block if we saw an "if" inside it + if trimmed == "fi" && inIfBlock { + blockEnd = i + inIfBlock = false + continue + } + break + } - return nil - }, + // Keep everything before and after the block + remaining := make([]string, 0, len(lines)) + remaining = append(remaining, lines[:blockStart]...) + remaining = append(remaining, lines[blockEnd+1:]...) + + // Check if anything meaningful remains + hasContent := false + for _, line := range remaining { + trimmed := strings.TrimSpace(line) + if trimmed != "" && !strings.HasPrefix(trimmed, "#!") { + hasContent = true + break + } } + + hookName := filepath.Base(hookPath) + if hasContent { + newContent := strings.Join(remaining, "\n") + if !strings.HasSuffix(newContent, "\n") { + newContent += "\n" + } + if err := os.WriteFile(hookPath, []byte(newContent), 0755); err != nil { + return fmt.Errorf("write %s: %w", hookName, err) + } + fmt.Printf("Removed roborev from %s\n", hookName) + } else { + if err := os.Remove(hookPath); err != nil { + return fmt.Errorf("remove %s: %w", hookName, err) + } + fmt.Printf("Removed %s hook\n", hookName) + } + return nil } func skillsCmd() *cobra.Command { @@ -3205,19 +3408,45 @@ func resolveReasoningWithFast(reasoning string, fast bool, reasoningExplicitlySe // Bump this when the hook template changes to trigger upgrade warnings. const hookVersionMarker = "post-commit hook v2" -// hookNeedsUpgrade checks whether a repo's post-commit hook contains roborev -// but is outdated (missing the current version marker). -func hookNeedsUpgrade(repoPath string) bool { +const postRewriteHookVersionMarker = "post-rewrite hook v1" + +// hookNeedsUpgrade checks whether a repo's named hook contains roborev +// but is outdated (missing the given version marker). +func hookNeedsUpgrade(repoPath, hookName, versionMarker string) bool { hooksDir, err := git.GetHooksPath(repoPath) if err != nil { return false } - content, err := os.ReadFile(filepath.Join(hooksDir, "post-commit")) + content, err := os.ReadFile(filepath.Join(hooksDir, hookName)) if err != nil { return false } s := string(content) - return strings.Contains(strings.ToLower(s), "roborev") && !strings.Contains(s, hookVersionMarker) + return strings.Contains(strings.ToLower(s), "roborev") && + !strings.Contains(s, versionMarker) +} + +// hookMissing checks whether a repo has roborev installed (post-commit +// hook present) but is missing the named hook entirely. +func hookMissing(repoPath, hookName string) bool { + hooksDir, err := git.GetHooksPath(repoPath) + if err != nil { + return false + } + // Only warn if roborev is installed (post-commit hook exists) + pcContent, err := os.ReadFile(filepath.Join(hooksDir, "post-commit")) + if err != nil { + return false + } + if !strings.Contains(strings.ToLower(string(pcContent)), "roborev") { + return false + } + // Check if the target hook is missing or has no roborev content + content, err := os.ReadFile(filepath.Join(hooksDir, hookName)) + if err != nil { + return true // hook file doesn't exist + } + return !strings.Contains(strings.ToLower(string(content)), "roborev") } func generateHookContent() string { @@ -3247,3 +3476,79 @@ fi "$ROBOREV" enqueue --quiet 2>/dev/null `, roborevPath) } + +func installPostRewriteHook(hooksDir string) { + hookPath := filepath.Join(hooksDir, "post-rewrite") + hookContent := generatePostRewriteHookContent() + + if existing, err := os.ReadFile(hookPath); err == nil { + existingStr := string(existing) + if !strings.Contains(strings.ToLower(existingStr), "roborev") { + if !isShellHook(existingStr) { + fmt.Printf(" Warning: %s uses a non-shell interpreter, skipping\n", hookPath) + return + } + // No roborev content — append to existing hook + hookContent = existingStr + "\n" + hookContent + } else if strings.Contains(existingStr, postRewriteHookVersionMarker) { + fmt.Println(" Post-rewrite hook already installed") + return + } else { + // Upgrade: remove old roborev snippet, append new one. + // This preserves user content around the snippet. + if !isShellHook(existingStr) { + fmt.Printf(" Warning: %s uses a non-shell interpreter, skipping\n", hookPath) + return + } + if rmErr := removeRoborevFromHook(hookPath); rmErr != nil { + fmt.Printf(" Warning: %v\n", rmErr) + return + } + updated, readErr := os.ReadFile(hookPath) + if readErr != nil && !os.IsNotExist(readErr) { + fmt.Printf(" Warning: re-read %s after cleanup: %v\n", + hookPath, readErr) + return + } + if readErr == nil { + remaining := string(updated) + if remaining != "" && !strings.HasSuffix(remaining, "\n") { + remaining += "\n" + } + hookContent = remaining + hookContent + } + // If the file was deleted (snippet-only), hookContent + // is already the fresh generated content. + } + } + + if err := os.WriteFile(hookPath, []byte(hookContent), 0755); err != nil { + fmt.Printf(" Warning: could not install post-rewrite hook: %v\n", err) + return + } + fmt.Printf(" Installed post-rewrite hook\n") +} + +func generatePostRewriteHookContent() string { + roborevPath, err := os.Executable() + if err == nil { + if resolved, err := filepath.EvalSymlinks(roborevPath); err == nil { + roborevPath = resolved + } + } else { + roborevPath, _ = exec.LookPath("roborev") + if roborevPath == "" { + roborevPath = "roborev" + } + } + + return fmt.Sprintf(`#!/bin/sh +# roborev post-rewrite hook v1 - remaps reviews after rebase/amend +ROBOREV=%q +if [ ! -x "$ROBOREV" ]; then + ROBOREV=$(command -v roborev 2>/dev/null) + [ -z "$ROBOREV" ] || [ ! -x "$ROBOREV" ] && exit 0 +fi +"$ROBOREV" remap --quiet 2>/dev/null +`, roborevPath) +} diff --git a/cmd/roborev/refine_test.go b/cmd/roborev/refine_test.go index 3c6b3d33..df9d46c4 100644 --- a/cmd/roborev/refine_test.go +++ b/cmd/roborev/refine_test.go @@ -167,6 +167,10 @@ func (m *mockDaemonClient) GetCommentsForJob(jobID int64) ([]storage.Response, e return m.responses[jobID], nil } +func (m *mockDaemonClient) Remap(req daemon.RemapRequest) (*daemon.RemapResult, error) { + return &daemon.RemapResult{}, nil +} + // WithReview adds a review to the mock client, returning the client for chaining. func (m *mockDaemonClient) WithReview(sha string, jobID int64, output string, addressed bool) *mockDaemonClient { m.nextReviewID++ diff --git a/cmd/roborev/remap.go b/cmd/roborev/remap.go new file mode 100644 index 00000000..52906ef0 --- /dev/null +++ b/cmd/roborev/remap.go @@ -0,0 +1,119 @@ +package main + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" + + "github.com/roborev-dev/roborev/internal/daemon" + "github.com/roborev-dev/roborev/internal/git" + "github.com/spf13/cobra" +) + +// gitSHAPattern matches a full hex git SHA: 40 chars (SHA-1) +// or 64 chars (SHA-256). +var gitSHAPattern = regexp.MustCompile(`^[0-9a-f]{40}([0-9a-f]{24})?$`) + +func remapCmd() *cobra.Command { + var quiet bool + + cmd := &cobra.Command{ + Use: "remap", + Short: "Remap review jobs after a rebase", + Hidden: true, + Long: `Reads old-sha/new-sha pairs from stdin (one per line, +space-separated) and updates review jobs to point at the +new commits. Called automatically by the post-rewrite hook.`, + RunE: func(cmd *cobra.Command, args []string) error { + gitCwd, err := git.GetRepoRoot(".") + if err != nil { + return fmt.Errorf("not a git repository: %w", err) + } + repoRoot, err := git.GetMainRepoRoot(".") + if err != nil { + repoRoot = gitCwd + } + + // Parse stdin: "old_sha new_sha" per line + var mappings []daemon.RemapMapping + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + oldSHA, newSHA := fields[0], fields[1] + + if !gitSHAPattern.MatchString(oldSHA) || + !gitSHAPattern.MatchString(newSHA) { + continue + } + + oldPatchID := git.GetPatchID(gitCwd, oldSHA) + newPatchID := git.GetPatchID(gitCwd, newSHA) + + // Skip if either has no patch-id or they differ + if oldPatchID == "" || newPatchID == "" { + continue + } + if oldPatchID != newPatchID { + continue + } + + info, err := git.GetCommitInfo(gitCwd, newSHA) + if err != nil { + continue + } + + mappings = append(mappings, daemon.RemapMapping{ + OldSHA: oldSHA, + NewSHA: newSHA, + PatchID: newPatchID, + Author: info.Author, + Subject: info.Subject, + Timestamp: info.Timestamp.Format("2006-01-02T15:04:05Z07:00"), + }) + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("read stdin: %w", err) + } + + if len(mappings) == 0 { + if !quiet { + fmt.Println("No mappings to remap") + } + return nil + } + + addr := getDaemonAddr() + client := daemon.NewHTTPClient(addr) + + result, err := client.Remap(daemon.RemapRequest{ + RepoPath: repoRoot, + Mappings: mappings, + }) + if err != nil { + if !quiet { + fmt.Fprintf(os.Stderr, "remap failed: %v\n", err) + } + return nil // Don't fail the hook + } + + if !quiet { + fmt.Printf("Remapped %d review(s), skipped %d\n", + result.Remapped, result.Skipped) + } + return nil + }, + } + + cmd.Flags().BoolVar(&quiet, "quiet", false, "suppress output") + + return cmd +} diff --git a/cmd/roborev/remap_test.go b/cmd/roborev/remap_test.go new file mode 100644 index 00000000..166bf103 --- /dev/null +++ b/cmd/roborev/remap_test.go @@ -0,0 +1,70 @@ +package main + +import ( + "strings" + "testing" +) + +func TestRemapStdinParsing(t *testing.T) { + input := "abc123 def456\nfoo bar\n\n baz qux \n" + + lines := strings.Split(strings.TrimSpace(input), "\n") + var pairs [][2]string + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + pairs = append(pairs, [2]string{fields[0], fields[1]}) + } + + expected := [][2]string{ + {"abc123", "def456"}, + {"foo", "bar"}, + {"baz", "qux"}, + } + + if len(pairs) != len(expected) { + t.Fatalf("expected %d pairs, got %d", len(expected), len(pairs)) + } + for i, p := range pairs { + if p != expected[i] { + t.Errorf("pair %d: expected %v, got %v", i, expected[i], p) + } + } +} + +func TestGitSHAValidation(t *testing.T) { + sha256Valid := "abc123def456abc123def456abc123def456abc1" + + "aabbccddeeff00112233aabb" + tests := []struct { + input string + valid bool + }{ + {"abc123def456abc123def456abc123def456abc1", true}, + {"0000000000000000000000000000000000000000", true}, + {"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", true}, + {sha256Valid, true}, // SHA-256 (64 chars) + {"abc123", false}, // too short + {"ABC123DEF456ABC123DEF456ABC123DEF456ABC1", false}, // uppercase + {"--option", false}, // flag injection + {"-n1", false}, // short flag + {"abc123def456abc123def456abc123def456abc1x", false}, // 41 chars + {"abc123def456abc123def456abc123def456abc", false}, // 39 chars + {"zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", false}, // non-hex + {"abc123def456abc123def456abc123def456abc1 ", false}, // trailing space + {sha256Valid + "aa", false}, // 66 chars + {sha256Valid[:63], false}, // 63 chars + } + for _, tt := range tests { + got := gitSHAPattern.MatchString(tt.input) + if got != tt.valid { + t.Errorf("gitSHAPattern.MatchString(%q) = %v, want %v", + tt.input, got, tt.valid) + } + } +} diff --git a/internal/daemon/client.go b/internal/daemon/client.go index 3303f3ab..76a587a8 100644 --- a/internal/daemon/client.go +++ b/internal/daemon/client.go @@ -43,6 +43,9 @@ type Client interface { // GetCommentsForJob fetches comments for a job GetCommentsForJob(jobID int64) ([]storage.Response, error) + + // Remap updates git_ref for jobs whose commits were rewritten + Remap(req RemapRequest) (*RemapResult, error) } // DefaultPollInterval is the default polling interval for WaitForReview. @@ -393,3 +396,38 @@ func (c *HTTPClient) GetCommentsForJob(jobID int64) ([]storage.Response, error) return result.Responses, nil } + +// RemapResult is the response from POST /api/remap. +type RemapResult struct { + Remapped int `json:"remapped"` + Skipped int `json:"skipped"` +} + +// Remap sends rewritten commit mappings to the daemon so that +// review jobs are updated to point at the new SHAs. +func (c *HTTPClient) Remap(req RemapRequest) (*RemapResult, error) { + reqBody, err := json.Marshal(req) + if err != nil { + return nil, err + } + + resp, err := c.httpClient.Post( + c.addr+"/api/remap", "application/json", + bytes.NewReader(reqBody), + ) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("remap: %s: %s", resp.Status, body) + } + + var result RemapResult + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + return &result, nil +} diff --git a/internal/daemon/remap_integration_test.go b/internal/daemon/remap_integration_test.go new file mode 100644 index 00000000..16e1d10a --- /dev/null +++ b/internal/daemon/remap_integration_test.go @@ -0,0 +1,311 @@ +//go:build integration + +package daemon + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + gitpkg "github.com/roborev-dev/roborev/internal/git" + "github.com/roborev-dev/roborev/internal/storage" + "github.com/roborev-dev/roborev/internal/testutil" +) + +// gitHelper runs git commands in a repo directory. +type gitHelper struct { + t *testing.T + dir string +} + +func (g *gitHelper) run(args ...string) { + g.t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = g.dir + if out, err := cmd.CombinedOutput(); err != nil { + g.t.Fatalf("git %v: %s: %v", args, out, err) + } +} + +func (g *gitHelper) headSHA() string { + g.t.Helper() + cmd := exec.Command("git", "rev-parse", "HEAD") + cmd.Dir = g.dir + out, err := cmd.Output() + if err != nil { + g.t.Fatalf("git rev-parse HEAD: %v", err) + } + return strings.TrimSpace(string(out)) +} + +func (g *gitHelper) commitFile(name, content, msg string) { + g.t.Helper() + if err := os.WriteFile(filepath.Join(g.dir, name), []byte(content), 0644); err != nil { + g.t.Fatal(err) + } + g.run("add", name) + g.run("commit", "-m", msg) +} + +func newGitRepo(t *testing.T) *gitHelper { + t.Helper() + dir := t.TempDir() + g := &gitHelper{t: t, dir: dir} + g.run("init", "-b", "main") + g.run("config", "user.email", "test@test.com") + g.run("config", "user.name", "Test") + return g +} + +// TestRemapAfterRebase exercises the rebase remap flow by calling +// the HTTP handler directly with real git operations. This validates +// the full DB flow (patch-id matching, commit creation, git_ref +// update) without requiring a running daemon. End-to-end testing +// through the actual post-rewrite hook → CLI → daemon path is +// covered by manual testing (see design doc verification section). +func TestRemapAfterRebase(t *testing.T) { + server, db, _ := newTestServer(t) + + repo := newGitRepo(t) + repo.commitFile("base.txt", "base content", "initial commit") + + // Resolve symlinks for macOS /var -> /private/var + repoDir, err := filepath.EvalSymlinks(repo.dir) + if err != nil { + repoDir = repo.dir + } + + dbRepo, err := db.GetOrCreateRepo(repoDir) + if err != nil { + t.Fatalf("GetOrCreateRepo: %v", err) + } + + // Create feature branch with a commit + repo.run("checkout", "-b", "feature") + repo.commitFile("feature.txt", "feature content", "add feature") + oldSHA := repo.headSHA() + + patchID := gitpkg.GetPatchID(repo.dir, oldSHA) + if patchID == "" { + t.Fatal("expected non-empty patch-id for feature commit") + } + + // Enqueue and complete a review + commit, err := db.GetOrCreateCommit( + dbRepo.ID, oldSHA, "Test", "add feature", time.Now(), + ) + if err != nil { + t.Fatalf("GetOrCreateCommit: %v", err) + } + job, err := db.EnqueueJob(storage.EnqueueOpts{ + RepoID: dbRepo.ID, + CommitID: commit.ID, + GitRef: oldSHA, + Agent: "test", + PatchID: patchID, + }) + if err != nil { + t.Fatalf("EnqueueJob: %v", err) + } + _, err = db.ClaimJob("worker-int") + if err != nil { + t.Fatalf("ClaimJob: %v", err) + } + err = db.CompleteJob(job.ID, "test", "prompt", "LGTM - no issues found") + if err != nil { + t.Fatalf("CompleteJob: %v", err) + } + + // Advance main so rebase has work to do + repo.run("checkout", "main") + repo.commitFile("main2.txt", "more main", "advance main") + + // Rebase feature onto main + repo.run("checkout", "feature") + repo.run("rebase", "main") + newSHA := repo.headSHA() + + if oldSHA == newSHA { + t.Fatal("SHAs should differ after rebase") + } + + newPatchID := gitpkg.GetPatchID(repo.dir, newSHA) + if patchID != newPatchID { + t.Fatalf("patch-ids should match after clean rebase: %s != %s", + patchID, newPatchID) + } + + info, err := gitpkg.GetCommitInfo(repo.dir, newSHA) + if err != nil { + t.Fatalf("GetCommitInfo: %v", err) + } + + // POST /api/remap + reqData := RemapRequest{ + RepoPath: repoDir, + Mappings: []RemapMapping{ + { + OldSHA: oldSHA, + NewSHA: newSHA, + PatchID: patchID, + Author: info.Author, + Subject: info.Subject, + Timestamp: info.Timestamp.Format(time.RFC3339), + }, + }, + } + req := testutil.MakeJSONRequest(t, http.MethodPost, "/api/remap", reqData) + w := httptest.NewRecorder() + server.handleRemap(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var result map[string]int + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatal(err) + } + if result["remapped"] != 1 { + t.Errorf("expected remapped=1, got %d", result["remapped"]) + } + + // Verify the job now points at the new SHA + updatedJob, err := db.GetJobByID(job.ID) + if err != nil { + t.Fatalf("GetJobByID: %v", err) + } + if updatedJob.GitRef != newSHA { + t.Errorf("expected git_ref=%s, got %s", newSHA, updatedJob.GitRef) + } + + // Verify the review is reachable via the new SHA + review, err := db.GetReviewByCommitSHA(newSHA) + if err != nil { + t.Fatalf("GetReviewByCommitSHA(%s): %v", newSHA, err) + } + if review == nil { + t.Fatal("review should be reachable via new SHA after remap") + } + if !strings.Contains(review.Output, "LGTM") { + t.Errorf("unexpected review output: %s", review.Output) + } +} + +// TestRemapAfterAmendMessageOnly exercises the message-only amend flow. +func TestRemapAfterAmendMessageOnly(t *testing.T) { + server, db, _ := newTestServer(t) + + repo := newGitRepo(t) + repo.commitFile("file.txt", "content", "original message") + oldSHA := repo.headSHA() + + repoDir, err := filepath.EvalSymlinks(repo.dir) + if err != nil { + repoDir = repo.dir + } + + dbRepo, err := db.GetOrCreateRepo(repoDir) + if err != nil { + t.Fatalf("GetOrCreateRepo: %v", err) + } + + patchID := gitpkg.GetPatchID(repo.dir, oldSHA) + if patchID == "" { + t.Fatal("expected non-empty patch-id") + } + + commit, err := db.GetOrCreateCommit( + dbRepo.ID, oldSHA, "Test", "original message", time.Now(), + ) + if err != nil { + t.Fatalf("GetOrCreateCommit: %v", err) + } + job, err := db.EnqueueJob(storage.EnqueueOpts{ + RepoID: dbRepo.ID, + CommitID: commit.ID, + GitRef: oldSHA, + Agent: "test", + PatchID: patchID, + }) + if err != nil { + t.Fatalf("EnqueueJob: %v", err) + } + _, err = db.ClaimJob("worker-amend") + if err != nil { + t.Fatalf("ClaimJob: %v", err) + } + err = db.CompleteJob(job.ID, "test", "prompt", "looks good") + if err != nil { + t.Fatalf("CompleteJob: %v", err) + } + + // Amend message only + repo.run("commit", "--amend", "-m", "amended message") + newSHA := repo.headSHA() + + if oldSHA == newSHA { + t.Fatal("SHAs should differ after amend") + } + + newPatchID := gitpkg.GetPatchID(repo.dir, newSHA) + if patchID != newPatchID { + t.Fatalf("patch-ids should match for message-only amend: %s != %s", + patchID, newPatchID) + } + + info, err := gitpkg.GetCommitInfo(repo.dir, newSHA) + if err != nil { + t.Fatalf("GetCommitInfo: %v", err) + } + + reqData := RemapRequest{ + RepoPath: repoDir, + Mappings: []RemapMapping{ + { + OldSHA: oldSHA, + NewSHA: newSHA, + PatchID: patchID, + Author: info.Author, + Subject: info.Subject, + Timestamp: info.Timestamp.Format(time.RFC3339), + }, + }, + } + req := testutil.MakeJSONRequest(t, http.MethodPost, "/api/remap", reqData) + w := httptest.NewRecorder() + server.handleRemap(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var result map[string]int + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatal(err) + } + if result["remapped"] != 1 { + t.Errorf("expected remapped=1, got %d", result["remapped"]) + } + + // Review should be reachable via new SHA + review, err := db.GetReviewByCommitSHA(newSHA) + if err != nil { + t.Fatalf("GetReviewByCommitSHA(%s): %v", newSHA, err) + } + if review == nil { + t.Fatal("review should be reachable via new SHA after amend remap") + } + + // Old SHA should no longer find the review + oldReview, err := db.GetReviewByCommitSHA(oldSHA) + if err == nil && oldReview != nil { + t.Error("old SHA should not find the review after remap") + } +} diff --git a/internal/daemon/server.go b/internal/daemon/server.go index 18571848..37ddbdd1 100644 --- a/internal/daemon/server.go +++ b/internal/daemon/server.go @@ -86,6 +86,7 @@ func NewServer(db *storage.DB, cfg *config.Config, configPath string) *Server { mux.HandleFunc("/api/status", s.handleStatus) mux.HandleFunc("/api/stream/events", s.handleStreamEvents) mux.HandleFunc("/api/jobs/batch", s.handleBatchJobs) + mux.HandleFunc("/api/remap", s.handleRemap) mux.HandleFunc("/api/sync/now", s.handleSyncNow) mux.HandleFunc("/api/sync/status", s.handleSyncStatus) @@ -140,9 +141,13 @@ func (s *Server) Start(ctx context.Context) error { // Check for outdated hooks in registered repos if repos, err := s.db.ListRepos(); err == nil { for _, repo := range repos { - if hookNeedsUpgrade(repo.RootPath) { + if hookNeedsUpgrade(repo.RootPath, "post-commit", hookVersionMarker) { log.Printf("Warning: outdated post-commit hook in %s -- run 'roborev init' to upgrade", repo.RootPath) } + if hookNeedsUpgrade(repo.RootPath, "post-rewrite", postRewriteHookVersionMarker) || + hookMissing(repo.RootPath, "post-rewrite") { + log.Printf("Warning: missing or outdated post-rewrite hook in %s -- run 'roborev init' to install", repo.RootPath) + } } } @@ -158,19 +163,43 @@ func (s *Server) Start(ctx context.Context) error { // hookVersionMarker identifies the current hook version. const hookVersionMarker = "post-commit hook v2" +const postRewriteHookVersionMarker = "post-rewrite hook v1" -// hookNeedsUpgrade checks whether a repo's post-commit hook is outdated. -func hookNeedsUpgrade(repoPath string) bool { +// hookNeedsUpgrade checks whether a repo's named hook contains roborev +// but is outdated (missing the given version marker). +func hookNeedsUpgrade(repoPath, hookName, versionMarker string) bool { hooksDir, err := git.GetHooksPath(repoPath) if err != nil { return false } - content, err := os.ReadFile(filepath.Join(hooksDir, "post-commit")) + content, err := os.ReadFile(filepath.Join(hooksDir, hookName)) if err != nil { return false } s := string(content) - return strings.Contains(strings.ToLower(s), "roborev") && !strings.Contains(s, hookVersionMarker) + return strings.Contains(strings.ToLower(s), "roborev") && + !strings.Contains(s, versionMarker) +} + +// hookMissing checks whether a repo has roborev installed (post-commit +// hook present) but is missing the named hook entirely. +func hookMissing(repoPath, hookName string) bool { + hooksDir, err := git.GetHooksPath(repoPath) + if err != nil { + return false + } + pcContent, err := os.ReadFile(filepath.Join(hooksDir, "post-commit")) + if err != nil { + return false + } + if !strings.Contains(strings.ToLower(string(pcContent)), "roborev") { + return false + } + content, err := os.ReadFile(filepath.Join(hooksDir, hookName)) + if err != nil { + return true + } + return !strings.Contains(strings.ToLower(string(content)), "roborev") } // Stop gracefully shuts down the server @@ -681,6 +710,8 @@ func (s *Server) handleEnqueue(w http.ResponseWriter, r *http.Request) { return } + patchID := git.GetPatchID(gitCwd, sha) + job, err = s.db.EnqueueJob(storage.EnqueueOpts{ RepoID: repo.ID, CommitID: commit.ID, @@ -690,6 +721,7 @@ func (s *Server) handleEnqueue(w http.ResponseWriter, r *http.Request) { Model: model, Reasoning: reasoning, ReviewType: req.ReviewType, + PatchID: patchID, }) if err != nil { writeError(w, http.StatusInternalServerError, fmt.Sprintf("enqueue job: %v", err)) @@ -1386,6 +1418,117 @@ func (s *Server) handleAddressReview(w http.ResponseWriter, r *http.Request) { writeJSON(w, map[string]any{"success": true}) } +// RemapRequest is the request body for POST /api/remap. +type RemapRequest struct { + RepoPath string `json:"repo_path"` + Mappings []RemapMapping `json:"mappings"` +} + +// RemapMapping maps a pre-rewrite SHA to its post-rewrite replacement. +type RemapMapping struct { + OldSHA string `json:"old_sha"` + NewSHA string `json:"new_sha"` + PatchID string `json:"patch_id"` + Author string `json:"author"` + Subject string `json:"subject"` + Timestamp string `json:"timestamp"` // RFC3339 +} + +func (s *Server) handleRemap(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + // Cap body size: ~200 bytes per mapping, 1000 max → 1MB is generous. + const maxRemapBody = 1 << 20 // 1MB + r.Body = http.MaxBytesReader(w, r.Body, maxRemapBody) + + var req RemapRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + writeError(w, http.StatusRequestEntityTooLarge, + "request body too large") + return + } + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + + const maxMappings = 1000 + if len(req.Mappings) > maxMappings { + writeError(w, http.StatusBadRequest, + fmt.Sprintf("too many mappings (%d, max %d)", + len(req.Mappings), maxMappings)) + return + } + + if req.RepoPath == "" { + writeError(w, http.StatusBadRequest, "repo_path is required") + return + } + + repoRoot, err := git.GetMainRepoRoot(req.RepoPath) + if err != nil { + writeError(w, http.StatusBadRequest, + fmt.Sprintf("not a git repository: %s", req.RepoPath)) + return + } + + repo, err := s.db.GetRepoByPath(repoRoot) + if errors.Is(err, sql.ErrNoRows) { + writeError(w, http.StatusNotFound, + fmt.Sprintf("unknown repo: %s", repoRoot)) + return + } + if err != nil { + writeError(w, http.StatusInternalServerError, + fmt.Sprintf("lookup repo: %v", err)) + return + } + + // Validate all timestamps upfront before modifying any state. + timestamps := make([]time.Time, len(req.Mappings)) + for i, m := range req.Mappings { + ts, err := time.Parse(time.RFC3339, m.Timestamp) + if err != nil { + writeError(w, http.StatusBadRequest, + fmt.Sprintf("invalid timestamp %q: %v", m.Timestamp, err)) + return + } + timestamps[i] = ts + } + + var remapped, skipped int + for i, m := range req.Mappings { + n, err := s.db.RemapJob( + repo.ID, m.OldSHA, m.NewSHA, m.PatchID, + m.Author, m.Subject, timestamps[i], + ) + if err != nil { + skipped++ + continue + } + remapped += n + if n == 0 { + skipped++ + } + } + + if remapped > 0 { + s.broadcaster.Broadcast(Event{ + Type: "review.remapped", + TS: time.Now(), + Repo: repo.RootPath, + }) + } + + writeJSON(w, map[string]int{ + "remapped": remapped, "skipped": skipped, + }) +} + func (s *Server) handleStreamEvents(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") diff --git a/internal/daemon/server_test.go b/internal/daemon/server_test.go index 53abeee9..1a6c5d0e 100644 --- a/internal/daemon/server_test.go +++ b/internal/daemon/server_test.go @@ -3267,3 +3267,223 @@ func TestHandleBatchJobs(t *testing.T) { } }) } + +func TestHandleRemap(t *testing.T) { + server, db, tmpDir := newTestServer(t) + + // Set up a git repo so handleRemap can resolve paths + repoDir := filepath.Join(tmpDir, "remap-repo") + if err := os.MkdirAll(repoDir, 0755); err != nil { + t.Fatal(err) + } + for _, args := range [][]string{ + {"init"}, + {"config", "user.email", "test@test.com"}, + {"config", "user.name", "Test"}, + } { + cmd := exec.Command("git", args...) + cmd.Dir = repoDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git %v: %s: %v", args, out, err) + } + } + + // Resolve symlinks to get the canonical path + // (macOS /var -> /private/var symlink) + resolvedDir, err := filepath.EvalSymlinks(repoDir) + if err != nil { + resolvedDir = repoDir + } + + repo, err := db.GetOrCreateRepo(resolvedDir) + if err != nil { + t.Fatal(err) + } + commit, err := db.GetOrCreateCommit(repo.ID, "oldsha111", "Test", "old commit", time.Now()) + if err != nil { + t.Fatal(err) + } + _, err = db.EnqueueJob(storage.EnqueueOpts{ + RepoID: repo.ID, + CommitID: commit.ID, + GitRef: "oldsha111", + Agent: "test", + PatchID: "patchXYZ", + }) + if err != nil { + t.Fatal(err) + } + + t.Run("remap updates job", func(t *testing.T) { + reqData := RemapRequest{ + RepoPath: repoDir, + Mappings: []RemapMapping{ + { + OldSHA: "oldsha111", + NewSHA: "newsha222", + PatchID: "patchXYZ", + Author: "Test", + Subject: "new commit", + Timestamp: time.Now().Format(time.RFC3339), + }, + }, + } + req := testutil.MakeJSONRequest(t, http.MethodPost, "/api/remap", reqData) + w := httptest.NewRecorder() + server.handleRemap(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var result map[string]int + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatal(err) + } + if result["remapped"] != 1 { + t.Errorf("expected remapped=1, got %d", result["remapped"]) + } + }) + + t.Run("remap with non-git path returns 400", func(t *testing.T) { + reqData := RemapRequest{ + RepoPath: "/nonexistent/repo", + Mappings: []RemapMapping{ + {OldSHA: "a", NewSHA: "b", PatchID: "c", Author: "x", Subject: "y", Timestamp: time.Now().Format(time.RFC3339)}, + }, + } + req := testutil.MakeJSONRequest(t, http.MethodPost, "/api/remap", reqData) + w := httptest.NewRecorder() + server.handleRemap(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) + } + }) + + t.Run("remap with unregistered repo returns 404", func(t *testing.T) { + // Create a valid git repo that is NOT registered in the DB + unregistered := filepath.Join(tmpDir, "unregistered-repo") + if err := os.MkdirAll(unregistered, 0755); err != nil { + t.Fatal(err) + } + cmd := exec.Command("git", "init") + cmd.Dir = unregistered + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git init: %s: %v", out, err) + } + + reqData := RemapRequest{ + RepoPath: unregistered, + Mappings: []RemapMapping{ + {OldSHA: "a", NewSHA: "b", PatchID: "c", Author: "x", Subject: "y", Timestamp: time.Now().Format(time.RFC3339)}, + }, + } + req := testutil.MakeJSONRequest(t, http.MethodPost, "/api/remap", reqData) + w := httptest.NewRecorder() + server.handleRemap(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String()) + } + }) + + t.Run("remap with invalid timestamp returns 400", func(t *testing.T) { + reqData := RemapRequest{ + RepoPath: repoDir, + Mappings: []RemapMapping{ + { + OldSHA: "oldsha111", + NewSHA: "newsha333", + PatchID: "patchXYZ", + Author: "Test", + Subject: "bad ts", + Timestamp: "not-a-timestamp", + }, + }, + } + req := testutil.MakeJSONRequest(t, http.MethodPost, "/api/remap", reqData) + w := httptest.NewRecorder() + server.handleRemap(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) + } + }) + + t.Run("remap with empty repo_path returns 400", func(t *testing.T) { + reqData := RemapRequest{ + RepoPath: "", + Mappings: []RemapMapping{ + { + OldSHA: "a", + NewSHA: "b", + PatchID: "c", + Author: "x", + Subject: "y", + Timestamp: time.Now().Format(time.RFC3339), + }, + }, + } + req := testutil.MakeJSONRequest( + t, http.MethodPost, "/api/remap", reqData, + ) + w := httptest.NewRecorder() + server.handleRemap(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", + w.Code, w.Body.String()) + } + }) + + t.Run("remap rejects too many mappings", func(t *testing.T) { + mappings := make([]RemapMapping, 1001) + for i := range mappings { + mappings[i] = RemapMapping{ + OldSHA: "a", NewSHA: "b", PatchID: "c", + Author: "x", Subject: "y", + Timestamp: time.Now().Format(time.RFC3339), + } + } + reqData := RemapRequest{ + RepoPath: repoDir, + Mappings: mappings, + } + req := testutil.MakeJSONRequest( + t, http.MethodPost, "/api/remap", reqData, + ) + w := httptest.NewRecorder() + server.handleRemap(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", + w.Code, w.Body.String()) + } + }) + + t.Run("remap rejects oversized body", func(t *testing.T) { + // Build a payload larger than 1MB. + // JSON-encode repoDir so Windows backslashes are escaped. + escapedDir, _ := json.Marshal(repoDir) + body := []byte(`{"repo_path":` + string(escapedDir) + `,"mappings":[`) + entry := []byte(`{"old_sha":"a","new_sha":"b","patch_id":"c","author":"x","subject":"y","timestamp":"2026-01-01T00:00:00Z"},`) + for len(body) < 1<<20+1 { + body = append(body, entry...) + } + body = append(body[:len(body)-1], []byte(`]}`)...) + + req := httptest.NewRequest( + http.MethodPost, "/api/remap", + bytes.NewReader(body), + ) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + server.handleRemap(w, req) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("expected 413, got %d: %s", + w.Code, w.Body.String()) + } + }) +} diff --git a/internal/git/git.go b/internal/git/git.go index 349d60b7..26621b87 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -917,6 +917,48 @@ func getRemoteURLByName(repoPath, name string) string { return strings.TrimSpace(string(out)) } +// GetPatchID returns the stable patch-id for a commit. Patch-ids are +// content-based hashes of the diff, so two commits with the same code +// change (e.g. before and after a rebase) share the same patch-id. +// Returns "" for merge commits, empty commits, or on any error. +func GetPatchID(repoPath, sha string) string { + show := exec.Command("git", "-c", "color.ui=false", "show", sha) + show.Dir = repoPath + + patchID := exec.Command("git", "patch-id", "--stable") + patchID.Dir = repoPath + + pipe, err := show.StdoutPipe() + if err != nil { + return "" + } + patchID.Stdin = pipe + + var out bytes.Buffer + patchID.Stdout = &out + + if err := show.Start(); err != nil { + return "" + } + if err := patchID.Start(); err != nil { + pipe.Close() // unblock show if pipe buffer is full + _ = show.Wait() + return "" + } + + _ = show.Wait() // only patchID's exit status matters + if err := patchID.Wait(); err != nil { + return "" + } + + // Output format: " \n" + fields := strings.Fields(out.String()) + if len(fields) < 1 { + return "" + } + return fields[0] +} + func getAnyRemoteURL(repoPath string) string { // List all remotes cmd := exec.Command("git", "remote") diff --git a/internal/git/git_test.go b/internal/git/git_test.go index 74f3ce64..486642c0 100644 --- a/internal/git/git_test.go +++ b/internal/git/git_test.go @@ -1507,3 +1507,68 @@ func TestIsAncestor(t *testing.T) { } }) } + +func TestGetPatchID(t *testing.T) { + t.Run("stable across rebase", func(t *testing.T) { + repo := NewTestRepo(t) + // Use -b to name the initial branch explicitly + repo.Run("checkout", "-b", "main") + repo.CommitFile("base.txt", "base", "initial") + + // Create a commit on a branch + repo.Run("checkout", "-b", "feature") + repo.CommitFile("feature.txt", "hello", "add feature") + sha1 := repo.HeadSHA() + patchID1 := GetPatchID(repo.Dir, sha1) + + if patchID1 == "" { + t.Fatal("expected non-empty patch-id") + } + + // Rebase onto a new base commit + repo.Run("checkout", "main") + repo.CommitFile("other.txt", "other", "another commit") + repo.Run("checkout", "feature") + repo.Run("rebase", "main") + sha2 := repo.HeadSHA() + patchID2 := GetPatchID(repo.Dir, sha2) + + if sha1 == sha2 { + t.Fatal("SHAs should differ after rebase") + } + if patchID1 != patchID2 { + t.Errorf("patch-ids should match: %s != %s", patchID1, patchID2) + } + }) + + t.Run("different for modified commits", func(t *testing.T) { + repo := NewTestRepo(t) + repo.CommitFile("a.txt", "content-a", "commit a") + sha1 := repo.HeadSHA() + + repo.CommitFile("b.txt", "content-b", "commit b") + sha2 := repo.HeadSHA() + + pid1 := GetPatchID(repo.Dir, sha1) + pid2 := GetPatchID(repo.Dir, sha2) + + if pid1 == "" || pid2 == "" { + t.Fatal("expected non-empty patch-ids") + } + if pid1 == pid2 { + t.Error("different diffs should produce different patch-ids") + } + }) + + t.Run("empty for empty commit", func(t *testing.T) { + repo := NewTestRepo(t) + repo.CommitFile("a.txt", "content", "first") + repo.Run("commit", "--allow-empty", "-m", "empty") + sha := repo.HeadSHA() + + pid := GetPatchID(repo.Dir, sha) + if pid != "" { + t.Errorf("expected empty patch-id for empty commit, got %s", pid) + } + }) +} diff --git a/internal/storage/db.go b/internal/storage/db.go index ea332ea4..a44ee8cb 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -580,6 +580,22 @@ func (db *DB) migrate() error { } } + // Migration: add patch_id column to review_jobs if missing + err = db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('review_jobs') WHERE name = 'patch_id'`).Scan(&count) + if err != nil { + return fmt.Errorf("check patch_id column: %w", err) + } + if count == 0 { + _, err = db.Exec(`ALTER TABLE review_jobs ADD COLUMN patch_id TEXT`) + if err != nil { + return fmt.Errorf("add patch_id column: %w", err) + } + _, err = db.Exec(`CREATE INDEX IF NOT EXISTS idx_review_jobs_patch_id ON review_jobs(patch_id)`) + if err != nil { + return fmt.Errorf("create idx_review_jobs_patch_id: %w", err) + } + } + // Migration: add index on reviews.addressed for server-side filtering _, err = db.Exec(`CREATE INDEX IF NOT EXISTS idx_reviews_addressed ON reviews(addressed)`) if err != nil { diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index cffd7c51..bc09c02c 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -2492,6 +2492,123 @@ func TestListBranchesWithCounts(t *testing.T) { }) } +func TestPatchIDMigration(t *testing.T) { + db := openTestDB(t) + defer db.Close() + + var count int + err := db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('review_jobs') WHERE name = 'patch_id'`).Scan(&count) + if err != nil { + t.Fatalf("check patch_id column: %v", err) + } + if count != 1 { + t.Errorf("expected patch_id column to exist, got count=%d", count) + } +} + +func TestEnqueueJobWithPatchID(t *testing.T) { + db := openTestDB(t) + defer db.Close() + + repo := createRepo(t, db, "/tmp/test-patch-id") + commit := createCommit(t, db, repo.ID, "abc123") + + job, err := db.EnqueueJob(EnqueueOpts{ + RepoID: repo.ID, + CommitID: commit.ID, + GitRef: "abc123", + Agent: "test", + PatchID: "deadbeef1234", + }) + if err != nil { + t.Fatalf("EnqueueJob: %v", err) + } + if job.PatchID != "deadbeef1234" { + t.Errorf("expected PatchID=deadbeef1234, got %q", job.PatchID) + } + + // Verify it round-trips through GetJobByID + got, err := db.GetJobByID(job.ID) + if err != nil { + t.Fatalf("GetJobByID: %v", err) + } + if got.PatchID != "deadbeef1234" { + t.Errorf("GetJobByID: expected PatchID=deadbeef1234, got %q", got.PatchID) + } +} + +func TestRemapJobGitRef(t *testing.T) { + db := openTestDB(t) + defer db.Close() + + repo := createRepo(t, db, "/tmp/test-remap") + commit := createCommit(t, db, repo.ID, "oldsha") + + t.Run("remap updates matching jobs", func(t *testing.T) { + job, err := db.EnqueueJob(EnqueueOpts{ + RepoID: repo.ID, + CommitID: commit.ID, + GitRef: "oldsha", + Agent: "test", + PatchID: "patchabc", + }) + if err != nil { + t.Fatalf("EnqueueJob: %v", err) + } + + newCommit := createCommit(t, db, repo.ID, "newsha") + n, err := db.RemapJobGitRef(repo.ID, "oldsha", "newsha", "patchabc", newCommit.ID) + if err != nil { + t.Fatalf("RemapJobGitRef: %v", err) + } + if n != 1 { + t.Errorf("expected 1 row updated, got %d", n) + } + + got, err := db.GetJobByID(job.ID) + if err != nil { + t.Fatalf("GetJobByID: %v", err) + } + if got.GitRef != "newsha" { + t.Errorf("expected git_ref=newsha, got %q", got.GitRef) + } + }) + + t.Run("skips on patch_id mismatch", func(t *testing.T) { + commit2 := createCommit(t, db, repo.ID, "sha2") + _, err := db.EnqueueJob(EnqueueOpts{ + RepoID: repo.ID, + CommitID: commit2.ID, + GitRef: "sha2", + Agent: "test", + PatchID: "patch_original", + }) + if err != nil { + t.Fatalf("EnqueueJob: %v", err) + } + + newCommit := createCommit(t, db, repo.ID, "sha2_new") + n, err := db.RemapJobGitRef(repo.ID, "sha2", "sha2_new", "patch_different", newCommit.ID) + if err != nil { + t.Fatalf("RemapJobGitRef: %v", err) + } + if n != 0 { + t.Errorf("expected 0 rows updated (patch_id mismatch), got %d", n) + } + }) + + t.Run("returns 0 for no matches", func(t *testing.T) { + newCommit := createCommit(t, db, repo.ID, "nonexistent_new") + n, err := db.RemapJobGitRef(repo.ID, "nonexistent", "nonexistent_new", "patch", newCommit.ID) + if err != nil { + t.Fatalf("RemapJobGitRef: %v", err) + } + if n != 0 { + t.Errorf("expected 0 rows updated, got %d", n) + } + }) +} + func openTestDB(t *testing.T) *DB { t.Helper() tmpl, err := getTemplatePath() diff --git a/internal/storage/jobs.go b/internal/storage/jobs.go index 679965a8..ec15e275 100644 --- a/internal/storage/jobs.go +++ b/internal/storage/jobs.go @@ -3,6 +3,7 @@ package storage import ( "context" "database/sql" + "fmt" "log" "slices" "strings" @@ -719,6 +720,7 @@ type EnqueueOpts struct { Model string Reasoning string ReviewType string // e.g. "security" — changes which system prompt is used + PatchID string // Stable patch-id for rebase tracking DiffContent string // For dirty reviews (captured at enqueue time) Prompt string // For task jobs (pre-stored prompt) OutputPrefix string // Prefix to prepend to review output @@ -779,12 +781,12 @@ func (db *DB) EnqueueJob(opts EnqueueOpts) (*ReviewJob, error) { result, err := db.Exec(` INSERT INTO review_jobs (repo_id, commit_id, git_ref, branch, agent, model, reasoning, - status, job_type, review_type, diff_content, prompt, agentic, output_prefix, + status, job_type, review_type, patch_id, diff_content, prompt, agentic, output_prefix, uuid, source_machine_id, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, 'queued', ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES (?, ?, ?, ?, ?, ?, ?, 'queued', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, opts.RepoID, commitIDParam, gitRef, nullString(opts.Branch), opts.Agent, nullString(opts.Model), reasoning, - jobType, opts.ReviewType, + jobType, opts.ReviewType, nullString(opts.PatchID), nullString(opts.DiffContent), nullString(opts.Prompt), agenticInt, nullString(opts.OutputPrefix), uid, machineID, nowStr) @@ -803,6 +805,7 @@ func (db *DB) EnqueueJob(opts EnqueueOpts) (*ReviewJob, error) { Reasoning: reasoning, JobType: jobType, ReviewType: opts.ReviewType, + PatchID: opts.PatchID, Status: JobStatusQueued, EnqueuedAt: now, Prompt: opts.Prompt, @@ -863,10 +866,11 @@ func (db *DB) ClaimJob(workerID string) (*ReviewJob, error) { var jobType sql.NullString var reviewType sql.NullString var outputPrefix sql.NullString + var patchID sql.NullString err = db.QueryRow(` SELECT j.id, j.repo_id, j.commit_id, j.git_ref, j.branch, j.agent, j.model, j.reasoning, j.status, j.enqueued_at, r.root_path, r.name, c.subject, j.diff_content, j.prompt, COALESCE(j.agentic, 0), j.job_type, j.review_type, - j.output_prefix + j.output_prefix, j.patch_id FROM review_jobs j JOIN repos r ON r.id = j.repo_id LEFT JOIN commits c ON c.id = j.commit_id @@ -875,7 +879,7 @@ func (db *DB) ClaimJob(workerID string) (*ReviewJob, error) { LIMIT 1 `, workerID).Scan(&job.ID, &job.RepoID, &commitID, &job.GitRef, &branch, &job.Agent, &model, &job.Reasoning, &job.Status, &enqueuedAt, &job.RepoPath, &job.RepoName, &commitSubject, &diffContent, &prompt, &agenticInt, &jobType, &reviewType, - &outputPrefix) + &outputPrefix, &patchID) if err != nil { return nil, err } @@ -908,6 +912,9 @@ func (db *DB) ClaimJob(workerID string) (*ReviewJob, error) { if outputPrefix.Valid { job.OutputPrefix = outputPrefix.String } + if patchID.Valid { + job.PatchID = patchID.String + } job.EnqueuedAt = parseSQLiteTime(enqueuedAt) job.Status = JobStatusRunning job.WorkerID = workerID @@ -1151,7 +1158,7 @@ func (db *DB) ListJobs(statusFilter string, repoFilter string, limit, offset int SELECT j.id, j.repo_id, j.commit_id, j.git_ref, j.branch, j.agent, j.reasoning, j.status, j.enqueued_at, j.started_at, j.finished_at, j.worker_id, j.error, j.prompt, j.retry_count, COALESCE(j.agentic, 0), r.root_path, r.name, c.subject, rv.addressed, rv.output, - j.source_machine_id, j.uuid, j.model, j.job_type, j.review_type + j.source_machine_id, j.uuid, j.model, j.job_type, j.review_type, j.patch_id FROM review_jobs j JOIN repos r ON r.id = j.repo_id LEFT JOIN commits c ON c.id = j.commit_id @@ -1218,7 +1225,7 @@ func (db *DB) ListJobs(statusFilter string, repoFilter string, limit, offset int for rows.Next() { var j ReviewJob var enqueuedAt string - var startedAt, finishedAt, workerID, errMsg, prompt, output, sourceMachineID, jobUUID, model, branch, jobTypeStr, reviewTypeStr sql.NullString + var startedAt, finishedAt, workerID, errMsg, prompt, output, sourceMachineID, jobUUID, model, branch, jobTypeStr, reviewTypeStr, patchIDStr sql.NullString var commitID sql.NullInt64 var commitSubject sql.NullString var addressed sql.NullInt64 @@ -1227,7 +1234,7 @@ func (db *DB) ListJobs(statusFilter string, repoFilter string, limit, offset int err := rows.Scan(&j.ID, &j.RepoID, &commitID, &j.GitRef, &branch, &j.Agent, &j.Reasoning, &j.Status, &enqueuedAt, &startedAt, &finishedAt, &workerID, &errMsg, &prompt, &j.RetryCount, &agentic, &j.RepoPath, &j.RepoName, &commitSubject, &addressed, &output, - &sourceMachineID, &jobUUID, &model, &jobTypeStr, &reviewTypeStr) + &sourceMachineID, &jobUUID, &model, &jobTypeStr, &reviewTypeStr, &patchIDStr) if err != nil { return nil, err } @@ -1272,6 +1279,9 @@ func (db *DB) ListJobs(statusFilter string, repoFilter string, limit, offset int if reviewTypeStr.Valid { j.ReviewType = reviewTypeStr.String } + if patchIDStr.Valid { + j.PatchID = patchIDStr.String + } if branch.Valid { j.Branch = branch.String } @@ -1349,18 +1359,18 @@ func (db *DB) GetJobByID(id int64) (*ReviewJob, error) { var commitSubject sql.NullString var agentic int - var model, branch, jobTypeStr, reviewTypeStr sql.NullString + var model, branch, jobTypeStr, reviewTypeStr, patchIDStr sql.NullString err := db.QueryRow(` SELECT j.id, j.repo_id, j.commit_id, j.git_ref, j.branch, j.agent, j.reasoning, j.status, j.enqueued_at, j.started_at, j.finished_at, j.worker_id, j.error, j.prompt, COALESCE(j.agentic, 0), - r.root_path, r.name, c.subject, j.model, j.job_type, j.review_type + r.root_path, r.name, c.subject, j.model, j.job_type, j.review_type, j.patch_id FROM review_jobs j JOIN repos r ON r.id = j.repo_id LEFT JOIN commits c ON c.id = j.commit_id WHERE j.id = ? `, id).Scan(&j.ID, &j.RepoID, &commitID, &j.GitRef, &branch, &j.Agent, &j.Reasoning, &j.Status, &enqueuedAt, &startedAt, &finishedAt, &workerID, &errMsg, &prompt, &agentic, - &j.RepoPath, &j.RepoName, &commitSubject, &model, &jobTypeStr, &reviewTypeStr) + &j.RepoPath, &j.RepoName, &commitSubject, &model, &jobTypeStr, &reviewTypeStr, &patchIDStr) if err != nil { return nil, err } @@ -1399,6 +1409,9 @@ func (db *DB) GetJobByID(id int64) (*ReviewJob, error) { if reviewTypeStr.Valid { j.ReviewType = reviewTypeStr.String } + if patchIDStr.Valid { + j.PatchID = patchIDStr.String + } if branch.Valid { j.Branch = branch.String } @@ -1448,3 +1461,96 @@ func (db *DB) UpdateJobBranch(jobID int64, branch string) (int64, error) { } return result.RowsAffected() } + +// RemapJobGitRef updates git_ref and commit_id for jobs matching +// oldSHA in a repo, used after rebases to preserve review history. +// If a job has a stored patch_id that differs from the provided one, +// that job is skipped (the commit's content changed). +// Returns the number of rows updated. +func (db *DB) RemapJobGitRef( + repoID int64, oldSHA, newSHA, patchID string, newCommitID int64, +) (int, error) { + now := time.Now().Format(time.RFC3339) + result, err := db.Exec(` + UPDATE review_jobs + SET git_ref = ?, commit_id = ?, patch_id = ?, updated_at = ? + WHERE git_ref = ? AND repo_id = ? + AND (patch_id IS NULL OR patch_id = '' OR patch_id = ?) + `, newSHA, newCommitID, nullString(patchID), now, oldSHA, repoID, patchID) + if err != nil { + return 0, fmt.Errorf("remap job git_ref: %w", err) + } + n, err := result.RowsAffected() + if err != nil { + return 0, err + } + return int(n), nil +} + +// RemapJob atomically checks for matching jobs, creates the commit +// row, and updates git_ref — all in a single transaction to prevent +// orphan commit rows or races between concurrent remaps. +func (db *DB) RemapJob( + repoID int64, oldSHA, newSHA, patchID string, + author, subject string, timestamp time.Time, +) (int, error) { + tx, err := db.Begin() + if err != nil { + return 0, fmt.Errorf("begin remap tx: %w", err) + } + defer tx.Rollback() //nolint:errcheck + + var matchCount int + err = tx.QueryRow(` + SELECT COUNT(*) FROM review_jobs + WHERE git_ref = ? AND repo_id = ? + AND (patch_id IS NULL OR patch_id = '' OR patch_id = ?) + `, oldSHA, repoID, patchID).Scan(&matchCount) + if err != nil { + return 0, fmt.Errorf("count matching jobs: %w", err) + } + if matchCount == 0 { + return 0, nil + } + + // Create or find commit row for the new SHA + var commitID int64 + err = tx.QueryRow( + `SELECT id FROM commits WHERE repo_id = ? AND sha = ?`, + repoID, newSHA, + ).Scan(&commitID) + if err == sql.ErrNoRows { + result, insertErr := tx.Exec(` + INSERT INTO commits (repo_id, sha, author, subject, timestamp) + VALUES (?, ?, ?, ?, ?) + `, repoID, newSHA, author, subject, + timestamp.Format(time.RFC3339)) + if insertErr != nil { + return 0, fmt.Errorf("create commit: %w", insertErr) + } + commitID, _ = result.LastInsertId() + } else if err != nil { + return 0, fmt.Errorf("find commit: %w", err) + } + + now := time.Now().Format(time.RFC3339) + result, err := tx.Exec(` + UPDATE review_jobs + SET git_ref = ?, commit_id = ?, patch_id = ?, updated_at = ? + WHERE git_ref = ? AND repo_id = ? + AND (patch_id IS NULL OR patch_id = '' OR patch_id = ?) + `, newSHA, commitID, nullString(patchID), now, + oldSHA, repoID, patchID) + if err != nil { + return 0, fmt.Errorf("remap job git_ref: %w", err) + } + n, err := result.RowsAffected() + if err != nil { + return 0, err + } + + if err := tx.Commit(); err != nil { + return 0, fmt.Errorf("commit remap tx: %w", err) + } + return int(n), nil +} diff --git a/internal/storage/models.go b/internal/storage/models.go index 107b4393..18829f4f 100644 --- a/internal/storage/models.go +++ b/internal/storage/models.go @@ -63,6 +63,7 @@ type ReviewJob struct { DiffContent *string `json:"diff_content,omitempty"` // For dirty reviews (uncommitted changes) Agentic bool `json:"agentic"` // Enable agentic mode (allow file edits) ReviewType string `json:"review_type,omitempty"` // Review type (e.g., "security") - changes system prompt + PatchID string `json:"patch_id,omitempty"` // Stable patch-id for rebase tracking OutputPrefix string `json:"output_prefix,omitempty"` // Prefix to prepend to review output // Sync fields diff --git a/internal/storage/postgres.go b/internal/storage/postgres.go index 50b1feb7..dbaf0a8b 100644 --- a/internal/storage/postgres.go +++ b/internal/storage/postgres.go @@ -14,12 +14,12 @@ import ( ) // PostgreSQL schema version - increment when schema changes -const pgSchemaVersion = 4 +const pgSchemaVersion = 5 // pgSchemaName is the PostgreSQL schema used to isolate roborev tables const pgSchemaName = "roborev" -//go:embed schemas/postgres_v4.sql +//go:embed schemas/postgres_v5.sql var pgSchemaSQL string // pgSchemaStatements returns the individual DDL statements for schema creation. @@ -182,6 +182,10 @@ func (p *PgPool) EnsureSchema(ctx context.Context) error { if err != nil { return fmt.Errorf("create job_type index: %w", err) } + _, err = p.pool.Exec(ctx, `CREATE INDEX IF NOT EXISTS idx_review_jobs_patch_id ON review_jobs(patch_id)`) + if err != nil { + return fmt.Errorf("create patch_id index: %w", err) + } } else if currentVersion > pgSchemaVersion { return fmt.Errorf("database schema version %d is newer than supported version %d", currentVersion, pgSchemaVersion) } else if currentVersion < pgSchemaVersion { @@ -235,6 +239,17 @@ func (p *PgPool) EnsureSchema(ctx context.Context) error { return fmt.Errorf("migrate to v4 (add review_type column): %w", err) } } + if currentVersion < 5 { + // Migration 4->5: Add patch_id column to review_jobs + _, err = p.pool.Exec(ctx, `ALTER TABLE review_jobs ADD COLUMN IF NOT EXISTS patch_id TEXT`) + if err != nil { + return fmt.Errorf("migrate to v5 (add patch_id column): %w", err) + } + _, err = p.pool.Exec(ctx, `CREATE INDEX IF NOT EXISTS idx_review_jobs_patch_id ON review_jobs(patch_id)`) + if err != nil { + return fmt.Errorf("migrate to v5 (add patch_id index): %w", err) + } + } // Update version _, err = p.pool.Exec(ctx, `INSERT INTO schema_version (version) VALUES ($1) ON CONFLICT (version) DO NOTHING`, pgSchemaVersion) if err != nil { @@ -479,18 +494,21 @@ func (p *PgPool) Tx(ctx context.Context, fn func(tx pgx.Tx) error) error { func (p *PgPool) UpsertJob(ctx context.Context, j SyncableJob, pgRepoID int64, pgCommitID *int64) error { _, err := p.pool.Exec(ctx, ` INSERT INTO review_jobs ( - uuid, repo_id, commit_id, git_ref, agent, model, reasoning, job_type, review_type, status, agentic, + uuid, repo_id, commit_id, git_ref, agent, model, reasoning, job_type, review_type, patch_id, status, agentic, enqueued_at, started_at, finished_at, prompt, diff_content, error, source_machine_id, updated_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, NOW()) + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, NOW()) ON CONFLICT (uuid) DO UPDATE SET status = EXCLUDED.status, finished_at = EXCLUDED.finished_at, error = EXCLUDED.error, model = COALESCE(EXCLUDED.model, review_jobs.model), + git_ref = EXCLUDED.git_ref, + commit_id = EXCLUDED.commit_id, + patch_id = EXCLUDED.patch_id, updated_at = NOW() `, j.UUID, pgRepoID, pgCommitID, j.GitRef, j.Agent, nullString(j.Model), nullString(j.Reasoning), - defaultStr(j.JobType, "review"), j.ReviewType, j.Status, j.Agentic, j.EnqueuedAt, j.StartedAt, j.FinishedAt, + defaultStr(j.JobType, "review"), j.ReviewType, nullString(j.PatchID), j.Status, j.Agentic, j.EnqueuedAt, j.StartedAt, j.FinishedAt, nullString(j.Prompt), j.DiffContent, nullString(j.Error), j.SourceMachineID) return err } @@ -536,6 +554,7 @@ type PulledJob struct { Reasoning string JobType string ReviewType string + PatchID string Status string Agentic bool EnqueuedAt time.Time @@ -566,7 +585,7 @@ func (p *PgPool) PullJobs(ctx context.Context, excludeMachineID string, cursor s rows, err := p.pool.Query(ctx, ` SELECT j.uuid, r.identity, COALESCE(c.sha, ''), COALESCE(c.author, ''), COALESCE(c.subject, ''), COALESCE(c.timestamp, '1970-01-01'::timestamptz), - j.git_ref, j.agent, COALESCE(j.model, ''), COALESCE(j.reasoning, ''), COALESCE(j.job_type, 'review'), COALESCE(j.review_type, ''), j.status, j.agentic, + j.git_ref, j.agent, COALESCE(j.model, ''), COALESCE(j.reasoning, ''), COALESCE(j.job_type, 'review'), COALESCE(j.review_type, ''), COALESCE(j.patch_id, ''), j.status, j.agentic, j.enqueued_at, j.started_at, j.finished_at, COALESCE(j.prompt, ''), j.diff_content, COALESCE(j.error, ''), j.source_machine_id, j.updated_at, j.id @@ -593,7 +612,7 @@ func (p *PgPool) PullJobs(ctx context.Context, excludeMachineID string, cursor s err := rows.Scan( &j.UUID, &j.RepoIdentity, &j.CommitSHA, &j.CommitAuthor, &j.CommitSubject, &j.CommitTimestamp, - &j.GitRef, &j.Agent, &j.Model, &j.Reasoning, &j.JobType, &j.ReviewType, &j.Status, &j.Agentic, + &j.GitRef, &j.Agent, &j.Model, &j.Reasoning, &j.JobType, &j.ReviewType, &j.PatchID, &j.Status, &j.Agentic, &j.EnqueuedAt, &j.StartedAt, &j.FinishedAt, &j.Prompt, &diffContent, &j.Error, &j.SourceMachineID, &j.UpdatedAt, &lastID, @@ -862,17 +881,20 @@ func (p *PgPool) BatchUpsertJobs(ctx context.Context, jobs []JobWithPgIDs) ([]bo j := jw.Job batch.Queue(` INSERT INTO review_jobs ( - uuid, repo_id, commit_id, git_ref, agent, reasoning, job_type, review_type, status, agentic, + uuid, repo_id, commit_id, git_ref, agent, reasoning, job_type, review_type, patch_id, status, agentic, enqueued_at, started_at, finished_at, prompt, diff_content, error, source_machine_id, updated_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, NOW()) + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, NOW()) ON CONFLICT (uuid) DO UPDATE SET status = EXCLUDED.status, finished_at = EXCLUDED.finished_at, error = EXCLUDED.error, + git_ref = EXCLUDED.git_ref, + commit_id = EXCLUDED.commit_id, + patch_id = EXCLUDED.patch_id, updated_at = NOW() `, j.UUID, jw.PgRepoID, jw.PgCommitID, j.GitRef, j.Agent, nullString(j.Reasoning), - defaultStr(j.JobType, "review"), j.ReviewType, j.Status, j.Agentic, j.EnqueuedAt, j.StartedAt, j.FinishedAt, + defaultStr(j.JobType, "review"), j.ReviewType, nullString(j.PatchID), j.Status, j.Agentic, j.EnqueuedAt, j.StartedAt, j.FinishedAt, nullString(j.Prompt), j.DiffContent, nullString(j.Error), j.SourceMachineID) } diff --git a/internal/storage/reviews.go b/internal/storage/reviews.go index b92cb424..27109fea 100644 --- a/internal/storage/reviews.go +++ b/internal/storage/reviews.go @@ -14,14 +14,14 @@ func (db *DB) GetReviewByJobID(jobID int64) (*Review, error) { var addressed int var job ReviewJob var enqueuedAt string - var startedAt, finishedAt, workerID, errMsg, reviewUUID, model, jobTypeStr, reviewTypeStr sql.NullString + var startedAt, finishedAt, workerID, errMsg, reviewUUID, model, jobTypeStr, reviewTypeStr, patchIDStr sql.NullString var commitID sql.NullInt64 var commitSubject sql.NullString err := db.QueryRow(` SELECT rv.id, rv.job_id, rv.agent, rv.prompt, rv.output, rv.created_at, rv.addressed, rv.uuid, j.id, j.repo_id, j.commit_id, j.git_ref, j.agent, j.reasoning, j.status, j.enqueued_at, - j.started_at, j.finished_at, j.worker_id, j.error, j.model, j.job_type, j.review_type, + j.started_at, j.finished_at, j.worker_id, j.error, j.model, j.job_type, j.review_type, j.patch_id, rp.root_path, rp.name, c.subject FROM reviews rv JOIN review_jobs j ON j.id = rv.job_id @@ -30,7 +30,7 @@ func (db *DB) GetReviewByJobID(jobID int64) (*Review, error) { WHERE rv.job_id = ? `, jobID).Scan(&r.ID, &r.JobID, &r.Agent, &r.Prompt, &r.Output, &createdAt, &addressed, &reviewUUID, &job.ID, &job.RepoID, &commitID, &job.GitRef, &job.Agent, &job.Reasoning, &job.Status, &enqueuedAt, - &startedAt, &finishedAt, &workerID, &errMsg, &model, &jobTypeStr, &reviewTypeStr, + &startedAt, &finishedAt, &workerID, &errMsg, &model, &jobTypeStr, &reviewTypeStr, &patchIDStr, &job.RepoPath, &job.RepoName, &commitSubject) if err != nil { return nil, err @@ -56,6 +56,9 @@ func (db *DB) GetReviewByJobID(jobID int64) (*Review, error) { if reviewTypeStr.Valid { job.ReviewType = reviewTypeStr.String } + if patchIDStr.Valid { + job.PatchID = patchIDStr.String + } job.EnqueuedAt = parseSQLiteTime(enqueuedAt) if startedAt.Valid { t := parseSQLiteTime(startedAt.String) @@ -91,7 +94,7 @@ func (db *DB) GetReviewByCommitSHA(sha string) (*Review, error) { var addressed int var job ReviewJob var enqueuedAt string - var startedAt, finishedAt, workerID, errMsg, reviewUUID, model, jobTypeStr, reviewTypeStr sql.NullString + var startedAt, finishedAt, workerID, errMsg, reviewUUID, model, jobTypeStr, reviewTypeStr, patchIDStr sql.NullString var commitID sql.NullInt64 var commitSubject sql.NullString @@ -99,7 +102,7 @@ func (db *DB) GetReviewByCommitSHA(sha string) (*Review, error) { err := db.QueryRow(` SELECT rv.id, rv.job_id, rv.agent, rv.prompt, rv.output, rv.created_at, rv.addressed, rv.uuid, j.id, j.repo_id, j.commit_id, j.git_ref, j.agent, j.reasoning, j.status, j.enqueued_at, - j.started_at, j.finished_at, j.worker_id, j.error, j.model, j.job_type, j.review_type, + j.started_at, j.finished_at, j.worker_id, j.error, j.model, j.job_type, j.review_type, j.patch_id, rp.root_path, rp.name, c.subject FROM reviews rv JOIN review_jobs j ON j.id = rv.job_id @@ -110,7 +113,7 @@ func (db *DB) GetReviewByCommitSHA(sha string) (*Review, error) { LIMIT 1 `, sha).Scan(&r.ID, &r.JobID, &r.Agent, &r.Prompt, &r.Output, &createdAt, &addressed, &reviewUUID, &job.ID, &job.RepoID, &commitID, &job.GitRef, &job.Agent, &job.Reasoning, &job.Status, &enqueuedAt, - &startedAt, &finishedAt, &workerID, &errMsg, &model, &jobTypeStr, &reviewTypeStr, + &startedAt, &finishedAt, &workerID, &errMsg, &model, &jobTypeStr, &reviewTypeStr, &patchIDStr, &job.RepoPath, &job.RepoName, &commitSubject) if err != nil { return nil, err @@ -135,6 +138,9 @@ func (db *DB) GetReviewByCommitSHA(sha string) (*Review, error) { if reviewTypeStr.Valid { job.ReviewType = reviewTypeStr.String } + if patchIDStr.Valid { + job.PatchID = patchIDStr.String + } r.CreatedAt = parseSQLiteTime(createdAt) job.EnqueuedAt = parseSQLiteTime(enqueuedAt) diff --git a/internal/storage/schemas/postgres_v5.sql b/internal/storage/schemas/postgres_v5.sql new file mode 100644 index 00000000..d24c5352 --- /dev/null +++ b/internal/storage/schemas/postgres_v5.sql @@ -0,0 +1,99 @@ +-- PostgreSQL schema version 5 +-- Added patch_id column to review_jobs for rebase tracking. +-- Note: Version is managed by EnsureSchema(), not this file. + +CREATE SCHEMA IF NOT EXISTS roborev; + +CREATE TABLE IF NOT EXISTS roborev.schema_version ( + version INTEGER PRIMARY KEY, + applied_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS roborev.machines ( + id SERIAL PRIMARY KEY, + machine_id UUID UNIQUE NOT NULL, + name TEXT, + last_seen_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS roborev.repos ( + id SERIAL PRIMARY KEY, + identity TEXT UNIQUE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS roborev.commits ( + id SERIAL PRIMARY KEY, + repo_id INTEGER REFERENCES roborev.repos(id), + sha TEXT NOT NULL, + author TEXT NOT NULL, + subject TEXT NOT NULL, + timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + UNIQUE(repo_id, sha) +); + +CREATE TABLE IF NOT EXISTS roborev.review_jobs ( + id SERIAL PRIMARY KEY, + uuid UUID UNIQUE NOT NULL, + repo_id INTEGER NOT NULL REFERENCES roborev.repos(id), + commit_id INTEGER REFERENCES roborev.commits(id), + git_ref TEXT NOT NULL, + branch TEXT, + agent TEXT NOT NULL, + model TEXT, + reasoning TEXT, + job_type TEXT NOT NULL DEFAULT 'review', + review_type TEXT NOT NULL DEFAULT '', + patch_id TEXT, + status TEXT NOT NULL CHECK(status IN ('done', 'failed', 'canceled')), + agentic BOOLEAN DEFAULT FALSE, + enqueued_at TIMESTAMP WITH TIME ZONE NOT NULL, + started_at TIMESTAMP WITH TIME ZONE, + finished_at TIMESTAMP WITH TIME ZONE, + prompt TEXT, + diff_content TEXT, + error TEXT, + source_machine_id UUID NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS roborev.reviews ( + id SERIAL PRIMARY KEY, + uuid UUID UNIQUE NOT NULL, + job_uuid UUID NOT NULL REFERENCES roborev.review_jobs(uuid), + agent TEXT NOT NULL, + prompt TEXT NOT NULL, + output TEXT NOT NULL, + addressed BOOLEAN NOT NULL DEFAULT FALSE, + updated_by_machine_id UUID NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS roborev.responses ( + id SERIAL PRIMARY KEY, + uuid UUID UNIQUE NOT NULL, + job_uuid UUID NOT NULL REFERENCES roborev.review_jobs(uuid), + responder TEXT NOT NULL, + response TEXT NOT NULL, + source_machine_id UUID NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_review_jobs_source ON roborev.review_jobs(source_machine_id); +CREATE INDEX IF NOT EXISTS idx_review_jobs_updated ON roborev.review_jobs(updated_at); +-- Note: idx_review_jobs_branch, idx_review_jobs_job_type, and +-- idx_review_jobs_patch_id are created by migration code, not here +-- (to support upgrades from older versions where those columns +-- don't exist yet). +CREATE INDEX IF NOT EXISTS idx_reviews_job_uuid ON roborev.reviews(job_uuid); +CREATE INDEX IF NOT EXISTS idx_reviews_updated ON roborev.reviews(updated_at); +CREATE INDEX IF NOT EXISTS idx_responses_job_uuid ON roborev.responses(job_uuid); +CREATE INDEX IF NOT EXISTS idx_responses_id ON roborev.responses(id); + +CREATE TABLE IF NOT EXISTS roborev.sync_metadata ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); diff --git a/internal/storage/sync.go b/internal/storage/sync.go index 15e3aade..3b04d55f 100644 --- a/internal/storage/sync.go +++ b/internal/storage/sync.go @@ -275,6 +275,7 @@ type SyncableJob struct { Reasoning string JobType string ReviewType string + PatchID string Status string Agentic bool EnqueuedAt time.Time @@ -294,7 +295,7 @@ func (db *DB) GetJobsToSync(machineID string, limit int) ([]SyncableJob, error) SELECT j.id, j.uuid, j.repo_id, COALESCE(r.identity, ''), j.commit_id, COALESCE(c.sha, ''), COALESCE(c.author, ''), COALESCE(c.subject, ''), COALESCE(c.timestamp, ''), - j.git_ref, j.agent, COALESCE(j.model, ''), COALESCE(j.reasoning, ''), COALESCE(j.job_type, 'review'), COALESCE(j.review_type, ''), j.status, j.agentic, + j.git_ref, j.agent, COALESCE(j.model, ''), COALESCE(j.reasoning, ''), COALESCE(j.job_type, 'review'), COALESCE(j.review_type, ''), COALESCE(j.patch_id, ''), j.status, j.agentic, j.enqueued_at, COALESCE(j.started_at, ''), COALESCE(j.finished_at, ''), COALESCE(j.prompt, ''), j.diff_content, COALESCE(j.error, ''), j.source_machine_id, j.updated_at @@ -329,7 +330,7 @@ func (db *DB) GetJobsToSync(machineID string, limit int) ([]SyncableJob, error) err := rows.Scan( &j.ID, &j.UUID, &j.RepoID, &j.RepoIdentity, &commitID, &j.CommitSHA, &j.CommitAuthor, &j.CommitSubject, &commitTimestamp, - &j.GitRef, &j.Agent, &j.Model, &j.Reasoning, &j.JobType, &j.ReviewType, &j.Status, &j.Agentic, + &j.GitRef, &j.Agent, &j.Model, &j.Reasoning, &j.JobType, &j.ReviewType, &j.PatchID, &j.Status, &j.Agentic, &enqueuedAt, &startedAt, &finishedAt, &j.Prompt, &diffContent, &j.Error, &j.SourceMachineID, &updatedAt, @@ -573,19 +574,22 @@ func (db *DB) UpsertPulledJob(j PulledJob, repoID int64, commitID *int64) error now := time.Now().UTC().Format(time.RFC3339) _, err := db.Exec(` INSERT INTO review_jobs ( - uuid, repo_id, commit_id, git_ref, agent, model, reasoning, job_type, review_type, status, agentic, + uuid, repo_id, commit_id, git_ref, agent, model, reasoning, job_type, review_type, patch_id, status, agentic, enqueued_at, started_at, finished_at, prompt, diff_content, error, source_machine_id, updated_at, synced_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(uuid) DO UPDATE SET status = excluded.status, finished_at = excluded.finished_at, error = excluded.error, model = COALESCE(excluded.model, review_jobs.model), + git_ref = excluded.git_ref, + commit_id = excluded.commit_id, + patch_id = excluded.patch_id, updated_at = excluded.updated_at, synced_at = ? `, j.UUID, repoID, commitID, j.GitRef, j.Agent, nullStr(j.Model), j.Reasoning, j.JobType, - j.ReviewType, j.Status, j.Agentic, j.EnqueuedAt.Format(time.RFC3339), + j.ReviewType, nullStr(j.PatchID), j.Status, j.Agentic, j.EnqueuedAt.Format(time.RFC3339), nullTimeStr(j.StartedAt), nullTimeStr(j.FinishedAt), nullStr(j.Prompt), j.DiffContent, nullStr(j.Error), j.SourceMachineID, j.UpdatedAt.Format(time.RFC3339), now, now) diff --git a/internal/storage/sync_test.go b/internal/storage/sync_test.go index 09b7a358..f8f6f731 100644 --- a/internal/storage/sync_test.go +++ b/internal/storage/sync_test.go @@ -2327,6 +2327,276 @@ func TestSyncOrder_FullWorkflow(t *testing.T) { } } +func TestPatchIDSyncRoundTrip(t *testing.T) { + db := openTestDB(t) + defer db.Close() + + machineID, err := db.GetMachineID() + if err != nil { + t.Fatalf("GetMachineID: %v", err) + } + + repo, err := db.GetOrCreateRepo(t.TempDir()) + if err != nil { + t.Fatalf("GetOrCreateRepo: %v", err) + } + commit, err := db.GetOrCreateCommit( + repo.ID, "patchid-sync-sha", "Author", "Subject", time.Now(), + ) + if err != nil { + t.Fatalf("GetOrCreateCommit: %v", err) + } + + // Enqueue a job with a patch_id + job, err := db.EnqueueJob(EnqueueOpts{ + RepoID: repo.ID, + CommitID: commit.ID, + GitRef: "patchid-sync-sha", + Agent: "test", + PatchID: "deadbeef9999", + }) + if err != nil { + t.Fatalf("EnqueueJob: %v", err) + } + + // Complete the job so it becomes sync-eligible + _, err = db.ClaimJob("worker-sync") + if err != nil { + t.Fatalf("ClaimJob: %v", err) + } + err = db.CompleteJob(job.ID, "test", "prompt", "output") + if err != nil { + t.Fatalf("CompleteJob: %v", err) + } + + // Verify patch_id appears in GetJobsToSync + syncJobs, err := db.GetJobsToSync(machineID, 100) + if err != nil { + t.Fatalf("GetJobsToSync: %v", err) + } + var found *SyncableJob + for i := range syncJobs { + if syncJobs[i].ID == job.ID { + found = &syncJobs[i] + break + } + } + if found == nil { + t.Fatal("job not returned by GetJobsToSync") + } + if found.PatchID != "deadbeef9999" { + t.Errorf("GetJobsToSync PatchID: got %q, want %q", + found.PatchID, "deadbeef9999") + } + + // Simulate pull: upsert back into a fresh DB via UpsertPulledJob + db2 := openTestDB(t) + defer db2.Close() + + repo2, err := db2.GetOrCreateRepo(t.TempDir()) + if err != nil { + t.Fatalf("db2 GetOrCreateRepo: %v", err) + } + commitID2 := int64(0) + if found.CommitSHA != "" { + c, err := db2.GetOrCreateCommit( + repo2.ID, found.CommitSHA, + found.CommitAuthor, found.CommitSubject, found.CommitTimestamp, + ) + if err != nil { + t.Fatalf("db2 GetOrCreateCommit: %v", err) + } + commitID2 = c.ID + } + + pulledJob := PulledJob{ + UUID: found.UUID, + RepoIdentity: found.RepoIdentity, + GitRef: found.GitRef, + Agent: found.Agent, + Model: found.Model, + Reasoning: found.Reasoning, + JobType: found.JobType, + ReviewType: found.ReviewType, + PatchID: found.PatchID, + Status: found.Status, + Agentic: found.Agentic, + EnqueuedAt: found.EnqueuedAt, + StartedAt: found.StartedAt, + FinishedAt: found.FinishedAt, + Prompt: found.Prompt, + DiffContent: found.DiffContent, + Error: found.Error, + SourceMachineID: found.SourceMachineID, + UpdatedAt: found.UpdatedAt, + } + + cid := &commitID2 + err = db2.UpsertPulledJob(pulledJob, repo2.ID, cid) + if err != nil { + t.Fatalf("UpsertPulledJob: %v", err) + } + + // Verify patch_id survived the round trip + var patchIDVal sql.NullString + err = db2.QueryRow( + `SELECT patch_id FROM review_jobs WHERE uuid = ?`, found.UUID, + ).Scan(&patchIDVal) + if err != nil { + t.Fatalf("query patch_id in db2: %v", err) + } + if !patchIDVal.Valid || patchIDVal.String != "deadbeef9999" { + t.Errorf("patch_id after UpsertPulledJob: got %v, want %q", + patchIDVal, "deadbeef9999") + } +} + +func TestRemapJobGitRef_RunningJob(t *testing.T) { + // Per design doc: remap should include running jobs for maximum coverage. + // CompleteJob updates status/finished_at by job ID, not git_ref, + // so there is no race. + db := openTestDB(t) + defer db.Close() + + repo := createRepo(t, db, "/tmp/test-remap-running") + commit := createCommit(t, db, repo.ID, "running-oldsha") + + job, err := db.EnqueueJob(EnqueueOpts{ + RepoID: repo.ID, + CommitID: commit.ID, + GitRef: "running-oldsha", + Agent: "test", + PatchID: "patchRUN", + }) + if err != nil { + t.Fatalf("EnqueueJob: %v", err) + } + + // Set job to running + setJobStatus(t, db, job.ID, JobStatusRunning) + + // Remap should succeed on running jobs + newCommit := createCommit(t, db, repo.ID, "running-newsha") + n, err := db.RemapJobGitRef( + repo.ID, "running-oldsha", "running-newsha", + "patchRUN", newCommit.ID, + ) + if err != nil { + t.Fatalf("RemapJobGitRef: %v", err) + } + if n != 1 { + t.Errorf("expected 1 row updated, got %d", n) + } + + got, err := db.GetJobByID(job.ID) + if err != nil { + t.Fatalf("GetJobByID: %v", err) + } + if got.GitRef != "running-newsha" { + t.Errorf("expected git_ref=running-newsha, got %q", got.GitRef) + } + if JobStatus(got.Status) != JobStatusRunning { + t.Errorf("status should remain running, got %q", got.Status) + } +} + +func TestRemapTriggersResync(t *testing.T) { + // After remapping a synced job, updated_at should exceed synced_at, + // causing GetJobsToSync to include it again. + db := openTestDB(t) + defer db.Close() + + machineID, err := db.GetMachineID() + if err != nil { + t.Fatalf("GetMachineID: %v", err) + } + + repo, err := db.GetOrCreateRepo(t.TempDir()) + if err != nil { + t.Fatalf("GetOrCreateRepo: %v", err) + } + commit, err := db.GetOrCreateCommit( + repo.ID, "resync-oldsha", "Author", "Subject", time.Now(), + ) + if err != nil { + t.Fatalf("GetOrCreateCommit: %v", err) + } + + job, err := db.EnqueueJob(EnqueueOpts{ + RepoID: repo.ID, + CommitID: commit.ID, + GitRef: "resync-oldsha", + Agent: "test", + PatchID: "patchRESYNC", + }) + if err != nil { + t.Fatalf("EnqueueJob: %v", err) + } + + // Complete and mark synced + _, err = db.ClaimJob("worker-resync") + if err != nil { + t.Fatalf("ClaimJob: %v", err) + } + err = db.CompleteJob(job.ID, "test", "prompt", "output") + if err != nil { + t.Fatalf("CompleteJob: %v", err) + } + // Set synced_at to a past time so remap's updated_at is guaranteed later + pastTime := time.Now().UTC().Add(-time.Hour).Format(time.RFC3339) + _, err = db.Exec( + `UPDATE review_jobs SET synced_at = ?, updated_at = ? WHERE id = ?`, + pastTime, pastTime, job.ID, + ) + if err != nil { + t.Fatalf("set synced_at: %v", err) + } + + // Verify job is NOT returned by GetJobsToSync (updated_at == synced_at) + preJobs, err := db.GetJobsToSync(machineID, 100) + if err != nil { + t.Fatalf("GetJobsToSync (pre): %v", err) + } + for _, j := range preJobs { + if j.ID == job.ID { + t.Fatal("job should not appear in sync before remap") + } + } + + // Remap the job — this sets updated_at to time.Now() which is after pastTime + newCommit := createCommit(t, db, repo.ID, "resync-newsha") + n, err := db.RemapJobGitRef( + repo.ID, "resync-oldsha", "resync-newsha", + "patchRESYNC", newCommit.ID, + ) + if err != nil { + t.Fatalf("RemapJobGitRef: %v", err) + } + if n != 1 { + t.Errorf("expected 1 row updated, got %d", n) + } + + // Now GetJobsToSync should include it again + postJobs, err := db.GetJobsToSync(machineID, 100) + if err != nil { + t.Fatalf("GetJobsToSync (post): %v", err) + } + found := false + for _, j := range postJobs { + if j.ID == job.ID { + found = true + if j.GitRef != "resync-newsha" { + t.Errorf("synced job git_ref: got %q, want %q", + j.GitRef, "resync-newsha") + } + break + } + } + if !found { + t.Error("remapped job should appear in GetJobsToSync") + } +} + func TestUpsertPulledJob_BackfillsModel(t *testing.T) { // This test verifies that upserting a pulled job with a model value backfills // an existing job that has NULL model (COALESCE behavior in SQLite)