Skip to content

Commit

Permalink
fixing tests related to context passthrough
Browse files Browse the repository at this point in the history
  • Loading branch information
tednaleid committed Feb 17, 2024
1 parent de6651e commit 6420fb5
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 69 deletions.
31 changes: 18 additions & 13 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ func (buildInfo BuildInfo) ToString() string {
return buildInfo.Version + " " + buildInfo.Commit + " " + buildInfo.Date
}

// RunCommand allows us to mock out the args and input/output streams for testing
func RunCommand(
buildInfo BuildInfo,
args []string,
in io.Reader,
err io.Writer,
out io.Writer,
runBlock func(context *execcontext.Context),
) error {
command := setupCommand(buildInfo, in, err, out, runBlock)
return command.Run(ctx.Background(), args)
}

// create the cli.Command so it is wired up with the given in/stdout/stderr and runBlock
// this lets us mock out the input/output streams
// runBlock is where we can wire up the request and response workers and start processing (or mock for tests)
func setupCommand(
buildInfo BuildInfo,
in io.Reader,
Expand Down Expand Up @@ -194,19 +210,8 @@ func setupCommand(
}
}

// RunCommand allows us to mock out the args and input/output streams for testing
func RunCommand(
buildInfo BuildInfo,
args []string,
in io.Reader,
err io.Writer,
out io.Writer,
runBlock func(context *execcontext.Context),
) error {
command := setupCommand(buildInfo, in, err, out, runBlock)
return command.Run(ctx.Background(), args)
}

