Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resource Hook Pre-Decode Utilities #18548

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions internal/auth/internal/types/computed_traffic_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedComputedTrafficPermissions = resource.DecodedResource[*pbauth.ComputedTrafficPermissions]

func RegisterComputedTrafficPermission(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.ComputedTrafficPermissionsType,
Expand All @@ -26,16 +28,12 @@ func RegisterComputedTrafficPermission(r resource.Registry) {
})
}

func ValidateComputedTrafficPermissions(res *pbresource.Resource) error {
var ctp pbauth.ComputedTrafficPermissions

if err := res.Data.UnmarshalTo(&ctp); err != nil {
return resource.NewErrDataParse(&ctp, err)
}
var ValidateComputedTrafficPermissions = resource.DecodeAndValidate(validateComputedTrafficPermissions)

func validateComputedTrafficPermissions(res *DecodedComputedTrafficPermissions) error {
var merr error

for i, permission := range ctp.AllowPermissions {
for i, permission := range res.Data.AllowPermissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "allow_permissions",
Expand All @@ -48,7 +46,7 @@ func ValidateComputedTrafficPermissions(res *pbresource.Resource) error {
}
}

for i, permission := range ctp.DenyPermissions {
for i, permission := range res.Data.DenyPermissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "deny_permissions",
Expand Down
64 changes: 17 additions & 47 deletions internal/auth/internal/types/traffic_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedTrafficPermissions = resource.DecodedResource[*pbauth.TrafficPermissions]

func RegisterTrafficPermissions(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.TrafficPermissionsType,
Proto: &pbauth.TrafficPermissions{},
ACLs: &resource.ACLHooks{
Read: aclReadHookTrafficPermissions,
Write: aclWriteHookTrafficPermissions,
Read: resource.DecodeAndAuthorizeRead(aclReadHookTrafficPermissions),
Write: resource.DecodeAndAuthorizeWrite(aclWriteHookTrafficPermissions),
List: resource.NoOpACLListHook,
},
Validate: ValidateTrafficPermissions,
Expand All @@ -27,28 +29,20 @@ func RegisterTrafficPermissions(r resource.Registry) {
})
}

func MutateTrafficPermissions(res *pbresource.Resource) error {
var tp pbauth.TrafficPermissions

if err := res.Data.UnmarshalTo(&tp); err != nil {
return resource.NewErrDataParse(&tp, err)
}
var MutateTrafficPermissions = resource.DecodeAndMutate(mutateTrafficPermissions)

func mutateTrafficPermissions(res *DecodedTrafficPermissions) (bool, error) {
var changed bool

for _, p := range tp.Permissions {
for _, p := range res.Data.Permissions {
for _, s := range p.Sources {
if updated := normalizedTenancyForSource(s, res.Id.Tenancy); updated {
changed = true
}
}
}

if !changed {
return nil
}

return res.Data.MarshalFrom(&tp)
return changed, nil
}

func normalizedTenancyForSource(src *pbauth.Source, parentTenancy *pbresource.Tenancy) bool {
Expand Down Expand Up @@ -110,17 +104,13 @@ func firstNonEmptyString(a, b, c string) (string, bool) {
return c, true
}

func ValidateTrafficPermissions(res *pbresource.Resource) error {
var tp pbauth.TrafficPermissions

if err := res.Data.UnmarshalTo(&tp); err != nil {
return resource.NewErrDataParse(&tp, err)
}
var ValidateTrafficPermissions = resource.DecodeAndValidate(validateTrafficPermissions)

func validateTrafficPermissions(res *DecodedTrafficPermissions) error {
var merr error

// enumcover:pbauth.Action
switch tp.Action {
switch res.Data.Action {
case pbauth.Action_ACTION_ALLOW:
case pbauth.Action_ACTION_DENY:
case pbauth.Action_ACTION_UNSPECIFIED:
Expand All @@ -132,14 +122,14 @@ func ValidateTrafficPermissions(res *pbresource.Resource) error {
})
}

if tp.Destination == nil || (len(tp.Destination.IdentityName) == 0) {
if res.Data.Destination == nil || (len(res.Data.Destination.IdentityName) == 0) {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "data.destination",
Wrapped: resource.ErrEmpty,
})
}
// Validate permissions
for i, permission := range tp.Permissions {
for i, permission := range res.Data.Permissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "permissions",
Expand Down Expand Up @@ -271,30 +261,10 @@ func isLocalPeer(p string) bool {
return p == "local" || p == ""
}

func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, _ *pbresource.ID, res *pbresource.Resource) error {
if res == nil {
return resource.ErrNeedResource
}
return authorizeDestination(res, func(dest string) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(dest, authzContext)
})
func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(res.Data.Destination.IdentityName, authzContext)
}

func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error {
return authorizeDestination(res, func(dest string) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(dest, authzContext)
})
}

func authorizeDestination(res *pbresource.Resource, intentionAllowed func(string) error) error {
tp, err := resource.Decode[*pbauth.TrafficPermissions](res)
if err != nil {
return err
}
// Check intention:x permissions for identity
err = intentionAllowed(tp.Data.Destination.IdentityName)
if err != nil {
return err
}
return nil
func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(res.Data.Destination.IdentityName, authzContext)
}
11 changes: 10 additions & 1 deletion internal/auth/internal/types/workload_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedWorkloadIdentity = resource.DecodedResource[*pbauth.WorkloadIdentity]

func RegisterWorkloadIdentity(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.WorkloadIdentityType,
Expand All @@ -20,10 +22,17 @@ func RegisterWorkloadIdentity(r resource.Registry) {
Write: aclWriteHookWorkloadIdentity,
List: resource.NoOpACLListHook,
},
Validate: nil,
Validate: ValidateWorkloadIdentity,
})
}

var ValidateWorkloadIdentity = resource.DecodeAndValidate(validateWorkloadIdentity)

func validateWorkloadIdentity(res *DecodedWorkloadIdentity) error {
// currently the WorkloadIdentity type has no fields.
return nil
}

func aclReadHookWorkloadIdentity(
authorizer acl.Authorizer,
authzCtx *acl.AuthorizerContext,
Expand Down
10 changes: 10 additions & 0 deletions internal/auth/internal/types/workload_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,13 @@ func TestWorkloadIdentityACLs(t *testing.T) {
})
}
}

