diff --git a/client.go b/client.go index a3762567..1a370012 100644 --- a/client.go +++ b/client.go @@ -154,6 +154,7 @@ type Client struct { invalidHooks []ErrorHook panicHooks []ErrorHook rateLimiter RateLimiter + generateCurlOnDebug bool } // User type is to hold an username and password information @@ -443,12 +444,13 @@ func (c *Client) R() *Request { RawPathParams: map[string]string{}, Debug: c.Debug, - client: c, - multipartFiles: []*File{}, - multipartFields: []*MultipartField{}, - jsonEscapeHTML: c.jsonEscapeHTML, - log: c.log, - responseBodyLimit: c.ResponseBodyLimit, + client: c, + multipartFiles: []*File{}, + multipartFields: []*MultipartField{}, + jsonEscapeHTML: c.jsonEscapeHTML, + log: c.log, + responseBodyLimit: c.ResponseBodyLimit, + generateCurlOnDebug: c.generateCurlOnDebug, } return r } @@ -1130,6 +1132,23 @@ func (c *Client) DisableTrace() *Client { return c } +// EnableGenerateCurlOnDebug method enables the generation of CURL commands in the debug log. +// It works in conjunction with debug mode. +// +// NOTE: Use with care. +// - Potential to leak sensitive data in the debug log from [Request] and [Response]. +// - Beware of memory usage since the request body is reread. +func (c *Client) EnableGenerateCurlOnDebug() *Client { + c.generateCurlOnDebug = true + return c +} + +// DisableGenerateCurlOnDebug method disables the option set by [Client.EnableGenerateCurlOnDebug]. +func (c *Client) DisableGenerateCurlOnDebug() *Client { + c.generateCurlOnDebug = false + return c +} + // IsProxySet method returns the true is proxy is set from resty client otherwise // false. By default proxy is set from environment, refer to `http.ProxyFromEnvironment`. func (c *Client) IsProxySet() bool { diff --git a/examples/debug_curl_test.go b/curl_cmd_test.go similarity index 81% rename from examples/debug_curl_test.go rename to curl_cmd_test.go index 72bb68d9..55cac02b 100644 --- a/examples/debug_curl_test.go +++ b/curl_cmd_test.go @@ -1,4 +1,4 @@ -package examples +package resty import ( "io" @@ -6,16 +6,11 @@ import ( "os" "strings" "testing" - - "github.com/go-resty/resty/v2" ) // 1. Generate curl for unexecuted request(dry-run) -func TestGenerateUnexcutedCurl(t *testing.T) { - ts := createHttpbinServer(0) - defer ts.Close() - - req := resty.New().R(). +func TestGenerateUnexecutedCurl(t *testing.T) { + req := dclr(). SetBody(map[string]string{ "name": "Alex", }). @@ -25,7 +20,8 @@ func TestGenerateUnexcutedCurl(t *testing.T) { }, ) - curlCmdUnexecuted := req.GenerateCurlCommand() + curlCmdUnexecuted := req.EnableGenerateCurlOnDebug().GenerateCurlCommand() + req.DisableGenerateCurlOnDebug() if !strings.Contains(curlCmdUnexecuted, "Cookie: count=1") || !strings.Contains(curlCmdUnexecuted, "curl -X GET") || @@ -39,13 +35,14 @@ func TestGenerateUnexcutedCurl(t *testing.T) { // 2. Generate curl for executed request func TestGenerateExecutedCurl(t *testing.T) { - ts := createHttpbinServer(0) + ts := createPostServer(t) defer ts.Close() data := map[string]string{ "name": "Alex", } - req := resty.New().R(). + c := dcl() + req := c.R(). SetBody(data). SetCookies( []*http.Cookie{ @@ -53,14 +50,17 @@ func TestGenerateExecutedCurl(t *testing.T) { }, ) - url := ts.URL + "/post" + url := ts.URL + "/curl-cmd-post" resp, err := req. - EnableTrace(). + EnableGenerateCurlOnDebug(). Post(url) if err != nil { t.Fatal(err) } curlCmdExecuted := resp.Request.GenerateCurlCommand() + + c.DisableGenerateCurlOnDebug() + req.DisableGenerateCurlOnDebug() if !strings.Contains(curlCmdExecuted, "Cookie: count=1") || !strings.Contains(curlCmdExecuted, "curl -X POST") || !strings.Contains(curlCmdExecuted, `-d '{"name":"Alex"}'`) || @@ -73,7 +73,7 @@ func TestGenerateExecutedCurl(t *testing.T) { // 3. Generate curl in debug mode func TestDebugModeCurl(t *testing.T) { - ts := createHttpbinServer(0) + ts := createPostServer(t) defer ts.Close() // 1. Capture stderr @@ -81,7 +81,8 @@ func TestDebugModeCurl(t *testing.T) { defer restore() // 2. Build request - req := resty.New().R(). + c := New() + req := c.EnableGenerateCurlOnDebug().R(). SetBody(map[string]string{ "name": "Alex", }). @@ -92,12 +93,15 @@ func TestDebugModeCurl(t *testing.T) { ) // 3. Execute request: set debug mode - url := ts.URL + "/post" + url := ts.URL + "/curl-cmd-post" _, err := req.SetDebug(true).Post(url) if err != nil { t.Fatal(err) } + c.DisableGenerateCurlOnDebug() + req.DisableGenerateCurlOnDebug() + // 4. test output curl output := getOutput() if !strings.Contains(output, "Cookie: count=1") || diff --git a/examples/BUILD.bazel b/examples/BUILD.bazel deleted file mode 100644 index 849ea4e6..00000000 --- a/examples/BUILD.bazel +++ /dev/null @@ -1,10 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -go_test( - name = "examples_test", - srcs = [ - "debug_curl_test.go", - "server_test.go", - ], - deps = ["//:resty"], -) diff --git a/examples/server_test.go b/examples/server_test.go deleted file mode 100644 index 285f8b64..00000000 --- a/examples/server_test.go +++ /dev/null @@ -1,162 +0,0 @@ -package examples - -import ( - "bytes" - "encoding/json" - "fmt" - ioutil "io" - "net/http" - "net/http/httptest" - "net/url" - "strings" -) - -const maxMultipartMemory = 4 << 30 // 4MB - -// tlsCert: -// -// 0 No certificate -// 1 With self-signed certificate -// 2 With custom certificate from CA(todo) -func createHttpbinServer(tlsCert int) (ts *httptest.Server) { - ts = createTestServer(func(w http.ResponseWriter, r *http.Request) { - httpbinHandler(w, r) - }, tlsCert) - - return ts -} - -func httpbinHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - body, _ := ioutil.ReadAll(r.Body) - r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) // important!! - m := map[string]interface{}{ - "args": parseRequestArgs(r), - "headers": dumpRequestHeader(r), - "data": string(body), - "json": nil, - "form": map[string]string{}, - "files": map[string]string{}, - "method": r.Method, - "origin": r.RemoteAddr, - "url": r.URL.String(), - } - - // 1. parse text/plain - if strings.HasPrefix(r.Header.Get("Content-Type"), "text/plain") { - m["data"] = string(body) - } - - // 2. parse application/json - if strings.HasPrefix(r.Header.Get("Content-Type"), "application/json") { - var data interface{} - if err := json.Unmarshal(body, &data); err != nil { - m["err"] = err.Error() - } else { - m["json"] = data - } - } - - // 3. parse application/x-www-form-urlencoded - if strings.HasPrefix(r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") { - m["form"] = parseQueryString(string(body)) - } - - // 4. parse multipart/form-data - if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { - form, files := readMultipartForm(r) - m["form"] = form - m["files"] = files - } - buf, _ := json.Marshal(m) - _, _ = w.Write(buf) -} - -func readMultipartForm(r *http.Request) (map[string]string, map[string]string) { - if err := r.ParseMultipartForm(maxMultipartMemory); err != nil { - if err != http.ErrNotMultipart { - panic(fmt.Sprintf("error on parse multipart form array: %v", err)) - } - } - // parse form data - formData := make(map[string]string) - for k, vs := range r.PostForm { - for _, v := range vs { - formData[k] = v - } - } - // parse files - files := make(map[string]string) - if r.MultipartForm != nil && r.MultipartForm.File != nil { - for key, fhs := range r.MultipartForm.File { - // if len(fhs)>0 - // f, err := fhs[0].Open() - files[key] = fhs[0].Filename - } - } - return formData, files -} - -func dumpRequestHeader(req *http.Request) string { - var res strings.Builder - headers := sortHeaders(req) - for _, kv := range headers { - res.WriteString(kv[0] + ": " + kv[1] + "\n") - } - return res.String() -} - -// sortHeaders -func sortHeaders(request *http.Request) [][2]string { - headers := [][2]string{} - for k, vs := range request.Header { - for _, v := range vs { - headers = append(headers, [2]string{k, v}) - } - } - n := len(headers) - for i := 0; i < n; i++ { - for j := n - 1; j > i; j-- { - jj := j - 1 - h1, h2 := headers[j], headers[jj] - if h1[0] < h2[0] { - headers[jj], headers[j] = headers[j], headers[jj] - } - } - } - return headers -} - -func parseRequestArgs(request *http.Request) map[string]string { - query := request.URL.RawQuery - return parseQueryString(query) -} - -func parseQueryString(query string) map[string]string { - params := map[string]string{} - paramsList, _ := url.ParseQuery(query) - for key, vals := range paramsList { - // params[key] = vals[len(vals)-1] - params[key] = strings.Join(vals, ",") - } - return params -} - -/* -* - - tlsCert: - 0 no certificate - 1 with self-signed certificate - 2 with custom certificate from CA(todo) -*/ -func createTestServer(fn func(w http.ResponseWriter, r *http.Request), tlsCert int) (ts *httptest.Server) { - if tlsCert == 0 { - // 1. http test server - ts = httptest.NewServer(http.HandlerFunc(fn)) - } else if tlsCert == 1 { - // 2. https test server: https://stackoverflow.com/questions/54899550/create-https-test-server-for-any-client - ts = httptest.NewUnstartedServer(http.HandlerFunc(fn)) - ts.StartTLS() - } - return ts -} diff --git a/middleware.go b/middleware.go index 805f4ce9..84f0fdea 100644 --- a/middleware.go +++ b/middleware.go @@ -308,7 +308,7 @@ func addCredentials(c *Client, r *Request) error { } func createCurlCmd(c *Client, r *Request) (err error) { - if r.trace { + if r.Debug && r.generateCurlOnDebug { if r.resultCurlCmd == nil { r.resultCurlCmd = new(string) } @@ -338,10 +338,14 @@ func requestLogger(c *Client, r *Request) error { } } - reqLog := "\n==============================================================================\n" + - "~~~ REQUEST(curl) ~~~\n" + - fmt.Sprintf("CURL:\n %v\n", buildCurlRequest(r.RawRequest, r.client.httpClient.Jar)) + - "~~~ REQUEST ~~~\n" + + reqLog := "\n==============================================================================\n" + + if r.Debug && r.generateCurlOnDebug { + reqLog += "~~~ REQUEST(CURL) ~~~\n" + + fmt.Sprintf(" %v\n", *r.resultCurlCmd) + } + + reqLog += "~~~ REQUEST ~~~\n" + fmt.Sprintf("%s %s %s\n", r.Method, rr.URL.RequestURI(), rr.Proto) + fmt.Sprintf("HOST : %s\n", rr.URL.Host) + fmt.Sprintf("HEADERS:\n%s\n", composeHeaders(c, r, rl.Header)) + diff --git a/request.go b/request.go index cfbe89b4..8ad4a6ad 100644 --- a/request.go +++ b/request.go @@ -74,24 +74,27 @@ type Request struct { multipartFields []*MultipartField retryConditions []RetryConditionFunc responseBodyLimit int + generateCurlOnDebug bool } // Generate curl command for the request. func (r *Request) GenerateCurlCommand() string { + if !(r.Debug && r.generateCurlOnDebug) { + return "" + } + if r.resultCurlCmd != nil { return *r.resultCurlCmd - } else { - if r.RawRequest == nil { - r.client.executeBefore(r) // mock with r.Get("/") - } - if r.resultCurlCmd == nil { - r.resultCurlCmd = new(string) - } - if *r.resultCurlCmd == "" { - *r.resultCurlCmd = buildCurlRequest(r.RawRequest, r.client.httpClient.Jar) - } - return *r.resultCurlCmd } + + if r.RawRequest == nil { + r.client.executeBefore(r) // mock with r.Get("/") + } + if r.resultCurlCmd == nil { + r.resultCurlCmd = new(string) + } + *r.resultCurlCmd = buildCurlRequest(r.RawRequest, r.client.httpClient.Jar) + return *r.resultCurlCmd } // Context method returns the Context if its already set in request @@ -834,6 +837,24 @@ func (r *Request) EnableTrace() *Request { return r } +// EnableGenerateCurlOnDebug method enables the generation of CURL commands in the debug log. +// It works in conjunction with debug mode. It overrides the options set by the [Client]. +// +// NOTE: Use with care. +// - Potential to leak sensitive data in the debug log from [Request] and [Response]. +// - Beware of memory usage since the request body is reread. +func (r *Request) EnableGenerateCurlOnDebug() *Request { + r.generateCurlOnDebug = true + return r +} + +// DisableGenerateCurlOnDebug method disables the option set by [Request.EnableGenerateCurlOnDebug]. +// It overrides the options set by the [Client]. +func (r *Request) DisableGenerateCurlOnDebug() *Request { + r.generateCurlOnDebug = false + return r +} + // TraceInfo method returns the trace info for the request. // If either the Client or Request EnableTrace function has not been called // prior to the request being made, an empty TraceInfo object will be returned. diff --git a/resty_test.go b/resty_test.go index 95ef0b51..e10421c1 100644 --- a/resty_test.go +++ b/resty_test.go @@ -308,6 +308,16 @@ func createPostServer(t *testing.T) *httptest.Server { body, _ := io.ReadAll(r.Body) assertEqual(t, r.URL.Query().Get("body"), string(body)) w.WriteHeader(http.StatusOK) + case "/curl-cmd-post": + cookie := http.Cookie{ + Name: "testserver", + Domain: "localhost", + Path: "/", + Expires: time.Now().AddDate(0, 0, 1), + Value: "yes", + } + http.SetCookie(w, &cookie) + w.WriteHeader(http.StatusOK) } } }) diff --git a/util_curl.go b/util_curl.go index e50c3b3a..2bd91270 100644 --- a/util_curl.go +++ b/util_curl.go @@ -14,6 +14,7 @@ import ( func buildCurlRequest(req *http.Request, httpCookiejar http.CookieJar) (curl string) { // 1. Generate curl raw headers + curl = "curl -X " + req.Method + " " // req.Host + req.URL.Path + "?" + req.URL.RawQuery + " " + req.Proto + " " headers := dumpCurlHeaders(req) @@ -22,6 +23,7 @@ func buildCurlRequest(req *http.Request, httpCookiejar http.CookieJar) (curl str } // 2. Generate curl cookies + // TODO validate this block of code, I think its not required since cookie captured via Headers if cookieJar, ok := httpCookiejar.(*cookiejar.Jar); ok { cookies := cookieJar.Cookies(req.URL) if len(cookies) > 0 {