// ProcessRequests wires up the request and response workers with channels
// and asks the parser to start sending requests
func ProcessRequests(context *execcontext.Context) {
requestsWithContextChannel := make(chan parser.RequestWithContext)
responsesWithContextChannel := make(chan *responses.ResponseWithContext)
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ go 1.21.5

require (
github.com/stretchr/testify v1.8.4
github.com/urfave/cli/v3 v3.0.0-alpha7
github.com/urfave/cli/v3 v3.0.0-alpha9
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
github.com/xrash/smetrics v0.0.0-20231213231151-1d8dd44e695e // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/urfave/cli/v3 v3.0.0-alpha7 h1:dj+WjtBA2StTinGwue+o2oyFFvo8aQ/AGb5MYvUqk/8=
github.com/urfave/cli/v3 v3.0.0-alpha7/go.mod h1:0kK/RUFHyh+yIKSfWxwheGndfnrvYSmYFVeKCh03ZUc=
github.com/urfave/cli/v3 v3.0.0-alpha9 h1:P0RMy5fQm1AslQS+XCmy9UknDXctOmG/q/FZkUFnJSo=
github.com/urfave/cli/v3 v3.0.0-alpha9/go.mod h1:0kK/RUFHyh+yIKSfWxwheGndfnrvYSmYFVeKCh03ZUc=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
github.com/xrash/smetrics v0.0.0-20231213231151-1d8dd44e695e h1:+SOyEddqYF09QP7vr7CgJ1eti3pY9Fn3LHO1M1r/0sI=
github.com/xrash/smetrics v0.0.0-20231213231151-1d8dd44e695e/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/urfave/cli.v1 v1.20.0 h1:NdAVW6RYxDif9DhDHaAortIu956m2c0v+09AZBPTbE0=
Expand Down
8 changes: 7 additions & 1 deletion parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ func SendUrlsRequests(
if len(record) > 0 {
url := record[0]
request := createRequest(url, nil, requestMethod, staticHeaders)
requestsWithContext <- RequestWithContext{Request: request, RequestContext: record[1:]}
recordContext := record[1:]

if len(recordContext) == 0 {
recordContext = nil
}

requestsWithContext <- RequestWithContext{Request: request, RequestContext: recordContext}
}
}
return nil
Expand Down
7 changes: 3 additions & 4 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ func TestSendGetRequestUrlsHaveDefaultHeaders(t *testing.T) {
assert.Equal(t, "https://example.com/bar", request.URL.String(), "expected url")
assert.Equal(t, "GET", request.Method, "expected method")
assert.Equal(t, request.Header["Connection"][0], "keep-alive", "Connection header")
assert.Equal(t, []string{}, requestContext, "expected nil context")
assert.Equal(t, []string(nil), requestContext, "expected nil string context")

secondRequestWithContext := <-requestsWithContext
secondRequest := secondRequestWithContext.Request
secondRequestContext := secondRequestWithContext.RequestContext

assert.Equal(t, "https://example.com/qux", secondRequest.URL.String(), "expected url")
assert.Equal(t, []string{}, secondRequestContext, "expected nil context")

assert.Equal(t, []string(nil), secondRequestContext, "expected nil string context")
}

func TestSendGetRequestUrlsAddGivenHeaders(t *testing.T) {
Expand All @@ -65,7 +64,7 @@ func TestSendGetRequestUrlsAddGivenHeaders(t *testing.T) {
assert.Equal(t, request.Header["Connection"][0], "keep-alive", "Connection header")
assert.Equal(t, request.Header["X-Test"][0], "foo")
assert.Equal(t, request.Header["X-Test2"][0], "bar")
assert.Equal(t, []string{}, requestContext, "expected nil context")
assert.Equal(t, []string(nil), requestContext, "expected nil string context")
}

func TestSendRequestsHasRaggedRequestContext(t *testing.T) {
Expand Down
51 changes: 30 additions & 21 deletions responses/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@ func StartResponseWorkers(responsesWithContext <-chan *ResponseWithContext, cont

for i := 1; i <= context.ResponseWorkers; i++ {
go func() {
emitResponseWithContext := determineEmitResponseWithContextFn(context)
var emitResponse emitResponseWithContextFn
if context.JsonEnvelope {
emitResponse = determineEmitJsonResponseWithContextFn(context.ResponseBody)
} else {
emitResponse = determineEmitResponseFn(context.ResponseBody)
}

if context.WriteFiles {
responseSavingWorker(responsesWithContext, context, emitResponseWithContext)
responseSavingWorker(responsesWithContext, context, emitResponse)
} else {
responsePrintingWorker(responsesWithContext, context, emitResponseWithContext)
responsePrintingWorker(responsesWithContext, context, emitResponse)
}
responseWaitGroup.Done()
}()
Expand Down Expand Up @@ -98,23 +103,24 @@ func responsePrintingWorker(
type emitResponseFn func(response *http.Response, out io.Writer) (bytesWritten int64, err error)
type emitResponseWithContextFn func(responseWithContext *ResponseWithContext, out io.Writer) (bytesWritten int64, err error)

// we might wrap the body response in a JSON envelope
func determineEmitResponseWithContextFn(context *execcontext.Context) emitResponseWithContextFn {
bodyResponseFn := determineEmitBodyResponseFn(context)

if context.JsonEnvelope {
return jsonEnvelopeResponseFn(bodyResponseFn, context)
}
// emits the response without the context, context is only supported in JSON output
func determineEmitResponseFn(responseBody config.ResponseBodyType) emitResponseWithContextFn {
bodyResponseFn := determineEmitBodyResponseFn(responseBody)

// not emitting the context, just the body response
return func(responseWithContext *ResponseWithContext, out io.Writer) (bytesWritten int64, err error) {
return bodyResponseFn(responseWithContext.Response, out)
}
}

// surrounds the responsesBody with a JSON envelope that includes the context of the request (if any)
func determineEmitJsonResponseWithContextFn(responseBody config.ResponseBodyType) emitResponseWithContextFn {
bodyResponseFn := determineEmitBodyResponseFn(responseBody)
return jsonEnvelopeResponseFn(bodyResponseFn, responseBody)
}

// returns a function that will emit the JSON envelope around the response body
// the JSON envelope will include the url and http code along with the response body
func jsonEnvelopeResponseFn(bodyResponseFn emitResponseFn, context *execcontext.Context) emitResponseWithContextFn {
func jsonEnvelopeResponseFn(bodyResponseFn emitResponseFn, responseBody config.ResponseBodyType) emitResponseWithContextFn {
return func(responseWithContext *ResponseWithContext, out io.Writer) (bytesWritten int64, err error) {
var bodyBytesWritten int64
var contextBytesWritten int64
Expand All @@ -135,7 +141,7 @@ func jsonEnvelopeResponseFn(bodyResponseFn emitResponseFn, context *execcontext.
}

// emit the body response
if context.ResponseBody == config.Discard || context.ResponseBody == config.Raw {
if responseBody == config.Discard || responseBody == config.Raw {
// no need to wrap either of these in quotes, Raw is assumed to be JSON
bodyBytesWritten, err = bodyResponseFn(response, out)
} else {
Expand Down Expand Up @@ -163,16 +169,19 @@ func jsonEnvelopeResponseFn(bodyResponseFn emitResponseFn, context *execcontext.
}
}

// Add requestContext to JSON if it is not nil
// Add requestContext to JSON if it is not nil/null
if requestContext != nil {
requestContextJson, err := json.Marshal(requestContext)
if err != nil {
return bytesWritten, err
}
contextBytesWritten, err = appendString(bytesWritten, out, fmt.Sprintf(", \"context\": %s", string(requestContextJson)))
bytesWritten += contextBytesWritten
if err != nil {
return bytesWritten, err
requestContextString := string(requestContextJson)
if requestContextString != "null" {
contextBytesWritten, err = appendString(bytesWritten, out, fmt.Sprintf(", \"context\": %s", string(requestContextJson)))
bytesWritten += contextBytesWritten
if err != nil {
return bytesWritten, err
}
}
}

Expand All @@ -193,8 +202,8 @@ func appendString(bytesPreviouslyWritten int64, out io.Writer, s string) (int64,
return bytesPreviouslyWritten + int64(appendedBytes), err
}

func determineEmitBodyResponseFn(context *execcontext.Context) emitResponseFn {
switch context.ResponseBody {
func determineEmitBodyResponseFn(responseBody config.ResponseBodyType) emitResponseFn {
switch responseBody {
case config.Raw:
return emitRawBody
case config.Sha256:
Expand All @@ -206,7 +215,7 @@ func determineEmitBodyResponseFn(context *execcontext.Context) emitResponseFn {
case config.Base64:
return emitBase64Body
default:
panic(fmt.Sprintf("unknown response body type %s", context.ResponseBody))
panic(fmt.Sprintf("unknown response body type %s", responseBody))
}
}

Expand Down
Loading

0 comments on commit 6420fb5

Please sign in to comment.