diff --git a/.github/workflows/test-proxy.lock.yml b/.github/workflows/test-proxy.lock.yml index 7862e2149d..5e85b32028 100644 --- a/.github/workflows/test-proxy.lock.yml +++ b/.github/workflows/test-proxy.lock.yml @@ -799,6 +799,23 @@ jobs: - name: Clean up engine output files run: | rm -f output.txt + - name: Extract squid access logs + if: always() + run: | + mkdir -p /tmp/access-logs + echo 'Extracting access.log from squid-proxy-fetch container' + if docker ps -a --format '{{.Names}}' | grep -q '^squid-proxy-fetch$'; then + docker cp squid-proxy-fetch:/var/log/squid/access.log /tmp/access-logs/access-fetch.log 2>/dev/null || echo 'No access.log found for fetch' + else + echo 'Container squid-proxy-fetch not found' + fi + - name: Upload squid access logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: access.log + path: /tmp/access-logs/ + if-no-files-found: warn - name: Parse agent logs for step summary if: always() uses: actions/github-script@v7 diff --git a/pkg/cli/access_log.go b/pkg/cli/access_log.go new file mode 100644 index 0000000000..80d5da8107 --- /dev/null +++ b/pkg/cli/access_log.go @@ -0,0 +1,328 @@ +package cli + +import ( + "bufio" + "fmt" + neturl "net/url" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/githubnext/gh-aw/pkg/console" +) + +// AccessLogEntry represents a parsed squid access log entry +type AccessLogEntry struct { + Timestamp string + Duration string + ClientIP string + Status string + Size string + Method string + URL string + User string + Hierarchy string + Type string +} + +// DomainAnalysis represents analysis of domains from access logs +type DomainAnalysis struct { + AllowedDomains []string + DeniedDomains []string + TotalRequests int + AllowedCount int + DeniedCount int +} + +// parseSquidAccessLog parses a squid access log file and extracts domain information +func parseSquidAccessLog(logPath string, verbose bool) (*DomainAnalysis, error) { + file, err := os.Open(logPath) + if err != nil { + return nil, fmt.Errorf("failed to open access log: %w", err) + } + defer file.Close() + + analysis := &DomainAnalysis{ + AllowedDomains: []string{}, + DeniedDomains: []string{}, + } + + allowedDomainsSet := make(map[string]bool) + deniedDomainsSet := make(map[string]bool) + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + entry, err := parseSquidLogLine(line) + if err != nil { + if verbose { + fmt.Println(console.FormatWarningMessage(fmt.Sprintf("Failed to parse log line: %v", err))) + } + continue + } + + analysis.TotalRequests++ + + // Extract domain from URL + domain := extractDomainFromURL(entry.URL) + if domain == "" { + continue + } + + // Determine if request was allowed or denied based on status code + // Squid typically returns: + // - 200, 206, 304: Allowed/successful + // - 403: Forbidden (denied by ACL) + // - 407: Proxy authentication required + // - 502, 503: Connection/upstream errors + statusCode := entry.Status + isAllowed := statusCode == "TCP_HIT/200" || statusCode == "TCP_MISS/200" || + statusCode == "TCP_REFRESH_MODIFIED/200" || statusCode == "TCP_IMS_HIT/304" || + strings.Contains(statusCode, "/200") || strings.Contains(statusCode, "/206") || + strings.Contains(statusCode, "/304") + + if isAllowed { + analysis.AllowedCount++ + if !allowedDomainsSet[domain] { + allowedDomainsSet[domain] = true + analysis.AllowedDomains = append(analysis.AllowedDomains, domain) + } + } else { + analysis.DeniedCount++ + if !deniedDomainsSet[domain] { + deniedDomainsSet[domain] = true + analysis.DeniedDomains = append(analysis.DeniedDomains, domain) + } + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading access log: %w", err) + } + + // Sort domains for consistent output + sort.Strings(analysis.AllowedDomains) + sort.Strings(analysis.DeniedDomains) + + return analysis, nil +} + +// parseSquidLogLine parses a single squid access log line +// Squid log format: timestamp duration client status size method url user hierarchy type +func parseSquidLogLine(line string) (*AccessLogEntry, error) { + fields := strings.Fields(line) + if len(fields) < 10 { + return nil, fmt.Errorf("invalid log line format: expected at least 10 fields, got %d", len(fields)) + } + + return &AccessLogEntry{ + Timestamp: fields[0], + Duration: fields[1], + ClientIP: fields[2], + Status: fields[3], + Size: fields[4], + Method: fields[5], + URL: fields[6], + User: fields[7], + Hierarchy: fields[8], + Type: fields[9], + }, nil +} + +// extractDomainFromURL extracts the domain from a URL +func extractDomainFromURL(url string) string { + // Handle different URL formats + if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { + // Parse full URL + parsedURL, err := neturl.Parse(url) + if err != nil { + return "" + } + return parsedURL.Hostname() + } + + // Handle CONNECT requests (domain:port format) + if strings.Contains(url, ":") { + parts := strings.Split(url, ":") + if len(parts) >= 2 { + return parts[0] + } + } + + // Handle direct domain + return url +} + +// analyzeAccessLogs analyzes access logs in a run directory +func analyzeAccessLogs(runDir string, verbose bool) (*DomainAnalysis, error) { + // Check for access log files in access.log directory + accessLogsDir := filepath.Join(runDir, "access.log") + if _, err := os.Stat(accessLogsDir); err == nil { + return analyzeMultipleAccessLogs(accessLogsDir, verbose) + } + + // No access logs found + if verbose { + fmt.Println(console.FormatInfoMessage(fmt.Sprintf("No access logs found in %s", runDir))) + } + return nil, nil +} + +// analyzeMultipleAccessLogs analyzes multiple separate access log files +func analyzeMultipleAccessLogs(accessLogsDir string, verbose bool) (*DomainAnalysis, error) { + files, err := filepath.Glob(filepath.Join(accessLogsDir, "access-*.log")) + if err != nil { + return nil, fmt.Errorf("failed to find access log files: %w", err) + } + + if len(files) == 0 { + if verbose { + fmt.Println(console.FormatInfoMessage(fmt.Sprintf("No access log files found in %s", accessLogsDir))) + } + return nil, nil + } + + if verbose { + fmt.Println(console.FormatInfoMessage(fmt.Sprintf("Analyzing %d access log files from %s", len(files), accessLogsDir))) + } + + // Aggregate analysis from all files + aggregatedAnalysis := &DomainAnalysis{ + AllowedDomains: []string{}, + DeniedDomains: []string{}, + } + + allAllowedDomains := make(map[string]bool) + allDeniedDomains := make(map[string]bool) + + for _, file := range files { + if verbose { + fmt.Println(console.FormatInfoMessage(fmt.Sprintf("Parsing %s", filepath.Base(file)))) + } + + analysis, err := parseSquidAccessLog(file, verbose) + if err != nil { + if verbose { + fmt.Println(console.FormatWarningMessage(fmt.Sprintf("Failed to parse %s: %v", filepath.Base(file), err))) + } + continue + } + + // Aggregate the metrics + aggregatedAnalysis.TotalRequests += analysis.TotalRequests + aggregatedAnalysis.AllowedCount += analysis.AllowedCount + aggregatedAnalysis.DeniedCount += analysis.DeniedCount + + // Collect unique domains + for _, domain := range analysis.AllowedDomains { + allAllowedDomains[domain] = true + } + for _, domain := range analysis.DeniedDomains { + allDeniedDomains[domain] = true + } + } + + // Convert maps to sorted slices + for domain := range allAllowedDomains { + aggregatedAnalysis.AllowedDomains = append(aggregatedAnalysis.AllowedDomains, domain) + } + for domain := range allDeniedDomains { + aggregatedAnalysis.DeniedDomains = append(aggregatedAnalysis.DeniedDomains, domain) + } + + sort.Strings(aggregatedAnalysis.AllowedDomains) + sort.Strings(aggregatedAnalysis.DeniedDomains) + + return aggregatedAnalysis, nil +} + +// displayAccessLogAnalysis displays analysis of access logs from all runs with improved formatting +func displayAccessLogAnalysis(processedRuns []ProcessedRun, verbose bool) { + if len(processedRuns) == 0 { + return + } + + // Collect all access analyses + var analyses []*DomainAnalysis + runsWithAccess := 0 + for _, pr := range processedRuns { + if pr.AccessAnalysis != nil { + analyses = append(analyses, pr.AccessAnalysis) + runsWithAccess++ + } + } + + if len(analyses) == 0 { + fmt.Println(console.FormatInfoMessage("No access logs found in downloaded runs")) + return + } + + // Aggregate statistics + totalRequests := 0 + totalAllowed := 0 + totalDenied := 0 + allAllowedDomains := make(map[string]bool) + allDeniedDomains := make(map[string]bool) + + for _, analysis := range analyses { + totalRequests += analysis.TotalRequests + totalAllowed += analysis.AllowedCount + totalDenied += analysis.DeniedCount + + for _, domain := range analysis.AllowedDomains { + allAllowedDomains[domain] = true + } + for _, domain := range analysis.DeniedDomains { + allDeniedDomains[domain] = true + } + } + + fmt.Println() + + // Display allowed domains with better formatting + if len(allAllowedDomains) > 0 { + fmt.Println(console.FormatSuccessMessage(fmt.Sprintf("✅ Allowed Domains (%d):", len(allAllowedDomains)))) + allowedList := make([]string, 0, len(allAllowedDomains)) + for domain := range allAllowedDomains { + allowedList = append(allowedList, domain) + } + sort.Strings(allowedList) + for _, domain := range allowedList { + fmt.Println(console.FormatListItem(domain)) + } + fmt.Println() + } + + // Display denied domains with better formatting + if len(allDeniedDomains) > 0 { + fmt.Println(console.FormatErrorMessage(fmt.Sprintf("❌ Denied Domains (%d):", len(allDeniedDomains)))) + deniedList := make([]string, 0, len(allDeniedDomains)) + for domain := range allDeniedDomains { + deniedList = append(deniedList, domain) + } + sort.Strings(deniedList) + for _, domain := range deniedList { + fmt.Println(console.FormatListItem(domain)) + } + fmt.Println() + } + + if verbose && len(analyses) > 1 { + // Show per-run breakdown with improved formatting + fmt.Println(console.FormatInfoMessage("📋 Per-run breakdown:")) + for _, pr := range processedRuns { + if pr.AccessAnalysis != nil { + analysis := pr.AccessAnalysis + fmt.Printf(" %s Run %d: %d requests (%d allowed, %d denied)\n", + console.FormatListItem(""), + pr.Run.DatabaseID, analysis.TotalRequests, analysis.AllowedCount, analysis.DeniedCount) + } + } + fmt.Println() + } +} diff --git a/pkg/cli/access_log_test.go b/pkg/cli/access_log_test.go new file mode 100644 index 0000000000..7e6a1d945e --- /dev/null +++ b/pkg/cli/access_log_test.go @@ -0,0 +1,179 @@ +package cli + +import ( + "os" + "path/filepath" + "testing" +) + +func TestAccessLogParsing(t *testing.T) { + // Create a temporary directory for the test + tempDir := t.TempDir() + + // Create test access.log content + testLogContent := `1701234567.123 180 192.168.1.100 TCP_MISS/200 1234 GET http://example.com/api/data - HIER_DIRECT/93.184.216.34 text/html +1701234568.456 250 192.168.1.100 TCP_DENIED/403 0 CONNECT github.com:443 - HIER_NONE/- - +1701234569.789 120 192.168.1.100 TCP_HIT/200 5678 GET http://api.github.com/repos - HIER_DIRECT/140.82.112.6 application/json +1701234570.012 0 192.168.1.100 TCP_DENIED/403 0 GET http://malicious.site/evil - HIER_NONE/- -` + + // Write test log file + accessLogPath := filepath.Join(tempDir, "access.log") + err := os.WriteFile(accessLogPath, []byte(testLogContent), 0644) + if err != nil { + t.Fatalf("Failed to create test access.log: %v", err) + } + + // Test parsing + analysis, err := parseSquidAccessLog(accessLogPath, false) + if err != nil { + t.Fatalf("Failed to parse access log: %v", err) + } + + // Verify results + if analysis.TotalRequests != 4 { + t.Errorf("Expected 4 total requests, got %d", analysis.TotalRequests) + } + + if analysis.AllowedCount != 2 { + t.Errorf("Expected 2 allowed requests, got %d", analysis.AllowedCount) + } + + if analysis.DeniedCount != 2 { + t.Errorf("Expected 2 denied requests, got %d", analysis.DeniedCount) + } + + // Check allowed domains + expectedAllowed := []string{"api.github.com", "example.com"} + if len(analysis.AllowedDomains) != len(expectedAllowed) { + t.Errorf("Expected %d allowed domains, got %d", len(expectedAllowed), len(analysis.AllowedDomains)) + } +} + +func TestMultipleAccessLogAnalysis(t *testing.T) { + // Create a temporary directory for the test + tempDir := t.TempDir() + accessLogsDir := filepath.Join(tempDir, "access.log") + err := os.MkdirAll(accessLogsDir, 0755) + if err != nil { + t.Fatalf("Failed to create access.log directory: %v", err) + } + + // Create test access log content for multiple MCP servers + fetchLogContent := `1701234567.123 180 192.168.1.100 TCP_MISS/200 1234 GET http://example.com/api/data - HIER_DIRECT/93.184.216.34 text/html +1701234568.456 250 192.168.1.100 TCP_HIT/200 5678 GET http://api.github.com/repos - HIER_DIRECT/140.82.112.6 application/json` + + browserLogContent := `1701234569.789 120 192.168.1.100 TCP_DENIED/403 0 CONNECT github.com:443 - HIER_NONE/- - +1701234570.012 0 192.168.1.100 TCP_DENIED/403 0 GET http://malicious.site/evil - HIER_NONE/- -` + + // Write separate log files for different MCP servers + fetchLogPath := filepath.Join(accessLogsDir, "access-fetch.log") + err = os.WriteFile(fetchLogPath, []byte(fetchLogContent), 0644) + if err != nil { + t.Fatalf("Failed to create test access-fetch.log: %v", err) + } + + browserLogPath := filepath.Join(accessLogsDir, "access-browser.log") + err = os.WriteFile(browserLogPath, []byte(browserLogContent), 0644) + if err != nil { + t.Fatalf("Failed to create test access-browser.log: %v", err) + } + + // Test analysis of multiple access logs + analysis, err := analyzeMultipleAccessLogs(accessLogsDir, false) + if err != nil { + t.Fatalf("Failed to analyze multiple access logs: %v", err) + } + + // Verify aggregated results + if analysis.TotalRequests != 4 { + t.Errorf("Expected 4 total requests, got %d", analysis.TotalRequests) + } + + if analysis.AllowedCount != 2 { + t.Errorf("Expected 2 allowed requests, got %d", analysis.AllowedCount) + } + + if analysis.DeniedCount != 2 { + t.Errorf("Expected 2 denied requests, got %d", analysis.DeniedCount) + } + + // Check allowed domains + expectedAllowed := []string{"api.github.com", "example.com"} + if len(analysis.AllowedDomains) != len(expectedAllowed) { + t.Errorf("Expected %d allowed domains, got %d", len(expectedAllowed), len(analysis.AllowedDomains)) + } + + // Check denied domains + expectedDenied := []string{"github.com", "malicious.site"} + if len(analysis.DeniedDomains) != len(expectedDenied) { + t.Errorf("Expected %d denied domains, got %d", len(expectedDenied), len(analysis.DeniedDomains)) + } +} + +func TestAnalyzeAccessLogsDirectory(t *testing.T) { + // Create a temporary directory structure + tempDir := t.TempDir() + + // Test case 1: Multiple access logs in access-logs subdirectory + accessLogsDir := filepath.Join(tempDir, "run1", "access.log") + err := os.MkdirAll(accessLogsDir, 0755) + if err != nil { + t.Fatalf("Failed to create access.log directory: %v", err) + } + + fetchLogContent := `1701234567.123 180 192.168.1.100 TCP_MISS/200 1234 GET http://example.com/api/data - HIER_DIRECT/93.184.216.34 text/html` + fetchLogPath := filepath.Join(accessLogsDir, "access-fetch.log") + err = os.WriteFile(fetchLogPath, []byte(fetchLogContent), 0644) + if err != nil { + t.Fatalf("Failed to create test access-fetch.log: %v", err) + } + + analysis, err := analyzeAccessLogs(filepath.Join(tempDir, "run1"), false) + if err != nil { + t.Fatalf("Failed to analyze access logs: %v", err) + } + + if analysis == nil { + t.Fatal("Expected analysis result, got nil") + } + + if analysis.TotalRequests != 1 { + t.Errorf("Expected 1 total request, got %d", analysis.TotalRequests) + } + + // Test case 2: No access logs + run2Dir := filepath.Join(tempDir, "run2") + err = os.MkdirAll(run2Dir, 0755) + if err != nil { + t.Fatalf("Failed to create run2 directory: %v", err) + } + + analysis, err = analyzeAccessLogs(run2Dir, false) + if err != nil { + t.Fatalf("Failed to analyze no access logs: %v", err) + } + + if analysis != nil { + t.Errorf("Expected nil analysis for no access logs, got %+v", analysis) + } +} + +func TestExtractDomainFromURL(t *testing.T) { + tests := []struct { + url string + expected string + }{ + {"http://example.com/path", "example.com"}, + {"https://api.github.com/repos", "api.github.com"}, + {"github.com:443", "github.com"}, + {"malicious.site", "malicious.site"}, + {"http://sub.domain.com:8080/path", "sub.domain.com"}, + } + + for _, test := range tests { + result := extractDomainFromURL(test.url) + if result != test.expected { + t.Errorf("extractDomainFromURL(%q) = %q, expected %q", test.url, result, test.expected) + } + } +} diff --git a/pkg/cli/logs.go b/pkg/cli/logs.go index 3c7765cacb..625ad9cf2b 100644 --- a/pkg/cli/logs.go +++ b/pkg/cli/logs.go @@ -44,16 +44,23 @@ type WorkflowRun struct { // This is now an alias to the shared type in workflow package type LogMetrics = workflow.LogMetrics +// ProcessedRun represents a workflow run with its associated analysis +type ProcessedRun struct { + Run WorkflowRun + AccessAnalysis *DomainAnalysis +} + // ErrNoArtifacts indicates that a workflow run has no artifacts var ErrNoArtifacts = errors.New("no artifacts found for this run") // DownloadResult represents the result of downloading artifacts for a single run type DownloadResult struct { - Run WorkflowRun - Metrics LogMetrics - Error error - Skipped bool - LogsPath string + Run WorkflowRun + Metrics LogMetrics + AccessAnalysis *DomainAnalysis + Error error + Skipped bool + LogsPath string } // Constants for the iterative algorithm @@ -199,7 +206,7 @@ func DownloadWorkflowLogs(workflowName string, count int, startDate, endDate, ou fmt.Println(console.FormatInfoMessage("Fetching workflow runs from GitHub Actions...")) } - var processedRuns []WorkflowRun + var processedRuns []ProcessedRun var beforeDate string iteration := 0 @@ -315,12 +322,19 @@ func DownloadWorkflowLogs(workflowName string, count int, startDate, endDate, ou run.EstimatedCost = result.Metrics.EstimatedCost run.LogsPath = result.LogsPath + // Store access analysis for later display (we'll access it via the result) + // No need to modify the WorkflowRun struct for this + // Always use GitHub API timestamps for duration calculation if !run.StartedAt.IsZero() && !run.UpdatedAt.IsZero() { run.Duration = run.UpdatedAt.Sub(run.StartedAt) } - processedRuns = append(processedRuns, run) + processedRun := ProcessedRun{ + Run: run, + AccessAnalysis: result.AccessAnalysis, + } + processedRuns = append(processedRuns, processedRun) batchProcessed++ } @@ -354,7 +368,14 @@ func DownloadWorkflowLogs(workflowName string, count int, startDate, endDate, ou } // Display overview table - displayLogsOverview(processedRuns, outputDir) + workflowRuns := make([]WorkflowRun, len(processedRuns)) + for i, pr := range processedRuns { + workflowRuns[i] = pr.Run + } + displayLogsOverview(workflowRuns, outputDir) + + // Display access log analysis + displayAccessLogAnalysis(processedRuns, verbose) // Display logs location prominently absOutputDir, _ := filepath.Abs(outputDir) @@ -417,6 +438,15 @@ func downloadRunArtifactsConcurrent(runs []WorkflowRun, outputDir string, verbos metrics = LogMetrics{} } result.Metrics = metrics + + // Analyze access logs if available + accessAnalysis, accessErr := analyzeAccessLogs(runOutputDir, verbose) + if accessErr != nil { + if verbose { + fmt.Println(console.FormatWarningMessage(fmt.Sprintf("Failed to analyze access logs for run %d: %v", run.DatabaseID, accessErr))) + } + } + result.AccessAnalysis = accessAnalysis } return result @@ -483,6 +513,9 @@ func listWorkflowRunsWithPagination(workflowName string, count int, startDate, e errMsg := err.Error() outputMsg := string(output) combinedMsg := errMsg + " " + outputMsg + if verbose { + fmt.Println(console.FormatVerboseMessage(outputMsg)) + } if strings.Contains(combinedMsg, "exit status 4") || strings.Contains(combinedMsg, "exit status 1") || strings.Contains(combinedMsg, "not logged into any GitHub hosts") || @@ -560,6 +593,10 @@ func downloadRunArtifacts(runID int64, outputDir string, verbose bool) error { spinner.Stop() } if err != nil { + if verbose { + fmt.Println(console.FormatVerboseMessage(string(output))) + } + // Check if it's because there are no artifacts if strings.Contains(string(output), "no valid artifacts") || strings.Contains(string(output), "not found") { // Clean up empty directory diff --git a/pkg/workflow/compiler.go b/pkg/workflow/compiler.go index 2343a5a5a6..0d1b601c8e 100644 --- a/pkg/workflow/compiler.go +++ b/pkg/workflow/compiler.go @@ -2370,6 +2370,10 @@ func (c *Compiler) generateMainJobSteps(yaml *strings.Builder, data *WorkflowDat c.generateEngineOutputCollection(yaml, engine) } + // Extract and upload squid access logs (if any proxy tools were used) + c.generateExtractAccessLogs(yaml, data.Tools) + c.generateUploadAccessLogs(yaml, data.Tools) + // parse agent logs for GITHUB_STEP_SUMMARY c.generateLogParsing(yaml, engine, logFileFull) @@ -2452,6 +2456,64 @@ func (c *Compiler) generateUploadAwInfo(yaml *strings.Builder) { yaml.WriteString(" if-no-files-found: warn\n") } +func (c *Compiler) generateExtractAccessLogs(yaml *strings.Builder, tools map[string]any) { + // Check if any tools require proxy setup + var proxyTools []string + for toolName, toolConfig := range tools { + if toolConfigMap, ok := toolConfig.(map[string]any); ok { + needsProxySetup, _ := needsProxy(toolConfigMap) + if needsProxySetup { + proxyTools = append(proxyTools, toolName) + } + } + } + + // If no proxy tools, no access logs to extract + if len(proxyTools) == 0 { + return + } + + yaml.WriteString(" - name: Extract squid access logs\n") + yaml.WriteString(" if: always()\n") + yaml.WriteString(" run: |\n") + yaml.WriteString(" mkdir -p /tmp/access-logs\n") + + for _, toolName := range proxyTools { + fmt.Fprintf(yaml, " echo 'Extracting access.log from squid-proxy-%s container'\n", toolName) + fmt.Fprintf(yaml, " if docker ps -a --format '{{.Names}}' | grep -q '^squid-proxy-%s$'; then\n", toolName) + fmt.Fprintf(yaml, " docker cp squid-proxy-%s:/var/log/squid/access.log /tmp/access-logs/access-%s.log 2>/dev/null || echo 'No access.log found for %s'\n", toolName, toolName, toolName) + yaml.WriteString(" else\n") + fmt.Fprintf(yaml, " echo 'Container squid-proxy-%s not found'\n", toolName) + yaml.WriteString(" fi\n") + } +} + +func (c *Compiler) generateUploadAccessLogs(yaml *strings.Builder, tools map[string]any) { + // Check if any tools require proxy setup + var proxyTools []string + for toolName, toolConfig := range tools { + if toolConfigMap, ok := toolConfig.(map[string]any); ok { + needsProxySetup, _ := needsProxy(toolConfigMap) + if needsProxySetup { + proxyTools = append(proxyTools, toolName) + } + } + } + + // If no proxy tools, no access logs to upload + if len(proxyTools) == 0 { + return + } + + yaml.WriteString(" - name: Upload squid access logs\n") + yaml.WriteString(" if: always()\n") + yaml.WriteString(" uses: actions/upload-artifact@v4\n") + yaml.WriteString(" with:\n") + yaml.WriteString(" name: access.log\n") + yaml.WriteString(" path: /tmp/access-logs/\n") + yaml.WriteString(" if-no-files-found: warn\n") +} + func (c *Compiler) generatePrompt(yaml *strings.Builder, data *WorkflowData, engine AgenticEngine) { yaml.WriteString(" - name: Create prompt\n") diff --git a/pkg/workflow/compiler_test.go b/pkg/workflow/compiler_test.go index 10fd4db59d..b27d5e5b8a 100644 --- a/pkg/workflow/compiler_test.go +++ b/pkg/workflow/compiler_test.go @@ -5782,3 +5782,88 @@ func TestComputeAllowedToolsWithSafeOutputs(t *testing.T) { }) } } + +func TestAccessLogUploadConditional(t *testing.T) { + compiler := NewCompiler(false, "", "test") + + tests := []struct { + name string + tools map[string]any + expectSteps bool + }{ + { + name: "no tools - no access log steps", + tools: map[string]any{ + "github": map[string]any{ + "allowed": []any{"list_issues"}, + }, + }, + expectSteps: false, + }, + { + name: "tool with container but no network permissions - no access log steps", + tools: map[string]any{ + "simple": map[string]any{ + "mcp": map[string]any{ + "type": "stdio", + "container": "simple/tool", + }, + "allowed": []any{"test"}, + }, + }, + expectSteps: false, + }, + { + name: "tool with container and network permissions - access log steps generated", + tools: map[string]any{ + "fetch": map[string]any{ + "mcp": map[string]any{ + "type": "stdio", + "container": "mcp/fetch", + }, + "permissions": map[string]any{ + "network": map[string]any{ + "allowed": []any{"example.com"}, + }, + }, + "allowed": []any{"fetch"}, + }, + }, + expectSteps: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var yaml strings.Builder + + // Test generateExtractAccessLogs + compiler.generateExtractAccessLogs(&yaml, tt.tools) + extractContent := yaml.String() + + // Test generateUploadAccessLogs + yaml.Reset() + compiler.generateUploadAccessLogs(&yaml, tt.tools) + uploadContent := yaml.String() + + hasExtractStep := strings.Contains(extractContent, "name: Extract squid access logs") + hasUploadStep := strings.Contains(uploadContent, "name: Upload squid access logs") + + if tt.expectSteps { + if !hasExtractStep { + t.Errorf("Expected extract step to be generated but it wasn't") + } + if !hasUploadStep { + t.Errorf("Expected upload step to be generated but it wasn't") + } + } else { + if hasExtractStep { + t.Errorf("Expected no extract step but one was generated") + } + if hasUploadStep { + t.Errorf("Expected no upload step but one was generated") + } + } + }) + } +}