diff --git a/v2/client/client_test.go b/v2/client/client_test.go index e9fd21ea5..ad0c6bca9 100644 --- a/v2/client/client_test.go +++ b/v2/client/client_test.go @@ -12,10 +12,12 @@ import ( "fmt" "io/ioutil" "log" + "net" "net/http" "net/http/httptest" "net/url" "strings" + "sync" "testing" "time" @@ -357,6 +359,48 @@ func TestClientReceive(t *testing.T) { } } +func TestClientContext(t *testing.T) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("error creating listener: %v", err) + } + defer listener.Close() + type key string + + c, err := client.NewHTTP(cehttp.WithListener(listener), cehttp.WithMiddleware(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), key("inner"), "bar") + next.ServeHTTP(w, r.WithContext(ctx)) + }) + })) + + if err != nil { + t.Fatalf("error creating client: %v", err) + } + var wg sync.WaitGroup + wg.Add(1) + handler := func(ctx context.Context) { + if v := ctx.Value(key("outer")); v != "foo" { + t.Errorf("expected context to have outer value, got %v", v) + } + if v := ctx.Value(key("inner")); v != "bar" { + t.Errorf("expected context to have inner value, got %v", v) + } + wg.Done() + } + go func() { + c.StartReceiver(context.WithValue(context.Background(), key("outer"), "foo"), handler) + }() + + body := strings.NewReader(`{"data":{"msg":"hello","sq":"42"},"datacontenttype":"application/json","id":"AABBCCDDEE","source":"/unit/test/client","specversion":"0.3","time":%q,"type":"unit.test.client"}`) + resp, err := http.Post(fmt.Sprintf("http://%s", listener.Addr().String()), "application/cloudevents+json", body) + if err != nil { + t.Errorf("err sending request, response: %v, err: %v", resp, err) + } + + wg.Wait() +} + type requestValidation struct { Host string Headers http.Header diff --git a/v2/client/invoker.go b/v2/client/invoker.go index a3a83016c..403fb0f55 100644 --- a/v2/client/invoker.go +++ b/v2/client/invoker.go @@ -128,7 +128,7 @@ func (r *receiveInvoker) IsResponder() bool { func computeInboundContext(message binding.Message, fallback context.Context, inboundContextDecorators []func(context.Context, binding.Message) context.Context) context.Context { result := fallback if mctx, ok := message.(binding.MessageContext); ok { - result = mctx.Context() + result = cecontext.ValuesDelegating(mctx.Context(), fallback) } for _, f := range inboundContextDecorators { result = f(result, message) diff --git a/v2/context/delegating.go b/v2/context/delegating.go new file mode 100644 index 000000000..434a4da7a --- /dev/null +++ b/v2/context/delegating.go @@ -0,0 +1,25 @@ +package context + +import "context" + +type valuesDelegating struct { + context.Context + parent context.Context +} + +// ValuesDelegating wraps a child and parent context. It will perform Value() +// lookups first on the child, and then fall back to the child. All other calls +// go solely to the child context. +func ValuesDelegating(child, parent context.Context) context.Context { + return &valuesDelegating{ + Context: child, + parent: parent, + } +} + +func (c *valuesDelegating) Value(key interface{}) interface{} { + if val := c.Context.Value(key); val != nil { + return val + } + return c.parent.Value(key) +} diff --git a/v2/context/delegating_test.go b/v2/context/delegating_test.go new file mode 100644 index 000000000..62ad004f9 --- /dev/null +++ b/v2/context/delegating_test.go @@ -0,0 +1,78 @@ +package context + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestValuesDelegating(t *testing.T) { + type key string + tests := []struct { + name string + child context.Context + parent context.Context + assert func(*testing.T, context.Context) + }{ + { + name: "it delegates to child first", + child: context.WithValue(context.Background(), key("foo"), "foo"), + parent: context.WithValue(context.Background(), key("foo"), "bar"), + assert: func(t *testing.T, c context.Context) { + if v := c.Value(key("foo")); v != "foo" { + t.Errorf("expected child value, got %s", v) + } + }, + }, + { + name: "it delegates to parent if missing from child", + child: context.Background(), + parent: context.WithValue(context.Background(), key("foo"), "foo"), + assert: func(t *testing.T, c context.Context) { + if v := c.Value(key("foo")); v != "foo" { + t.Errorf("expected parent value, got %s", v) + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValuesDelegating(tt.child, tt.parent) + tt.assert(t, got) + }) + } +} +func TestValuesDelegatingDelegatesOtherwiseToChild(t *testing.T) { + parent, parentCancel := context.WithCancel(context.Background()) + child, childCancel := context.WithCancel(context.Background()) + derived := ValuesDelegating(child, parent) + + ch := make(chan string) + go func() { + <-derived.Done() + ch <- "derived" + }() + go func() { + <-child.Done() + ch <- "child" + }() + go func() { + <-parent.Done() + ch <- "parent" + }() + + parentCancel() + v1 := <-ch + if v1 != "parent" { + t.Errorf("cancelling parent should not cancel child or derived: %s", v1) + } + childCancel() + v2 := <-ch + v3 := <-ch + diff := cmp.Diff([]string{"derived", "child"}, []string{v2, v3}, cmpopts.SortSlices(func(a, b string) bool { return a < b })) + if diff != "" { + t.Errorf("unexpected (-want, +got) = %v", diff) + } +}