Skip to content

Commit

Permalink
Resource Hook Pre-Decode Utilities (#18548)
Browse files Browse the repository at this point in the history
Add some generic type hook wrappers to first decode the data

There seems to be a pattern for Validation, Mutation and Write Authorization hooks where they first need to decode the Any data before doing the domain specific work.

This PR introduces 3 new functions to generate wrappers around the other hooks to pre-decode the data into a DecodedResource and pass that in instead of the original pbresource.Resource.

This PR also updates the various catalog data types to use the new hook generators.
  • Loading branch information
mkeeler authored Oct 26, 2023
1 parent ea91e58 commit 5698353
Show file tree
Hide file tree
Showing 33 changed files with 603 additions and 429 deletions.
14 changes: 6 additions & 8 deletions internal/auth/internal/types/computed_traffic_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedComputedTrafficPermissions = resource.DecodedResource[*pbauth.ComputedTrafficPermissions]

func RegisterComputedTrafficPermission(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.ComputedTrafficPermissionsType,
Expand All @@ -26,16 +28,12 @@ func RegisterComputedTrafficPermission(r resource.Registry) {
})
}

func ValidateComputedTrafficPermissions(res *pbresource.Resource) error {
var ctp pbauth.ComputedTrafficPermissions

if err := res.Data.UnmarshalTo(&ctp); err != nil {
return resource.NewErrDataParse(&ctp, err)
}
var ValidateComputedTrafficPermissions = resource.DecodeAndValidate(validateComputedTrafficPermissions)

func validateComputedTrafficPermissions(res *DecodedComputedTrafficPermissions) error {
var merr error

for i, permission := range ctp.AllowPermissions {
for i, permission := range res.Data.AllowPermissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "allow_permissions",
Expand All @@ -48,7 +46,7 @@ func ValidateComputedTrafficPermissions(res *pbresource.Resource) error {
}
}

for i, permission := range ctp.DenyPermissions {
for i, permission := range res.Data.DenyPermissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "deny_permissions",
Expand Down
64 changes: 17 additions & 47 deletions internal/auth/internal/types/traffic_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedTrafficPermissions = resource.DecodedResource[*pbauth.TrafficPermissions]

func RegisterTrafficPermissions(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.TrafficPermissionsType,
Proto: &pbauth.TrafficPermissions{},
ACLs: &resource.ACLHooks{
Read: aclReadHookTrafficPermissions,
Write: aclWriteHookTrafficPermissions,
Read: resource.DecodeAndAuthorizeRead(aclReadHookTrafficPermissions),
Write: resource.DecodeAndAuthorizeWrite(aclWriteHookTrafficPermissions),
List: resource.NoOpACLListHook,
},
Validate: ValidateTrafficPermissions,
Expand All @@ -27,28 +29,20 @@ func RegisterTrafficPermissions(r resource.Registry) {
})
}

func MutateTrafficPermissions(res *pbresource.Resource) error {
var tp pbauth.TrafficPermissions

if err := res.Data.UnmarshalTo(&tp); err != nil {
return resource.NewErrDataParse(&tp, err)
}
var MutateTrafficPermissions = resource.DecodeAndMutate(mutateTrafficPermissions)

func mutateTrafficPermissions(res *DecodedTrafficPermissions) (bool, error) {
var changed bool

for _, p := range tp.Permissions {
for _, p := range res.Data.Permissions {
for _, s := range p.Sources {
if updated := normalizedTenancyForSource(s, res.Id.Tenancy); updated {
changed = true
}
}
}

if !changed {
return nil
}

return res.Data.MarshalFrom(&tp)
return changed, nil
}

func normalizedTenancyForSource(src *pbauth.Source, parentTenancy *pbresource.Tenancy) bool {
Expand Down Expand Up @@ -110,17 +104,13 @@ func firstNonEmptyString(a, b, c string) (string, bool) {
return c, true
}

func ValidateTrafficPermissions(res *pbresource.Resource) error {
var tp pbauth.TrafficPermissions

if err := res.Data.UnmarshalTo(&tp); err != nil {
return resource.NewErrDataParse(&tp, err)
}
var ValidateTrafficPermissions = resource.DecodeAndValidate(validateTrafficPermissions)

func validateTrafficPermissions(res *DecodedTrafficPermissions) error {
var merr error

// enumcover:pbauth.Action
switch tp.Action {
switch res.Data.Action {
case pbauth.Action_ACTION_ALLOW:
case pbauth.Action_ACTION_DENY:
case pbauth.Action_ACTION_UNSPECIFIED:
Expand All @@ -132,14 +122,14 @@ func ValidateTrafficPermissions(res *pbresource.Resource) error {
})
}

if tp.Destination == nil || (len(tp.Destination.IdentityName) == 0) {
if res.Data.Destination == nil || (len(res.Data.Destination.IdentityName) == 0) {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "data.destination",
Wrapped: resource.ErrEmpty,
})
}
// Validate permissions
for i, permission := range tp.Permissions {
for i, permission := range res.Data.Permissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "permissions",
Expand Down Expand Up @@ -271,30 +261,10 @@ func isLocalPeer(p string) bool {
return p == "local" || p == ""
}

func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, _ *pbresource.ID, res *pbresource.Resource) error {
if res == nil {
return resource.ErrNeedResource
}
return authorizeDestination(res, func(dest string) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(dest, authzContext)
})
func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(res.Data.Destination.IdentityName, authzContext)
}

func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error {
return authorizeDestination(res, func(dest string) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(dest, authzContext)
})
}

func authorizeDestination(res *pbresource.Resource, intentionAllowed func(string) error) error {
tp, err := resource.Decode[*pbauth.TrafficPermissions](res)
if err != nil {
return err
}
// Check intention:x permissions for identity
err = intentionAllowed(tp.Data.Destination.IdentityName)
if err != nil {
return err
}
return nil
func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(res.Data.Destination.IdentityName, authzContext)
}
11 changes: 10 additions & 1 deletion internal/auth/internal/types/workload_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedWorkloadIdentity = resource.DecodedResource[*pbauth.WorkloadIdentity]

func RegisterWorkloadIdentity(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.WorkloadIdentityType,
Expand All @@ -20,10 +22,17 @@ func RegisterWorkloadIdentity(r resource.Registry) {
Write: aclWriteHookWorkloadIdentity,
List: resource.NoOpACLListHook,
},
Validate: nil,
Validate: ValidateWorkloadIdentity,
})
}

var ValidateWorkloadIdentity = resource.DecodeAndValidate(validateWorkloadIdentity)

func validateWorkloadIdentity(res *DecodedWorkloadIdentity) error {
// currently the WorkloadIdentity type has no fields.
return nil
}

func aclReadHookWorkloadIdentity(
authorizer acl.Authorizer,
authzCtx *acl.AuthorizerContext,
Expand Down
10 changes: 10 additions & 0 deletions internal/auth/internal/types/workload_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,13 @@ func TestWorkloadIdentityACLs(t *testing.T) {
})
}
}

