-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[analyze] Add client filter to detect successful unsafe HTTP requests (…
…#3305) * Move analyzer client to its own file * Add analyzer client filter to detect successful unsafe HTTP requests * Close response body in test
- Loading branch information
1 parent
1b59a5e
commit b2da2a6
Showing
3 changed files
with
221 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package analyzers | ||
|
||
import ( | ||
"fmt" | ||
"net/http" | ||
"os" | ||
"strings" | ||
"time" | ||
|
||
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config" | ||
) | ||
|
||
type AnalyzeClient struct { | ||
http.Client | ||
LoggingEnabled bool | ||
LogFile string | ||
} | ||
|
||
func CreateLogFileName(baseName string) string { | ||
// Get the current time | ||
currentTime := time.Now() | ||
|
||
// Format the time as "2024_06_30_07_15_30" | ||
timeString := currentTime.Format("2006_01_02_15_04_05") | ||
|
||
// Create the log file name | ||
logFileName := fmt.Sprintf("%s_%s.log", timeString, baseName) | ||
return logFileName | ||
} | ||
|
||
func NewAnalyzeClient(cfg *config.Config) *http.Client { | ||
client := &http.Client{ | ||
Transport: AnalyzerRoundTripper{parent: http.DefaultTransport}, | ||
} | ||
if cfg == nil || !cfg.LoggingEnabled { | ||
return client | ||
} | ||
return &http.Client{ | ||
Transport: LoggingRoundTripper{ | ||
parent: client.Transport, | ||
logFile: cfg.LogFile, | ||
}, | ||
} | ||
} | ||
|
||
type LoggingRoundTripper struct { | ||
parent http.RoundTripper | ||
// TODO: io.Writer | ||
logFile string | ||
} | ||
|
||
func (r LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | ||
startTime := time.Now() | ||
|
||
resp, parentErr := r.parent.RoundTrip(req) | ||
if resp == nil { | ||
return resp, parentErr | ||
} | ||
|
||
// TODO: JSON | ||
var logEntry string | ||
if parentErr != nil { | ||
logEntry = fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d, Error: %s\n", | ||
startTime.Format(time.RFC3339), | ||
req.Method, | ||
req.URL.Path, | ||
resp.StatusCode, | ||
parentErr.Error(), | ||
) | ||
} else { | ||
logEntry = fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d\n", | ||
startTime.Format(time.RFC3339), | ||
req.Method, | ||
req.URL.Path, | ||
resp.StatusCode, | ||
) | ||
} | ||
|
||
// Open log file in append mode. | ||
file, err := os.OpenFile(r.logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||
if err != nil { | ||
return resp, fmt.Errorf("failed to open log file: %w", err) | ||
} | ||
defer file.Close() | ||
|
||
// Write log entry to file. | ||
if _, err := file.WriteString(logEntry); err != nil { | ||
return resp, fmt.Errorf("failed to write log entry to file: %w", err) | ||
} | ||
|
||
return resp, parentErr | ||
} | ||
|
||
type AnalyzerRoundTripper struct { | ||
parent http.RoundTripper | ||
} | ||
|
||
func (r AnalyzerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | ||
resp, err := r.parent.RoundTrip(req) | ||
if err != nil || methodIsSafe(req.Method) { | ||
return resp, err | ||
} | ||
// Check that unsafe methods did NOT return a valid status code. | ||
if resp.StatusCode >= 200 && resp.StatusCode < 300 { | ||
return resp, fmt.Errorf("non-safe request returned success") | ||
} | ||
return resp, nil | ||
} | ||
|
||
// methodIsSafe is a helper method to check whether the HTTP method is safe according to MDN Web Docs. | ||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods#safe_idempotent_and_cacheable_request_methods | ||
func methodIsSafe(method string) bool { | ||
switch strings.ToUpper(method) { | ||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: | ||
return true | ||
default: | ||
return false | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package analyzers | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
) | ||
|
||
func TestAnalyzerClientUnsafeSuccess(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
method string | ||
expectedStatus int | ||
expectedError bool | ||
}{ | ||
{ | ||
name: "Safe method (GET)", | ||
method: http.MethodGet, | ||
expectedStatus: http.StatusOK, | ||
expectedError: false, | ||
}, | ||
{ | ||
name: "Safe method (HEAD)", | ||
method: http.MethodHead, | ||
expectedStatus: http.StatusOK, | ||
expectedError: false, | ||
}, | ||
{ | ||
name: "Safe method (OPTIONS)", | ||
method: http.MethodOptions, | ||
expectedStatus: http.StatusOK, | ||
expectedError: false, | ||
}, | ||
{ | ||
name: "Safe method (TRACE)", | ||
method: http.MethodTrace, | ||
expectedStatus: http.StatusOK, | ||
expectedError: false, | ||
}, | ||
{ | ||
name: "Unsafe method (POST) with success status", | ||
method: http.MethodPost, | ||
expectedStatus: http.StatusOK, | ||
expectedError: true, | ||
}, | ||
{ | ||
name: "Unsafe method (PUT) with success status", | ||
method: http.MethodPut, | ||
expectedStatus: http.StatusOK, | ||
expectedError: true, | ||
}, | ||
{ | ||
name: "Unsafe method (DELETE) with success status", | ||
method: http.MethodDelete, | ||
expectedStatus: http.StatusOK, | ||
expectedError: true, | ||
}, | ||
{ | ||
name: "Unsafe method (POST) with error status", | ||
method: http.MethodPost, | ||
expectedStatus: http.StatusInternalServerError, | ||
expectedError: false, | ||
}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
// Create a test server that returns the expected status code | ||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(tc.expectedStatus) | ||
})) | ||
defer server.Close() | ||
|
||
// Create a test request | ||
req, err := http.NewRequest(tc.method, server.URL, nil) | ||
if err != nil { | ||
t.Fatalf("Failed to create test request: %v", err) | ||
} | ||
|
||
// Create the AnalyzerRoundTripper with a test client | ||
client := NewAnalyzeClient(nil) | ||
|
||
// Perform the request | ||
resp, err := client.Do(req) | ||
if resp != nil { | ||
_ = resp.Body.Close() | ||
} | ||
|
||
// Check the error | ||
if err != nil && !tc.expectedError { | ||
t.Errorf("Unexpected error: %v", err) | ||
} else if err == nil && tc.expectedError { | ||
t.Errorf("Expected error, but got nil") | ||
} | ||
|
||
// Check the response status code | ||
if resp != nil && resp.StatusCode != tc.expectedStatus { | ||
t.Errorf("Expected status code: %d, but got: %d", tc.expectedStatus, resp.StatusCode) | ||
} | ||
}) | ||
} | ||
} |