From edfef92b9d43a8f32422db511a5e5df233807749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20St=C3=A4bler?= Date: Wed, 27 Sep 2023 09:23:38 +0200 Subject: [PATCH] Refactor kncloudevents Dispatcher to a class instead of functions to include the token handler later --- cmd/broker/filter/main.go | 5 +- cmd/broker/ingress/main.go | 6 +- pkg/broker/filter/filter_handler.go | 18 +++-- pkg/broker/filter/filter_handler_test.go | 16 ++++- pkg/broker/ingress/ingress_handler.go | 16 +++-- pkg/broker/ingress/ingress_handler_test.go | 8 ++- pkg/channel/fanout/fanout_event_handler.go | 6 +- .../fanout/fanout_event_handler_test.go | 15 +++- .../multi_channel_fanout_event_handler.go | 5 +- ...multi_channel_fanout_event_handler_test.go | 34 ++++++++-- pkg/inmemorychannel/event_dispatcher_test.go | 21 ++++-- pkg/kncloudevents/event_dispatcher.go | 68 ++++++++++++++----- pkg/kncloudevents/event_dispatcher_test.go | 27 ++++++-- .../inmemorychannel/dispatcher/controller.go | 4 ++ .../dispatcher/inmemorychannel.go | 4 +- .../dispatcher/inmemorychannel_test.go | 21 ++++-- 16 files changed, 216 insertions(+), 58 deletions(-) diff --git a/cmd/broker/filter/main.go b/cmd/broker/filter/main.go index 13d882be7be..94709479205 100644 --- a/cmd/broker/filter/main.go +++ b/cmd/broker/filter/main.go @@ -38,6 +38,7 @@ import ( "knative.dev/eventing/cmd/broker" "knative.dev/eventing/pkg/apis/feature" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/broker/filter" triggerinformer "knative.dev/eventing/pkg/client/injection/informers/eventing/v1/trigger" "knative.dev/eventing/pkg/reconciler/names" @@ -64,6 +65,7 @@ func main() { metrics.MemStatsOrDie(ctx) cfg := injection.ParseAndGetRESTConfigOrDie() + ctx = injection.WithConfig(ctx, cfg) var env envConfig if err := envconfig.Process("", &env); err != nil { @@ -118,9 +120,10 @@ func main() { reporter := filter.NewStatsReporter(env.ContainerName, kmeta.ChildName(env.PodName, uuid.New().String())) + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) // We are running both the receiver (takes messages in from the Broker) and the dispatcher (send // the messages to the triggers' subscribers) in this binary. - handler, err := filter.NewHandler(logger, triggerinformer.Get(ctx), reporter, ctxFunc) + handler, err := filter.NewHandler(logger, oidcTokenProvider, triggerinformer.Get(ctx), reporter, ctxFunc) if err != nil { logger.Fatal("Error creating Handler", zap.Error(err)) } diff --git a/cmd/broker/ingress/main.go b/cmd/broker/ingress/main.go index bd8376dbe67..0eb838e115a 100644 --- a/cmd/broker/ingress/main.go +++ b/cmd/broker/ingress/main.go @@ -42,6 +42,7 @@ import ( cmdbroker "knative.dev/eventing/cmd/broker" "knative.dev/eventing/pkg/apis/feature" + "knative.dev/eventing/pkg/auth" broker "knative.dev/eventing/pkg/broker" "knative.dev/eventing/pkg/broker/ingress" eventingclient "knative.dev/eventing/pkg/client/injection/client" @@ -82,6 +83,7 @@ func main() { metrics.MemStatsOrDie(ctx) cfg := injection.ParseAndGetRESTConfigOrDie() + ctx = injection.WithConfig(ctx, cfg) var env envConfig if err := envconfig.Process("", &env); err != nil { @@ -150,7 +152,9 @@ func main() { reporter := ingress.NewStatsReporter(env.ContainerName, kmeta.ChildName(env.PodName, uuid.New().String())) - handler, err = ingress.NewHandler(logger, reporter, broker.TTLDefaulter(logger, int32(env.MaxTTL)), brokerInformer) + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + + handler, err = ingress.NewHandler(logger, reporter, broker.TTLDefaulter(logger, int32(env.MaxTTL)), brokerInformer, oidcTokenProvider) if err != nil { logger.Fatal("Error creating Handler", zap.Error(err)) } diff --git a/pkg/broker/filter/filter_handler.go b/pkg/broker/filter/filter_handler.go index 70e8eb9b2b0..b692811a11c 100644 --- a/pkg/broker/filter/filter_handler.go +++ b/pkg/broker/filter/filter_handler.go @@ -38,6 +38,7 @@ import ( "knative.dev/pkg/logging" "knative.dev/eventing/pkg/apis" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/utils" eventingv1 "knative.dev/eventing/pkg/apis/eventing/v1" @@ -69,6 +70,8 @@ type Handler struct { // reporter reports stats of status code and dispatch time reporter StatsReporter + eventDispatcher *kncloudevents.Dispatcher + triggerLister eventinglisters.TriggerLister logger *zap.Logger withContext func(ctx context.Context) context.Context @@ -76,7 +79,7 @@ type Handler struct { } // NewHandler creates a new Handler and its associated EventReceiver. -func NewHandler(logger *zap.Logger, triggerInformer v1.TriggerInformer, reporter StatsReporter, wc func(ctx context.Context) context.Context) (*Handler, error) { +func NewHandler(logger *zap.Logger, oidcTokenProvider *auth.OIDCTokenProvider, triggerInformer v1.TriggerInformer, reporter StatsReporter, wc func(ctx context.Context) context.Context) (*Handler, error) { kncloudevents.ConfigureConnectionArgs(&kncloudevents.ConnectionArgs{ MaxIdleConns: defaultMaxIdleConnections, MaxIdleConnsPerHost: defaultMaxIdleConnectionsPerHost, @@ -124,11 +127,12 @@ func NewHandler(logger *zap.Logger, triggerInformer v1.TriggerInformer, reporter }) return &Handler{ - reporter: reporter, - triggerLister: triggerInformer.Lister(), - logger: logger, - withContext: wc, - filtersMap: fm, + reporter: reporter, + eventDispatcher: kncloudevents.NewDispatcher(oidcTokenProvider), + triggerLister: triggerInformer.Lister(), + logger: logger, + withContext: wc, + filtersMap: fm, }, nil } @@ -238,7 +242,7 @@ func (h *Handler) send(ctx context.Context, writer http.ResponseWriter, headers additionalHeaders := headers.Clone() additionalHeaders.Set(apis.KnNamespaceHeader, t.GetNamespace()) - dispatchInfo, err := kncloudevents.SendEvent(ctx, *event, target, kncloudevents.WithHeader(additionalHeaders)) + dispatchInfo, err := h.eventDispatcher.SendEvent(ctx, *event, target, kncloudevents.WithHeader(additionalHeaders)) if err != nil { h.logger.Error("failed to send event", zap.Error(err)) diff --git a/pkg/broker/filter/filter_handler_test.go b/pkg/broker/filter/filter_handler_test.go index eecdf596581..d7a4c58c14b 100644 --- a/pkg/broker/filter/filter_handler_test.go +++ b/pkg/broker/filter/filter_handler_test.go @@ -39,6 +39,7 @@ import ( "k8s.io/apimachinery/pkg/types" eventingv1 "knative.dev/eventing/pkg/apis/eventing/v1" "knative.dev/eventing/pkg/apis/feature" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/broker" "knative.dev/eventing/pkg/eventfilter/subscriptionsapi" "knative.dev/pkg/apis" @@ -46,6 +47,9 @@ import ( reconcilertesting "knative.dev/pkg/reconciler/testing" triggerinformerfake "knative.dev/eventing/pkg/client/injection/informers/eventing/v1/trigger/fake" + + // Fake injection client + _ "knative.dev/pkg/client/injection/kube/client/fake" ) const ( @@ -425,6 +429,9 @@ func TestReceiver(t *testing.T) { s := httptest.NewServer(&fh) defer s.Close() + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.AddCaller())) + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + // Replace the SubscriberURI to point at our fake server. for _, trig := range tc.triggers { if trig.Status.SubscriberURI != nil && trig.Status.SubscriberURI.String() == toBeReplaced { @@ -439,7 +446,8 @@ func TestReceiver(t *testing.T) { } reporter := &mockReporter{} r, err := NewHandler( - zaptest.NewLogger(t, zaptest.WrapOptions(zap.AddCaller())), + logger, + oidcTokenProvider, triggerinformerfake.Get(ctx), reporter, func(ctx context.Context) context.Context { @@ -606,6 +614,9 @@ func TestReceiver_WithSubscriptionsAPI(t *testing.T) { filtersMap := subscriptionsapi.NewFiltersMap() + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.AddCaller())) + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + // Replace the SubscriberURI to point at our fake server. for _, trig := range tc.triggers { if trig.Status.SubscriberURI != nil && trig.Status.SubscriberURI.String() == toBeReplaced { @@ -621,7 +632,8 @@ func TestReceiver_WithSubscriptionsAPI(t *testing.T) { } reporter := &mockReporter{} r, err := NewHandler( - zaptest.NewLogger(t, zaptest.WrapOptions(zap.AddCaller())), + logger, + oidcTokenProvider, triggerinformerfake.Get(ctx), reporter, func(ctx context.Context) context.Context { diff --git a/pkg/broker/ingress/ingress_handler.go b/pkg/broker/ingress/ingress_handler.go index 7c4fbeb5795..ffdfde58529 100644 --- a/pkg/broker/ingress/ingress_handler.go +++ b/pkg/broker/ingress/ingress_handler.go @@ -39,6 +39,7 @@ import ( "knative.dev/eventing/pkg/apis/eventing" eventingv1 "knative.dev/eventing/pkg/apis/eventing/v1" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/broker" v1 "knative.dev/eventing/pkg/client/informers/externalversions/eventing/v1" eventinglisters "knative.dev/eventing/pkg/client/listers/eventing/v1" @@ -64,9 +65,11 @@ type Handler struct { EvenTypeHandler *eventtype.EventTypeAutoHandler Logger *zap.Logger + + eventDispatcher *kncloudevents.Dispatcher } -func NewHandler(logger *zap.Logger, reporter StatsReporter, defaulter client.EventDefaulter, brokerInformer v1.BrokerInformer) (*Handler, error) { +func NewHandler(logger *zap.Logger, reporter StatsReporter, defaulter client.EventDefaulter, brokerInformer v1.BrokerInformer, oidcTokenProvider *auth.OIDCTokenProvider) (*Handler, error) { connectionArgs := kncloudevents.ConnectionArgs{ MaxIdleConns: defaultMaxIdleConnections, MaxIdleConnsPerHost: defaultMaxIdleConnectionsPerHost, @@ -107,10 +110,11 @@ func NewHandler(logger *zap.Logger, reporter StatsReporter, defaulter client.Eve }) return &Handler{ - Defaulter: defaulter, - Reporter: reporter, - Logger: logger, - BrokerLister: brokerInformer.Lister(), + Defaulter: defaulter, + Reporter: reporter, + Logger: logger, + BrokerLister: brokerInformer.Lister(), + eventDispatcher: kncloudevents.NewDispatcher(oidcTokenProvider), }, nil } @@ -282,7 +286,7 @@ func (h *Handler) receive(ctx context.Context, headers http.Header, event *cloud return http.StatusBadRequest, kncloudevents.NoDuration } - dispatchInfo, err := kncloudevents.SendEvent(ctx, *event, *channelAddress, kncloudevents.WithHeader(headers)) + dispatchInfo, err := h.eventDispatcher.SendEvent(ctx, *event, *channelAddress, kncloudevents.WithHeader(headers)) if err != nil { h.Logger.Error("failed to dispatch event", zap.Error(err)) return http.StatusInternalServerError, kncloudevents.NoDuration diff --git a/pkg/broker/ingress/ingress_handler_test.go b/pkg/broker/ingress/ingress_handler_test.go index 9eb08c18790..e9b0b614d2f 100644 --- a/pkg/broker/ingress/ingress_handler_test.go +++ b/pkg/broker/ingress/ingress_handler_test.go @@ -36,10 +36,14 @@ import ( "knative.dev/eventing/pkg/apis/eventing" eventingv1 "knative.dev/eventing/pkg/apis/eventing/v1" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/broker" reconcilertesting "knative.dev/pkg/reconciler/testing" brokerinformerfake "knative.dev/eventing/pkg/client/injection/informers/eventing/v1/broker/fake" + + // Fake injection client + _ "knative.dev/pkg/client/injection/kube/client/fake" ) const ( @@ -281,7 +285,9 @@ func TestHandler_ServeHTTP(t *testing.T) { brokerinformerfake.Get(ctx).Informer().GetStore().Add(b) } - h, err := NewHandler(logger, &mockReporter{}, tc.defaulter, brokerinformerfake.Get(ctx)) + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + + h, err := NewHandler(logger, &mockReporter{}, tc.defaulter, brokerinformerfake.Get(ctx), oidcTokenProvider) if err != nil { t.Fatal("Unable to create receiver:", err) } diff --git a/pkg/channel/fanout/fanout_event_handler.go b/pkg/channel/fanout/fanout_event_handler.go index 2bc142ef9bc..6f82030e4ea 100644 --- a/pkg/channel/fanout/fanout_event_handler.go +++ b/pkg/channel/fanout/fanout_event_handler.go @@ -81,6 +81,8 @@ type FanoutEventHandler struct { receiver *channel.EventReceiver + eventDispatcher *kncloudevents.Dispatcher + // TODO: Plumb context through the receiver and dispatcher and use that to store the timeout, // rather than a member variable. timeout time.Duration @@ -100,6 +102,7 @@ func NewFanoutEventHandler( eventTypeHandler *eventtype.EventTypeAutoHandler, channelAddressable *duckv1.KReference, channelUID *types.UID, + eventDispatcher *kncloudevents.Dispatcher, receiverOpts ...channel.EventReceiverOptions, ) (*FanoutEventHandler, error) { handler := &FanoutEventHandler{ @@ -110,6 +113,7 @@ func NewFanoutEventHandler( eventTypeHandler: eventTypeHandler, channelAddressable: channelAddressable, channelUID: channelUID, + eventDispatcher: eventDispatcher, } handler.subscriptions = make([]Subscription, len(config.Subscriptions)) copy(handler.subscriptions, config.Subscriptions) @@ -313,7 +317,7 @@ func (f *FanoutEventHandler) dispatch(ctx context.Context, subs []Subscription, // makeFanoutRequest sends the request to exactly one subscription. It handles both the `call` and // the `sink` portions of the subscription. func (f *FanoutEventHandler) makeFanoutRequest(ctx context.Context, event event.Event, additionalHeaders nethttp.Header, sub Subscription) (*kncloudevents.DispatchInfo, error) { - return kncloudevents.SendEvent(ctx, event, sub.Subscriber, + return f.eventDispatcher.SendEvent(ctx, event, sub.Subscriber, kncloudevents.WithHeader(additionalHeaders), kncloudevents.WithReply(sub.Reply), kncloudevents.WithDeadLetterSink(sub.DeadLetter), diff --git a/pkg/channel/fanout/fanout_event_handler_test.go b/pkg/channel/fanout/fanout_event_handler_test.go index 902e9d4f609..44ab95ccf82 100644 --- a/pkg/channel/fanout/fanout_event_handler_test.go +++ b/pkg/channel/fanout/fanout_event_handler_test.go @@ -27,10 +27,13 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "k8s.io/client-go/rest" "k8s.io/utils/pointer" eventingduckv1 "knative.dev/eventing/pkg/apis/duck/v1" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/kncloudevents" duckv1 "knative.dev/pkg/apis/duck/v1" + "knative.dev/pkg/injection" cloudevents "github.com/cloudevents/sdk-go/v2" "github.com/cloudevents/sdk-go/v2/binding" @@ -43,6 +46,8 @@ import ( "knative.dev/pkg/apis" "knative.dev/eventing/pkg/channel" + fakekubeclient "knative.dev/pkg/client/injection/kube/client/fake" + _ "knative.dev/pkg/system/testing" ) // Domains used in subscriptions, which will be replaced by the real domains of the started HTTP @@ -318,6 +323,10 @@ func TestFanoutEventHandler_ServeHTTP(t *testing.T) { } func testFanoutEventHandler(t *testing.T, async bool, receiverFunc channel.EventReceiverFunc, timeout time.Duration, inSubs []Subscription, subscriberHandler func(http.ResponseWriter, *http.Request), subscriberReqs int, replierHandler func(http.ResponseWriter, *http.Request), replierReqs int, expectedStatus int) { + ctx := context.Background() + ctx, _ = fakekubeclient.With(ctx) + ctx = injection.WithConfig(ctx, &rest.Config{}) + var subscriberServerWg *sync.WaitGroup reporter := channel.NewStatsReporter("testcontainer", "testpod") if subscriberReqs != 0 { @@ -362,6 +371,9 @@ func testFanoutEventHandler(t *testing.T, async bool, receiverFunc channel.Event t.Fatal(err) } + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) + calledChan := make(chan bool, 1) recvOptionFunc := func(*channel.EventReceiver) error { calledChan <- true @@ -378,6 +390,7 @@ func testFanoutEventHandler(t *testing.T, async bool, receiverFunc channel.Event nil, nil, nil, + dispatcher, recvOptionFunc, ) <-calledChan @@ -403,7 +416,7 @@ func testFanoutEventHandler(t *testing.T, async bool, receiverFunc channel.Event reqCtx, _ := trace.StartSpan(context.TODO(), "bla") req := httptest.NewRequest(http.MethodPost, "http://channelname.channelnamespace/", nil).WithContext(reqCtx) - ctx := context.Background() + ctx = context.Background() if err := bindingshttp.WriteRequest(ctx, binding.ToMessage(&event), req); err != nil { t.Fatal("WriteRequest =", err) diff --git a/pkg/channel/multichannelfanout/multi_channel_fanout_event_handler.go b/pkg/channel/multichannelfanout/multi_channel_fanout_event_handler.go index a849bda4fdc..20bbeb7d3e9 100644 --- a/pkg/channel/multichannelfanout/multi_channel_fanout_event_handler.go +++ b/pkg/channel/multichannelfanout/multi_channel_fanout_event_handler.go @@ -35,6 +35,7 @@ import ( "knative.dev/eventing/pkg/channel" "knative.dev/eventing/pkg/channel/fanout" + "knative.dev/eventing/pkg/kncloudevents" ) type MultiChannelEventHandler interface { @@ -64,7 +65,7 @@ func NewEventHandler(_ context.Context, logger *zap.Logger) *EventHandler { // NewEventHandlerWithConfig creates a new Handler with the specified configuration. This is really meant for tests // where you want to apply a fully specified configuration for tests. Reconciler operates on single channel at a time. -func NewEventHandlerWithConfig(_ context.Context, logger *zap.Logger, conf Config, reporter channel.StatsReporter, recvOptions ...channel.EventReceiverOptions) (*EventHandler, error) { +func NewEventHandlerWithConfig(_ context.Context, logger *zap.Logger, conf Config, reporter channel.StatsReporter, eventDispatcher *kncloudevents.Dispatcher, recvOptions ...channel.EventReceiverOptions) (*EventHandler, error) { handlers := make(map[string]fanout.EventHandler, len(conf.ChannelConfigs)) for _, cc := range conf.ChannelConfigs { @@ -73,7 +74,7 @@ func NewEventHandlerWithConfig(_ context.Context, logger *zap.Logger, conf Confi if key == "" { continue } - handler, err := fanout.NewFanoutEventHandler(logger, cc.FanoutConfig, reporter, cc.EventTypeHandler, cc.ChannelAddressable, cc.ChannelUID, recvOptions...) + handler, err := fanout.NewFanoutEventHandler(logger, cc.FanoutConfig, reporter, cc.EventTypeHandler, cc.ChannelAddressable, cc.ChannelUID, eventDispatcher, recvOptions...) if err != nil { logger.Error("Failed creating new fanout handler.", zap.Error(err)) return nil, err diff --git a/pkg/channel/multichannelfanout/multi_channel_fanout_event_handler_test.go b/pkg/channel/multichannelfanout/multi_channel_fanout_event_handler_test.go index 49db902cac0..455d3fc9746 100644 --- a/pkg/channel/multichannelfanout/multi_channel_fanout_event_handler_test.go +++ b/pkg/channel/multichannelfanout/multi_channel_fanout_event_handler_test.go @@ -28,12 +28,19 @@ import ( "github.com/google/uuid" "go.uber.org/zap" "go.uber.org/zap/zaptest" + "k8s.io/client-go/rest" "knative.dev/pkg/apis" duckv1 "knative.dev/pkg/apis/duck/v1" + "knative.dev/pkg/injection" "knative.dev/pkg/ptr" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/channel" "knative.dev/eventing/pkg/channel/fanout" + "knative.dev/eventing/pkg/kncloudevents" + + fakekubeclient "knative.dev/pkg/client/injection/kube/client/fake" + _ "knative.dev/pkg/system/testing" ) var ( @@ -67,12 +74,20 @@ func TestNewEventHandlerWithConfig(t *testing.T) { reporter := channel.NewStatsReporter("testcontainer", "testpod") for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + ctx, _ = fakekubeclient.With(ctx) + ctx = injection.WithConfig(ctx, &rest.Config{}) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.AddCaller())) + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) _, err := NewEventHandlerWithConfig( context.TODO(), logger, tc.config, reporter, + dispatcher, ) if tc.createErr != "" { if err == nil { @@ -89,10 +104,17 @@ func TestNewEventHandlerWithConfig(t *testing.T) { } func TestNewEventHandler(t *testing.T) { + ctx := context.Background() + ctx, _ = fakekubeclient.With(ctx) + ctx = injection.WithConfig(ctx, &rest.Config{}) + handlerName := "handler.example.com" reporter := channel.NewStatsReporter("testcontainer", "testpod") logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.AddCaller())) + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) + handler := NewEventHandler(context.TODO(), logger) h := handler.GetChannelHandler(handlerName) if len(handler.handlers) != 0 { @@ -101,7 +123,7 @@ func TestNewEventHandler(t *testing.T) { if h != nil { t.Errorf("Found handler for %q but not expected", handlerName) } - f, err := fanout.NewFanoutEventHandler(logger, fanout.Config{}, reporter, nil, nil, nil) + f, err := fanout.NewFanoutEventHandler(logger, fanout.Config{}, reporter, nil, nil, nil, dispatcher) if err != nil { t.Error("Failed to create FanoutMessagHandler: ", err) } @@ -303,6 +325,10 @@ func TestServeHTTPEventHandler(t *testing.T) { } for n, tc := range testCases { t.Run(n, func(t *testing.T) { + ctx := context.Background() + ctx, _ = fakekubeclient.With(ctx) + ctx = injection.WithConfig(ctx, &rest.Config{}) + server := httptest.NewServer(fakeHandler(tc.respStatusCode)) defer server.Close() @@ -311,13 +337,13 @@ func TestServeHTTPEventHandler(t *testing.T) { logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.AddCaller())) reporter := channel.NewStatsReporter("testcontainer", "testpod") - h, err := NewEventHandlerWithConfig(context.TODO(), logger, tc.config, reporter, tc.recvOptions...) + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) + h, err := NewEventHandlerWithConfig(context.TODO(), logger, tc.config, reporter, dispatcher, tc.recvOptions...) if err != nil { t.Fatalf("Unexpected NewHandler error: '%v'", err) } - ctx := context.Background() - event := cloudevents.NewEvent(cloudevents.VersionV1) id := uuid.New().String() diff --git a/pkg/inmemorychannel/event_dispatcher_test.go b/pkg/inmemorychannel/event_dispatcher_test.go index 61369135fc1..904933a4f8a 100644 --- a/pkg/inmemorychannel/event_dispatcher_test.go +++ b/pkg/inmemorychannel/event_dispatcher_test.go @@ -33,17 +33,21 @@ import ( "github.com/cloudevents/sdk-go/v2/test" "github.com/pkg/errors" "go.uber.org/zap" + "k8s.io/client-go/rest" "knative.dev/pkg/apis" duckv1 "knative.dev/pkg/apis/duck/v1" + "knative.dev/pkg/injection" "knative.dev/pkg/tracing" tracingconfig "knative.dev/pkg/tracing/config" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/channel" "knative.dev/eventing/pkg/channel/fanout" "knative.dev/eventing/pkg/channel/multichannelfanout" "knative.dev/eventing/pkg/kncloudevents" + fakekubeclient "knative.dev/pkg/client/injection/kube/client/fake" logtesting "knative.dev/pkg/logging/testing" _ "knative.dev/pkg/system/testing" ) @@ -104,7 +108,13 @@ func TestDispatcher_close(t *testing.T) { // This test emulates a real dispatcher usage func TestDispatcher_dispatch(t *testing.T) { + ctx := context.Background() + ctx, _ = fakekubeclient.With(ctx) + ctx = injection.WithConfig(ctx, &rest.Config{}) + logger, err := zap.NewDevelopment(zap.AddStacktrace(zap.WarnLevel)) + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) reporter := channel.NewStatsReporter("testcontainer", "testpod") if err != nil { t.Fatal(err) @@ -230,7 +240,7 @@ func TestDispatcher_dispatch(t *testing.T) { }, } - sh, err := multichannelfanout.NewEventHandlerWithConfig(context.TODO(), logger, config, reporter) + sh, err := multichannelfanout.NewEventHandlerWithConfig(context.TODO(), logger, config, reporter, dispatcher) if err != nil { t.Fatal(err) } @@ -245,21 +255,22 @@ func TestDispatcher_dispatch(t *testing.T) { Logger: logger, } - dispatcher := NewEventDispatcher(dispatcherArgs) + inMemoryDispatcher := NewEventDispatcher(dispatcherArgs) serverCtx, cancel := context.WithCancel(context.Background()) defer cancel() // Start the dispatcher go func() { - if err := dispatcher.Start(serverCtx); err != nil { + if err := inMemoryDispatcher.Start(serverCtx); err != nil { t.Error(err) } }() - dispatcher.WaitReady() + inMemoryDispatcher.WaitReady() // Ok now everything should be ready to send the event - dispatchInfo, err := kncloudevents.SendEvent(context.TODO(), test.FullEvent(), *mustParseUrlToAddressable(t, channelAProxy.URL)) + d := kncloudevents.NewDispatcher(oidcTokenProvider) + dispatchInfo, err := d.SendEvent(context.TODO(), test.FullEvent(), *mustParseUrlToAddressable(t, channelAProxy.URL)) if err != nil { t.Fatal(err) } diff --git a/pkg/kncloudevents/event_dispatcher.go b/pkg/kncloudevents/event_dispatcher.go index bd55c5d66ef..442b6125215 100644 --- a/pkg/kncloudevents/event_dispatcher.go +++ b/pkg/kncloudevents/event_dispatcher.go @@ -32,6 +32,7 @@ import ( cehttp "github.com/cloudevents/sdk-go/v2/protocol/http" "github.com/hashicorp/go-retryablehttp" "go.opencensus.io/trace" + "k8s.io/apimachinery/pkg/types" "knative.dev/pkg/apis" duckv1 "knative.dev/pkg/apis/duck/v1" @@ -39,6 +40,7 @@ import ( "knative.dev/pkg/system" eventingapis "knative.dev/eventing/pkg/apis" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/utils" "knative.dev/eventing/pkg/broker" @@ -101,28 +103,50 @@ func WithTransformers(transformers ...binding.Transformer) SendOption { } } +func WithOIDCAuthentication(serviceAccount *types.NamespacedName) SendOption { + return func(sc *senderConfig) error { + if serviceAccount != nil && serviceAccount.Name != "" && serviceAccount.Namespace == "" { + sc.oidcServiceAccount = serviceAccount + return nil + } else { + return fmt.Errorf("service account name and namespace for OIDC authentication must not be empty") + } + } +} + type senderConfig struct { - reply *duckv1.Addressable - deadLetterSink *duckv1.Addressable - additionalHeaders http.Header - retryConfig *RetryConfig - transformers binding.Transformers + reply *duckv1.Addressable + deadLetterSink *duckv1.Addressable + additionalHeaders http.Header + retryConfig *RetryConfig + transformers binding.Transformers + oidcServiceAccount *types.NamespacedName +} + +type Dispatcher struct { + oidcTokenProvider *auth.OIDCTokenProvider +} + +func NewDispatcher(oidcTokenProvider *auth.OIDCTokenProvider) *Dispatcher { + return &Dispatcher{ + oidcTokenProvider: oidcTokenProvider, + } } // SendEvent sends the given event to the given destination. -func SendEvent(ctx context.Context, event event.Event, destination duckv1.Addressable, options ...SendOption) (*DispatchInfo, error) { +func (d *Dispatcher) SendEvent(ctx context.Context, event event.Event, destination duckv1.Addressable, options ...SendOption) (*DispatchInfo, error) { // clone the event since: // - we mutate the event and the callers might not expect this // - it might produce data races if the caller is trying to read the event in different go routines c := event.Clone() message := binding.ToMessage(&c) - return SendMessage(ctx, message, destination, options...) + return d.SendMessage(ctx, message, destination, options...) } // SendMessage sends the given message to the given destination. // SendMessage is kept for compatibility and SendEvent should be used whenever possible. -func SendMessage(ctx context.Context, message binding.Message, destination duckv1.Addressable, options ...SendOption) (*DispatchInfo, error) { +func (d *Dispatcher) SendMessage(ctx context.Context, message binding.Message, destination duckv1.Addressable, options ...SendOption) (*DispatchInfo, error) { config := &senderConfig{ additionalHeaders: make(http.Header), } @@ -134,10 +158,10 @@ func SendMessage(ctx context.Context, message binding.Message, destination duckv } } - return send(ctx, message, destination, config) + return d.send(ctx, message, destination, config) } -func send(ctx context.Context, message binding.Message, destination duckv1.Addressable, config *senderConfig) (*DispatchInfo, error) { +func (d *Dispatcher) send(ctx context.Context, message binding.Message, destination duckv1.Addressable, config *senderConfig) (*DispatchInfo, error) { dispatchExecutionInfo := &DispatchInfo{} // All messages that should be finished at the end of this function @@ -167,12 +191,12 @@ func send(ctx context.Context, message binding.Message, destination duckv1.Addre } additionalHeadersForDestination.Set("Prefer", "reply") - ctx, responseMessage, dispatchExecutionInfo, err := executeRequest(ctx, destination, message, additionalHeadersForDestination, config.retryConfig, config.transformers) + ctx, responseMessage, dispatchExecutionInfo, err := d.executeRequest(ctx, destination, message, additionalHeadersForDestination, config.retryConfig, config.oidcServiceAccount, config.transformers) if err != nil { // If DeadLetter is configured, then send original message with knative error extensions if config.deadLetterSink != nil { dispatchTransformers := dispatchExecutionInfoTransformers(destination.URL, dispatchExecutionInfo) - _, deadLetterResponse, dispatchExecutionInfo, deadLetterErr := executeRequest(ctx, *config.deadLetterSink, message, config.additionalHeaders, config.retryConfig, append(config.transformers, dispatchTransformers)) + _, deadLetterResponse, dispatchExecutionInfo, deadLetterErr := d.executeRequest(ctx, *config.deadLetterSink, message, config.additionalHeaders, config.retryConfig, config.oidcServiceAccount, append(config.transformers, dispatchTransformers)) if deadLetterErr != nil { return dispatchExecutionInfo, fmt.Errorf("unable to complete request to either %s (%v) or %s (%v)", destination.URL, err, config.deadLetterSink.URL, deadLetterErr) } @@ -208,12 +232,12 @@ func send(ctx context.Context, message binding.Message, destination duckv1.Addre // send reply - ctx, responseResponseMessage, dispatchExecutionInfo, err := executeRequest(ctx, *config.reply, responseMessage, responseAdditionalHeaders, config.retryConfig, config.transformers) + ctx, responseResponseMessage, dispatchExecutionInfo, err := d.executeRequest(ctx, *config.reply, responseMessage, responseAdditionalHeaders, config.retryConfig, config.oidcServiceAccount, config.transformers) if err != nil { // If DeadLetter is configured, then send original message with knative error extensions if config.deadLetterSink != nil { dispatchTransformers := dispatchExecutionInfoTransformers(config.reply.URL, dispatchExecutionInfo) - _, deadLetterResponse, dispatchExecutionInfo, deadLetterErr := executeRequest(ctx, *config.deadLetterSink, message, responseAdditionalHeaders, config.retryConfig, append(config.transformers, dispatchTransformers)) + _, deadLetterResponse, dispatchExecutionInfo, deadLetterErr := d.executeRequest(ctx, *config.deadLetterSink, message, responseAdditionalHeaders, config.retryConfig, config.oidcServiceAccount, append(config.transformers, dispatchTransformers)) if deadLetterErr != nil { return dispatchExecutionInfo, fmt.Errorf("failed to forward reply to %s (%v) and failed to send it to the dead letter sink %s (%v)", config.reply.URL, err, config.deadLetterSink.URL, deadLetterErr) } @@ -233,7 +257,7 @@ func send(ctx context.Context, message binding.Message, destination duckv1.Addre return dispatchExecutionInfo, nil } -func executeRequest(ctx context.Context, target duckv1.Addressable, message cloudevents.Message, additionalHeaders http.Header, retryConfig *RetryConfig, transformers ...binding.Transformer) (context.Context, cloudevents.Message, *DispatchInfo, error) { +func (d *Dispatcher) executeRequest(ctx context.Context, target duckv1.Addressable, message cloudevents.Message, additionalHeaders http.Header, retryConfig *RetryConfig, oidcServiceAccount *types.NamespacedName, transformers ...binding.Transformer) (context.Context, cloudevents.Message, *DispatchInfo, error) { dispatchInfo := DispatchInfo{ Duration: NoDuration, ResponseCode: NoResponse, @@ -247,7 +271,7 @@ func executeRequest(ctx context.Context, target duckv1.Addressable, message clou transformers = append(transformers, tracing.PopulateSpan(span, target.URL.String())) } - req, err := createRequest(ctx, message, target, additionalHeaders, transformers...) + req, err := d.createRequest(ctx, message, target, additionalHeaders, oidcServiceAccount, transformers...) if err != nil { return ctx, nil, &dispatchInfo, fmt.Errorf("failed to create request: %w", err) } @@ -305,7 +329,7 @@ func executeRequest(ctx context.Context, target duckv1.Addressable, message clou return ctx, responseMessage, &dispatchInfo, nil } -func createRequest(ctx context.Context, message binding.Message, target duckv1.Addressable, additionalHeaders http.Header, transformers ...binding.Transformer) (*http.Request, error) { +func (d *Dispatcher) createRequest(ctx context.Context, message binding.Message, target duckv1.Addressable, additionalHeaders http.Header, oidcServiceAccount *types.NamespacedName, transformers ...binding.Transformer) (*http.Request, error) { request, err := http.NewRequestWithContext(ctx, "POST", target.URL.String(), nil) if err != nil { return nil, fmt.Errorf("could not create http request: %w", err) @@ -319,6 +343,16 @@ func createRequest(ctx context.Context, message binding.Message, target duckv1.A request.Header[key] = val } + if oidcServiceAccount != nil { + if target.Audience != nil && *target.Audience != "" { + jwt, err := d.oidcTokenProvider.GetJWT(*oidcServiceAccount, *target.Audience) + if err != nil { + return nil, fmt.Errorf("could not get JWT: %w", err) + } + request.Header.Set("Authorization", fmt.Sprintf("Bearer: %s", jwt)) + } + } + return request, nil } diff --git a/pkg/kncloudevents/event_dispatcher_test.go b/pkg/kncloudevents/event_dispatcher_test.go index 8f4b5e13035..a988e7aa358 100644 --- a/pkg/kncloudevents/event_dispatcher_test.go +++ b/pkg/kncloudevents/event_dispatcher_test.go @@ -38,14 +38,19 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/rest" + "knative.dev/pkg/injection" rectesting "knative.dev/pkg/reconciler/testing" "knative.dev/pkg/apis" duckv1 "knative.dev/pkg/apis/duck/v1" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/eventingtls/eventingtlstesting" "knative.dev/eventing/pkg/kncloudevents" "knative.dev/eventing/pkg/utils" + fakekubeclient "knative.dev/pkg/client/injection/kube/client/fake" + _ "knative.dev/pkg/system/testing" ) var ( @@ -779,6 +784,12 @@ func TestSendEvent(t *testing.T) { } for n, tc := range testCases { t.Run(n, func(t *testing.T) { + ctx := context.Background() + ctx, _ = fakekubeclient.With(ctx) + ctx = injection.WithConfig(ctx, &rest.Config{}) + + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) destHandler := &fakeHandler{ t: t, response: tc.fakeResponse, @@ -824,7 +835,7 @@ func TestSendEvent(t *testing.T) { } event.SetData(cloudevents.ApplicationJSON, tc.body) - ctx := context.Background() + ctx = context.Background() destination := duckv1.Addressable{ URL: getOnlyDomainURL(t, tc.sendToDestination, destServer.URL), @@ -850,7 +861,7 @@ func TestSendEvent(t *testing.T) { if tc.header != nil { headers = utils.PassThroughHeaders(tc.header) } - info, err := kncloudevents.SendMessage(ctx, message, destination, + info, err := dispatcher.SendMessage(ctx, message, destination, kncloudevents.WithReply(reply), kncloudevents.WithDeadLetterSink(deadLetterSink), kncloudevents.WithHeader(headers)) @@ -923,6 +934,8 @@ func TestDispatchMessageToTLSEndpoint(t *testing.T) { // give the servers a bit time to fully shutdown to prevent port clashes time.Sleep(500 * time.Millisecond) }() + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) eventToSend := test.FullEvent() // destination @@ -945,7 +958,7 @@ func TestDispatchMessageToTLSEndpoint(t *testing.T) { // send event message := binding.ToMessage(&eventToSend) - info, err := kncloudevents.SendMessage(ctx, message, destination) + info, err := dispatcher.SendMessage(ctx, message, destination) require.Nil(t, err) require.Equal(t, 200, info.ResponseCode) @@ -969,6 +982,8 @@ func TestDispatchMessageToTLSEndpointWithReply(t *testing.T) { // give the servers a bit time to fully shutdown to prevent port clashes time.Sleep(500 * time.Millisecond) }() + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) eventToSend := test.FullEvent() eventToReply := test.FullEvent() @@ -1008,7 +1023,7 @@ func TestDispatchMessageToTLSEndpointWithReply(t *testing.T) { // send event message := binding.ToMessage(&eventToSend) - info, err := kncloudevents.SendMessage(ctx, message, destination, kncloudevents.WithReply(&reply)) + info, err := dispatcher.SendMessage(ctx, message, destination, kncloudevents.WithReply(&reply)) require.Nil(t, err) require.Equal(t, 200, info.ResponseCode) @@ -1032,6 +1047,8 @@ func TestDispatchMessageToTLSEndpointWithDeadLetterSink(t *testing.T) { // give the servers a bit time to fully shutdown to prevent port clashes time.Sleep(500 * time.Millisecond) }() + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) eventToSend := test.FullEvent() // destination @@ -1066,7 +1083,7 @@ func TestDispatchMessageToTLSEndpointWithDeadLetterSink(t *testing.T) { // send event message := binding.ToMessage(&eventToSend) - info, err := kncloudevents.SendMessage(ctx, message, destination, kncloudevents.WithDeadLetterSink(&dls)) + info, err := dispatcher.SendMessage(ctx, message, destination, kncloudevents.WithDeadLetterSink(&dls)) require.Nil(t, err) require.Equal(t, 200, info.ResponseCode) diff --git a/pkg/reconciler/inmemorychannel/dispatcher/controller.go b/pkg/reconciler/inmemorychannel/dispatcher/controller.go index 7d4b4394429..d344455a17a 100644 --- a/pkg/reconciler/inmemorychannel/dispatcher/controller.go +++ b/pkg/reconciler/inmemorychannel/dispatcher/controller.go @@ -30,6 +30,7 @@ import ( "github.com/kelseyhightower/envconfig" "knative.dev/pkg/kmeta" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/channel/multichannelfanout" "knative.dev/eventing/pkg/eventingtls" "knative.dev/eventing/pkg/kncloudevents" @@ -118,12 +119,15 @@ func NewController( chMsgHandler: sh, } + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + r := &Reconciler{ multiChannelEventHandler: sh, reporter: reporter, messagingClientSet: eventingclient.Get(ctx).MessagingV1(), eventingClient: eventingclient.Get(ctx).EventingV1beta2(), eventTypeLister: eventtypeinformer.Get(ctx).Lister(), + eventDispatcher: kncloudevents.NewDispatcher(oidcTokenProvider), } impl := inmemorychannelreconciler.NewImpl(ctx, r, func(impl *controller.Impl) controller.Options { diff --git a/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go b/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go index 14e157db781..cb1024dec5c 100644 --- a/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go +++ b/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go @@ -56,6 +56,7 @@ type Reconciler struct { eventTypeLister v1beta2.EventTypeLister eventingClient eventingv1beta2.EventingV1beta2Interface featureStore *feature.Store + eventDispatcher *kncloudevents.Dispatcher } // Check the interfaces Reconciler should implement @@ -108,7 +109,6 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec channelRef = toKReference(imc) UID = &imc.UID - } // First grab the host based MultiChannelFanoutMessage httpHandler @@ -122,6 +122,7 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec eventTypeAutoHandler, channelRef, UID, + r.eventDispatcher, ) if err != nil { logging.FromContext(ctx).Error("Failed to create a new fanout.EventHandler", err) @@ -150,6 +151,7 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec eventTypeAutoHandler, channelRef, UID, + r.eventDispatcher, channel.ResolveChannelFromPath(channel.ParseChannelFromPath), ) if err != nil { diff --git a/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel_test.go b/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel_test.go index a1fda055b0d..5a1ebb34884 100644 --- a/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel_test.go +++ b/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel_test.go @@ -40,6 +40,7 @@ import ( eventingduckv1 "knative.dev/eventing/pkg/apis/duck/v1" "knative.dev/eventing/pkg/apis/feature" v1 "knative.dev/eventing/pkg/apis/messaging/v1" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/channel/fanout" fakeeventingclient "knative.dev/eventing/pkg/client/injection/client/fake" "knative.dev/eventing/pkg/client/injection/reconciler/messaging/v1/inmemorychannel" @@ -491,13 +492,17 @@ func TestReconciler_ReconcileKind(t *testing.T) { }, } for n, tc := range testCases { - ctx, fakeEventingClient := fakeeventingclient.With(context.Background(), tc.imc) + ctx, _ := SetupFakeContext(t) + ctx, fakeEventingClient := fakeeventingclient.With(ctx, tc.imc) feature.ToContext(ctx, feature.Flags{ feature.EvenTypeAutoCreate: feature.Disabled, }) + + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) // Just run the tests once with no existing handler (creates the handler) and once // with an existing, so we exercise both paths at once. - fh, err := fanout.NewFanoutEventHandler(nil, fanout.Config{}, nil, nil, nil, nil) + fh, err := fanout.NewFanoutEventHandler(nil, fanout.Config{}, nil, nil, nil, nil, dispatcher) if err != nil { t.Error(err) } @@ -542,7 +547,11 @@ func TestReconciler_InvalidInputs(t *testing.T) { }, } for n, tc := range testCases { - fh, err := fanout.NewFanoutEventHandler(nil, fanout.Config{}, nil, nil, nil, nil) + ctx, _ := SetupFakeContext(t) + + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) + fh, err := fanout.NewFanoutEventHandler(nil, fanout.Config{}, nil, nil, nil, nil, dispatcher) if err != nil { t.Error(err) } @@ -572,7 +581,11 @@ func TestReconciler_Deletion(t *testing.T) { }, } for n, tc := range testCases { - fh, err := fanout.NewFanoutEventHandler(nil, fanout.Config{}, nil, nil, nil, nil) + ctx, _ := SetupFakeContext(t) + + oidcTokenProvider := auth.NewOIDCTokenProvider(ctx) + dispatcher := kncloudevents.NewDispatcher(oidcTokenProvider) + fh, err := fanout.NewFanoutEventHandler(nil, fanout.Config{}, nil, nil, nil, nil, dispatcher) if err != nil { t.Error(err) }