Skip to content

Commit

Permalink
[analyze] Add client filter to detect successful unsafe HTTP requests (
Browse files Browse the repository at this point in the history
…#3305)

* Move analyzer client to its own file

* Add analyzer client filter to detect successful unsafe HTTP requests

* Close response body in test
  • Loading branch information
mcastorina authored Sep 18, 2024
1 parent 1b59a5e commit b2da2a6
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 66 deletions.
66 changes: 0 additions & 66 deletions pkg/analyzer/analyzers/analyzers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@ package analyzers
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"

"github.com/fatih/color"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/pb/analyzerpb"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)
Expand Down Expand Up @@ -181,68 +177,6 @@ var YellowWriter = color.New(color.FgYellow).SprintFunc()
var RedWriter = color.New(color.FgRed).SprintFunc()
var DefaultWriter = color.New().SprintFunc()

type AnalyzeClient struct {
http.Client
LoggingEnabled bool
LogFile string
}

func CreateLogFileName(baseName string) string {
// Get the current time
currentTime := time.Now()

// Format the time as "2024_06_30_07_15_30"
timeString := currentTime.Format("2006_01_02_15_04_05")

// Create the log file name
logFileName := fmt.Sprintf("%s_%s.log", timeString, baseName)
return logFileName
}

func NewAnalyzeClient(cfg *config.Config) *http.Client {
if cfg == nil || !cfg.LoggingEnabled {
return &http.Client{}
}
return &http.Client{
Transport: LoggingRoundTripper{
parent: http.DefaultTransport,
logFile: cfg.LogFile,
},
}
}

type LoggingRoundTripper struct {
parent http.RoundTripper
// TODO: io.Writer
logFile string
}

func (r LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
startTime := time.Now()

resp, err := r.parent.RoundTrip(req)
if err != nil {
return resp, err
}

// TODO: JSON
logEntry := fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d\n", startTime.Format(time.RFC3339), req.Method, req.URL.Path, resp.StatusCode)

// Open log file in append mode.
file, err := os.OpenFile(r.logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return resp, fmt.Errorf("failed to open log file: %w", err)
}
defer file.Close()

// Write log entry to file.
if _, err := file.WriteString(logEntry); err != nil {
return resp, fmt.Errorf("failed to write log entry to file: %w", err)
}

return resp, nil
}

// BindAllPermissions creates a Binding for each permission to the given
// resource.
func BindAllPermissions(r Resource, perms ...Permission) []Binding {
Expand Down
119 changes: 119 additions & 0 deletions pkg/analyzer/analyzers/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package analyzers

import (
"fmt"
"net/http"
"os"
"strings"
"time"

"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config"
)

type AnalyzeClient struct {
http.Client
LoggingEnabled bool
LogFile string
}

func CreateLogFileName(baseName string) string {
// Get the current time
currentTime := time.Now()

// Format the time as "2024_06_30_07_15_30"
timeString := currentTime.Format("2006_01_02_15_04_05")

// Create the log file name
logFileName := fmt.Sprintf("%s_%s.log", timeString, baseName)
return logFileName
}

func NewAnalyzeClient(cfg *config.Config) *http.Client {
client := &http.Client{
Transport: AnalyzerRoundTripper{parent: http.DefaultTransport},
}
if cfg == nil || !cfg.LoggingEnabled {
return client
}
return &http.Client{
Transport: LoggingRoundTripper{
parent: client.Transport,
logFile: cfg.LogFile,
},
}
}

type LoggingRoundTripper struct {
parent http.RoundTripper
// TODO: io.Writer
logFile string
}

func (r LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
startTime := time.Now()

resp, parentErr := r.parent.RoundTrip(req)
if resp == nil {
return resp, parentErr
}

// TODO: JSON
var logEntry string
if parentErr != nil {
logEntry = fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d, Error: %s\n",
startTime.Format(time.RFC3339),
req.Method,
req.URL.Path,
resp.StatusCode,
parentErr.Error(),
)
} else {
logEntry = fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d\n",
startTime.Format(time.RFC3339),
req.Method,
req.URL.Path,
resp.StatusCode,
)
}

// Open log file in append mode.
file, err := os.OpenFile(r.logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return resp, fmt.Errorf("failed to open log file: %w", err)
}
defer file.Close()

// Write log entry to file.
if _, err := file.WriteString(logEntry); err != nil {
return resp, fmt.Errorf("failed to write log entry to file: %w", err)
}

return resp, parentErr
}

type AnalyzerRoundTripper struct {
parent http.RoundTripper
}

func (r AnalyzerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := r.parent.RoundTrip(req)
if err != nil || methodIsSafe(req.Method) {
return resp, err
}
// Check that unsafe methods did NOT return a valid status code.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return resp, fmt.Errorf("non-safe request returned success")
}
return resp, nil
}

// methodIsSafe is a helper method to check whether the HTTP method is safe according to MDN Web Docs.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods#safe_idempotent_and_cacheable_request_methods
func methodIsSafe(method string) bool {
switch strings.ToUpper(method) {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
102 changes: 102 additions & 0 deletions pkg/analyzer/analyzers/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package analyzers

import (
"net/http"
"net/http/httptest"
"testing"
)

func TestAnalyzerClientUnsafeSuccess(t *testing.T) {
testCases := []struct {
name string
method string
expectedStatus int
expectedError bool
}{
{
name: "Safe method (GET)",
method: http.MethodGet,
expectedStatus: http.StatusOK,
expectedError: false,
},
{
name: "Safe method (HEAD)",
method: http.MethodHead,
expectedStatus: http.StatusOK,
expectedError: false,
},
{
name: "Safe method (OPTIONS)",
method: http.MethodOptions,
expectedStatus: http.StatusOK,
expectedError: false,
},
{
name: "Safe method (TRACE)",
method: http.MethodTrace,
expectedStatus: http.StatusOK,
expectedError: false,
},
{
name: "Unsafe method (POST) with success status",
method: http.MethodPost,
expectedStatus: http.StatusOK,
expectedError: true,
},
{
name: "Unsafe method (PUT) with success status",
method: http.MethodPut,
expectedStatus: http.StatusOK,
expectedError: true,
},
{
name: "Unsafe method (DELETE) with success status",
method: http.MethodDelete,
expectedStatus: http.StatusOK,
expectedError: true,
},
{
name: "Unsafe method (POST) with error status",
method: http.MethodPost,
expectedStatus: http.StatusInternalServerError,
expectedError: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a test server that returns the expected status code
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.expectedStatus)
}))
defer server.Close()

// Create a test request
req, err := http.NewRequest(tc.method, server.URL, nil)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}

// Create the AnalyzerRoundTripper with a test client
client := NewAnalyzeClient(nil)

// Perform the request
resp, err := client.Do(req)
if resp != nil {
_ = resp.Body.Close()
}

// Check the error
if err != nil && !tc.expectedError {
t.Errorf("Unexpected error: %v", err)
} else if err == nil && tc.expectedError {
t.Errorf("Expected error, but got nil")
}

// Check the response status code
if resp != nil && resp.StatusCode != tc.expectedStatus {
t.Errorf("Expected status code: %d, but got: %d", tc.expectedStatus, resp.StatusCode)
}
})
}
}

0 comments on commit b2da2a6

Please sign in to comment.