diff --git a/agent/xds/jwt_authn.go b/agent/xds/jwt_authn.go index 5840a04770eb..0dc95f5eecee 100644 --- a/agent/xds/jwt_authn.go +++ b/agent/xds/jwt_authn.go @@ -16,13 +16,13 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" ) -var ( +const ( jwtEnvoyFilter = "envoy.filters.http.jwt_authn" jwtMetadataKeyPrefix = "jwt_payload" ) // This is an intermediate JWTProvider form used to associate -// unique keys to providers +// unique payload keys to providers type jwtAuthnProvider struct { ComputedName string Provider *structs.IntentionJWTProvider @@ -116,6 +116,11 @@ func getPermissionsProviders(p []*structs.IntentionPermission) []*jwtAuthnProvid return reqs } +// makeComputedProviderName is used to create names for unique provider per permission +// This is to stop jwt claims cross validation across permissions/providers. +// +// eg. If Permission x is the 3rd permission and has a provider of original name okta +// this function will return okta_3 as the computed provider name func makeComputedProviderName(name string, perm *structs.IntentionPermission, idx int) string { if perm == nil { return name @@ -123,6 +128,13 @@ func makeComputedProviderName(name string, perm *structs.IntentionPermission, id return fmt.Sprintf("%s_%d", name, idx) } +// buildPayloadInMetadataKey is used to create a unique payload key per provider/permissions. +// This is to ensure claims are validated/forwarded specifically under the right permission/path +// and ensure we don't accidentally validate claims from different permissions/providers. +// +// eg. With a provider named okta, the second permission in permission list will have a provider of: +// okta_2 and a payload key of: jwt_payload_okta_2. Whereas an okta provider with no specific permission +// will have a payload key of: jwt_payload_okta func buildPayloadInMetadataKey(providerName string, perm *structs.IntentionPermission, idx int) string { return fmt.Sprintf("%s_%s", jwtMetadataKeyPrefix, makeComputedProviderName(providerName, perm, idx)) } diff --git a/agent/xds/rbac.go b/agent/xds/rbac.go index 2b3ead31cb49..071dc054a4c8 100644 --- a/agent/xds/rbac.go +++ b/agent/xds/rbac.go @@ -236,8 +236,7 @@ func intentionToIntermediateRBACForm( var c []*JWTInfo for _, prov := range ixn.JWT.Providers { if len(prov.VerifyClaims) > 0 { - ji := &JWTInfo{Claims: prov.VerifyClaims, MetadataPayloadKey: buildPayloadInMetadataKey(prov.Name, nil, 0)} - c = append(c, ji) + c = append(c, makeJWTInfos(prov, nil, 0)) } } if len(c) > 0 { @@ -259,8 +258,7 @@ func intentionToIntermediateRBACForm( var c []*JWTInfo for _, prov := range perm.JWT.Providers { if len(prov.VerifyClaims) > 0 { - ji := &JWTInfo{Claims: prov.VerifyClaims, MetadataPayloadKey: buildPayloadInMetadataKey(prov.Name, perm, k)} - c = append(c, ji) + c = append(c, makeJWTInfos(prov, perm, k)) } } if len(c) > 0 { @@ -280,6 +278,10 @@ func intentionToIntermediateRBACForm( return rixn } +func makeJWTInfos(p *structs.IntentionJWTProvider, perm *structs.IntentionPermission, permKey int) *JWTInfo { + return &JWTInfo{Claims: p.VerifyClaims, MetadataPayloadKey: buildPayloadInMetadataKey(p.Name, perm, permKey)} +} + type intentionAction int type JWTInfo struct { @@ -582,7 +584,7 @@ func makeRBACRules( finalPrincipals := optimizePrincipals([]*envoy_rbac_v3.Principal{rbacIxn.ComputedPrincipal}) if len(infos) > 0 { - claimsPrincipal := infoListToPrincipals(infos) + claimsPrincipal := jwtInfosToPrincipals(infos) finalPrincipals = append(finalPrincipals, claimsPrincipal) } // For L7: we should generate one Policy per Principal and list all of the Permissions @@ -599,7 +601,7 @@ func makeRBACRules( principalsL4 = append(principalsL4, rbacIxn.ComputedPrincipal) // Append JWT principals to list of principals if len(infos) > 0 { - claimsPrincipal := infoListToPrincipals(infos) + claimsPrincipal := jwtInfosToPrincipals(infos) principalsL4 = append(principalsL4, claimsPrincipal) } } @@ -630,20 +632,25 @@ func collectJWTInfos(rbacIxn *rbacIntention) []*JWTInfo { return infos } -func infoListToPrincipals(c []*JWTInfo) *envoy_rbac_v3.Principal { +func jwtInfosToPrincipals(c []*JWTInfo) *envoy_rbac_v3.Principal { ps := make([]*envoy_rbac_v3.Principal, 0) for _, jwtInfo := range c { if jwtInfo != nil { for _, claim := range jwtInfo.Claims { - ps = append(ps, claimToPrincipal(claim, jwtInfo.MetadataPayloadKey)) + ps = append(ps, jwtClaimToPrincipal(claim, jwtInfo.MetadataPayloadKey)) } } } return orPrincipals(ps) } -func claimToPrincipal(c *structs.IntentionJWTClaimVerification, payloadKey string) *envoy_rbac_v3.Principal { +// jwtClaimToPrincipal takes in a payloadkey which is generated by using provider name, permission index with +// a jwt_payload prefix. See buildPayloadInMetadataKey in agent/xds/jwt_authn.go for more info on payloadkey +// +// This uniquely generated payloadKey is the segment in the path to validate the JWT claims. The subsequent keys +// come from the Path in the IntentionJWTClaimVerification object. +func jwtClaimToPrincipal(c *structs.IntentionJWTClaimVerification, payloadKey string) *envoy_rbac_v3.Principal { segments := pathToSegments(c.Path, payloadKey) return &envoy_rbac_v3.Principal{ @@ -665,6 +672,22 @@ func claimToPrincipal(c *structs.IntentionJWTClaimVerification, payloadKey strin } } +// pathToSegments generates an array of MetadataMatcher_PathSegment that starts with the payloadkey +// and is followed by all existing strings in the path. +// +// eg. calling: pathToSegments([]string{"perms", "roles"}, "jwt_payload_okta") should return the following: +// +// []*envoy_matcher_v3.MetadataMatcher_PathSegment{ +// { +// Segment: &envoy_matcher_v3.MetadataMatcher_PathSegment_Key{Key: "jwt_payload_okta"}, +// }, +// { +// Segment: &envoy_matcher_v3.MetadataMatcher_PathSegment_Key{Key: "perms"}, +// }, +// { +// Segment: &envoy_matcher_v3.MetadataMatcher_PathSegment_Key{Key: "roles"}, +// }, +// }, func pathToSegments(paths []string, payloadKey string) []*envoy_matcher_v3.MetadataMatcher_PathSegment { segments := make([]*envoy_matcher_v3.MetadataMatcher_PathSegment, 0, len(paths)) diff --git a/agent/xds/rbac_test.go b/agent/xds/rbac_test.go index f930d55b46fe..76f4467bffa6 100644 --- a/agent/xds/rbac_test.go +++ b/agent/xds/rbac_test.go @@ -486,12 +486,8 @@ func TestMakeRBACNetworkAndHTTPFilters(t *testing.T) { } testIntentionWithJWT := func(src string, action structs.IntentionAction, jwt *structs.IntentionJWTRequirement, perms ...*structs.IntentionPermission) *structs.Intention { ixn := testIntention(t, src, "api", action) - if jwt != nil { - ixn.JWT = jwt - } - if action != "" { - ixn.Action = action - } + ixn.JWT = jwt + ixn.Action = action if perms != nil { ixn.Permissions = perms ixn.Action = "" @@ -1206,7 +1202,7 @@ func TestPathToSegments(t *testing.T) { } } -func TestClaimToPrincipal(t *testing.T) { +func TestJwtClaimToPrincipal(t *testing.T) { var ( firstClaim = structs.IntentionJWTClaimVerification{ Path: []string{"perms"}, @@ -1295,7 +1291,7 @@ func TestClaimToPrincipal(t *testing.T) { for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { - principal := infoListToPrincipals(tt.jwtInfos) + principal := jwtInfosToPrincipals(tt.jwtInfos) require.Equal(t, principal, tt.expected) }) }