Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 207 additions & 96 deletions pkg/cli/access_log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (

"github.com/githubnext/gh-aw/pkg/stringutil"
"github.com/githubnext/gh-aw/pkg/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAccessLogParsing(t *testing.T) {
Expand All @@ -22,44 +24,29 @@ func TestAccessLogParsing(t *testing.T) {
// Write test log file
accessLogPath := filepath.Join(tempDir, "access.log")
err := os.WriteFile(accessLogPath, []byte(testLogContent), 0644)
if err != nil {
t.Fatalf("Failed to create test access.log: %v", err)
}
require.NoError(t, err, "should create test access log file")

// Test parsing
analysis, err := parseSquidAccessLog(accessLogPath, false)
if err != nil {
t.Fatalf("Failed to parse access log: %v", err)
}
require.NoError(t, err, "should parse valid squid access log")
require.NotNil(t, analysis, "should return analysis result")

// Verify results
if analysis.TotalRequests != 4 {
t.Errorf("Expected 4 total requests, got %d", analysis.TotalRequests)
}

if analysis.AllowedCount != 2 {
t.Errorf("Expected 2 allowed requests, got %d", analysis.AllowedCount)
}

if analysis.BlockedCount != 2 {
t.Errorf("Expected 2 denied requests, got %d", analysis.BlockedCount)
}
assert.Equal(t, 4, analysis.TotalRequests, "should count all log entries")
assert.Equal(t, 2, analysis.AllowedCount, "should count allowed requests")
assert.Equal(t, 2, analysis.BlockedCount, "should count blocked requests")

// Check allowed domains
expectedAllowed := []string{"api.github.com", "example.com"}
if len(analysis.AllowedDomains) != len(expectedAllowed) {
t.Errorf("Expected %d allowed domains, got %d", len(expectedAllowed), len(analysis.AllowedDomains))
}
assert.Len(t, analysis.AllowedDomains, len(expectedAllowed), "should extract correct number of allowed domains")
}

func TestMultipleAccessLogAnalysis(t *testing.T) {
// Create a temporary directory for the test
tempDir := testutil.TempDir(t, "test-*")
accessLogsDir := filepath.Join(tempDir, "access.log")
err := os.MkdirAll(accessLogsDir, 0755)
if err != nil {
t.Fatalf("Failed to create access.log directory: %v", err)
}
require.NoError(t, err, "should create access.log directory")

// Create test access log content for multiple MCP servers
fetchLogContent := `1701234567.123 180 192.168.1.100 TCP_MISS/200 1234 GET http://example.com/api/data - HIER_DIRECT/93.184.216.34 text/html
Expand All @@ -71,94 +58,62 @@ func TestMultipleAccessLogAnalysis(t *testing.T) {
// Write separate log files for different MCP servers
fetchLogPath := filepath.Join(accessLogsDir, "access-fetch.log")
err = os.WriteFile(fetchLogPath, []byte(fetchLogContent), 0644)
if err != nil {
t.Fatalf("Failed to create test access-fetch.log: %v", err)
}
require.NoError(t, err, "should create test access-fetch.log")

browserLogPath := filepath.Join(accessLogsDir, "access-browser.log")
err = os.WriteFile(browserLogPath, []byte(browserLogContent), 0644)
if err != nil {
t.Fatalf("Failed to create test access-browser.log: %v", err)
}
require.NoError(t, err, "should create test access-browser.log")

// Test analysis of multiple access logs
analysis, err := analyzeMultipleAccessLogs(accessLogsDir, false)
if err != nil {
t.Fatalf("Failed to analyze multiple access logs: %v", err)
}
require.NoError(t, err, "should analyze multiple access logs")
require.NotNil(t, analysis, "should return analysis result")

// Verify aggregated results
if analysis.TotalRequests != 4 {
t.Errorf("Expected 4 total requests, got %d", analysis.TotalRequests)
}

if analysis.AllowedCount != 2 {
t.Errorf("Expected 2 allowed requests, got %d", analysis.AllowedCount)
}

if analysis.BlockedCount != 2 {
t.Errorf("Expected 2 denied requests, got %d", analysis.BlockedCount)
}
assert.Equal(t, 4, analysis.TotalRequests, "should count all requests from multiple logs")
assert.Equal(t, 2, analysis.AllowedCount, "should count allowed requests")
assert.Equal(t, 2, analysis.BlockedCount, "should count blocked requests")

// Check allowed domains
expectedAllowed := []string{"api.github.com", "example.com"}
if len(analysis.AllowedDomains) != len(expectedAllowed) {
t.Errorf("Expected %d allowed domains, got %d", len(expectedAllowed), len(analysis.AllowedDomains))
}
assert.Len(t, analysis.AllowedDomains, len(expectedAllowed), "should extract correct number of allowed domains")

// Check blocked domains
expectedDenied := []string{"github.com", "malicious.site"}
if len(analysis.BlockedDomains) != len(expectedDenied) {
t.Errorf("Expected %d blocked domains, got %d", len(expectedDenied), len(analysis.BlockedDomains))
}
assert.Len(t, analysis.BlockedDomains, len(expectedDenied), "should extract correct number of blocked domains")
}

func TestAnalyzeAccessLogsDirectory(t *testing.T) {
// Create a temporary directory structure
tempDir := testutil.TempDir(t, "test-*")

// Test case 1: Multiple access logs in access-logs subdirectory
accessLogsDir := filepath.Join(tempDir, "run1", "access.log")
err := os.MkdirAll(accessLogsDir, 0755)
if err != nil {
t.Fatalf("Failed to create access.log directory: %v", err)
}

fetchLogContent := `1701234567.123 180 192.168.1.100 TCP_MISS/200 1234 GET http://example.com/api/data - HIER_DIRECT/93.184.216.34 text/html`
fetchLogPath := filepath.Join(accessLogsDir, "access-fetch.log")
err = os.WriteFile(fetchLogPath, []byte(fetchLogContent), 0644)
if err != nil {
t.Fatalf("Failed to create test access-fetch.log: %v", err)
}

analysis, err := analyzeAccessLogs(filepath.Join(tempDir, "run1"), false)
if err != nil {
t.Fatalf("Failed to analyze access logs: %v", err)
}

if analysis == nil {
t.Fatal("Expected analysis result, got nil")
}

if analysis.TotalRequests != 1 {
t.Errorf("Expected 1 total request, got %d", analysis.TotalRequests)
}

// Test case 2: No access logs
run2Dir := filepath.Join(tempDir, "run2")
err = os.MkdirAll(run2Dir, 0755)
if err != nil {
t.Fatalf("Failed to create run2 directory: %v", err)
}

analysis, err = analyzeAccessLogs(run2Dir, false)
if err != nil {
t.Fatalf("Failed to analyze no access logs: %v", err)
}

if analysis != nil {
t.Errorf("Expected nil analysis for no access logs, got %+v", analysis)
}
t.Run("multiple access logs in subdirectory", func(t *testing.T) {
// Test case 1: Multiple access logs in access-logs subdirectory
accessLogsDir := filepath.Join(tempDir, "run1", "access.log")
err := os.MkdirAll(accessLogsDir, 0755)
require.NoError(t, err, "should create access.log directory")

fetchLogContent := `1701234567.123 180 192.168.1.100 TCP_MISS/200 1234 GET http://example.com/api/data - HIER_DIRECT/93.184.216.34 text/html`
fetchLogPath := filepath.Join(accessLogsDir, "access-fetch.log")
err = os.WriteFile(fetchLogPath, []byte(fetchLogContent), 0644)
require.NoError(t, err, "should create test access-fetch.log")

analysis, err := analyzeAccessLogs(filepath.Join(tempDir, "run1"), false)
require.NoError(t, err, "should analyze access logs")
require.NotNil(t, analysis, "should return analysis for valid logs")
assert.Equal(t, 1, analysis.TotalRequests, "should count request from log file")
})

t.Run("no access logs - returns nil", func(t *testing.T) {
// Test case 2: No access logs
run2Dir := filepath.Join(tempDir, "run2")
err := os.MkdirAll(run2Dir, 0755)
require.NoError(t, err, "should create run2 directory")

analysis, err := analyzeAccessLogs(run2Dir, false)
require.NoError(t, err, "should not error when no logs present")
assert.Nil(t, analysis, "should return nil when no logs found")
})
}

func TestExtractDomainFromURL(t *testing.T) {
Expand All @@ -173,10 +128,166 @@ func TestExtractDomainFromURL(t *testing.T) {
{"http://sub.domain.com:8080/path", "sub.domain.com"},
}

for _, test := range tests {
result := stringutil.ExtractDomainFromURL(test.url)
if result != test.expected {
t.Errorf("stringutil.ExtractDomainFromURL(%q) = %q, expected %q", test.url, result, test.expected)
}
for _, tt := range tests {
t.Run(tt.url, func(t *testing.T) {
result := stringutil.ExtractDomainFromURL(tt.url)
assert.Equal(t, tt.expected, result, "should extract correct domain from URL")
})
}
}

func TestParseSquidLogLine(t *testing.T) {
tests := []struct {
name string
line string
expected *AccessLogEntry
shouldErr bool
}{
{
name: "valid squid log line",
line: "1701234567.123 180 192.168.1.100 TCP_MISS/200 1234 GET http://example.com/api - HIER_DIRECT/93.184.216.34 text/html",
expected: &AccessLogEntry{
Timestamp: "1701234567.123",
Duration: "180",
ClientIP: "192.168.1.100",
Status: "TCP_MISS/200",
Size: "1234",
Method: "GET",
URL: "http://example.com/api",
User: "-",
Hierarchy: "HIER_DIRECT/93.184.216.34",
Type: "text/html",
},
shouldErr: false,
},
{
name: "valid denied request",
line: "1701234568.456 250 192.168.1.100 TCP_DENIED/403 0 CONNECT github.com:443 - HIER_NONE/- -",
expected: &AccessLogEntry{
Timestamp: "1701234568.456",
Duration: "250",
ClientIP: "192.168.1.100",
Status: "TCP_DENIED/403",
Size: "0",
Method: "CONNECT",
URL: "github.com:443",
User: "-",
Hierarchy: "HIER_NONE/-",
Type: "-",
},
shouldErr: false,
},
{
name: "insufficient fields - should error",
line: "1701234567.123 180 192.168.1.100",
shouldErr: true,
},
{
name: "empty line",
line: "",
shouldErr: true,
},
{
name: "exactly 9 fields - should error",
line: "1701234567.123 180 192.168.1.100 TCP_MISS/200 1234 GET http://example.com/api - HIER_DIRECT/93.184.216.34",
shouldErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseSquidLogLine(tt.line)

if tt.shouldErr {
require.Error(t, err, "should return error for invalid line")
assert.Nil(t, result, "should not return entry on error")
} else {
require.NoError(t, err, "should parse valid log line")
require.NotNil(t, result, "should return parsed entry")
assert.Equal(t, tt.expected.Timestamp, result.Timestamp, "timestamp should match")
assert.Equal(t, tt.expected.Duration, result.Duration, "duration should match")
assert.Equal(t, tt.expected.ClientIP, result.ClientIP, "client IP should match")
assert.Equal(t, tt.expected.Status, result.Status, "status should match")
assert.Equal(t, tt.expected.Size, result.Size, "size should match")
assert.Equal(t, tt.expected.Method, result.Method, "method should match")
assert.Equal(t, tt.expected.URL, result.URL, "URL should match")
assert.Equal(t, tt.expected.User, result.User, "user should match")
assert.Equal(t, tt.expected.Hierarchy, result.Hierarchy, "hierarchy should match")
assert.Equal(t, tt.expected.Type, result.Type, "type should match")
}
})
}
}

func TestAddMetrics(t *testing.T) {
tests := []struct {
name string
base *DomainAnalysis
toAdd LogAnalysis
expected *DomainAnalysis
}{
{
name: "add valid domain analysis",
base: &DomainAnalysis{
TotalRequests: 10,
AllowedCount: 8,
BlockedCount: 2,
},
toAdd: &DomainAnalysis{
TotalRequests: 5,
AllowedCount: 4,
BlockedCount: 1,
},
expected: &DomainAnalysis{
TotalRequests: 15,
AllowedCount: 12,
BlockedCount: 3,
},
},
{
name: "add zero values",
base: &DomainAnalysis{
TotalRequests: 10,
AllowedCount: 8,
BlockedCount: 2,
},
toAdd: &DomainAnalysis{
TotalRequests: 0,
AllowedCount: 0,
BlockedCount: 0,
},
expected: &DomainAnalysis{
TotalRequests: 10,
AllowedCount: 8,
BlockedCount: 2,
},
},
{
name: "add to empty base",
base: &DomainAnalysis{
TotalRequests: 0,
AllowedCount: 0,
BlockedCount: 0,
},
toAdd: &DomainAnalysis{
TotalRequests: 5,
AllowedCount: 3,
BlockedCount: 2,
},
expected: &DomainAnalysis{
TotalRequests: 5,
AllowedCount: 3,
BlockedCount: 2,
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.base.AddMetrics(tt.toAdd)
assert.Equal(t, tt.expected.TotalRequests, tt.base.TotalRequests, "total requests should match")
assert.Equal(t, tt.expected.AllowedCount, tt.base.AllowedCount, "allowed count should match")
assert.Equal(t, tt.expected.BlockedCount, tt.base.BlockedCount, "blocked count should match")
})
}
}
Loading