Skip to content

Commit

Permalink
Add some generic type hook wrappers to first decode the data
Browse files Browse the repository at this point in the history
There seems to be a pattern for Validation, Mutation and Write Authorization hooks where they first need to decode the Any data before doing the domain specific work.

This PR introduces 3 new functions to generate wrappers around the other hooks to pre-decode the data into a DecodedResource and pass that in instead of the original pbresource.Resource.

This PR also updates the various catalog data types to use the new hook generators.
  • Loading branch information
mkeeler committed Oct 23, 2023
1 parent fea35e6 commit 95a62ca
Show file tree
Hide file tree
Showing 33 changed files with 597 additions and 429 deletions.
14 changes: 6 additions & 8 deletions internal/auth/internal/types/computed_traffic_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedComputedTrafficPermissions = resource.DecodedResource[*pbauth.ComputedTrafficPermissions]

func RegisterComputedTrafficPermission(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.ComputedTrafficPermissionsType,
Expand All @@ -26,16 +28,12 @@ func RegisterComputedTrafficPermission(r resource.Registry) {
})
}

func ValidateComputedTrafficPermissions(res *pbresource.Resource) error {
var ctp pbauth.ComputedTrafficPermissions

if err := res.Data.UnmarshalTo(&ctp); err != nil {
return resource.NewErrDataParse(&ctp, err)
}
var ValidateComputedTrafficPermissions = resource.DecodeAndValidate(validateComputedTrafficPermissions)

func validateComputedTrafficPermissions(res *DecodedComputedTrafficPermissions) error {
var merr error

for i, permission := range ctp.AllowPermissions {
for i, permission := range res.Data.AllowPermissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "allow_permissions",
Expand All @@ -48,7 +46,7 @@ func ValidateComputedTrafficPermissions(res *pbresource.Resource) error {
}
}

for i, permission := range ctp.DenyPermissions {
for i, permission := range res.Data.DenyPermissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "deny_permissions",
Expand Down
64 changes: 17 additions & 47 deletions internal/auth/internal/types/traffic_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedTrafficPermissions = resource.DecodedResource[*pbauth.TrafficPermissions]

func RegisterTrafficPermissions(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.TrafficPermissionsType,
Proto: &pbauth.TrafficPermissions{},
ACLs: &resource.ACLHooks{
Read: aclReadHookTrafficPermissions,
Write: aclWriteHookTrafficPermissions,
Read: resource.DecodeAndAuthorizeRead(aclReadHookTrafficPermissions),
Write: resource.DecodeAndAuthorizeWrite(aclWriteHookTrafficPermissions),
List: resource.NoOpACLListHook,
},
Validate: ValidateTrafficPermissions,
Expand All @@ -27,28 +29,20 @@ func RegisterTrafficPermissions(r resource.Registry) {
})
}

func MutateTrafficPermissions(res *pbresource.Resource) error {
var tp pbauth.TrafficPermissions

if err := res.Data.UnmarshalTo(&tp); err != nil {
return resource.NewErrDataParse(&tp, err)
}
var MutateTrafficPermissions = resource.DecodeAndMutate(mutateTrafficPermissions)

func mutateTrafficPermissions(res *DecodedTrafficPermissions) (bool, error) {
var changed bool

for _, p := range tp.Permissions {
for _, p := range res.Data.Permissions {
for _, s := range p.Sources {
if updated := normalizedTenancyForSource(s, res.Id.Tenancy); updated {
changed = true
}
}
}

if !changed {
return nil
}

return res.Data.MarshalFrom(&tp)
return changed, nil
}

func normalizedTenancyForSource(src *pbauth.Source, parentTenancy *pbresource.Tenancy) bool {
Expand Down Expand Up @@ -110,17 +104,13 @@ func firstNonEmptyString(a, b, c string) (string, bool) {
return c, true
}

func ValidateTrafficPermissions(res *pbresource.Resource) error {
var tp pbauth.TrafficPermissions

if err := res.Data.UnmarshalTo(&tp); err != nil {
return resource.NewErrDataParse(&tp, err)
}
var ValidateTrafficPermissions = resource.DecodeAndValidate(validateTrafficPermissions)

func validateTrafficPermissions(res *DecodedTrafficPermissions) error {
var merr error

// enumcover:pbauth.Action
switch tp.Action {
switch res.Data.Action {
case pbauth.Action_ACTION_ALLOW:
case pbauth.Action_ACTION_DENY:
case pbauth.Action_ACTION_UNSPECIFIED:
Expand All @@ -132,14 +122,14 @@ func ValidateTrafficPermissions(res *pbresource.Resource) error {
})
}

if tp.Destination == nil || (len(tp.Destination.IdentityName) == 0) {
if res.Data.Destination == nil || (len(res.Data.Destination.IdentityName) == 0) {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "data.destination",
Wrapped: resource.ErrEmpty,
})
}
// Validate permissions
for i, permission := range tp.Permissions {
for i, permission := range res.Data.Permissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "permissions",
Expand Down Expand Up @@ -271,30 +261,10 @@ func isLocalPeer(p string) bool {
return p == "local" || p == ""
}

func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, _ *pbresource.ID, res *pbresource.Resource) error {
if res == nil {
return resource.ErrNeedResource
}
return authorizeDestination(res, func(dest string) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(dest, authzContext)
})
func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(res.Data.Destination.IdentityName, authzContext)
}

func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error {
return authorizeDestination(res, func(dest string) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(dest, authzContext)
})
}

func authorizeDestination(res *pbresource.Resource, intentionAllowed func(string) error) error {
tp, err := resource.Decode[*pbauth.TrafficPermissions](res)
if err != nil {
return err
}
// Check intention:x permissions for identity
err = intentionAllowed(tp.Data.Destination.IdentityName)
if err != nil {
return err
}
return nil
func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(res.Data.Destination.IdentityName, authzContext)
}
11 changes: 10 additions & 1 deletion internal/auth/internal/types/workload_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedWorkloadIdentity = resource.DecodedResource[*pbauth.WorkloadIdentity]

func RegisterWorkloadIdentity(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.WorkloadIdentityType,
Expand All @@ -20,10 +22,17 @@ func RegisterWorkloadIdentity(r resource.Registry) {
Write: aclWriteHookWorkloadIdentity,
List: resource.NoOpACLListHook,
},
Validate: nil,
Validate: ValidateWorkloadIdentity,
})
}

var ValidateWorkloadIdentity = resource.DecodeAndValidate(validateWorkloadIdentity)

func validateWorkloadIdentity(res *DecodedWorkloadIdentity) error {
// currently the WorkloadIdentity type has no fields.
return nil
}

func aclReadHookWorkloadIdentity(
authorizer acl.Authorizer,
authzCtx *acl.AuthorizerContext,
Expand Down
10 changes: 10 additions & 0 deletions internal/auth/internal/types/workload_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,13 @@ func TestWorkloadIdentityACLs(t *testing.T) {
})
}
}

