Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InMemoryChannel: reject request for wrong audience #7449

Merged
merged 12 commits into from
Dec 5, 2023
51 changes: 51 additions & 0 deletions pkg/channel/event_receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import (
nethttp "net/http"
"time"

"knative.dev/eventing/pkg/apis/feature"

"knative.dev/eventing/pkg/auth"

"github.com/cloudevents/sdk-go/v2/event"
"github.com/cloudevents/sdk-go/v2/protocol/http"
"go.uber.org/zap"
Expand Down Expand Up @@ -65,6 +69,9 @@ type EventReceiver struct {
hostToChannelFunc ResolveChannelFromHostFunc
pathToChannelFunc ResolveChannelFromPathFunc
reporter StatsReporter
tokenVerifier *auth.OIDCTokenVerifier
audience string
withContext func(context.Context) context.Context
}

// EventReceiverFunc is the function to be called for handling the event.
Expand Down Expand Up @@ -100,6 +107,21 @@ func ResolveChannelFromPath(PathToChannelFunc ResolveChannelFromPathFunc) EventR
}
}

func OIDCTokenVerification(tokenVerifier *auth.OIDCTokenVerifier, audience string) EventReceiverOptions {
return func(r *EventReceiver) error {
r.tokenVerifier = tokenVerifier
r.audience = audience
return nil
}
}

func ReceiverWithContextFunc(fn func(context.Context) context.Context) EventReceiverOptions {
return func(r *EventReceiver) error {
r.withContext = fn
return nil
}
}

