diff --git a/internal/auth/internal/types/computed_traffic_permissions.go b/internal/auth/internal/types/computed_traffic_permissions.go index 0a32e13d2926..800d2a8fb66f 100644 --- a/internal/auth/internal/types/computed_traffic_permissions.go +++ b/internal/auth/internal/types/computed_traffic_permissions.go @@ -12,6 +12,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedComputedTrafficPermissions = resource.DecodedResource[*pbauth.ComputedTrafficPermissions] + func RegisterComputedTrafficPermission(r resource.Registry) { r.Register(resource.Registration{ Type: pbauth.ComputedTrafficPermissionsType, @@ -26,16 +28,12 @@ func RegisterComputedTrafficPermission(r resource.Registry) { }) } -func ValidateComputedTrafficPermissions(res *pbresource.Resource) error { - var ctp pbauth.ComputedTrafficPermissions - - if err := res.Data.UnmarshalTo(&ctp); err != nil { - return resource.NewErrDataParse(&ctp, err) - } +var ValidateComputedTrafficPermissions = resource.DecodeAndValidate(validateComputedTrafficPermissions) +func validateComputedTrafficPermissions(res *DecodedComputedTrafficPermissions) error { var merr error - for i, permission := range ctp.AllowPermissions { + for i, permission := range res.Data.AllowPermissions { wrapErr := func(err error) error { return resource.ErrInvalidListElement{ Name: "allow_permissions", @@ -48,7 +46,7 @@ func ValidateComputedTrafficPermissions(res *pbresource.Resource) error { } } - for i, permission := range ctp.DenyPermissions { + for i, permission := range res.Data.DenyPermissions { wrapErr := func(err error) error { return resource.ErrInvalidListElement{ Name: "deny_permissions", diff --git a/internal/auth/internal/types/traffic_permissions.go b/internal/auth/internal/types/traffic_permissions.go index 78d53c70c628..bf22fdb0b5fa 100644 --- a/internal/auth/internal/types/traffic_permissions.go +++ b/internal/auth/internal/types/traffic_permissions.go @@ -12,13 +12,15 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedTrafficPermissions = resource.DecodedResource[*pbauth.TrafficPermissions] + func RegisterTrafficPermissions(r resource.Registry) { r.Register(resource.Registration{ Type: pbauth.TrafficPermissionsType, Proto: &pbauth.TrafficPermissions{}, ACLs: &resource.ACLHooks{ - Read: aclReadHookTrafficPermissions, - Write: aclWriteHookTrafficPermissions, + Read: resource.DecodeAndAuthorizeRead(aclReadHookTrafficPermissions), + Write: resource.DecodeAndAuthorizeWrite(aclWriteHookTrafficPermissions), List: resource.NoOpACLListHook, }, Validate: ValidateTrafficPermissions, @@ -27,16 +29,12 @@ func RegisterTrafficPermissions(r resource.Registry) { }) } -func MutateTrafficPermissions(res *pbresource.Resource) error { - var tp pbauth.TrafficPermissions - - if err := res.Data.UnmarshalTo(&tp); err != nil { - return resource.NewErrDataParse(&tp, err) - } +var MutateTrafficPermissions = resource.DecodeAndMutate(mutateTrafficPermissions) +func mutateTrafficPermissions(res *DecodedTrafficPermissions) (bool, error) { var changed bool - for _, p := range tp.Permissions { + for _, p := range res.Data.Permissions { for _, s := range p.Sources { if updated := normalizedTenancyForSource(s, res.Id.Tenancy); updated { changed = true @@ -44,11 +42,7 @@ func MutateTrafficPermissions(res *pbresource.Resource) error { } } - if !changed { - return nil - } - - return res.Data.MarshalFrom(&tp) + return changed, nil } func normalizedTenancyForSource(src *pbauth.Source, parentTenancy *pbresource.Tenancy) bool { @@ -110,17 +104,13 @@ func firstNonEmptyString(a, b, c string) (string, bool) { return c, true } -func ValidateTrafficPermissions(res *pbresource.Resource) error { - var tp pbauth.TrafficPermissions - - if err := res.Data.UnmarshalTo(&tp); err != nil { - return resource.NewErrDataParse(&tp, err) - } +var ValidateTrafficPermissions = resource.DecodeAndValidate(validateTrafficPermissions) +func validateTrafficPermissions(res *DecodedTrafficPermissions) error { var merr error // enumcover:pbauth.Action - switch tp.Action { + switch res.Data.Action { case pbauth.Action_ACTION_ALLOW: case pbauth.Action_ACTION_DENY: case pbauth.Action_ACTION_UNSPECIFIED: @@ -132,14 +122,14 @@ func ValidateTrafficPermissions(res *pbresource.Resource) error { }) } - if tp.Destination == nil || (len(tp.Destination.IdentityName) == 0) { + if res.Data.Destination == nil || (len(res.Data.Destination.IdentityName) == 0) { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "data.destination", Wrapped: resource.ErrEmpty, }) } // Validate permissions - for i, permission := range tp.Permissions { + for i, permission := range res.Data.Permissions { wrapErr := func(err error) error { return resource.ErrInvalidListElement{ Name: "permissions", @@ -271,30 +261,10 @@ func isLocalPeer(p string) bool { return p == "local" || p == "" } -func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, _ *pbresource.ID, res *pbresource.Resource) error { - if res == nil { - return resource.ErrNeedResource - } - return authorizeDestination(res, func(dest string) error { - return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(dest, authzContext) - }) +func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error { + return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(res.Data.Destination.IdentityName, authzContext) } -func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error { - return authorizeDestination(res, func(dest string) error { - return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(dest, authzContext) - }) -} - -func authorizeDestination(res *pbresource.Resource, intentionAllowed func(string) error) error { - tp, err := resource.Decode[*pbauth.TrafficPermissions](res) - if err != nil { - return err - } - // Check intention:x permissions for identity - err = intentionAllowed(tp.Data.Destination.IdentityName) - if err != nil { - return err - } - return nil +func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error { + return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(res.Data.Destination.IdentityName, authzContext) } diff --git a/internal/auth/internal/types/workload_identity.go b/internal/auth/internal/types/workload_identity.go index 17334e66099e..a15fd5bf5b2d 100644 --- a/internal/auth/internal/types/workload_identity.go +++ b/internal/auth/internal/types/workload_identity.go @@ -10,6 +10,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedWorkloadIdentity = resource.DecodedResource[*pbauth.WorkloadIdentity] + func RegisterWorkloadIdentity(r resource.Registry) { r.Register(resource.Registration{ Type: pbauth.WorkloadIdentityType, @@ -20,10 +22,17 @@ func RegisterWorkloadIdentity(r resource.Registry) { Write: aclWriteHookWorkloadIdentity, List: resource.NoOpACLListHook, }, - Validate: nil, + Validate: ValidateWorkloadIdentity, }) } +var ValidateWorkloadIdentity = resource.DecodeAndValidate(validateWorkloadIdentity) + +func validateWorkloadIdentity(res *DecodedWorkloadIdentity) error { + // currently the WorkloadIdentity type has no fields. + return nil +} + func aclReadHookWorkloadIdentity( authorizer acl.Authorizer, authzCtx *acl.AuthorizerContext, diff --git a/internal/auth/internal/types/workload_identity_test.go b/internal/auth/internal/types/workload_identity_test.go index 8dfb22bc74a2..19ed4cbeea87 100644 --- a/internal/auth/internal/types/workload_identity_test.go +++ b/internal/auth/internal/types/workload_identity_test.go @@ -144,3 +144,13 @@ func TestWorkloadIdentityACLs(t *testing.T) { }) } } + +func TestWorkloadIdentity_ParseError(t *testing.T) { + rsc := resourcetest.Resource(pbauth.WorkloadIdentityType, "example"). + WithData(t, &pbauth.TrafficPermissions{}). + Build() + + err := ValidateWorkloadIdentity(rsc) + var parseErr resource.ErrDataParse + require.ErrorAs(t, err, &parseErr) +} diff --git a/internal/catalog/internal/types/acl_hooks.go b/internal/catalog/internal/types/acl_hooks.go index 8250767f7254..d9ddcb8e93cc 100644 --- a/internal/catalog/internal/types/acl_hooks.go +++ b/internal/catalog/internal/types/acl_hooks.go @@ -13,31 +13,22 @@ func aclReadHookResourceWithWorkloadSelector(authorizer acl.Authorizer, authzCon return authorizer.ToAllowAuthorizer().ServiceReadAllowed(id.GetName(), authzContext) } -func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error { - if res == nil { - return resource.ErrNeedResource - } - - decodedService, err := resource.Decode[T](res) - if err != nil { - return resource.ErrNeedResource - } - +func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, r *resource.DecodedResource[T]) error { // First check service:write on the name. - err = authorizer.ToAllowAuthorizer().ServiceWriteAllowed(res.GetId().GetName(), authzContext) + err := authorizer.ToAllowAuthorizer().ServiceWriteAllowed(r.GetId().GetName(), authzContext) if err != nil { return err } // Then also check whether we're allowed to select a service. - for _, name := range decodedService.GetData().GetWorkloads().GetNames() { + for _, name := range r.Data.GetWorkloads().GetNames() { err = authorizer.ToAllowAuthorizer().ServiceReadAllowed(name, authzContext) if err != nil { return err } } - for _, prefix := range decodedService.GetData().GetWorkloads().GetPrefixes() { + for _, prefix := range r.Data.GetWorkloads().GetPrefixes() { err = authorizer.ToAllowAuthorizer().ServiceReadPrefixAllowed(prefix, authzContext) if err != nil { return err @@ -50,7 +41,7 @@ func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer ac func ACLHooksForWorkloadSelectingType[T WorkloadSelecting]() *resource.ACLHooks { return &resource.ACLHooks{ Read: aclReadHookResourceWithWorkloadSelector, - Write: aclWriteHookResourceWithWorkloadSelector[T], + Write: resource.DecodeAndAuthorizeWrite(aclWriteHookResourceWithWorkloadSelector[T]), List: resource.NoOpACLListHook, } } diff --git a/internal/catalog/internal/types/dns_policy.go b/internal/catalog/internal/types/dns_policy.go index 8e9dd864a957..91dd2615455c 100644 --- a/internal/catalog/internal/types/dns_policy.go +++ b/internal/catalog/internal/types/dns_policy.go @@ -10,9 +10,10 @@ import ( "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedDNSPolicy = resource.DecodedResource[*pbcatalog.DNSPolicy] + func RegisterDNSPolicy(r resource.Registry) { r.Register(resource.Registration{ Type: pbcatalog.DNSPolicyType, @@ -23,17 +24,13 @@ func RegisterDNSPolicy(r resource.Registry) { }) } -func ValidateDNSPolicy(res *pbresource.Resource) error { - var policy pbcatalog.DNSPolicy - - if err := res.Data.UnmarshalTo(&policy); err != nil { - return resource.NewErrDataParse(&policy, err) - } +var ValidateDNSPolicy = resource.DecodeAndValidate(validateDNSPolicy) +func validateDNSPolicy(res *DecodedDNSPolicy) error { var err error // Ensure that this resource isn't useless and is attempting to // select at least one workload. - if selErr := ValidateSelector(policy.Workloads, false); selErr != nil { + if selErr := ValidateSelector(res.Data.Workloads, false); selErr != nil { err = multierror.Append(err, resource.ErrInvalidField{ Name: "workloads", Wrapped: selErr, @@ -41,7 +38,7 @@ func ValidateDNSPolicy(res *pbresource.Resource) error { } // Validate the weights - if weightErr := validateDNSPolicyWeights(policy.Weights); weightErr != nil { + if weightErr := validateDNSPolicyWeights(res.Data.Weights); weightErr != nil { err = multierror.Append(err, resource.ErrInvalidField{ Name: "weights", Wrapped: weightErr, diff --git a/internal/catalog/internal/types/failover_policy.go b/internal/catalog/internal/types/failover_policy.go index 047bb9a95b05..012150fc046d 100644 --- a/internal/catalog/internal/types/failover_policy.go +++ b/internal/catalog/internal/types/failover_policy.go @@ -15,6 +15,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedFailoverPolicy = resource.DecodedResource[*pbcatalog.FailoverPolicy] + func RegisterFailoverPolicy(r resource.Registry) { r.Register(resource.Registration{ Type: pbcatalog.FailoverPolicyType, @@ -24,36 +26,32 @@ func RegisterFailoverPolicy(r resource.Registry) { Validate: ValidateFailoverPolicy, ACLs: &resource.ACLHooks{ Read: aclReadHookFailoverPolicy, - Write: aclWriteHookFailoverPolicy, + Write: resource.DecodeAndAuthorizeWrite(aclWriteHookFailoverPolicy), List: resource.NoOpACLListHook, }, }) } -func MutateFailoverPolicy(res *pbresource.Resource) error { - var failover pbcatalog.FailoverPolicy - - if err := res.Data.UnmarshalTo(&failover); err != nil { - return resource.NewErrDataParse(&failover, err) - } +var MutateFailoverPolicy = resource.DecodeAndMutate(mutateFailoverPolicy) +func mutateFailoverPolicy(res *DecodedFailoverPolicy) (bool, error) { changed := false // Handle eliding empty configs. - if failover.Config != nil && failover.Config.IsEmpty() { - failover.Config = nil + if res.Data.Config != nil && res.Data.Config.IsEmpty() { + res.Data.Config = nil changed = true } - if failover.Config != nil { - if mutateFailoverConfig(res.Id.Tenancy, failover.Config) { + if res.Data.Config != nil { + if mutateFailoverConfig(res.Id.Tenancy, res.Data.Config) { changed = true } } - for port, pc := range failover.PortConfigs { + for port, pc := range res.Data.PortConfigs { if pc.IsEmpty() { - delete(failover.PortConfigs, port) + delete(res.Data.PortConfigs, port) changed = true } else { if mutateFailoverConfig(res.Id.Tenancy, pc) { @@ -61,16 +59,12 @@ func MutateFailoverPolicy(res *pbresource.Resource) error { } } } - if len(failover.PortConfigs) == 0 { - failover.PortConfigs = nil + if len(res.Data.PortConfigs) == 0 { + res.Data.PortConfigs = nil changed = true } - if !changed { - return nil - } - - return res.Data.MarshalFrom(&failover) + return changed, nil } func mutateFailoverConfig(policyTenancy *pbresource.Tenancy, config *pbcatalog.FailoverConfig) (changed bool) { @@ -109,35 +103,31 @@ func isLocalPeer(p string) bool { return p == "local" || p == "" } -func ValidateFailoverPolicy(res *pbresource.Resource) error { - var failover pbcatalog.FailoverPolicy - - if err := res.Data.UnmarshalTo(&failover); err != nil { - return resource.NewErrDataParse(&failover, err) - } +var ValidateFailoverPolicy = resource.DecodeAndValidate(validateFailoverPolicy) +func validateFailoverPolicy(res *DecodedFailoverPolicy) error { var merr error - if failover.Config == nil && len(failover.PortConfigs) == 0 { + if res.Data.Config == nil && len(res.Data.PortConfigs) == 0 { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "config", Wrapped: fmt.Errorf("at least one of config or port_configs must be set"), }) } - if failover.Config != nil { + if res.Data.Config != nil { wrapConfigErr := func(err error) error { return resource.ErrInvalidField{ Name: "config", Wrapped: err, } } - if cfgErr := validateFailoverConfig(failover.Config, false, wrapConfigErr); cfgErr != nil { + if cfgErr := validateFailoverConfig(res.Data.Config, false, wrapConfigErr); cfgErr != nil { merr = multierror.Append(merr, cfgErr) } } - for portName, pc := range failover.PortConfigs { + for portName, pc := range res.Data.PortConfigs { wrapConfigErr := func(err error) error { return resource.ErrInvalidMapValue{ Map: "port_configs", @@ -333,7 +323,7 @@ func aclReadHookFailoverPolicy(authorizer acl.Authorizer, authzContext *acl.Auth return authorizer.ToAllowAuthorizer().ServiceReadAllowed(serviceName, authzContext) } -func aclWriteHookFailoverPolicy(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error { +func aclWriteHookFailoverPolicy(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedFailoverPolicy) error { // FailoverPolicy is name-aligned with Service serviceName := res.Id.Name @@ -342,15 +332,10 @@ func aclWriteHookFailoverPolicy(authorizer acl.Authorizer, authzContext *acl.Aut return err } - dec, err := resource.Decode[*pbcatalog.FailoverPolicy](res) - if err != nil { - return err - } - // Ensure you have service:read on any destination that may be affected by // traffic FROM this config change. - if dec.Data.Config != nil { - for _, dest := range dec.Data.Config.Destinations { + if res.Data.Config != nil { + for _, dest := range res.Data.Config.Destinations { destAuthzContext := resource.AuthorizerContext(dest.Ref.GetTenancy()) destServiceName := dest.Ref.GetName() if err := authorizer.ToAllowAuthorizer().ServiceReadAllowed(destServiceName, destAuthzContext); err != nil { @@ -358,7 +343,7 @@ func aclWriteHookFailoverPolicy(authorizer acl.Authorizer, authzContext *acl.Aut } } } - for _, pc := range dec.Data.PortConfigs { + for _, pc := range res.Data.PortConfigs { for _, dest := range pc.Destinations { destAuthzContext := resource.AuthorizerContext(dest.Ref.GetTenancy()) destServiceName := dest.Ref.GetName() diff --git a/internal/catalog/internal/types/health_checks.go b/internal/catalog/internal/types/health_checks.go index 1333e2368d88..3d819e12885a 100644 --- a/internal/catalog/internal/types/health_checks.go +++ b/internal/catalog/internal/types/health_checks.go @@ -8,9 +8,10 @@ import ( "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedHealthChecks = resource.DecodedResource[*pbcatalog.HealthChecks] + func RegisterHealthChecks(r resource.Registry) { r.Register(resource.Registration{ Type: pbcatalog.HealthChecksType, @@ -21,17 +22,13 @@ func RegisterHealthChecks(r resource.Registry) { }) } -func ValidateHealthChecks(res *pbresource.Resource) error { - var checks pbcatalog.HealthChecks - - if err := res.Data.UnmarshalTo(&checks); err != nil { - return resource.NewErrDataParse(&checks, err) - } +var ValidateHealthChecks = resource.DecodeAndValidate(validateHealthChecks) +func validateHealthChecks(res *DecodedHealthChecks) error { var err error // Validate the workload selector - if selErr := ValidateSelector(checks.Workloads, false); selErr != nil { + if selErr := ValidateSelector(res.Data.Workloads, false); selErr != nil { err = multierror.Append(err, resource.ErrInvalidField{ Name: "workloads", Wrapped: selErr, @@ -39,7 +36,7 @@ func ValidateHealthChecks(res *pbresource.Resource) error { } // Validate each check - for idx, check := range checks.HealthChecks { + for idx, check := range res.Data.HealthChecks { if checkErr := validateCheck(check); checkErr != nil { err = multierror.Append(err, resource.ErrInvalidListElement{ Name: "checks", diff --git a/internal/catalog/internal/types/health_status.go b/internal/catalog/internal/types/health_status.go index fe92e858b025..c5ea7e106fa3 100644 --- a/internal/catalog/internal/types/health_status.go +++ b/internal/catalog/internal/types/health_status.go @@ -12,6 +12,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedHealthStatus = resource.DecodedResource[*pbcatalog.HealthStatus] + func RegisterHealthStatus(r resource.Registry) { r.Register(resource.Registration{ Type: pbcatalog.HealthStatusType, @@ -19,33 +21,29 @@ func RegisterHealthStatus(r resource.Registry) { Scope: resource.ScopeNamespace, Validate: ValidateHealthStatus, ACLs: &resource.ACLHooks{ - Read: aclReadHookHealthStatus, + Read: resource.AuthorizeReadWithResource(aclReadHookHealthStatus), Write: aclWriteHookHealthStatus, List: resource.NoOpACLListHook, }, }) } -func ValidateHealthStatus(res *pbresource.Resource) error { - var hs pbcatalog.HealthStatus - - if err := res.Data.UnmarshalTo(&hs); err != nil { - return resource.NewErrDataParse(&hs, err) - } +var ValidateHealthStatus = resource.DecodeAndValidate(validateHealthStatus) +func validateHealthStatus(res *DecodedHealthStatus) error { var err error // Should we allow empty types? I think for now it will be safest to require // the type field is set and we can relax this restriction in the future // if we deem it desirable. - if hs.Type == "" { + if res.Data.Type == "" { err = multierror.Append(err, resource.ErrInvalidField{ Name: "type", Wrapped: resource.ErrMissing, }) } - switch hs.Status { + switch res.Data.Status { case pbcatalog.Health_HEALTH_PASSING, pbcatalog.Health_HEALTH_WARNING, pbcatalog.Health_HEALTH_CRITICAL, @@ -61,7 +59,7 @@ func ValidateHealthStatus(res *pbresource.Resource) error { // owner is currently the resource that this HealthStatus applies to. If we // change this to be a parent reference within the HealthStatus.Data then // we could allow for other owners. - if res.Owner == nil { + if res.Resource.Owner == nil { err = multierror.Append(err, resource.ErrInvalidField{ Name: "owner", Wrapped: resource.ErrMissing, @@ -73,15 +71,13 @@ func ValidateHealthStatus(res *pbresource.Resource) error { return err } -func aclReadHookHealthStatus(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, _ *pbresource.ID, res *pbresource.Resource) error { - if res == nil { - return resource.ErrNeedResource - } +func aclReadHookHealthStatus(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error { // For a health status of a workload we need to check service:read perms. if res.GetOwner() != nil && resource.EqualType(res.GetOwner().GetType(), pbcatalog.WorkloadType) { return authorizer.ToAllowAuthorizer().ServiceReadAllowed(res.GetOwner().GetName(), authzContext) } + // For a health status of a node we need to check node:read perms. if res.GetOwner() != nil && resource.EqualType(res.GetOwner().GetType(), pbcatalog.NodeType) { return authorizer.ToAllowAuthorizer().NodeReadAllowed(res.GetOwner().GetName(), authzContext) } @@ -95,6 +91,7 @@ func aclWriteHookHealthStatus(authorizer acl.Authorizer, authzContext *acl.Autho return authorizer.ToAllowAuthorizer().ServiceWriteAllowed(res.GetOwner().GetName(), authzContext) } + // For a health status of a node we need to check node:write perms. if res.GetOwner() != nil && resource.EqualType(res.GetOwner().GetType(), pbcatalog.NodeType) { return authorizer.ToAllowAuthorizer().NodeWriteAllowed(res.GetOwner().GetName(), authzContext) } diff --git a/internal/catalog/internal/types/node.go b/internal/catalog/internal/types/node.go index 42ac833c6e7d..1ee68f22ca82 100644 --- a/internal/catalog/internal/types/node.go +++ b/internal/catalog/internal/types/node.go @@ -12,6 +12,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedNode = resource.DecodedResource[*pbcatalog.Node] + func RegisterNode(r resource.Registry) { r.Register(resource.Registration{ Type: pbcatalog.NodeType, @@ -31,16 +33,12 @@ func RegisterNode(r resource.Registry) { }) } -func ValidateNode(res *pbresource.Resource) error { - var node pbcatalog.Node - - if err := res.Data.UnmarshalTo(&node); err != nil { - return resource.NewErrDataParse(&node, err) - } +var ValidateNode = resource.DecodeAndValidate(validateNode) +func validateNode(res *DecodedNode) error { var err error // Validate that the node has at least 1 address - if len(node.Addresses) < 1 { + if len(res.Data.Addresses) < 1 { err = multierror.Append(err, resource.ErrInvalidField{ Name: "addresses", Wrapped: resource.ErrEmpty, @@ -48,7 +46,7 @@ func ValidateNode(res *pbresource.Resource) error { } // Validate each node address - for idx, addr := range node.Addresses { + for idx, addr := range res.Data.Addresses { if addrErr := validateNodeAddress(addr); addrErr != nil { err = multierror.Append(err, resource.ErrInvalidListElement{ Name: "addresses", diff --git a/internal/catalog/internal/types/service.go b/internal/catalog/internal/types/service.go index ad9351f0d54f..131e4042d72f 100644 --- a/internal/catalog/internal/types/service.go +++ b/internal/catalog/internal/types/service.go @@ -10,9 +10,10 @@ import ( "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedService = resource.DecodedResource[*pbcatalog.Service] + func RegisterService(r resource.Registry) { r.Register(resource.Registration{ Type: pbcatalog.ServiceType, @@ -24,37 +25,25 @@ func RegisterService(r resource.Registry) { }) } -func MutateService(res *pbresource.Resource) error { - var service pbcatalog.Service - - if err := res.Data.UnmarshalTo(&service); err != nil { - return err - } +var MutateService = resource.DecodeAndMutate(mutateService) +func mutateService(res *DecodedService) (bool, error) { changed := false // Default service port protocols. - for _, port := range service.Ports { + for _, port := range res.Data.Ports { if port.Protocol == pbcatalog.Protocol_PROTOCOL_UNSPECIFIED { port.Protocol = pbcatalog.Protocol_PROTOCOL_TCP changed = true } } - if !changed { - return nil - } - - return res.Data.MarshalFrom(&service) + return changed, nil } -func ValidateService(res *pbresource.Resource) error { - var service pbcatalog.Service - - if err := res.Data.UnmarshalTo(&service); err != nil { - return resource.NewErrDataParse(&service, err) - } +var ValidateService = resource.DecodeAndValidate(validateService) +func validateService(res *DecodedService) error { var err error // Validate the workload selector. We are allowing selectors with no @@ -62,7 +51,7 @@ func ValidateService(res *pbresource.Resource) error { // ServiceEndpoints objects for this service such as when desiring to // configure endpoint information for external services that are not // registered as workloads - if selErr := ValidateSelector(service.Workloads, true); selErr != nil { + if selErr := ValidateSelector(res.Data.Workloads, true); selErr != nil { err = multierror.Append(err, resource.ErrInvalidField{ Name: "workloads", Wrapped: selErr, @@ -72,7 +61,7 @@ func ValidateService(res *pbresource.Resource) error { usedVirtualPorts := make(map[uint32]int) // Validate each port - for idx, port := range service.Ports { + for idx, port := range res.Data.Ports { if usedIdx, found := usedVirtualPorts[port.VirtualPort]; found { err = multierror.Append(err, resource.ErrInvalidListElement{ Name: "ports", @@ -130,7 +119,7 @@ func ValidateService(res *pbresource.Resource) error { } // Validate that the Virtual IPs are all IP addresses - for idx, vip := range service.VirtualIps { + for idx, vip := range res.Data.VirtualIps { if vipErr := validateIPAddress(vip); vipErr != nil { err = multierror.Append(err, resource.ErrInvalidListElement{ Name: "virtual_ips", diff --git a/internal/catalog/internal/types/service_endpoints.go b/internal/catalog/internal/types/service_endpoints.go index 14f055fcba77..c540d88d75d2 100644 --- a/internal/catalog/internal/types/service_endpoints.go +++ b/internal/catalog/internal/types/service_endpoints.go @@ -14,6 +14,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedServiceEndpoints = resource.DecodedResource[*pbcatalog.ServiceEndpoints] + func RegisterServiceEndpoints(r resource.Registry) { r.Register(resource.Registration{ Type: pbcatalog.ServiceEndpointsType, @@ -45,13 +47,9 @@ func MutateServiceEndpoints(res *pbresource.Resource) error { return nil } -func ValidateServiceEndpoints(res *pbresource.Resource) error { - var svcEndpoints pbcatalog.ServiceEndpoints - - if err := res.Data.UnmarshalTo(&svcEndpoints); err != nil { - return resource.NewErrDataParse(&svcEndpoints, err) - } +var ValidateServiceEndpoints = resource.DecodeAndValidate[*pbcatalog.ServiceEndpoints](validateServiceEndpoints) +func validateServiceEndpoints(res *DecodedServiceEndpoints) error { var err error if !resource.EqualType(res.Owner.Type, pbcatalog.ServiceType) { err = multierror.Append(err, resource.ErrOwnerTypeInvalid{ @@ -78,8 +76,8 @@ func ValidateServiceEndpoints(res *pbresource.Resource) error { }) } - for idx, endpoint := range svcEndpoints.Endpoints { - if endpointErr := validateEndpoint(endpoint, res); endpointErr != nil { + for idx, endpoint := range res.Data.Endpoints { + if endpointErr := validateEndpoint(endpoint, res.Resource); endpointErr != nil { err = multierror.Append(err, resource.ErrInvalidListElement{ Name: "endpoints", Index: idx, diff --git a/internal/catalog/internal/types/virtual_ips.go b/internal/catalog/internal/types/virtual_ips.go index 9c7a06547405..be692f63ed65 100644 --- a/internal/catalog/internal/types/virtual_ips.go +++ b/internal/catalog/internal/types/virtual_ips.go @@ -12,6 +12,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedVirtualIPs = resource.DecodedResource[*pbcatalog.VirtualIPs] + func RegisterVirtualIPs(r resource.Registry) { r.Register(resource.Registration{ Type: pbcatalog.VirtualIPsType, @@ -30,15 +32,11 @@ func RegisterVirtualIPs(r resource.Registry) { }) } -func ValidateVirtualIPs(res *pbresource.Resource) error { - var vips pbcatalog.VirtualIPs - - if err := res.Data.UnmarshalTo(&vips); err != nil { - return resource.NewErrDataParse(&vips, err) - } +var ValidateVirtualIPs = resource.DecodeAndValidate(validateVirtualIPs) +func validateVirtualIPs(res *DecodedVirtualIPs) error { var err error - for idx, ip := range vips.Ips { + for idx, ip := range res.Data.Ips { if vipErr := validateIPAddress(ip.Address); vipErr != nil { err = multierror.Append(err, resource.ErrInvalidListElement{ Name: "ips", diff --git a/internal/catalog/internal/types/workload.go b/internal/catalog/internal/types/workload.go index db0175d46d1a..1c02928fcd95 100644 --- a/internal/catalog/internal/types/workload.go +++ b/internal/catalog/internal/types/workload.go @@ -15,6 +15,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type DecodedWorkload = resource.DecodedResource[*pbcatalog.Workload] + func RegisterWorkload(r resource.Registry) { r.Register(resource.Registration{ Type: pbcatalog.WorkloadType, @@ -23,23 +25,19 @@ func RegisterWorkload(r resource.Registry) { Validate: ValidateWorkload, ACLs: &resource.ACLHooks{ Read: aclReadHookWorkload, - Write: aclWriteHookWorkload, + Write: resource.DecodeAndAuthorizeWrite(aclWriteHookWorkload), List: resource.NoOpACLListHook, }, }) } -func ValidateWorkload(res *pbresource.Resource) error { - var workload pbcatalog.Workload - - if err := res.Data.UnmarshalTo(&workload); err != nil { - return resource.NewErrDataParse(&workload, err) - } +var ValidateWorkload = resource.DecodeAndValidate(validateWorkload) +func validateWorkload(res *DecodedWorkload) error { var err error // Validate that the workload has at least one port - if len(workload.Ports) < 1 { + if len(res.Data.Ports) < 1 { err = multierror.Append(err, resource.ErrInvalidField{ Name: "ports", Wrapped: resource.ErrEmpty, @@ -49,7 +47,7 @@ func ValidateWorkload(res *pbresource.Resource) error { var meshPorts []string // Validate the Workload Ports - for portName, port := range workload.Ports { + for portName, port := range res.Data.Ports { if portNameErr := ValidatePortName(portName); portNameErr != nil { err = multierror.Append(err, resource.ErrInvalidMapKey{ Map: "ports", @@ -100,12 +98,12 @@ func ValidateWorkload(res *pbresource.Resource) error { // If the workload is mesh enabled then a valid identity must be provided. // If not mesh enabled but a non-empty identity is provided then we still // validate that its valid. - if len(meshPorts) > 0 && workload.Identity == "" { + if len(meshPorts) > 0 && res.Data.Identity == "" { err = multierror.Append(err, resource.ErrInvalidField{ Name: "identity", Wrapped: resource.ErrMissing, }) - } else if workload.Identity != "" && !isValidDNSLabel(workload.Identity) { + } else if res.Data.Identity != "" && !isValidDNSLabel(res.Data.Identity) { err = multierror.Append(err, resource.ErrInvalidField{ Name: "identity", Wrapped: errNotDNSLabel, @@ -113,7 +111,7 @@ func ValidateWorkload(res *pbresource.Resource) error { } // Validate workload locality - if workload.Locality != nil && workload.Locality.Region == "" && workload.Locality.Zone != "" { + if res.Data.Locality != nil && res.Data.Locality.Region == "" && res.Data.Locality.Zone != "" { err = multierror.Append(err, resource.ErrInvalidField{ Name: "locality", Wrapped: errLocalityZoneNoRegion, @@ -122,8 +120,8 @@ func ValidateWorkload(res *pbresource.Resource) error { // Node associations are optional but if present the name should // be a valid DNS label. - if workload.NodeName != "" { - if !isValidDNSLabel(workload.NodeName) { + if res.Data.NodeName != "" { + if !isValidDNSLabel(res.Data.NodeName) { err = multierror.Append(err, resource.ErrInvalidField{ Name: "node_name", Wrapped: errNotDNSLabel, @@ -131,7 +129,7 @@ func ValidateWorkload(res *pbresource.Resource) error { } } - if len(workload.Addresses) < 1 { + if len(res.Data.Addresses) < 1 { err = multierror.Append(err, resource.ErrInvalidField{ Name: "addresses", Wrapped: resource.ErrEmpty, @@ -139,8 +137,8 @@ func ValidateWorkload(res *pbresource.Resource) error { } // Validate Workload Addresses - for idx, addr := range workload.Addresses { - if addrErr := validateWorkloadAddress(addr, workload.Ports); addrErr != nil { + for idx, addr := range res.Data.Addresses { + if addrErr := validateWorkloadAddress(addr, res.Data.Ports); addrErr != nil { err = multierror.Append(err, resource.ErrInvalidListElement{ Name: "addresses", Index: idx, @@ -156,26 +154,21 @@ func aclReadHookWorkload(authorizer acl.Authorizer, authzContext *acl.Authorizer return authorizer.ToAllowAuthorizer().ServiceReadAllowed(id.GetName(), authzContext) } -func aclWriteHookWorkload(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error { - decodedWorkload, err := resource.Decode[*pbcatalog.Workload](res) - if err != nil { - return resource.ErrNeedResource - } - +func aclWriteHookWorkload(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedWorkload) error { // First check service:write on the workload name. - err = authorizer.ToAllowAuthorizer().ServiceWriteAllowed(res.GetId().GetName(), authzContext) + err := authorizer.ToAllowAuthorizer().ServiceWriteAllowed(res.GetId().GetName(), authzContext) if err != nil { return err } // Check node:read permissions if node is specified. - if decodedWorkload.GetData().GetNodeName() != "" { - return authorizer.ToAllowAuthorizer().NodeReadAllowed(decodedWorkload.GetData().GetNodeName(), authzContext) + if res.Data.GetNodeName() != "" { + return authorizer.ToAllowAuthorizer().NodeReadAllowed(res.Data.GetNodeName(), authzContext) } // Check identity:read permissions if identity is specified. - if decodedWorkload.GetData().GetIdentity() != "" { - return authorizer.ToAllowAuthorizer().IdentityReadAllowed(decodedWorkload.GetData().GetIdentity(), authzContext) + if res.Data.GetIdentity() != "" { + return authorizer.ToAllowAuthorizer().IdentityReadAllowed(res.Data.GetIdentity(), authzContext) } return nil diff --git a/internal/mesh/internal/types/computed_routes.go b/internal/mesh/internal/types/computed_routes.go index 1f66cc97ac21..b572c01fc376 100644 --- a/internal/mesh/internal/types/computed_routes.go +++ b/internal/mesh/internal/types/computed_routes.go @@ -11,7 +11,6 @@ import ( "github.com/hashicorp/consul/internal/resource" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" ) const ( @@ -30,16 +29,12 @@ func RegisterComputedRoutes(r resource.Registry) { }) } -func ValidateComputedRoutes(res *pbresource.Resource) error { - var config pbmesh.ComputedRoutes - - if err := res.Data.UnmarshalTo(&config); err != nil { - return resource.NewErrDataParse(&config, err) - } +var ValidateComputedRoutes = resource.DecodeAndValidate(validateComputedRoutes) +func validateComputedRoutes(res *DecodedComputedRoutes) error { var merr error - if len(config.PortedConfigs) == 0 { + if len(res.Data.PortedConfigs) == 0 { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "ported_configs", Wrapped: resource.ErrEmpty, @@ -48,7 +43,7 @@ func ValidateComputedRoutes(res *pbresource.Resource) error { // TODO(rb): do more elaborate validation - for port, pmc := range config.PortedConfigs { + for port, pmc := range res.Data.PortedConfigs { wrapErr := func(err error) error { return resource.ErrInvalidMapValue{ Map: "ported_configs", diff --git a/internal/mesh/internal/types/decoded.go b/internal/mesh/internal/types/decoded.go index ee1244fdcb1b..be4836c066ff 100644 --- a/internal/mesh/internal/types/decoded.go +++ b/internal/mesh/internal/types/decoded.go @@ -15,6 +15,7 @@ type ( DecodedGRPCRoute = resource.DecodedResource[*pbmesh.GRPCRoute] DecodedTCPRoute = resource.DecodedResource[*pbmesh.TCPRoute] DecodedDestinationPolicy = resource.DecodedResource[*pbmesh.DestinationPolicy] + DecodedDestinationsConfiguration = resource.DecodedResource[*pbmesh.DestinationsConfiguration] DecodedComputedRoutes = resource.DecodedResource[*pbmesh.ComputedRoutes] DecodedComputedTrafficPermissions = resource.DecodedResource[*pbauth.ComputedTrafficPermissions] DecodedFailoverPolicy = resource.DecodedResource[*pbcatalog.FailoverPolicy] diff --git a/internal/mesh/internal/types/destination_policy.go b/internal/mesh/internal/types/destination_policy.go index 68b37345baf3..4fe3062367cf 100644 --- a/internal/mesh/internal/types/destination_policy.go +++ b/internal/mesh/internal/types/destination_policy.go @@ -29,23 +29,19 @@ func RegisterDestinationPolicy(r resource.Registry) { }) } -func ValidateDestinationPolicy(res *pbresource.Resource) error { - var policy pbmesh.DestinationPolicy - - if err := res.Data.UnmarshalTo(&policy); err != nil { - return resource.NewErrDataParse(&policy, err) - } +var ValidateDestinationPolicy = resource.DecodeAndValidate(validateDestinationPolicy) +func validateDestinationPolicy(res *DecodedDestinationPolicy) error { var merr error - if len(policy.PortConfigs) == 0 { + if len(res.Data.PortConfigs) == 0 { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "port_configs", Wrapped: resource.ErrEmpty, }) } - for port, pc := range policy.PortConfigs { + for port, pc := range res.Data.PortConfigs { wrapErr := func(err error) error { return resource.ErrInvalidMapValue{ Map: "port_configs", diff --git a/internal/mesh/internal/types/destinations.go b/internal/mesh/internal/types/destinations.go index 34287e627ab7..7de3011e3ef0 100644 --- a/internal/mesh/internal/types/destinations.go +++ b/internal/mesh/internal/types/destinations.go @@ -26,16 +26,12 @@ func RegisterDestinations(r resource.Registry) { }) } -func MutateDestinations(res *pbresource.Resource) error { - var destinations pbmesh.Destinations - - if err := res.Data.UnmarshalTo(&destinations); err != nil { - return resource.NewErrDataParse(&destinations, err) - } +var MutateDestinations = resource.DecodeAndMutate(mutateDestinations) +func mutateDestinations(res *DecodedDestinations) (bool, error) { changed := false - for _, dest := range destinations.Destinations { + for _, dest := range res.Data.Destinations { if dest.DestinationRef == nil { continue // skip; let the validation hook error out instead } @@ -56,41 +52,33 @@ func MutateDestinations(res *pbresource.Resource) error { } } - if !changed { - return nil - } - - return res.Data.MarshalFrom(&destinations) + return changed, nil } func isLocalPeer(p string) bool { return p == "local" || p == "" } -func ValidateDestinations(res *pbresource.Resource) error { - var destinations pbmesh.Destinations - - if err := res.Data.UnmarshalTo(&destinations); err != nil { - return resource.NewErrDataParse(&destinations, err) - } +var ValidateDestinations = resource.DecodeAndValidate(validateDestinations) +func validateDestinations(res *DecodedDestinations) error { var merr error - if selErr := catalog.ValidateSelector(destinations.Workloads, false); selErr != nil { + if selErr := catalog.ValidateSelector(res.Data.Workloads, false); selErr != nil { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "workloads", Wrapped: selErr, }) } - if destinations.GetPqDestinations() != nil { + if res.Data.GetPqDestinations() != nil { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "pq_destinations", Wrapped: resource.ErrUnsupported, }) } - for i, dest := range destinations.Destinations { + for i, dest := range res.Data.Destinations { wrapDestErr := func(err error) error { return resource.ErrInvalidListElement{ Name: "destinations", diff --git a/internal/mesh/internal/types/destinations_configuration.go b/internal/mesh/internal/types/destinations_configuration.go index fedbe40df48c..7d46d93ed999 100644 --- a/internal/mesh/internal/types/destinations_configuration.go +++ b/internal/mesh/internal/types/destinations_configuration.go @@ -10,7 +10,6 @@ import ( "github.com/hashicorp/consul/internal/resource" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" ) func RegisterDestinationsConfiguration(r resource.Registry) { @@ -23,17 +22,13 @@ func RegisterDestinationsConfiguration(r resource.Registry) { }) } -func ValidateDestinationsConfiguration(res *pbresource.Resource) error { - var cfg pbmesh.DestinationsConfiguration - - if err := res.Data.UnmarshalTo(&cfg); err != nil { - return resource.NewErrDataParse(&cfg, err) - } +var ValidateDestinationsConfiguration = resource.DecodeAndValidate(validateDestinationsConfiguration) +func validateDestinationsConfiguration(res *DecodedDestinationsConfiguration) error { var merr error // Validate the workload selector - if selErr := catalog.ValidateSelector(cfg.Workloads, false); selErr != nil { + if selErr := catalog.ValidateSelector(res.Data.Workloads, false); selErr != nil { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "workloads", Wrapped: selErr, diff --git a/internal/mesh/internal/types/grpc_route.go b/internal/mesh/internal/types/grpc_route.go index 630e416e611c..b861abccdc05 100644 --- a/internal/mesh/internal/types/grpc_route.go +++ b/internal/mesh/internal/types/grpc_route.go @@ -11,7 +11,6 @@ import ( "github.com/hashicorp/consul/internal/resource" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" ) func RegisterGRPCRoute(r resource.Registry) { @@ -25,20 +24,16 @@ func RegisterGRPCRoute(r resource.Registry) { }) } -func MutateGRPCRoute(res *pbresource.Resource) error { - var route pbmesh.GRPCRoute - - if err := res.Data.UnmarshalTo(&route); err != nil { - return resource.NewErrDataParse(&route, err) - } +var MutateGRPCRoute = resource.DecodeAndMutate(mutateGRPCRoute) +func mutateGRPCRoute(res *DecodedGRPCRoute) (bool, error) { changed := false - if mutateParentRefs(res.Id.Tenancy, route.ParentRefs) { + if mutateParentRefs(res.Id.Tenancy, res.Data.ParentRefs) { changed = true } - for _, rule := range route.Rules { + for _, rule := range res.Data.Rules { for _, backend := range rule.BackendRefs { if backend.BackendRef == nil || backend.BackendRef.Ref == nil { continue @@ -49,33 +44,25 @@ func MutateGRPCRoute(res *pbresource.Resource) error { } } - if !changed { - return nil - } - - return res.Data.MarshalFrom(&route) + return changed, nil } -func ValidateGRPCRoute(res *pbresource.Resource) error { - var route pbmesh.GRPCRoute - - if err := res.Data.UnmarshalTo(&route); err != nil { - return resource.NewErrDataParse(&route, err) - } +var ValidateGRPCRoute = resource.DecodeAndValidate(validateGRPCRoute) +func validateGRPCRoute(res *DecodedGRPCRoute) error { var merr error - if err := validateParentRefs(res.Id, route.ParentRefs); err != nil { + if err := validateParentRefs(res.Id, res.Data.ParentRefs); err != nil { merr = multierror.Append(merr, err) } - if len(route.Hostnames) > 0 { + if len(res.Data.Hostnames) > 0 { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "hostnames", Wrapped: errors.New("should not populate hostnames"), }) } - for i, rule := range route.Rules { + for i, rule := range res.Data.Rules { wrapRuleErr := func(err error) error { return resource.ErrInvalidListElement{ Name: "rules", diff --git a/internal/mesh/internal/types/http_route.go b/internal/mesh/internal/types/http_route.go index 0ac2dcbf5c8b..d32f55dc6cc8 100644 --- a/internal/mesh/internal/types/http_route.go +++ b/internal/mesh/internal/types/http_route.go @@ -13,7 +13,6 @@ import ( "github.com/hashicorp/consul/internal/resource" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" ) func RegisterHTTPRoute(r resource.Registry) { @@ -27,20 +26,16 @@ func RegisterHTTPRoute(r resource.Registry) { }) } -func MutateHTTPRoute(res *pbresource.Resource) error { - var route pbmesh.HTTPRoute - - if err := res.Data.UnmarshalTo(&route); err != nil { - return resource.NewErrDataParse(&route, err) - } +var MutateHTTPRoute = resource.DecodeAndMutate(mutateHTTPRoute) +func mutateHTTPRoute(res *DecodedHTTPRoute) (bool, error) { changed := false - if mutateParentRefs(res.Id.Tenancy, route.ParentRefs) { + if mutateParentRefs(res.Id.Tenancy, res.Data.ParentRefs) { changed = true } - for _, rule := range route.Rules { + for _, rule := range res.Data.Rules { for _, match := range rule.Matches { if match.Method != "" { norm := strings.ToUpper(match.Method) @@ -60,33 +55,25 @@ func MutateHTTPRoute(res *pbresource.Resource) error { } } - if !changed { - return nil - } - - return res.Data.MarshalFrom(&route) + return changed, nil } -func ValidateHTTPRoute(res *pbresource.Resource) error { - var route pbmesh.HTTPRoute - - if err := res.Data.UnmarshalTo(&route); err != nil { - return resource.NewErrDataParse(&route, err) - } +var ValidateHTTPRoute = resource.DecodeAndValidate(validateHTTPRoute) +func validateHTTPRoute(res *DecodedHTTPRoute) error { var merr error - if err := validateParentRefs(res.Id, route.ParentRefs); err != nil { + if err := validateParentRefs(res.Id, res.Data.ParentRefs); err != nil { merr = multierror.Append(merr, err) } - if len(route.Hostnames) > 0 { + if len(res.Data.Hostnames) > 0 { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "hostnames", Wrapped: errors.New("should not populate hostnames"), }) } - for i, rule := range route.Rules { + for i, rule := range res.Data.Rules { wrapRuleErr := func(err error) error { return resource.ErrInvalidListElement{ Name: "rules", diff --git a/internal/mesh/internal/types/proxy_configuration.go b/internal/mesh/internal/types/proxy_configuration.go index 081324d72167..9a4388a40f01 100644 --- a/internal/mesh/internal/types/proxy_configuration.go +++ b/internal/mesh/internal/types/proxy_configuration.go @@ -12,7 +12,6 @@ import ( "github.com/hashicorp/consul/internal/resource" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" "github.com/hashicorp/consul/sdk/iptables" ) @@ -27,52 +26,40 @@ func RegisterProxyConfiguration(r resource.Registry) { }) } -func MutateProxyConfiguration(res *pbresource.Resource) error { - var proxyCfg pbmesh.ProxyConfiguration - err := res.Data.UnmarshalTo(&proxyCfg) - if err != nil { - return resource.NewErrDataParse(&proxyCfg, err) - } +var MutateProxyConfiguration = resource.DecodeAndMutate(mutateProxyConfiguration) +func mutateProxyConfiguration(res *DecodedProxyConfiguration) (bool, error) { changed := false // Default the tproxy outbound port. - if proxyCfg.IsTransparentProxy() { - if proxyCfg.GetDynamicConfig().GetTransparentProxy() == nil { - proxyCfg.DynamicConfig.TransparentProxy = &pbmesh.TransparentProxy{ + if res.Data.IsTransparentProxy() { + if res.Data.GetDynamicConfig().GetTransparentProxy() == nil { + res.Data.DynamicConfig.TransparentProxy = &pbmesh.TransparentProxy{ OutboundListenerPort: iptables.DefaultTProxyOutboundPort, } changed = true - } else if proxyCfg.GetDynamicConfig().GetTransparentProxy().OutboundListenerPort == 0 { - proxyCfg.DynamicConfig.TransparentProxy.OutboundListenerPort = iptables.DefaultTProxyOutboundPort + } else if res.Data.GetDynamicConfig().GetTransparentProxy().OutboundListenerPort == 0 { + res.Data.DynamicConfig.TransparentProxy.OutboundListenerPort = iptables.DefaultTProxyOutboundPort changed = true } } - if !changed { - return nil - } - - return res.Data.MarshalFrom(&proxyCfg) + return changed, nil } -func ValidateProxyConfiguration(res *pbresource.Resource) error { - decodedProxyCfg, decodeErr := resource.Decode[*pbmesh.ProxyConfiguration](res) - if decodeErr != nil { - return resource.NewErrDataParse(decodedProxyCfg.GetData(), decodeErr) - } - proxyCfg := decodedProxyCfg.GetData() +var ValidateProxyConfiguration = resource.DecodeAndValidate(validateProxyConfiguration) +func validateProxyConfiguration(res *DecodedProxyConfiguration) error { var err error - if selErr := catalog.ValidateSelector(proxyCfg.Workloads, false); selErr != nil { + if selErr := catalog.ValidateSelector(res.Data.Workloads, false); selErr != nil { err = multierror.Append(err, resource.ErrInvalidField{ Name: "workloads", Wrapped: selErr, }) } - if proxyCfg.GetDynamicConfig() == nil && proxyCfg.GetBootstrapConfig() == nil { + if res.Data.GetDynamicConfig() == nil && res.Data.GetBootstrapConfig() == nil { err = multierror.Append(err, resource.ErrInvalidFields{ Names: []string{"dynamic_config", "bootstrap_config"}, Wrapped: errMissingProxyConfigData, @@ -80,14 +67,14 @@ func ValidateProxyConfiguration(res *pbresource.Resource) error { } // nolint:staticcheck - if proxyCfg.GetOpaqueConfig() != nil { + if res.Data.GetOpaqueConfig() != nil { err = multierror.Append(err, resource.ErrInvalidField{ Name: "opaque_config", Wrapped: resource.ErrUnsupported, }) } - if dynamicCfgErr := validateDynamicProxyConfiguration(proxyCfg.GetDynamicConfig()); dynamicCfgErr != nil { + if dynamicCfgErr := validateDynamicProxyConfiguration(res.Data.GetDynamicConfig()); dynamicCfgErr != nil { err = multierror.Append(err, resource.ErrInvalidField{ Name: "dynamic_config", Wrapped: dynamicCfgErr, diff --git a/internal/mesh/internal/types/proxy_state_template.go b/internal/mesh/internal/types/proxy_state_template.go index b84d0e9b45cb..1fd2e5c52562 100644 --- a/internal/mesh/internal/types/proxy_state_template.go +++ b/internal/mesh/internal/types/proxy_state_template.go @@ -49,25 +49,21 @@ func RegisterProxyStateTemplate(r resource.Registry) { }) } -func ValidateProxyStateTemplate(res *pbresource.Resource) error { - // TODO(v2): validate a lot more of this - - var pst pbmesh.ProxyStateTemplate +var ValidateProxyStateTemplate = resource.DecodeAndValidate(validateProxyStateTemplate) - if err := res.Data.UnmarshalTo(&pst); err != nil { - return resource.NewErrDataParse(&pst, err) - } +func validateProxyStateTemplate(res *DecodedProxyStateTemplate) error { + // TODO(v2): validate a lot more of this var merr error - if pst.ProxyState != nil { + if res.Data.ProxyState != nil { wrapProxyStateErr := func(err error) error { return resource.ErrInvalidField{ Name: "proxy_state", Wrapped: err, } } - for name, cluster := range pst.ProxyState.Clusters { + for name, cluster := range res.Data.ProxyState.Clusters { if name == "" { merr = multierror.Append(merr, wrapProxyStateErr(resource.ErrInvalidMapKey{ Map: "clusters", diff --git a/internal/mesh/internal/types/tcp_route.go b/internal/mesh/internal/types/tcp_route.go index c7470b14d55a..02dd5aaa10fd 100644 --- a/internal/mesh/internal/types/tcp_route.go +++ b/internal/mesh/internal/types/tcp_route.go @@ -10,7 +10,6 @@ import ( "github.com/hashicorp/consul/internal/resource" pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" ) func RegisterTCPRoute(r resource.Registry) { @@ -24,20 +23,16 @@ func RegisterTCPRoute(r resource.Registry) { }) } -func MutateTCPRoute(res *pbresource.Resource) error { - var route pbmesh.TCPRoute - - if err := res.Data.UnmarshalTo(&route); err != nil { - return resource.NewErrDataParse(&route, err) - } +var MutateTCPRoute = resource.DecodeAndMutate(mutateTCPRoute) +func mutateTCPRoute(res *DecodedTCPRoute) (bool, error) { changed := false - if mutateParentRefs(res.Id.Tenancy, route.ParentRefs) { + if mutateParentRefs(res.Id.Tenancy, res.Data.ParentRefs) { changed = true } - for _, rule := range route.Rules { + for _, rule := range res.Data.Rules { for _, backend := range rule.BackendRefs { if backend.BackendRef == nil || backend.BackendRef.Ref == nil { continue @@ -48,34 +43,26 @@ func MutateTCPRoute(res *pbresource.Resource) error { } } - if !changed { - return nil - } - - return res.Data.MarshalFrom(&route) + return changed, nil } -func ValidateTCPRoute(res *pbresource.Resource) error { - var route pbmesh.TCPRoute - - if err := res.Data.UnmarshalTo(&route); err != nil { - return resource.NewErrDataParse(&route, err) - } +var ValidateTCPRoute = resource.DecodeAndValidate(validateTCPRoute) +func validateTCPRoute(res *DecodedTCPRoute) error { var merr error - if err := validateParentRefs(res.Id, route.ParentRefs); err != nil { + if err := validateParentRefs(res.Id, res.Data.ParentRefs); err != nil { merr = multierror.Append(merr, err) } - if len(route.Rules) > 1 { + if len(res.Data.Rules) > 1 { merr = multierror.Append(merr, resource.ErrInvalidField{ Name: "rules", Wrapped: fmt.Errorf("must only specify a single rule for now"), }) } - for i, rule := range route.Rules { + for i, rule := range res.Data.Rules { wrapRuleErr := func(err error) error { return resource.ErrInvalidListElement{ Name: "rules", diff --git a/internal/mesh/internal/types/xroute.go b/internal/mesh/internal/types/xroute.go index 619c9cb68243..92e2136cd135 100644 --- a/internal/mesh/internal/types/xroute.go +++ b/internal/mesh/internal/types/xroute.go @@ -288,28 +288,17 @@ func isValidRetryCondition(retryOn string) bool { func xRouteACLHooks[R XRouteData]() *resource.ACLHooks { hooks := &resource.ACLHooks{ - Read: aclReadHookXRoute[R], - Write: aclWriteHookXRoute[R], + Read: resource.DecodeAndAuthorizeRead(aclReadHookXRoute[R]), + Write: resource.DecodeAndAuthorizeWrite(aclWriteHookXRoute[R]), List: resource.NoOpACLListHook, } return hooks } -func aclReadHookXRoute[R XRouteData](authorizer acl.Authorizer, _ *acl.AuthorizerContext, _ *pbresource.ID, res *pbresource.Resource) error { - if res == nil { - return resource.ErrNeedResource - } - - dec, err := resource.Decode[R](res) - if err != nil { - return err - } - - route := dec.Data - +func aclReadHookXRoute[R XRouteData](authorizer acl.Authorizer, _ *acl.AuthorizerContext, res *resource.DecodedResource[R]) error { // Need service:read on ALL of the services this is controlling traffic for. - for _, parentRef := range route.GetParentRefs() { + for _, parentRef := range res.Data.GetParentRefs() { parentAuthzContext := resource.AuthorizerContext(parentRef.Ref.GetTenancy()) parentServiceName := parentRef.Ref.GetName() @@ -321,16 +310,9 @@ func aclReadHookXRoute[R XRouteData](authorizer acl.Authorizer, _ *acl.Authorize return nil } -func aclWriteHookXRoute[R XRouteData](authorizer acl.Authorizer, _ *acl.AuthorizerContext, res *pbresource.Resource) error { - dec, err := resource.Decode[R](res) - if err != nil { - return err - } - - route := dec.Data - +func aclWriteHookXRoute[R XRouteData](authorizer acl.Authorizer, _ *acl.AuthorizerContext, res *resource.DecodedResource[R]) error { // Need service:write on ALL of the services this is controlling traffic for. - for _, parentRef := range route.GetParentRefs() { + for _, parentRef := range res.Data.GetParentRefs() { parentAuthzContext := resource.AuthorizerContext(parentRef.Ref.GetTenancy()) parentServiceName := parentRef.Ref.GetName() @@ -340,7 +322,7 @@ func aclWriteHookXRoute[R XRouteData](authorizer acl.Authorizer, _ *acl.Authoriz } // Need service:read on ALL of the services this directs traffic at. - for _, backendRef := range route.GetUnderlyingBackendRefs() { + for _, backendRef := range res.Data.GetUnderlyingBackendRefs() { backendAuthzContext := resource.AuthorizerContext(backendRef.Ref.GetTenancy()) backendServiceName := backendRef.Ref.GetName() diff --git a/internal/resource/decode.go b/internal/resource/decode.go index ba9abd87d60d..d96cb79a9a7c 100644 --- a/internal/resource/decode.go +++ b/internal/resource/decode.go @@ -16,8 +16,10 @@ import ( // DecodedResource is a generic holder to contain an original Resource and its // decoded contents. type DecodedResource[T proto.Message] struct { - Resource *pbresource.Resource - Data T + // Embedding here allows us to shadow the Resource.Data Any field to fake out + // using a single struct with inlined data. + *pbresource.Resource + Data T } func (d *DecodedResource[T]) GetResource() *pbresource.Resource { diff --git a/internal/resource/hooks.go b/internal/resource/hooks.go new file mode 100644 index 000000000000..e722a4afc6cd --- /dev/null +++ b/internal/resource/hooks.go @@ -0,0 +1,104 @@ +package resource + +import ( + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/proto-public/pbresource" + "google.golang.org/protobuf/proto" +) + +// DecodedValidationHook is the function signature needed for usage with the DecodeAndValidate function +type DecodedValidationHook[T proto.Message] func(*DecodedResource[T]) error + +// DecodeAndValidate will generate a validation hook function that decodes the specified type and +// passes it off to another validation hook. This is mainly a convenience to avoid many other +// validation hooks needing to attempt decoding the data and erroring in a consistent manner. +func DecodeAndValidate[T proto.Message](fn DecodedValidationHook[T]) ValidationHook { + return func(res *pbresource.Resource) error { + decoded, err := Decode[T](res) + if err != nil { + return err + } + + return fn(decoded) + } +} + +// DecodedMutationHook is the function signature needed for usage with the DecodeAndMutate function +// The boolean return value indicates whether the Data field within the DecodedResource was modified. +// When true, the DecodeAndMutate hook function will automatically re-encode the Any data and store +// it on the internal Resource's Data field. +type DecodedMutationHook[T proto.Message] func(*DecodedResource[T]) (bool, error) + +// DecodeAndMutate will generate a MutationHook that decodes the specified type and passes it +// off to another mutation hook. This is mainly a convenience to avoid other mutation hooks +// needing to decode and potentially reencode the Any data. When the inner mutation hook returns +// no error and that the Data was modified (true for the boolean return value), the generated +// hook will reencode the Any data back into the Resource wrapper +func DecodeAndMutate[T proto.Message](fn DecodedMutationHook[T]) MutationHook { + return func(res *pbresource.Resource) error { + decoded, err := Decode[T](res) + if err != nil { + return err + } + + modified, err := fn(decoded) + if err != nil { + return err + } + + if modified { + return decoded.Resource.Data.MarshalFrom(decoded.Data) + } + return nil + } +} + +// DecodedAuthorizationHook is the function signature needed for usage with the DecodeAndAuthorizeWrite +// and DecodeAndAuthorizeRead functions. +type DecodedAuthorizationHook[T proto.Message] func(acl.Authorizer, *acl.AuthorizerContext, *DecodedResource[T]) error + +// DecodeAndAuthorizeWrite will generate an ACLAuthorizeWriteHook that decodes the specified type and passes +// it off to another authorization hook. This is mainly a convenience to avoid many other write authorization +// hooks needing to attempt decoding the data and erroring in a consistent manner. +func DecodeAndAuthorizeWrite[T proto.Message](fn DecodedAuthorizationHook[T]) ACLAuthorizeWriteHook { + return func(authz acl.Authorizer, ctx *acl.AuthorizerContext, res *pbresource.Resource) error { + decoded, err := Decode[T](res) + if err != nil { + return err + } + + return fn(authz, ctx, decoded) + } +} + +// DecodeAndAuthorizeRead will generate an ACLAuthorizeReadHook that decodes the specified type and passes +// it off to another authorization hook. This is mainly a convenience to avoid many other read authorization +// hooks needing to attempt decoding the data and erroring in a consistent manner. +func DecodeAndAuthorizeRead[T proto.Message](fn DecodedAuthorizationHook[T]) ACLAuthorizeReadHook { + return func(authz acl.Authorizer, ctx *acl.AuthorizerContext, _ *pbresource.ID, res *pbresource.Resource) error { + if res == nil { + return ErrNeedResource + } + + decoded, err := Decode[T](res) + if err != nil { + return err + } + + return fn(authz, ctx, decoded) + } +} + +type ReadAuthorizationWithResourceHook func(acl.Authorizer, *acl.AuthorizerContext, *pbresource.Resource) error + +// AuthorizeReadWithResource is a small wrapper to ensure that the authorization function is +// invoked with the full resource being read instead of just an id. +func AuthorizeReadWithResource(fn ReadAuthorizationWithResourceHook) ACLAuthorizeReadHook { + return func(authz acl.Authorizer, ctx *acl.AuthorizerContext, id *pbresource.ID, res *pbresource.Resource) error { + if res == nil { + return ErrNeedResource + } + + return fn(authz, ctx, res) + } +} diff --git a/internal/resource/hooks_test.go b/internal/resource/hooks_test.go new file mode 100644 index 000000000000..4397ff8e4c17 --- /dev/null +++ b/internal/resource/hooks_test.go @@ -0,0 +1,240 @@ +package resource_test + +import ( + "fmt" + "testing" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/resource/demo" + rtest "github.com/hashicorp/consul/internal/resource/resourcetest" + "github.com/hashicorp/consul/proto-public/pbresource" + pbdemo "github.com/hashicorp/consul/proto/private/pbdemo/v2" + "github.com/stretchr/testify/require" +) + +func TestDecodeAndValidate(t *testing.T) { + res := rtest.Resource(demo.TypeV2Artist, "babypants"). + WithData(t, &pbdemo.Artist{Name: "caspar babypants"}). + Build() + + t.Run("ok", func(t *testing.T) { + err := resource.DecodeAndValidate[*pbdemo.Artist](func(dec *resource.DecodedResource[*pbdemo.Artist]) error { + require.NotNil(t, dec.Resource) + require.NotNil(t, dec.Data) + + return nil + })(res) + + require.NoError(t, err) + }) + + t.Run("inner-validation-error", func(t *testing.T) { + fakeErr := fmt.Errorf("fake") + + err := resource.DecodeAndValidate[*pbdemo.Artist](func(dec *resource.DecodedResource[*pbdemo.Artist]) error { + return fakeErr + })(res) + + require.Error(t, err) + require.Equal(t, fakeErr, err) + }) + + t.Run("decode-error", func(t *testing.T) { + err := resource.DecodeAndValidate[*pbdemo.Album](func(dec *resource.DecodedResource[*pbdemo.Album]) error { + require.Fail(t, "callback should not be called when decoding fails") + return nil + })(res) + + require.Error(t, err) + require.ErrorAs(t, err, &resource.ErrDataParse{}) + }) +} + +func TestDecodeAndMutate(t *testing.T) { + res := rtest.Resource(demo.TypeV2Artist, "babypants"). + WithData(t, &pbdemo.Artist{Name: "caspar babypants"}). + Build() + + t.Run("no-writeback", func(t *testing.T) { + original := res.Data.Value + + err := resource.DecodeAndMutate[*pbdemo.Artist](func(dec *resource.DecodedResource[*pbdemo.Artist]) (bool, error) { + require.NotNil(t, dec.Resource) + require.NotNil(t, dec.Data) + + // we are going to change the data but not tell the outer hook about it + dec.Data.Name = "changed" + + return false, nil + })(res) + + require.NoError(t, err) + // Ensure that the outer hook didn't overwrite the resources data because we told it not to + require.Equal(t, original, res.Data.Value) + }) + + t.Run("writeback", func(t *testing.T) { + original := res.Data.Value + + err := resource.DecodeAndMutate[*pbdemo.Artist](func(dec *resource.DecodedResource[*pbdemo.Artist]) (bool, error) { + require.NotNil(t, dec.Resource) + require.NotNil(t, dec.Data) + + dec.Data.Name = "changed" + + return true, nil + })(res) + + require.NoError(t, err) + // Ensure that the outer hook reencoded the Any data because we told it to. + require.NotEqual(t, original, res.Data.Value) + }) + + t.Run("inner-mutation-error", func(t *testing.T) { + fakeErr := fmt.Errorf("fake") + + err := resource.DecodeAndMutate[*pbdemo.Artist](func(dec *resource.DecodedResource[*pbdemo.Artist]) (bool, error) { + return false, fakeErr + })(res) + + require.Error(t, err) + require.Equal(t, fakeErr, err) + }) + + t.Run("decode-error", func(t *testing.T) { + err := resource.DecodeAndMutate[*pbdemo.Album](func(dec *resource.DecodedResource[*pbdemo.Album]) (bool, error) { + require.Fail(t, "callback should not be called when decoding fails") + return false, nil + })(res) + + require.Error(t, err) + require.ErrorAs(t, err, &resource.ErrDataParse{}) + }) +} + +func TestDecodeAndAuthorizeWrite(t *testing.T) { + res := rtest.Resource(demo.TypeV2Artist, "babypants"). + WithData(t, &pbdemo.Artist{Name: "caspar babypants"}). + Build() + + t.Run("allowed", func(t *testing.T) { + err := resource.DecodeAndAuthorizeWrite[*pbdemo.Artist](func(a acl.Authorizer, c *acl.AuthorizerContext, dec *resource.DecodedResource[*pbdemo.Artist]) error { + require.NotNil(t, a) + require.NotNil(t, c) + require.NotNil(t, dec.Resource) + require.NotNil(t, dec.Data) + + // access allowed + return nil + })(acl.DenyAll(), &acl.AuthorizerContext{}, res) + + require.NoError(t, err) + }) + + t.Run("denied", func(t *testing.T) { + err := resource.DecodeAndAuthorizeWrite[*pbdemo.Artist](func(a acl.Authorizer, c *acl.AuthorizerContext, dec *resource.DecodedResource[*pbdemo.Artist]) error { + return acl.PermissionDenied("fake") + })(acl.DenyAll(), nil, res) + + require.Error(t, err) + require.True(t, acl.IsErrPermissionDenied(err)) + }) + + t.Run("decode-error", func(t *testing.T) { + err := resource.DecodeAndAuthorizeWrite[*pbdemo.Album](func(a acl.Authorizer, c *acl.AuthorizerContext, dec *resource.DecodedResource[*pbdemo.Album]) error { + require.Fail(t, "callback should not be called when decoding fails") + return nil + })(acl.DenyAll(), &acl.AuthorizerContext{}, res) + + require.Error(t, err) + require.ErrorAs(t, err, &resource.ErrDataParse{}) + }) +} + +func TestDecodeAndAuthorizeRead(t *testing.T) { + res := rtest.Resource(demo.TypeV2Artist, "babypants"). + WithData(t, &pbdemo.Artist{Name: "caspar babypants"}). + Build() + + t.Run("allowed", func(t *testing.T) { + err := resource.DecodeAndAuthorizeRead[*pbdemo.Artist](func(a acl.Authorizer, c *acl.AuthorizerContext, dec *resource.DecodedResource[*pbdemo.Artist]) error { + require.NotNil(t, a) + require.NotNil(t, c) + require.NotNil(t, dec.Resource) + require.NotNil(t, dec.Data) + + // access allowed + return nil + })(acl.DenyAll(), &acl.AuthorizerContext{}, nil, res) + + require.NoError(t, err) + }) + + t.Run("denied", func(t *testing.T) { + err := resource.DecodeAndAuthorizeRead[*pbdemo.Artist](func(a acl.Authorizer, c *acl.AuthorizerContext, dec *resource.DecodedResource[*pbdemo.Artist]) error { + return acl.PermissionDenied("fake") + })(acl.DenyAll(), nil, nil, res) + + require.Error(t, err) + require.True(t, acl.IsErrPermissionDenied(err)) + }) + + t.Run("decode-error", func(t *testing.T) { + err := resource.DecodeAndAuthorizeRead[*pbdemo.Album](func(a acl.Authorizer, c *acl.AuthorizerContext, dec *resource.DecodedResource[*pbdemo.Album]) error { + require.Fail(t, "callback should not be called when decoding fails") + return nil + })(acl.DenyAll(), &acl.AuthorizerContext{}, nil, res) + + require.Error(t, err) + require.ErrorAs(t, err, &resource.ErrDataParse{}) + }) + + t.Run("err-need-resource", func(t *testing.T) { + err := resource.DecodeAndAuthorizeRead[*pbdemo.Artist](func(a acl.Authorizer, c *acl.AuthorizerContext, dec *resource.DecodedResource[*pbdemo.Artist]) error { + require.Fail(t, "callback should not be called when no resource was provided to be decoded") + return nil + })(acl.DenyAll(), &acl.AuthorizerContext{}, nil, nil) + + require.Error(t, err) + require.ErrorIs(t, err, resource.ErrNeedResource) + }) +} + +func TestAuthorizeReadWithResource(t *testing.T) { + res := rtest.Resource(demo.TypeV2Artist, "babypants"). + WithData(t, &pbdemo.Artist{Name: "caspar babypants"}). + Build() + + t.Run("allowed", func(t *testing.T) { + err := resource.AuthorizeReadWithResource(func(a acl.Authorizer, c *acl.AuthorizerContext, res *pbresource.Resource) error { + require.NotNil(t, a) + require.NotNil(t, c) + require.NotNil(t, res) + + // access allowed + return nil + })(acl.DenyAll(), &acl.AuthorizerContext{}, nil, res) + + require.NoError(t, err) + }) + + t.Run("denied", func(t *testing.T) { + err := resource.AuthorizeReadWithResource(func(a acl.Authorizer, c *acl.AuthorizerContext, res *pbresource.Resource) error { + return acl.PermissionDenied("fake") + })(acl.DenyAll(), nil, nil, res) + + require.Error(t, err) + require.True(t, acl.IsErrPermissionDenied(err)) + }) + + t.Run("err-need-resource", func(t *testing.T) { + err := resource.AuthorizeReadWithResource(func(a acl.Authorizer, c *acl.AuthorizerContext, res *pbresource.Resource) error { + require.Fail(t, "callback should not be called when no resource was provided to be decoded") + return nil + })(acl.DenyAll(), &acl.AuthorizerContext{}, nil, nil) + + require.Error(t, err) + require.ErrorIs(t, err, resource.ErrNeedResource) + }) +} diff --git a/internal/resource/registry.go b/internal/resource/registry.go index 20c1f4dc41a8..7897ffb1b4bc 100644 --- a/internal/resource/registry.go +++ b/internal/resource/registry.go @@ -42,6 +42,17 @@ type Registry interface { Types() []Registration } +// ValidationHook is the function signature for a validation hook. These hooks can inspect +// the data as they see fit but are expected to not mutate the data in any way. If Go +// supported it, we would pass something akin to a const pointer into the callback to have +// the compiler enforce this immutability. +type ValidationHook func(*pbresource.Resource) error + +// MutationHook is the function signature for a validation hook. These hooks can inspect +// and mutate the resource. If modifying the resources Data, the hook needs to ensure that +// the data gets reencoded and stored back to the Data field. +type MutationHook func(*pbresource.Resource) error + type Registration struct { // Type is the GVK of the resource type. Type *pbresource.Type @@ -56,13 +67,13 @@ type Registration struct { // Validate is called to structurally validate the resource (e.g. // check for required fields). Validate can assume that Mutate // has been called. - Validate func(*pbresource.Resource) error + Validate ValidationHook // Mutate is called to fill out any autogenerated fields (e.g. UUIDs) or // apply defaults before validation. Mutate can assume that // Resource.ID is populated and has non-empty tenancy fields. This does // not mean those tenancy fields actually exist. - Mutate func(*pbresource.Resource) error + Mutate MutationHook // Scope describes the tenancy scope of a resource. Scope Scope @@ -70,6 +81,10 @@ type Registration struct { var ErrNeedResource = errors.New("authorization check requires the entire resource") +type ACLAuthorizeReadHook func(acl.Authorizer, *acl.AuthorizerContext, *pbresource.ID, *pbresource.Resource) error +type ACLAuthorizeWriteHook func(acl.Authorizer, *acl.AuthorizerContext, *pbresource.Resource) error +type ACLAuthorizeListHook func(acl.Authorizer, *acl.AuthorizerContext) error + type ACLHooks struct { // Read is used to authorize Read RPCs and to filter results in List // RPCs. @@ -79,17 +94,17 @@ type ACLHooks struct { // check will be deferred until the data is fetched from the storage layer. // // If it is omitted, `operator:read` permission is assumed. - Read func(acl.Authorizer, *acl.AuthorizerContext, *pbresource.ID, *pbresource.Resource) error + Read ACLAuthorizeReadHook // Write is used to authorize Write and Delete RPCs. // // If it is omitted, `operator:write` permission is assumed. - Write func(acl.Authorizer, *acl.AuthorizerContext, *pbresource.Resource) error + Write ACLAuthorizeWriteHook // List is used to authorize List RPCs. // // If it is omitted, we only filter the results using Read. - List func(acl.Authorizer, *acl.AuthorizerContext) error + List ACLAuthorizeListHook } // Resource type registry diff --git a/internal/tenancy/exports.go b/internal/tenancy/exports.go index b126e7445f74..c07b25903c39 100644 --- a/internal/tenancy/exports.go +++ b/internal/tenancy/exports.go @@ -9,19 +9,6 @@ import ( "github.com/hashicorp/consul/internal/tenancy/internal/types" ) -var ( - // API Group Information - - APIGroup = types.GroupName - VersionV2Beta1 = types.VersionV2Beta1 - CurrentVersion = types.CurrentVersion - - // Resource Kind Names. - - NamespaceKind = types.NamespaceKind - NamespaceV2Beta1Type = types.NamespaceV2Beta1Type -) - type ( V2TenancyBridge = bridge.V2TenancyBridge ) diff --git a/internal/tenancy/internal/bridge/tenancy_bridge.go b/internal/tenancy/internal/bridge/tenancy_bridge.go index 93e85ded49f2..db6a4dd53a17 100644 --- a/internal/tenancy/internal/bridge/tenancy_bridge.go +++ b/internal/tenancy/internal/bridge/tenancy_bridge.go @@ -5,8 +5,9 @@ package bridge import ( "context" - "github.com/hashicorp/consul/internal/tenancy/internal/types" + "github.com/hashicorp/consul/proto-public/pbresource" + pbtenancy "github.com/hashicorp/consul/proto-public/pbtenancy/v2beta1" ) // V2TenancyBridge is used by the resource service to access V2 implementations of @@ -34,7 +35,7 @@ func (b *V2TenancyBridge) NamespaceExists(partition, namespace string) (bool, er Tenancy: &pbresource.Tenancy{ Partition: partition, }, - Type: types.NamespaceType, + Type: pbtenancy.NamespaceType, }, }) return read != nil && read.Resource != nil, err @@ -47,7 +48,7 @@ func (b *V2TenancyBridge) IsNamespaceMarkedForDeletion(partition, namespace stri Tenancy: &pbresource.Tenancy{ Partition: partition, }, - Type: types.NamespaceType, + Type: pbtenancy.NamespaceType, }, }) return read.Resource != nil, err diff --git a/internal/tenancy/internal/types/namespace.go b/internal/tenancy/internal/types/namespace.go index 1bb016bf3c4d..88bf21512558 100644 --- a/internal/tenancy/internal/types/namespace.go +++ b/internal/tenancy/internal/types/namespace.go @@ -13,22 +13,9 @@ import ( pbtenancy "github.com/hashicorp/consul/proto-public/pbtenancy/v2beta1" ) -const ( - NamespaceKind = "Namespace" -) - -var ( - NamespaceV2Beta1Type = &pbresource.Type{ - Group: GroupName, - GroupVersion: VersionV2Beta1, - Kind: NamespaceKind, - } - NamespaceType = NamespaceV2Beta1Type -) - func RegisterNamespace(r resource.Registry) { r.Register(resource.Registration{ - Type: NamespaceType, + Type: pbtenancy.NamespaceType, Proto: &pbtenancy.Namespace{}, Scope: resource.ScopePartition, Validate: ValidateNamespace, diff --git a/internal/tenancy/internal/types/types_test.go b/internal/tenancy/internal/types/types_test.go index 4a089c28d8b5..a82d5b9e6c5a 100644 --- a/internal/tenancy/internal/types/types_test.go +++ b/internal/tenancy/internal/types/types_test.go @@ -5,6 +5,8 @@ package types import ( "errors" + "testing" + "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" "github.com/hashicorp/consul/proto-public/pbresource" @@ -12,13 +14,12 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/anypb" - "testing" ) func createNamespaceResource(t *testing.T, data protoreflect.ProtoMessage) *pbresource.Resource { res := &pbresource.Resource{ Id: &pbresource.ID{ - Type: NamespaceV2Beta1Type, + Type: pbtenancy.NamespaceType, Tenancy: resource.DefaultPartitionedTenancy(), Name: "ns1234", },