func TestWorkloadIdentity_ParseError(t *testing.T) {
rsc := resourcetest.Resource(pbauth.WorkloadIdentityType, "example").
WithData(t, &pbauth.TrafficPermissions{}).
Build()

err := ValidateWorkloadIdentity(rsc)
var parseErr resource.ErrDataParse
require.ErrorAs(t, err, &parseErr)
}
19 changes: 5 additions & 14 deletions internal/catalog/internal/types/acl_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,22 @@ func aclReadHookResourceWithWorkloadSelector(authorizer acl.Authorizer, authzCon
return authorizer.ToAllowAuthorizer().ServiceReadAllowed(id.GetName(), authzContext)
}

func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error {
if res == nil {
return resource.ErrNeedResource
}

decodedService, err := resource.Decode[T](res)
if err != nil {
return resource.ErrNeedResource
}

func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, r *resource.DecodedResource[T]) error {
// First check service:write on the name.
err = authorizer.ToAllowAuthorizer().ServiceWriteAllowed(res.GetId().GetName(), authzContext)
err := authorizer.ToAllowAuthorizer().ServiceWriteAllowed(r.GetId().GetName(), authzContext)
if err != nil {
return err
}

// Then also check whether we're allowed to select a service.
for _, name := range decodedService.GetData().GetWorkloads().GetNames() {
for _, name := range r.Data.GetWorkloads().GetNames() {
err = authorizer.ToAllowAuthorizer().ServiceReadAllowed(name, authzContext)
if err != nil {
return err
}
}

for _, prefix := range decodedService.GetData().GetWorkloads().GetPrefixes() {
for _, prefix := range r.Data.GetWorkloads().GetPrefixes() {
err = authorizer.ToAllowAuthorizer().ServiceReadPrefixAllowed(prefix, authzContext)
if err != nil {
return err
Expand All @@ -50,7 +41,7 @@ func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer ac
func ACLHooksForWorkloadSelectingType[T WorkloadSelecting]() *resource.ACLHooks {
return &resource.ACLHooks{
Read: aclReadHookResourceWithWorkloadSelector,
Write: aclWriteHookResourceWithWorkloadSelector[T],
Write: resource.DecodeAndAuthorizeWrite(aclWriteHookResourceWithWorkloadSelector[T]),
List: resource.NoOpACLListHook,
}
}
15 changes: 6 additions & 9 deletions internal/catalog/internal/types/dns_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (

"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedDNSPolicy = resource.DecodedResource[*pbcatalog.DNSPolicy]

func RegisterDNSPolicy(r resource.Registry) {
r.Register(resource.Registration{
Type: pbcatalog.DNSPolicyType,
Expand All @@ -23,25 +24,21 @@ func RegisterDNSPolicy(r resource.Registry) {
})
}

func ValidateDNSPolicy(res *pbresource.Resource) error {
var policy pbcatalog.DNSPolicy

if err := res.Data.UnmarshalTo(&policy); err != nil {
return resource.NewErrDataParse(&policy, err)
}
var ValidateDNSPolicy = resource.DecodeAndValidate(validateDNSPolicy)

func validateDNSPolicy(res *DecodedDNSPolicy) error {
var err error
// Ensure that this resource isn't useless and is attempting to
// select at least one workload.
if selErr := ValidateSelector(policy.Workloads, false); selErr != nil {
if selErr := ValidateSelector(res.Data.Workloads, false); selErr != nil {
err = multierror.Append(err, resource.ErrInvalidField{
Name: "workloads",
Wrapped: selErr,
})
}

// Validate the weights
if weightErr := validateDNSPolicyWeights(policy.Weights); weightErr != nil {
if weightErr := validateDNSPolicyWeights(res.Data.Weights); weightErr != nil {
err = multierror.Append(err, resource.ErrInvalidField{
Name: "weights",
Wrapped: weightErr,
Expand Down
Loading

0 comments on commit 95a62ca

Please sign in to comment.