// NewEventReceiver creates an event receiver passing new events to the
// receiverFunc.
func NewEventReceiver(receiverFunc EventReceiverFunc, logger *zap.Logger, reporter StatsReporter, opts ...EventReceiverOptions) (*EventReceiver, error) {
Expand Down Expand Up @@ -153,6 +175,12 @@ func (r *EventReceiver) Start(ctx context.Context) error {
}

func (r *EventReceiver) ServeHTTP(response nethttp.ResponseWriter, request *nethttp.Request) {
ctx := request.Context()

if r.withContext != nil {
ctx = r.withContext(ctx)
}

response.Header().Set("Allow", "POST, OPTIONS")
if request.Method == nethttp.MethodOptions {
response.Header().Set("WebHook-Allowed-Origin", "*") // Accept from any Origin:
Expand Down Expand Up @@ -218,6 +246,29 @@ func (r *EventReceiver) ServeHTTP(response nethttp.ResponseWriter, request *neth
return
}

/// Here we do the OIDC audience verification
features := feature.FromContext(ctx)
if features.IsOIDCAuthentication() {
r.logger.Debug("OIDC authentication is enabled")

token := auth.GetJWTFromHeader(request.Header)
if token == "" {
r.logger.Warn(fmt.Sprintf("No JWT in %s header provided while feature %s is enabled", auth.AuthHeaderKey, feature.OIDCAuthentication))
response.WriteHeader(nethttp.StatusUnauthorized)
return
}

if _, err := r.tokenVerifier.VerifyJWT(ctx, token, r.audience); err != nil {
r.logger.Warn("no valid JWT provided", zap.Error(err))
response.WriteHeader(nethttp.StatusUnauthorized)
return
}

r.logger.Debug("Request contained a valid JWT. Continuing...")
} else {
r.logger.Debug("OIDC authentication is disabled")
}

Leo6Leo marked this conversation as resolved.
Show resolved Hide resolved
err = r.receiverFunc(request.Context(), channel, *event, utils.PassThroughHeaders(request.Header))
if err != nil {
if _, ok := err.(*UnknownChannelError); ok {
Expand Down
33 changes: 17 additions & 16 deletions pkg/channel/fanout/fanout_event_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ type FanoutEventHandler struct {
// rather than a member variable.
timeout time.Duration

reporter channel.StatsReporter
logger *zap.Logger
eventTypeHandler *eventtype.EventTypeAutoHandler
channelAddressable *duckv1.KReference
channelUID *types.UID
reporter channel.StatsReporter
logger *zap.Logger
eventTypeHandler *eventtype.EventTypeAutoHandler
channelRef *duckv1.KReference
channelUID *types.UID
}

// NewFanoutEventHandler creates a new fanout.EventHandler.
Expand All @@ -101,20 +101,21 @@ func NewFanoutEventHandler(
config Config,
reporter channel.StatsReporter,
eventTypeHandler *eventtype.EventTypeAutoHandler,
channelAddressable *duckv1.KReference,
channelRef *duckv1.KReference,
channelUID *types.UID,
eventDispatcher *kncloudevents.Dispatcher,
receiverOpts ...channel.EventReceiverOptions,

) (*FanoutEventHandler, error) {
handler := &FanoutEventHandler{
logger: logger,
timeout: defaultTimeout,
reporter: reporter,
asyncHandler: config.AsyncHandler,
eventTypeHandler: eventTypeHandler,
channelAddressable: channelAddressable,
channelUID: channelUID,
eventDispatcher: eventDispatcher,
logger: logger,
timeout: defaultTimeout,
reporter: reporter,
asyncHandler: config.AsyncHandler,
eventTypeHandler: eventTypeHandler,
channelRef: channelRef,
channelUID: channelUID,
eventDispatcher: eventDispatcher,
}
handler.subscriptions = make([]Subscription, len(config.Subscriptions))
copy(handler.subscriptions, config.Subscriptions)
Expand Down Expand Up @@ -184,15 +185,15 @@ func (f *FanoutEventHandler) GetSubscriptions(ctx context.Context) []Subscriptio
}

func (f *FanoutEventHandler) autoCreateEventType(ctx context.Context, evnt event.Event) {
if f.channelAddressable == nil {
if f.channelRef == nil {
f.logger.Warn("No addressable for channel")
return
} else {
if f.channelUID == nil {
f.logger.Warn("No channelUID provided, unable to autocreate event type")
return
}
err := f.eventTypeHandler.AutoCreateEventType(ctx, &evnt, f.channelAddressable, *f.channelUID)
err := f.eventTypeHandler.AutoCreateEventType(ctx, &evnt, f.channelRef, *f.channelUID)
if err != nil {
f.logger.Warn("EventTypeCreate failed")
return
Expand Down
1 change: 1 addition & 0 deletions pkg/reconciler/inmemorychannel/dispatcher/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ func NewController(
eventingClient: eventingclient.Get(ctx).EventingV1beta2(),
eventTypeLister: eventtypeinformer.Get(ctx).Lister(),
eventDispatcher: kncloudevents.NewDispatcher(oidcTokenProvider),
tokenVerifier: auth.NewOIDCTokenVerifier(ctx),
}

var globalResync func(obj interface{})
Expand Down
14 changes: 14 additions & 0 deletions pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
eventingduckv1 "knative.dev/eventing/pkg/apis/duck/v1"
"knative.dev/eventing/pkg/apis/feature"
v1 "knative.dev/eventing/pkg/apis/messaging/v1"
"knative.dev/eventing/pkg/auth"
"knative.dev/eventing/pkg/channel"
"knative.dev/eventing/pkg/channel/fanout"
"knative.dev/eventing/pkg/channel/multichannelfanout"
Expand All @@ -57,6 +58,7 @@ type Reconciler struct {
eventingClient eventingv1beta2.EventingV1beta2Interface
featureStore *feature.Store
eventDispatcher *kncloudevents.Dispatcher
tokenVerifier *auth.OIDCTokenVerifier
}

// Check the interfaces Reconciler should implement
Expand Down Expand Up @@ -111,6 +113,10 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec
UID = &imc.UID
}

wc := func(ctx context.Context) context.Context {
return r.featureStore.ToContext(ctx)
}

// First grab the host based MultiChannelFanoutMessage httpHandler
httpHandler := r.multiChannelEventHandler.GetChannelHandler(config.HostName)
if httpHandler == nil {
Expand All @@ -123,6 +129,8 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec
channelRef,
UID,
r.eventDispatcher,
channel.OIDCTokenVerification(r.tokenVerifier, audience(imc)),
channel.ReceiverWithContextFunc(wc),
)
if err != nil {
logging.FromContext(ctx).Error("Failed to create a new fanout.EventHandler", err)
Expand Down Expand Up @@ -153,6 +161,8 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec
UID,
r.eventDispatcher,
channel.ResolveChannelFromPath(channel.ParseChannelFromPath),
channel.OIDCTokenVerification(r.tokenVerifier, audience(imc)),
channel.ReceiverWithContextFunc(wc),
)
if err != nil {
logging.FromContext(ctx).Error("Failed to create a new fanout.EventHandler", err)
Expand Down Expand Up @@ -295,3 +305,7 @@ func toKReference(imc *v1.InMemoryChannel) *duckv1.KReference {
Address: imc.Status.Address.Name,
}
}

func audience(imc *v1.InMemoryChannel) string {
return auth.GetAudience(v1.SchemeGroupVersion.WithKind("InMemoryChannel"), imc.ObjectMeta)
}
2 changes: 1 addition & 1 deletion test/auth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestChannelImplSupportsOIDC(t *testing.T) {
name := feature.MakeRandomK8sName("channelimpl")
env.Prerequisite(ctx, t, channel.ImplGoesReady(name))

env.Test(ctx, t, oidc.AddressableHasAudiencePopulated(channel_impl.GVR(), channel_impl.GVK().Kind, name, env.Namespace()))
env.TestSet(ctx, t, oidc.AddressableOIDCConformance(channel_impl.GVR(), channel_impl.GVK().Kind, name, env.Namespace()))
Cali0707 marked this conversation as resolved.
Show resolved Hide resolved
}

func TestParallelSupportsOIDC(t *testing.T) {
Expand Down
Loading