diff --git a/cmd/connectconformance/main.go b/cmd/connectconformance/main.go index 6303123b..bb4d9c49 100644 --- a/cmd/connectconformance/main.go +++ b/cmd/connectconformance/main.go @@ -47,6 +47,7 @@ const ( tlsKeyFlagName = "key" portFlagName = "port" bindFlagName = "bind" + traceFlagName = "trace" ) type flags struct { @@ -65,6 +66,7 @@ type flags struct { tlsKeyFile string port uint bind string + trace bool } func main() { @@ -157,6 +159,8 @@ func bind(cmd *cobra.Command, flags *flags) { "in client mode, the port number on which the reference server should listen (implies --max-servers=1)") cmd.Flags().StringVar(&flags.bind, bindFlagName, internal.DefaultHost, "in client mode, the bind address on which the reference server should listen (0.0.0.0 means listen on all interfaces)") + cmd.Flags().BoolVar(&flags.trace, traceFlagName, false, + "if true, full HTTP traces will be captured and shown alongside failing test cases") } func run(flags *flags, cobraFlags *pflag.FlagSet, command []string) { //nolint:gocyclo @@ -298,6 +302,7 @@ func run(flags *flags, cobraFlags *pflag.FlagSet, command []string) { //nolint:g TLSKeyFile: flags.tlsKeyFile, ServerPort: flags.port, ServerBind: flags.bind, + HTTPTrace: flags.trace, }, internal.NewPrinter(os.Stdout), internal.NewPrinter(os.Stderr), diff --git a/cmd/referenceclient/main.go b/cmd/referenceclient/main.go index 9871168f..e9f05d2c 100644 --- a/cmd/referenceclient/main.go +++ b/cmd/referenceclient/main.go @@ -23,7 +23,7 @@ import ( ) func main() { - err := referenceclient.Run(context.Background(), os.Args, os.Stdin, os.Stdout, os.Stderr) + err := referenceclient.Run(context.Background(), os.Args, os.Stdin, os.Stdout, os.Stderr, nil) if err != nil { log.Fatalf("an error occurred running the reference client: %s", err.Error()) } diff --git a/internal/app/connectconformance/connectconformance.go b/internal/app/connectconformance/connectconformance.go index f1759106..0b288b3d 100644 --- a/internal/app/connectconformance/connectconformance.go +++ b/internal/app/connectconformance/connectconformance.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "io" "os" "path" "sort" @@ -32,6 +33,7 @@ import ( "connectrpc.com/conformance/internal/app/referenceclient" "connectrpc.com/conformance/internal/app/referenceserver" conformancev1 "connectrpc.com/conformance/internal/gen/proto/go/connectrpc/conformance/v1" + "connectrpc.com/conformance/internal/tracer" "golang.org/x/sync/semaphore" "google.golang.org/protobuf/proto" ) @@ -54,6 +56,7 @@ type Flags struct { TLSKeyFile string ServerPort uint ServerBind string + HTTPTrace bool } func Run(flags *Flags, logPrinter internal.Printer, errPrinter internal.Printer) (bool, error) { @@ -221,6 +224,11 @@ func run( //nolint:gocyclo } } + var trace *tracer.Tracer + if flags.HTTPTrace { + trace = &tracer.Tracer{} + } + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -232,7 +240,9 @@ func run( //nolint:gocyclo start: runInProcess([]string{ "reference-client", "-p", strconv.Itoa(int(flags.Parallelism)), - }, referenceclient.Run), + }, func(ctx context.Context, args []string, inReader io.ReadCloser, outWriter, errWriter io.WriteCloser) error { + return referenceclient.Run(ctx, args, inReader, outWriter, errWriter, trace) + }), isReferenceImpl: true, }, { @@ -252,7 +262,7 @@ func run( //nolint:gocyclo } } - results := newResults(knownFailing, knownFlaky) + results := newResults(knownFailing, knownFlaky, trace) for _, clientInfo := range clients { clientProcess, err := runClient(ctx, clientInfo.start) @@ -272,7 +282,9 @@ func run( //nolint:gocyclo "-bind", flags.ServerBind, "-cert", flags.TLSCertFile, "-key", flags.TLSKeyFile, - }, referenceserver.RunInReferenceMode), + }, func(ctx context.Context, args []string, inReader io.ReadCloser, outWriter, errWriter io.WriteCloser) error { + return referenceserver.RunInReferenceMode(ctx, args, inReader, outWriter, errWriter, trace) + }), isReferenceImpl: true, }, { @@ -349,6 +361,7 @@ func run( //nolint:gocyclo errPrinter, results, clientProcess, + trace, ) }(ctx, clientInfo, serverInfo, svrInstance) } diff --git a/internal/app/connectconformance/results.go b/internal/app/connectconformance/results.go index f9a2017e..5cb4ccfb 100644 --- a/internal/app/connectconformance/results.go +++ b/internal/app/connectconformance/results.go @@ -16,6 +16,7 @@ package connectconformance import ( "bytes" + "context" "errors" "fmt" "reflect" @@ -23,9 +24,11 @@ import ( "strconv" "strings" "sync" + "time" "connectrpc.com/conformance/internal" conformancev1 "connectrpc.com/conformance/internal/gen/proto/go/connectrpc/conformance/v1" + "connectrpc.com/conformance/internal/tracer" "connectrpc.com/connect" "github.com/google/go-cmp/cmp" "google.golang.org/protobuf/proto" @@ -42,16 +45,21 @@ const timeoutCheckGracePeriodMillis = 500 type testResults struct { knownFailing *testTrie knownFlaky *testTrie + tracer *tracer.Tracer + + traceWaitGroup sync.WaitGroup mu sync.Mutex outcomes map[string]testOutcome + traces map[string]*tracer.Trace serverSideband map[string]string } -func newResults(knownFailing, knownFlaky *testTrie) *testResults { +func newResults(knownFailing, knownFlaky *testTrie, tracer *tracer.Tracer) *testResults { return &testResults{ knownFailing: knownFailing, knownFlaky: knownFlaky, + tracer: tracer, outcomes: map[string]testOutcome{}, serverSideband: map[string]string{}, } @@ -74,6 +82,36 @@ func (r *testResults) setOutcomeLocked(testCase string, setupError bool, err err knownFailing: r.knownFailing.match(strings.Split(testCase, "/")), knownFlaky: r.knownFlaky.match(strings.Split(testCase, "/")), } + r.fetchTrace(testCase) +} + +//nolint:contextcheck,nolintlint // intentionally using context.Background; nolintlint incorrectly complains about this +func (r *testResults) fetchTrace(testCase string) { + if r.tracer == nil { + return + } + r.traceWaitGroup.Add(1) + go func() { + defer r.traceWaitGroup.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + trace, err := r.tracer.Await(ctx, testCase) + r.tracer.Clear(testCase) + if err != nil { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + outcome := r.outcomes[testCase] + if outcome.actualFailure == nil || outcome.setupError || outcome.knownFlaky || outcome.knownFailing { + return + } + if r.traces == nil { + r.traces = map[string]*tracer.Trace{} + } + r.traces[testCase] = trace + }() } // failedToStart marks all the given test cases with the given setup error. @@ -186,6 +224,7 @@ func (r *testResults) processSidebandInfoLocked() { } func (r *testResults) report(printer internal.Printer) bool { + r.traceWaitGroup.Wait() // make sure all traces have been received r.mu.Lock() defer r.mu.Unlock() if len(r.serverSideband) > 0 { @@ -208,6 +247,12 @@ func (r *testResults) report(printer internal.Printer) bool { switch { case !expectError && outcome.actualFailure != nil: printer.Printf("FAILED: %s:\n%s", name, indent(outcome.actualFailure.Error())) + trace := r.traces[name] + if trace != nil { + printer.Printf("---- HTTP Trace ----") + trace.Print(printer) + printer.Printf("--------------------") + } failed++ case expectError && outcome.actualFailure == nil: printer.Printf("FAILED: %s was expected to fail but did not", name) diff --git a/internal/app/connectconformance/results_test.go b/internal/app/connectconformance/results_test.go index f09b1cf1..a59a2ac2 100644 --- a/internal/app/connectconformance/results_test.go +++ b/internal/app/connectconformance/results_test.go @@ -30,7 +30,7 @@ import ( func TestResults_SetOutcome(t *testing.T) { t.Parallel() - results := newResults(makeKnownFailing(), makeKnownFlaky()) + results := newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("foo/bar/1", false, nil) results.setOutcome("foo/bar/2", true, errors.New("fail")) results.setOutcome("foo/bar/3", false, errors.New("fail")) @@ -58,7 +58,7 @@ func TestResults_SetOutcome(t *testing.T) { func TestResults_FailedToStart(t *testing.T) { t.Parallel() - results := newResults(makeKnownFailing(), makeKnownFlaky()) + results := newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.failedToStart([]*conformancev1.TestCase{ {Request: &conformancev1.ClientCompatRequest{TestName: "foo/bar/1"}}, {Request: &conformancev1.ClientCompatRequest{TestName: "known-to-fail/1"}}, @@ -76,7 +76,7 @@ func TestResults_FailedToStart(t *testing.T) { func TestResults_FailRemaining(t *testing.T) { t.Parallel() - results := newResults(makeKnownFailing(), makeKnownFlaky()) + results := newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("foo/bar/1", false, nil) results.setOutcome("known-to-fail/1", false, errors.New("fail")) results.failRemaining([]*conformancev1.TestCase{ @@ -101,7 +101,7 @@ func TestResults_FailRemaining(t *testing.T) { func TestResults_Failed(t *testing.T) { t.Parallel() - results := newResults(makeKnownFailing(), makeKnownFlaky()) + results := newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.failed("foo/bar/1", &conformancev1.ClientErrorResult{Message: "fail"}) results.failed("known-to-fail/1", &conformancev1.ClientErrorResult{Message: "fail"}) @@ -116,7 +116,7 @@ func TestResults_Failed(t *testing.T) { func TestResults_Assert(t *testing.T) { t.Parallel() - results := newResults(makeKnownFailing(), makeKnownFlaky()) + results := newResults(makeKnownFailing(), makeKnownFlaky(), nil) payload1 := &conformancev1.ClientResponseResult{ Payloads: []*conformancev1.ConformancePayload{ {Data: []byte{0, 1, 2, 3, 4}}, @@ -688,7 +688,7 @@ func TestResults_Assert_ReportsAllErrors(t *testing.T) { testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() - results := newResults(&testTrie{}, &testTrie{}) + results := newResults(&testTrie{}, &testTrie{}, nil) expected := &conformancev1.ClientResponseResult{} err := protojson.Unmarshal(([]byte)(testCase.expected), expected) @@ -722,7 +722,7 @@ func TestResults_Assert_ReportsAllErrors(t *testing.T) { func TestResults_ServerSideband(t *testing.T) { t.Parallel() - results := newResults(makeKnownFailing(), makeKnownFlaky()) + results := newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("foo/bar/1", false, nil) results.setOutcome("foo/bar/2", false, errors.New("fail")) results.setOutcome("foo/bar/3", false, nil) @@ -745,7 +745,7 @@ func TestResults_ServerSideband(t *testing.T) { func TestResults_Report(t *testing.T) { t.Parallel() - results := newResults(makeKnownFailing(), makeKnownFlaky()) + results := newResults(makeKnownFailing(), makeKnownFlaky(), nil) logger := &linePrinter{} // No test cases? Report success. @@ -753,42 +753,42 @@ func TestResults_Report(t *testing.T) { require.True(t, success) // Only successful outcomes? Report success. - results = newResults(makeKnownFailing(), makeKnownFlaky()) + results = newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("foo/bar/1", false, nil) success = results.report(logger) require.True(t, success) // Unexpected failure? Report failure. - results = newResults(makeKnownFailing(), makeKnownFlaky()) + results = newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("foo/bar/1", false, errors.New("ruh roh")) success = results.report(logger) require.False(t, success) // Unexpected failure during setup? Report failure. - results = newResults(makeKnownFailing(), makeKnownFlaky()) + results = newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("foo/bar/1", true, errors.New("ruh roh")) success = results.report(logger) require.False(t, success) // Expected failure? Report success. - results = newResults(makeKnownFailing(), makeKnownFlaky()) + results = newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("known-to-fail/1", false, errors.New("ruh roh")) success = results.report(logger) require.True(t, success) // Setup error from expected failure? Report failure (setup errors never acceptable). - results = newResults(makeKnownFailing(), makeKnownFlaky()) + results = newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("known-to-fail/1", true, errors.New("ruh roh")) success = results.report(logger) require.False(t, success) // Flaky? Report success whether it passes or fails - results = newResults(makeKnownFailing(), makeKnownFlaky()) + results = newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("known-to-flake/1", false, nil) // succeeds success = results.report(logger) require.True(t, success) - results = newResults(makeKnownFailing(), makeKnownFlaky()) + results = newResults(makeKnownFailing(), makeKnownFlaky(), nil) results.setOutcome("known-to-flake/1", false, errors.New("ruh roh")) success = results.report(logger) require.True(t, success) diff --git a/internal/app/connectconformance/server_runner.go b/internal/app/connectconformance/server_runner.go index b31e1d4e..3f0265b9 100644 --- a/internal/app/connectconformance/server_runner.go +++ b/internal/app/connectconformance/server_runner.go @@ -26,6 +26,7 @@ import ( "connectrpc.com/conformance/internal" conformancev1 "connectrpc.com/conformance/internal/gen/proto/go/connectrpc/conformance/v1" + "connectrpc.com/conformance/internal/tracer" "google.golang.org/protobuf/proto" ) @@ -51,6 +52,7 @@ func runTestCasesForServer( errPrinter internal.Printer, results *testResults, client clientRunner, + tracer *tracer.Tracer, ) { expectations := make(map[string]*conformancev1.ClientResponseResult, len(testCases)) for _, testCase := range testCases { @@ -181,6 +183,7 @@ func runTestCasesForServer( } } + tracer.Init(req.TestName) wg.Add(1) err := client.sendRequest(req, func(name string, resp *conformancev1.ClientCompatResponse, err error) { defer wg.Done() diff --git a/internal/app/connectconformance/server_runner_test.go b/internal/app/connectconformance/server_runner_test.go index ea0198bd..beb59321 100644 --- a/internal/app/connectconformance/server_runner_test.go +++ b/internal/app/connectconformance/server_runner_test.go @@ -194,7 +194,7 @@ func TestRunTestCasesForServer(t *testing.T) { testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() - results := newResults(&testTrie{}, &testTrie{}) + results := newResults(&testTrie{}, &testTrie{}, nil) var procAddr atomic.Pointer[process] // populated when server process created var actualSvrRequest bytes.Buffer @@ -271,6 +271,7 @@ func TestRunTestCasesForServer(t *testing.T) { discardPrinter{}, results, &client, + nil, ) if testCase.svrFailsToStart { diff --git a/internal/app/connectconformance/testsuites/cancellation.yaml b/internal/app/connectconformance/testsuites/cancellation.yaml index e66cbff9..057d51fa 100644 --- a/internal/app/connectconformance/testsuites/cancellation.yaml +++ b/internal/app/connectconformance/testsuites/cancellation.yaml @@ -32,11 +32,11 @@ testCases: method: ClientStream streamType: STREAM_TYPE_CLIENT_STREAM cancel: - afterCloseSendMs: 50 + afterCloseSendMs: 5 requestMessages: - "@type": type.googleapis.com/connectrpc.conformance.v1.ClientStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseHeaders: - name: x-custom-header value: ["foo"] @@ -61,7 +61,7 @@ testCases: requestMessages: - "@type": type.googleapis.com/connectrpc.conformance.v1.ServerStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseHeaders: - name: x-custom-header value: ["foo"] @@ -85,7 +85,7 @@ testCases: requestMessages: - "@type": type.googleapis.com/connectrpc.conformance.v1.ServerStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseData: - "dGVzdCByZXNwb25zZQ==" - "dGVzdCByZXNwb25zZQ==" @@ -97,7 +97,7 @@ testCases: requests: - "@type": type.googleapis.com/connectrpc.conformance.v1.ServerStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseData: - "dGVzdCByZXNwb25zZQ==" - "dGVzdCByZXNwb25zZQ==" @@ -114,7 +114,7 @@ testCases: requestMessages: - "@type": type.googleapis.com/connectrpc.conformance.v1.BidiStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseData: - "dGVzdCByZXNwb25zZQ==" - "dGVzdCByZXNwb25zZQ==" @@ -135,7 +135,7 @@ testCases: requestMessages: - "@type": type.googleapis.com/connectrpc.conformance.v1.BidiStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseData: - "dGVzdCByZXNwb25zZQ==" - "dGVzdCByZXNwb25zZQ==" @@ -150,7 +150,7 @@ testCases: requests: - "@type": type.googleapis.com/connectrpc.conformance.v1.BidiStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseData: - "dGVzdCByZXNwb25zZQ==" - "dGVzdCByZXNwb25zZQ==" @@ -237,7 +237,7 @@ testCases: requestMessages: - "@type": type.googleapis.com/connectrpc.conformance.v1.BidiStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseData: - "dGVzdCByZXNwb25zZQ==" - "dGVzdCByZXNwb25zZQ==" @@ -257,7 +257,7 @@ testCases: requestMessages: - "@type": type.googleapis.com/connectrpc.conformance.v1.BidiStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseData: - "dGVzdCByZXNwb25zZQ==" - "dGVzdCByZXNwb25zZQ==" @@ -271,7 +271,7 @@ testCases: requests: - "@type": type.googleapis.com/connectrpc.conformance.v1.BidiStreamRequest responseDefinition: - responseDelayMs: 100 + responseDelayMs: 200 responseData: - "dGVzdCByZXNwb25zZQ==" - "dGVzdCByZXNwb25zZQ==" @@ -304,11 +304,11 @@ testCases: method: BidiStream streamType: STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM cancel: - afterCloseSendMs: 0 + afterCloseSendMs: 5 requestMessages: - "@type": type.googleapis.com/connectrpc.conformance.v1.BidiStreamRequest responseDefinition: - responseDelayMs: 250 + responseDelayMs: 200 responseData: - "dGVzdCByZXNwb25zZQ==" - "dGVzdCByZXNwb25zZQ==" diff --git a/internal/app/referenceclient/client.go b/internal/app/referenceclient/client.go index ab0bfcd3..0fc7fa45 100644 --- a/internal/app/referenceclient/client.go +++ b/internal/app/referenceclient/client.go @@ -36,6 +36,7 @@ import ( "connectrpc.com/conformance/internal/compression" v1 "connectrpc.com/conformance/internal/gen/proto/go/connectrpc/conformance/v1" "connectrpc.com/conformance/internal/gen/proto/go/connectrpc/conformance/v1/conformancev1connect" + "connectrpc.com/conformance/internal/tracer" "connectrpc.com/connect" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" @@ -47,7 +48,7 @@ import ( // is written to the 'out' writer, including any errors encountered during the actual run. Any error // returned from this function is indicative of an issue with the reader or writer and should not be related // to the actual run. -func Run(ctx context.Context, args []string, inReader io.ReadCloser, outWriter, _ io.WriteCloser) (retErr error) { +func Run(ctx context.Context, args []string, inReader io.ReadCloser, outWriter, _ io.WriteCloser, trace *tracer.Tracer) (retErr error) { flags := flag.NewFlagSet(args[0], flag.ContinueOnError) json := flags.Bool("json", false, "whether to use the JSON format for marshaling / unmarshaling messages") parallel := flags.Uint("p", uint(runtime.GOMAXPROCS(0))*4, "the number of parallel RPCs to issue") @@ -109,7 +110,7 @@ func Run(ctx context.Context, args []string, inReader io.ReadCloser, outWriter, defer wg.Done() defer sema.Release(1) - result, err := invoke(ctx, &req) + result, err := invoke(ctx, &req, trace) // Build the result for the out writer. resp := &v1.ClientCompatResponse{ @@ -146,7 +147,7 @@ func Run(ctx context.Context, args []string, inReader io.ReadCloser, outWriter, // returned from this function indicates a runtime/unexpected internal error and is not indicative of a // Connect error returned from calling an RPC. Any error (i.e. a Connect error) that _is_ returned from // the actual RPC invocation will be present in the returned ClientResponseResult. -func invoke(ctx context.Context, req *v1.ClientCompatRequest) (*v1.ClientResponseResult, error) { +func invoke(ctx context.Context, req *v1.ClientCompatRequest, trace *tracer.Tracer) (*v1.ClientResponseResult, error) { tlsConf, err := createTLSConfig(req) if err != nil { return nil, err @@ -215,6 +216,9 @@ func invoke(ctx context.Context, req *v1.ClientCompatRequest) (*v1.ClientRespons case v1.HTTPVersion_HTTP_VERSION_UNSPECIFIED: return nil, errors.New("an HTTP version must be specified") } + if trace != nil { + transport = tracer.TracingRoundTripper(transport, trace) + } if req.RawRequest != nil { transport = &rawRequestSender{transport: transport, rawRequest: req.RawRequest} diff --git a/internal/app/referenceserver/server.go b/internal/app/referenceserver/server.go index 292b8a2b..0e3206ca 100644 --- a/internal/app/referenceserver/server.go +++ b/internal/app/referenceserver/server.go @@ -22,6 +22,7 @@ import ( "flag" "fmt" "io" + "log" "net" "net/http" "os" @@ -34,6 +35,7 @@ import ( "connectrpc.com/conformance/internal/compression" v1 "connectrpc.com/conformance/internal/gen/proto/go/connectrpc/conformance/v1" "connectrpc.com/conformance/internal/gen/proto/go/connectrpc/conformance/v1/conformancev1connect" + "connectrpc.com/conformance/internal/tracer" connect "connectrpc.com/connect" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" @@ -44,17 +46,17 @@ import ( // Run runs the server according to server config read from the 'in' reader. func Run(ctx context.Context, args []string, inReader io.ReadCloser, outWriter, errWriter io.WriteCloser) error { - return run(ctx, false, args, inReader, outWriter, errWriter) + return run(ctx, false, args, inReader, outWriter, errWriter, nil) } // RunInReferenceMode is just like Run except that it performs additional checks // that only the conformance reference server runs. These checks do not work if // the server is run as a server under test, only when run as a reference server. -func RunInReferenceMode(ctx context.Context, args []string, inReader io.ReadCloser, outWriter, errWriter io.WriteCloser) error { - return run(ctx, true, args, inReader, outWriter, errWriter) +func RunInReferenceMode(ctx context.Context, args []string, inReader io.ReadCloser, outWriter, errWriter io.WriteCloser, tracer *tracer.Tracer) error { + return run(ctx, true, args, inReader, outWriter, errWriter, tracer) } -func run(ctx context.Context, referenceMode bool, args []string, inReader io.ReadCloser, outWriter, errWriter io.WriteCloser) error { +func run(ctx context.Context, referenceMode bool, args []string, inReader io.ReadCloser, outWriter, errWriter io.WriteCloser, tracer *tracer.Tracer) error { flags := flag.NewFlagSet(args[0], flag.ContinueOnError) json := flags.Bool("json", false, "whether to use the JSON format for marshaling / unmarshaling messages") host := flags.String("bind", internal.DefaultHost, "the bind address for the conformance server") @@ -87,7 +89,7 @@ func run(ctx context.Context, referenceMode bool, args []string, inReader io.Rea // Create an HTTP server based on the request errPrinter := internal.NewPrinter(errWriter) - server, certBytes, err := createServer(req, net.JoinHostPort(*host, strconv.Itoa(*port)), *tlsCert, *tlsKey, referenceMode, errPrinter) + server, certBytes, err := createServer(req, net.JoinHostPort(*host, strconv.Itoa(*port)), *tlsCert, *tlsKey, referenceMode, errPrinter, tracer) if err != nil { return err } @@ -163,7 +165,7 @@ func (s *stdHTTPServer) Addr() string { } // Creates an HTTP server using the provided ServerCompatRequest. -func createServer(req *v1.ServerCompatRequest, listenAddr, tlsCertFile, tlsKeyFile string, referenceMode bool, errPrinter internal.Printer) (httpServer, []byte, error) { +func createServer(req *v1.ServerCompatRequest, listenAddr, tlsCertFile, tlsKeyFile string, referenceMode bool, errPrinter internal.Printer, trace *tracer.Tracer) (httpServer, []byte, error) { mux := http.NewServeMux() interceptors := []connect.Interceptor{serverNameHandlerInterceptor{}} if referenceMode { @@ -199,6 +201,9 @@ func createServer(req *v1.ServerCompatRequest, listenAddr, tlsCertFile, tlsKeyFi handler = referenceServerChecks(handler, errPrinter) handler = rawResponder(handler, errPrinter) } + if trace != nil { + handler = tracer.TracingHandler(handler, trace) + } // The server needs a lenient cors setup so that it can handle testing // browser clients. handler = cors.New(cors.Options{ @@ -283,6 +288,7 @@ func newH1Server(handler http.Handler, listenAddr string, tlsConf *tls.Config) ( Handler: handler, TLSConfig: tlsConf, ReadHeaderTimeout: 5 * time.Second, + ErrorLog: nopLogger(), } lis, err := net.Listen("tcp", listenAddr) if err != nil { @@ -301,6 +307,7 @@ func newH2Server(handler http.Handler, listenAddr string, tlsConf *tls.Config) ( Handler: handler, TLSConfig: tlsConf, ReadHeaderTimeout: 5 * time.Second, + ErrorLog: nopLogger(), } lis, err := net.Listen("tcp", listenAddr) if err != nil { @@ -346,3 +353,9 @@ func (s *http3Server) GracefulShutdown(_ time.Duration) error { func (s *http3Server) Addr() string { return s.lis.Addr().String() } + +//nolint:forbidigo // must refer to log package in order to suppress it in net/http server +func nopLogger() *log.Logger { + // TODO: enable logging via -v option or env variable? + return log.New(io.Discard, "", 0) +} diff --git a/internal/tracer/builder.go b/internal/tracer/builder.go new file mode 100644 index 00000000..1a675f92 --- /dev/null +++ b/internal/tracer/builder.go @@ -0,0 +1,91 @@ +// Copyright 2023-2024 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tracer + +import ( + "net/http" + "sync" + "time" +) + +// builder accumulates events to build a trace. +type builder struct { + collector Collector + start time.Time + + mu sync.Mutex + trace Trace + reqCount, respCount int +} + +// newBuilder creates a new builder for the given HTTP operation. The +// returned builder will already have a RequestStart event, based on +// the given request, so callers should NOT explicitly call builder.add +// to add such an event. +func newBuilder(req *http.Request, collector Collector) *builder { + testName := req.Header.Get("x-test-case-name") + return &builder{ + collector: collector, + start: time.Now(), + trace: Trace{ + TestName: testName, + Events: []Event{&RequestStart{Request: req}}, + }, + } +} + +// add adds the given event to the trace being built. +func (b *builder) add(event Event) { + b.mu.Lock() + defer b.mu.Unlock() + if b.trace.TestName == "" { + return + } + switch event := event.(type) { + case *ResponseStart: + b.trace.Response = event.Response + case *ResponseError: + b.trace.Err = event.Err + case *RequestBodyEnd: + if b.trace.Err != nil { + b.trace.Err = event.Err + } + case *ResponseBodyEnd: + if b.trace.Err != nil { + b.trace.Err = event.Err + } + case *RequestBodyData: + event.MessageIndex = b.reqCount + b.reqCount++ + case *ResponseBodyData: + event.MessageIndex = b.respCount + b.respCount++ + } + event.setEventOffset(time.Since(b.start)) + b.trace.Events = append(b.trace.Events, event) +} + +// build builds the trace and provides the data to the given Tracer. +func (b *builder) build() { + b.mu.Lock() + trace := b.trace + b.trace = Trace{} // reset; subsequent calls to add or build ignored + b.mu.Unlock() + + if trace.TestName == "" { + return + } + b.collector.Complete(trace) +} diff --git a/internal/tracer/middleware.go b/internal/tracer/middleware.go new file mode 100644 index 00000000..6b7779c3 --- /dev/null +++ b/internal/tracer/middleware.go @@ -0,0 +1,185 @@ +// Copyright 2023-2024 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tracer + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strconv" + "strings" +) + +// TracingRoundTripper applies tracing to the given transport. The returned +// round tripper will record traces of all operations to the given tracer. +func TracingRoundTripper(transport http.RoundTripper, collector Collector) http.RoundTripper { + return roundTripperFunc(func(req *http.Request) (*http.Response, error) { + builder := newBuilder(req, collector) + req = req.Clone(req.Context()) + req.Body = newReader(req.Header, req.Body, true, builder) + resp, err := transport.RoundTrip(req) + if err != nil { + builder.add(&ResponseError{Err: err}) + builder.build() + return nil, err + } + builder.add(&ResponseStart{Response: resp}) + respClone := *resp + respClone.Body = newReader(resp.Header, resp.Body, false, builder) + return &respClone, nil + }) +} + +// TracingHandler applies tracing middleware to the given handler. The returned +// handler will record traces of all operations to the given tracer. +func TracingHandler(handler http.Handler, collector Collector) http.Handler { + return http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { + builder := newBuilder(req, collector) + req = req.Clone(req.Context()) + req.Body = newReader(req.Header, req.Body, true, builder) + traceWriter := &tracingResponseWriter{ + respWriter: respWriter, + req: req, + builder: builder, + } + + handler.ServeHTTP( + traceWriter, + req, + ) + + traceWriter.tryFinish(nil) + }) +} + +type tracingResponseWriter struct { + respWriter http.ResponseWriter + req *http.Request + builder *builder + started bool + resp *http.Response + finished bool + + dataTracer dataTracer +} + +func (t *tracingResponseWriter) Unwrap() http.ResponseWriter { + return t.respWriter +} + +func (t *tracingResponseWriter) Header() http.Header { + return t.respWriter.Header() +} + +func (t *tracingResponseWriter) Write(data []byte) (int, error) { + if !t.started { + t.WriteHeader(http.StatusOK) + } + n, err := t.respWriter.Write(data) + t.dataTracer.trace(data[:n]) + if err != nil { + t.tryFinish(err) + } + return n, err +} + +func (t *tracingResponseWriter) WriteHeader(statusCode int) { + if t.started { + return + } + t.started = true + t.respWriter.WriteHeader(statusCode) + isStreamProtocol, decompressor := propertiesFromHeaders(t.Header()) + t.dataTracer = dataTracer{ + isRequest: false, + isStreamProtocol: isStreamProtocol, + decompressor: decompressor, + builder: t.builder, + } + contentLenStr := t.Header().Get("Content-Length") + contentLen := int64(-1) + if contentLenStr != "" { + if intVal, err := strconv.ParseInt(contentLenStr, 10, 64); err == nil { + contentLen = intVal + } + } + t.resp = &http.Response{ + Header: t.respWriter.Header(), + Body: io.NopCloser(bytes.NewBuffer(nil)), // empty body + Status: fmt.Sprintf("%d %s", statusCode, http.StatusText(statusCode)), + StatusCode: statusCode, + Proto: t.req.Proto, + ProtoMajor: t.req.ProtoMajor, + ProtoMinor: t.req.ProtoMinor, + ContentLength: contentLen, + TLS: t.req.TLS, + Trailer: http.Header{}, + } + for _, trailerNames := range t.Header().Values("Trailer") { + for _, trailerName := range strings.Split(trailerNames, ",") { + trailerName = strings.TrimSpace(trailerName) + if trailerName == "" { + continue + } + t.resp.Trailer[trailerName] = nil + } + } + t.builder.add(&ResponseStart{Response: t.resp}) +} + +func (t *tracingResponseWriter) Flush() { + flusher, ok := t.respWriter.(http.Flusher) + if ok { + flusher.Flush() + } +} + +func (t *tracingResponseWriter) tryFinish(err error) { + if t.finished { + return // already finished + } + if !t.started { + t.WriteHeader(http.StatusOK) + } + + t.finished = true + t.dataTracer.emitUnfinished() + t.builder.add(&ResponseBodyEnd{Err: err}) + t.setTrailers() + t.builder.build() +} + +func (t *tracingResponseWriter) setTrailers() { + for trailerName := range t.resp.Trailer { + t.resp.Trailer[trailerName] = t.resp.Header[trailerName] + } + for key, vals := range t.resp.Header { + trailerKey := strings.TrimPrefix(key, http.TrailerPrefix) + if trailerKey == key { + // no prefix trimmed, so not a trailer + continue + } + existing := t.resp.Trailer[trailerKey] + t.resp.Trailer[trailerKey] = append(existing, vals...) + delete(t.resp.Header, key) + } +} + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (r roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return r(req) +} diff --git a/internal/tracer/reader.go b/internal/tracer/reader.go new file mode 100644 index 00000000..a666eafb --- /dev/null +++ b/internal/tracer/reader.go @@ -0,0 +1,314 @@ +// Copyright 2023-2024 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tracer + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync/atomic" + + "connectrpc.com/conformance/internal/compression" + conformancev1 "connectrpc.com/conformance/internal/gen/proto/go/connectrpc/conformance/v1" + "connectrpc.com/connect" +) + +const prefixLen = 5 + +type tracingReader struct { + reader io.ReadCloser + builder *builder + isRequest bool + closed atomic.Bool + + dataTracer dataTracer +} + +func newReader(headers http.Header, reader io.ReadCloser, isRequest bool, builder *builder) io.ReadCloser { + isStream, decompressor := propertiesFromHeaders(headers) + return &tracingReader{ + reader: reader, + isRequest: isRequest, + builder: builder, + dataTracer: dataTracer{ + isRequest: isRequest, + isStreamProtocol: isStream, + decompressor: decompressor, + builder: builder, + }, + } +} + +func (t *tracingReader) Read(data []byte) (n int, err error) { + n, err = t.reader.Read(data) + t.dataTracer.trace(data[:n]) + if err != nil { + if errors.Is(err, io.EOF) { + t.tryFinish(nil) + } else { + t.tryFinish(err) + } + } + return n, err +} + +func (t *tracingReader) Close() error { + err := t.reader.Close() + if err != nil { + t.tryFinish(fmt.Errorf("close: %w", err)) + } else { + t.tryFinish(errors.New("closed before fully consumed")) + } + return err +} + +func (t *tracingReader) tryFinish(err error) { + if !t.closed.CompareAndSwap(false, true) { + return // already finished + } + + t.dataTracer.emitUnfinished() + + if t.isRequest { + t.builder.add(&RequestBodyEnd{Err: err}) + return + } + + // On the response side, when the body reaches the end, whole thing is done. + t.builder.add(&ResponseBodyEnd{Err: err}) + t.builder.build() +} + +// dataTracer is responsible for translating bytes read/written into trace events. +type dataTracer struct { + isRequest bool + isStreamProtocol bool + decompressor connect.Decompressor + builder *builder + + prefix []byte + env *Envelope + expecting uint32 + actual uint64 + endStream *bytes.Buffer +} + +func (d *dataTracer) trace(data []byte) { + if !d.isStreamProtocol { + d.actual += uint64(len(data)) + return + } + for { + if len(data) == 0 { + return + } + + if d.expecting == 0 { + // still reading envelope prefix + n, done := d.tracePrefix(data) + if !done { + // need to read more data to finish prefix + return + } + data = data[n:] + continue + } + + n, done := d.traceMessage(data) + if !done { + // need to read more data to finish message + return + } + data = data[n:] + } +} + +func (d *dataTracer) tracePrefix(data []byte) (int, bool) { + need := prefixLen - len(d.prefix) + if len(data) < need { + // envelope still not complete... + d.prefix = append(d.prefix, data...) + return need, false + } + + d.prefix = append(d.prefix, data[:need]...) + d.env = &Envelope{ + Len: binary.BigEndian.Uint32(d.prefix[1:]), + Flags: d.prefix[0], + } + d.expecting = d.env.Len + d.prefix = d.prefix[:0] + if d.expecting == 0 { + // If we're not expecting any more data for this message, go + // ahead and emit event. + if d.isRequest { + d.builder.add(&RequestBodyData{ + Envelope: d.env, + Len: 0, + }) + } else { + d.builder.add(&ResponseBodyData{ + Envelope: d.env, + Len: 0, + }) + } + d.env = nil + } else if !d.isRequest && (d.env.Flags&0x82) != 0 { + // This is a response end-stream message. Capture the contents. + d.endStream = bytes.NewBuffer(make([]byte, 0, d.env.Len)) + } + return need, true +} + +func (d *dataTracer) traceMessage(data []byte) (int, bool) { + need := int(d.expecting - uint32(d.actual)) + if len(data) < need { + // message still not complete... + d.actual += uint64(len(data)) + if d.endStream != nil { + _, _ = d.endStream.Write(data) + } + return need, false + } + + if d.isRequest { + d.builder.add(&RequestBodyData{ + Envelope: d.env, + Len: uint64(d.expecting), + }) + } else { + d.builder.add(&ResponseBodyData{ + Envelope: d.env, + Len: uint64(d.expecting), + }) + } + if d.endStream != nil { //nolint:nestif + _, _ = d.endStream.Write(data[:need]) + var content string + if d.decompressor == nil { + content = d.endStream.String() + } else { + var uncompressed bytes.Buffer + if err := d.decompressor.Reset(d.endStream); err == nil { + _, err := uncompressed.ReadFrom(d.decompressor) + if err == nil { + content = uncompressed.String() + } + } + } + if content != "" { + d.builder.add(&ResponseBodyEndStream{ + Content: content, + }) + } + d.endStream = nil + } + d.env = nil + d.expecting = 0 + d.actual = 0 + return need, true +} + +func (d *dataTracer) emitUnfinished() { + var unfinished uint64 + if d.expecting == 0 && len(d.prefix) > 0 { + unfinished = uint64(len(d.prefix)) + } else { + unfinished = d.actual + } + + if unfinished > 0 { + if d.isRequest { + d.builder.add(&RequestBodyData{ + Envelope: d.env, + Len: unfinished, + }) + } else { + d.builder.add(&ResponseBodyData{ + Envelope: d.env, + Len: unfinished, + }) + } + } + + d.endStream = nil // we didn't finish reading end-stream message; discard what we got + d.env = nil + d.expecting = 0 + d.actual = 0 + d.prefix = d.prefix[:0] +} + +// brokenDecompressor is a no-op implementation that treats all compressed +// messages as if they were empty. +type brokenDecompressor struct{} + +func (brokenDecompressor) Read([]byte) (n int, err error) { + return 0, io.EOF +} + +func (brokenDecompressor) Close() error { + return nil +} + +func (brokenDecompressor) Reset(io.Reader) error { + return nil +} + +func propertiesFromHeaders(headers http.Header) (isStream bool, decomp connect.Decompressor) { + contentType := strings.ToLower(headers.Get("Content-Type")) + if headers.Get("Content-Encoding") != "" { + // full body is encoded, so don't bother trying to parse stream + return false, brokenDecompressor{} + } + switch { + case strings.HasPrefix(contentType, "application/connect"): + return true, getDecompressor(headers.Get("Connect-Content-Encoding")) + case strings.HasPrefix(contentType, "application/grpc"): + return true, getDecompressor(headers.Get("Grpc-Encoding")) + default: + // We should only need a decompressor for streams (to decompress the end-stream message) + // So for non-stream protocols, this no-op decompressor should suffice. + return false, brokenDecompressor{} + } +} + +func getDecompressor(encoding string) connect.Decompressor { + var comp conformancev1.Compression + switch strings.ToLower(encoding) { + case "", "identity": + comp = conformancev1.Compression_COMPRESSION_IDENTITY + case "gzip": + comp = conformancev1.Compression_COMPRESSION_GZIP + case "br": + comp = conformancev1.Compression_COMPRESSION_BR + case "zstd": + comp = conformancev1.Compression_COMPRESSION_ZSTD + case "deflate": + comp = conformancev1.Compression_COMPRESSION_DEFLATE + case "snappy": + comp = conformancev1.Compression_COMPRESSION_SNAPPY + default: + return brokenDecompressor{} + } + decomp, err := compression.GetDecompressor(comp) + if err != nil { + return brokenDecompressor{} + } + return decomp +} diff --git a/internal/tracer/tracer.go b/internal/tracer/tracer.go new file mode 100644 index 00000000..3cee78e0 --- /dev/null +++ b/internal/tracer/tracer.go @@ -0,0 +1,376 @@ +// Copyright 2023-2024 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tracer + +import ( + "context" + "fmt" + "net/http" + "sort" + "strings" + "sync" + "time" + + "connectrpc.com/conformance/internal" +) + +const ( + requestPrefix = " request>" + responsePrefix = "response<" +) + +// Tracer stores traces as they are produced and makes them available to a consumer. +// Each operation, identified by a test name, must first be initialized by the consumer +// via Init. The producer then populates the information for that operation via Complete. +// The consumer can then use Await to retrieve the trace (which may be produced +// asynchronously) and should finally use Clear, to free up resources associated with +// the operation. (If Clear is never called, the Tracer will use more and more memory, +// but limited by the amount to store all traces for every operation traced.) +type Tracer struct { + mu sync.Mutex + traces map[string]*traceResult +} + +// Init initializes the tracer to accept data for a trace for the given test name. +// This must be called before Clear, Complete, or Await for the same name. +func (t *Tracer) Init(testName string) { + if t == nil { + return + } + var result traceResult + result.done = make(chan struct{}) + t.mu.Lock() + defer t.mu.Unlock() + if t.traces == nil { + t.traces = map[string]*traceResult{} + } + t.traces[testName] = &result +} + +// Clear clears the data for the given test name. This frees up resources so +// that the tracer doesn't use more memory than necessary. +func (t *Tracer) Clear(testName string) { + if t == nil { + return + } + t.mu.Lock() + defer t.mu.Unlock() + delete(t.traces, testName) +} + +// Complete marks a test as complete with the given trace data. If Clear +// has already been called or Init was never called, this does nothing. +func (t *Tracer) Complete(trace Trace) { + if t == nil { + return + } + t.mu.Lock() + defer t.mu.Unlock() + result := t.traces[trace.TestName] + if result == nil || result.done == nil { + return + } + done := result.done + result.trace = trace + result.done = nil + close(done) +} + +// Await waits for the given test to complete and for its trace data to +// become available. It returns a context error if the given context is +// cancelled or its deadline is reached before completion. It also returns +// an error if Clear has alreadu been called for the test or if Init was +// never called. +func (t *Tracer) Await(ctx context.Context, testName string) (*Trace, error) { + if t == nil { + return nil, fmt.Errorf("%s: tracing not enabled", testName) + } + t.mu.Lock() + result := t.traces[testName] + var done chan struct{} + if result != nil { + done = result.done + } + t.mu.Unlock() + if result == nil { + return nil, fmt.Errorf("%s: trace already cleared", testName) + } + if done == nil { + return &result.trace, nil + } + select { + case <-done: + return &result.trace, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Trace represents the sequence of activity for a single HTTP operation. +type Trace struct { + TestName string + Request *http.Request + Response *http.Response + Err error + Events []Event +} + +func (t *Trace) Print(printer internal.Printer) { + for _, event := range t.Events { + event.print(printer) + } + if t.Response != nil && len(t.Response.Trailer) > 0 { + printer.Printf(responsePrefix) + printHeaders(responsePrefix, t.Response.Trailer, printer) + } +} + +// Collector is a consumer of traces. This is usually an +// instance of *Tracer, but is an interface so that the implementation +// can vary, even allowing decorating or intercepting the method on +// *Tracer. +type Collector interface { + // Complete accepts a trace once it is completed. + Complete(Trace) +} + +var _ Collector = (*Tracer)(nil) + +// Event is a single item in a sequence of activity for an HTTP operation. +type Event interface { + setEventOffset(time.Duration) + print(internal.Printer) +} + +// Envelope represents the metadata about an enveloped message in an +// RPC stream. Streaming protocols prefix each message with this +// metadata. +type Envelope struct { + Flags byte + Len uint32 +} + +// RequestStart is an event that represents when the request starts. This +// is recorded when the client sends the request or when the server +// receives it. This is always the first event for an HTTP operation. +type RequestStart struct { + Request *http.Request + + eventOffset +} + +func (r *RequestStart) print(printer internal.Printer) { + urlClone := *r.Request.URL + if urlClone.Host == "" { + urlClone.Host = "..." + } + if r.Request.TLS != nil { + urlClone.Scheme = "https" + } else { + urlClone.Scheme = "http" + } + printer.Printf("%s %9.3fms %s %s %s", requestPrefix, r.offsetMillis(), r.Request.Method, urlClone.String(), r.Request.Proto) + printHeaders(requestPrefix, r.Request.Header, printer) + printer.Printf(requestPrefix) +} + +// RequestBodyData represents some data written to or read from the +// request body. These operations are "chunked" so that a single event +// represents a full message (or incomplete, partial message if a full +// message is not written or read). +type RequestBodyData struct { + // For streaming protocols, each message is + // enveloped and this should be non-nil. It may + // be nil in a streaming protocol if an envelope + // prefix was expected, but only a partial prefix + // could be written/read. In such a case, a + // RequestBodyData event is emitted that has no + // envelope and whose Len field indicates the + // number of bytes written/read of the incomplete + // prefix. + Envelope *Envelope + // Actual length of the data, which could differ + // from the length indicated in the envelope if + // the full message could not be written/read. + Len uint64 + + // Sequentially numbered index. The first message + // in the stream should have an index of zero, and + // then one, etc. + MessageIndex int + + eventOffset +} + +func (r *RequestBodyData) print(printer internal.Printer) { + printData(requestPrefix, r.offsetMillis(), r.MessageIndex, r.Envelope, r.Len, printer) +} + +// RequestBodyEnd represents the end of the request body being reached. +// The Err value is the error returned from the final read (on the server) +// or call to close the body (on the client). If the final read returned +// io.EOF, Err will be nil. So a non-nil Err means an abnormal conclusion +// to the operation. No more request events will appear after this. +type RequestBodyEnd struct { + Err error + + eventOffset +} + +func (r *RequestBodyEnd) print(printer internal.Printer) { + if r.Err != nil { + printer.Printf("%s %9.3fms body end (err=%v)", requestPrefix, r.offsetMillis(), r.Err) + } else { + printer.Printf("%s %9.3fms body end", requestPrefix, r.offsetMillis()) + } +} + +// ResponseStart is an event that represents when the response starts. This +// is recorded when the client receives the response headers or when the +// server sends them. This will precede all other response events. +type ResponseStart struct { + Response *http.Response + + eventOffset +} + +func (r *ResponseStart) print(printer internal.Printer) { + printer.Printf("%s %9.3fms %s", responsePrefix, r.offsetMillis(), r.Response.Status) + printHeaders(responsePrefix, r.Response.Header, printer) + printer.Printf(responsePrefix) +} + +// ResponseError is an event that represents when the response fails. This +// is recorded when the client receives an error instead of a response, like +// due to a network error. +type ResponseError struct { + Err error + + eventOffset +} + +func (r *ResponseError) print(printer internal.Printer) { + printer.Printf("%s %9.3fms failed: %v", responsePrefix, r.offsetMillis(), r.Err) +} + +// ResponseBodyData represents some data written to or read from the +// response body. These operations are "chunked" so that a single event +// represents a full message (or incomplete, partial message if a full +// message is not written or read). +type ResponseBodyData struct { + // For streaming protocols, each message is + // enveloped and this should be non-nil. It may + // be nil in a streaming protocol if an envelope + // prefix was expected, but only a partial prefix + // could be written/read. In such a case, a + // ResponseBodyData event is emitted that has no + // envelope and whose Len field indicates the + // number of bytes written/read of the incomplete + // prefix. + Envelope *Envelope + // Actual length of the data, which could differ + // from the length indicated in the envelope if + // the full message could not be written/read. + Len uint64 + + // Sequentially numbered index. The first message + // in the stream should have an index of zero, and + // then one, etc. + MessageIndex int + + eventOffset +} + +func (r *ResponseBodyData) print(printer internal.Printer) { + printData(responsePrefix, r.offsetMillis(), r.MessageIndex, r.Envelope, r.Len, printer) +} + +// ResponseBodyEndStream represents the an "end-stream" message in the +// Connect streaming and gRPC-Web protocols. It is a special representation +// of the operation's status and trailers that is part of the response +// body. +type ResponseBodyEndStream struct { + Content string + + eventOffset +} + +func (r *ResponseBodyEndStream) print(printer internal.Printer) { + lines := strings.Split(r.Content, "\n") + for _, line := range lines { + line = strings.Trim(line, "\r") + printer.Printf("%s %11s eos: %s", responsePrefix, "", line) + } +} + +// ResponseBodyEnd represents the end of the response body being reached. +// The Err value is the error returned from the final read (on the client) +// or final write (on the server). If the final read returned io.EOF, Err +// will be nil. So a non-nil Err means an abnormal conclusion to the +// operation. No more events will appear after this. +type ResponseBodyEnd struct { + Err error + + eventOffset +} + +func (r *ResponseBodyEnd) print(printer internal.Printer) { + if r.Err != nil { + printer.Printf("%s %9.3fms body end (err=%v)", responsePrefix, r.offsetMillis(), r.Err) + } else { + printer.Printf("%s %9.3fms body end", responsePrefix, r.offsetMillis()) + } +} + +type traceResult struct { + trace Trace + done chan struct{} +} + +type eventOffset struct { + Offset time.Duration +} + +func (o *eventOffset) setEventOffset(offset time.Duration) { + o.Offset = offset +} + +func (o *eventOffset) offsetMillis() float64 { + return o.Offset.Seconds() * 1000 +} + +func printHeaders(prefix string, headers http.Header, printer internal.Printer) { + keys := make([]string, 0, len(headers)) + for key := range headers { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + for _, val := range headers[key] { + printer.Printf("%s %11s %s: %s", prefix, "", key, val) + } + } +} + +func printData(prefix string, offsetMillis float64, index int, env *Envelope, length uint64, printer internal.Printer) { + if env != nil { + printer.Printf("%s %9.3fms message #%d: prefix: flags=%d, len=%d", prefix, offsetMillis, index+1, env.Flags, env.Len) + if length > 0 { + printer.Printf("%s %11s message #%d: data: %d/%d bytes", prefix, "", index+1, length, env.Len) + } + } else { + printer.Printf("%s %9.3fms message #%d: data: %d bytes", prefix, offsetMillis, index+1, length) + } +}