func TestWorkloadIdentity_ParseError(t *testing.T) {
rsc := resourcetest.Resource(pbauth.WorkloadIdentityType, "example").
WithData(t, &pbauth.TrafficPermissions{}).
Build()

err := ValidateWorkloadIdentity(rsc)
var parseErr resource.ErrDataParse
require.ErrorAs(t, err, &parseErr)
}
19 changes: 5 additions & 14 deletions internal/catalog/internal/types/acl_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,22 @@ func aclReadHookResourceWithWorkloadSelector(authorizer acl.Authorizer, authzCon
return authorizer.ToAllowAuthorizer().ServiceReadAllowed(id.GetName(), authzContext)
}

func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error {
if res == nil {
return resource.ErrNeedResource
}

decodedService, err := resource.Decode[T](res)
if err != nil {
return resource.ErrNeedResource
}

func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, r *resource.DecodedResource[T]) error {
// First check service:write on the name.
err = authorizer.ToAllowAuthorizer().ServiceWriteAllowed(res.GetId().GetName(), authzContext)
err := authorizer.ToAllowAuthorizer().ServiceWriteAllowed(r.GetId().GetName(), authzContext)
if err != nil {
return err
}

// Then also check whether we're allowed to select a service.
for _, name := range decodedService.GetData().GetWorkloads().GetNames() {
for _, name := range r.Data.GetWorkloads().GetNames() {
err = authorizer.ToAllowAuthorizer().ServiceReadAllowed(name, authzContext)
if err != nil {
return err
}
}

for _, prefix := range decodedService.GetData().GetWorkloads().GetPrefixes() {
for _, prefix := range r.Data.GetWorkloads().GetPrefixes() {
err = authorizer.ToAllowAuthorizer().ServiceReadPrefixAllowed(prefix, authzContext)
if err != nil {
return err
Expand All @@ -50,7 +41,7 @@ func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer ac
func ACLHooksForWorkloadSelectingType[T WorkloadSelecting]() *resource.ACLHooks {
return &resource.ACLHooks{
Read: aclReadHookResourceWithWorkloadSelector,
Write: aclWriteHookResourceWithWorkloadSelector[T],
Write: resource.DecodeAndAuthorizeWrite(aclWriteHookResourceWithWorkloadSelector[T]),
List: resource.NoOpACLListHook,
}
}
15 changes: 6 additions & 9 deletions internal/catalog/internal/types/dns_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (

"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedDNSPolicy = resource.DecodedResource[*pbcatalog.DNSPolicy]

func RegisterDNSPolicy(r resource.Registry) {
r.Register(resource.Registration{
Type: pbcatalog.DNSPolicyType,
Expand All @@ -23,25 +24,21 @@ func RegisterDNSPolicy(r resource.Registry) {
})
}

func ValidateDNSPolicy(res *pbresource.Resource) error {
var policy pbcatalog.DNSPolicy

if err := res.Data.UnmarshalTo(&policy); err != nil {
return resource.NewErrDataParse(&policy, err)
}
var ValidateDNSPolicy = resource.DecodeAndValidate(validateDNSPolicy)

func validateDNSPolicy(res *DecodedDNSPolicy) error {
var err error
// Ensure that this resource isn't useless and is attempting to
// select at least one workload.
if selErr := ValidateSelector(policy.Workloads, false); selErr != nil {
if selErr := ValidateSelector(res.Data.Workloads, false); selErr != nil {
err = multierror.Append(err, resource.ErrInvalidField{
Name: "workloads",
Wrapped: selErr,
})
}

// Validate the weights
if weightErr := validateDNSPolicyWeights(policy.Weights); weightErr != nil {
if weightErr := validateDNSPolicyWeights(res.Data.Weights); weightErr != nil {
err = multierror.Append(err, resource.ErrInvalidField{
Name: "weights",
Wrapped: weightErr,
Expand Down
Loading
Loading