func TestWorkloadIdentity_ParseError(t *testing.T) {
rsc := resourcetest.Resource(pbauth.WorkloadIdentityType, "example").
WithData(t, &pbauth.TrafficPermissions{}).
Build()

err := ValidateWorkloadIdentity(rsc)
var parseErr resource.ErrDataParse
require.ErrorAs(t, err, &parseErr)
}
19 changes: 5 additions & 14 deletions internal/catalog/internal/types/acl_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,22 @@ func aclReadHookResourceWithWorkloadSelector(authorizer acl.Authorizer, authzCon
return authorizer.ToAllowAuthorizer().ServiceReadAllowed(id.GetName(), authzContext)
}

func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error {
if res == nil {
return resource.ErrNeedResource
}

decodedService, err := resource.Decode[T](res)
if err != nil {
return resource.ErrNeedResource
}

func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, r *resource.DecodedResource[T]) error {
// First check service:write on the name.
err = authorizer.ToAllowAuthorizer().ServiceWriteAllowed(res.GetId().GetName(), authzContext)
err := authorizer.ToAllowAuthorizer().ServiceWriteAllowed(r.GetId().GetName(), authzContext)
if err != nil {
return err
}

// Then also check whether we're allowed to select a service.
for _, name := range decodedService.GetData().GetWorkloads().GetNames() {
for _, name := range r.Data.GetWorkloads().GetNames() {
err = authorizer.ToAllowAuthorizer().ServiceReadAllowed(name, authzContext)
if err != nil {
return err
}
}

for _, prefix := range decodedService.GetData().GetWorkloads().GetPrefixes() {
for _, prefix := range r.Data.GetWorkloads().GetPrefixes() {
err = authorizer.ToAllowAuthorizer().ServiceReadPrefixAllowed(prefix, authzContext)
if err != nil {
return err
Expand All @@ -50,7 +41,7 @@ func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer ac
func ACLHooksForWorkloadSelectingType[T WorkloadSelecting]() *resource.ACLHooks {
return &resource.ACLHooks{
Read: aclReadHookResourceWithWorkloadSelector,
Write: aclWriteHookResourceWithWorkloadSelector[T],
Write: resource.DecodeAndAuthorizeWrite(aclWriteHookResourceWithWorkloadSelector[T]),
List: resource.NoOpACLListHook,
}
}
15 changes: 6 additions & 9 deletions internal/catalog/internal/types/dns_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (

"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedDNSPolicy = resource.DecodedResource[*pbcatalog.DNSPolicy]

func RegisterDNSPolicy(r resource.Registry) {
r.Register(resource.Registration{
Type: pbcatalog.DNSPolicyType,
Expand All @@ -23,25 +24,21 @@ func RegisterDNSPolicy(r resource.Registry) {
})
}

func ValidateDNSPolicy(res *pbresource.Resource) error {
var policy pbcatalog.DNSPolicy

if err := res.Data.UnmarshalTo(&policy); err != nil {
return resource.NewErrDataParse(&policy, err)
}
var ValidateDNSPolicy = resource.DecodeAndValidate(validateDNSPolicy)

func validateDNSPolicy(res *DecodedDNSPolicy) error {
var err error
// Ensure that this resource isn't useless and is attempting to
// select at least one workload.
if selErr := ValidateSelector(policy.Workloads, false); selErr != nil {
if selErr := ValidateSelector(res.Data.Workloads, false); selErr != nil {
err = multierror.Append(err, resource.ErrInvalidField{
Name: "workloads",
Wrapped: selErr,
})
}

// Validate the weights
if weightErr := validateDNSPolicyWeights(policy.Weights); weightErr != nil {
if weightErr := validateDNSPolicyWeights(res.Data.Weights); weightErr != nil {
err = multierror.Append(err, resource.ErrInvalidField{
Name: "weights",
Wrapped: weightErr,
Expand Down
Loading

0 comments on commit 5698353

Please sign in to comment.