Skip to content

Commit

Permalink
Backport of mesh: add more validations to Destinations resource into …
Browse files Browse the repository at this point in the history
…release/1.17.x (#19211)

backport of commit f6c7c4d

Co-authored-by: Iryna Shustava <iryna@hashicorp.com>
  • Loading branch information
hc-github-team-consul-core and ishustava authored Oct 13, 2023
1 parent 9ceec77 commit 41a986c
Show file tree
Hide file tree
Showing 13 changed files with 288 additions and 36 deletions.
1 change: 1 addition & 0 deletions command/resource/testdata/nested_data.hcl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Data {
DestinationPort = "tcp"

IpPort = {
Ip = "127.0.0.1"
Port = 1234
}
}
Expand Down
8 changes: 8 additions & 0 deletions internal/catalog/exports.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,11 @@ func ValidateLocalServiceRefNoSection(ref *pbresource.Reference, wrapErr func(er
func ValidateSelector(sel *pbcatalog.WorkloadSelector, allowEmpty bool) error {
return types.ValidateSelector(sel, allowEmpty)
}

func ValidatePortName(name string) error {
return types.ValidatePortName(name)
}

func IsValidUnixSocketPath(host string) bool {
return types.IsValidUnixSocketPath(host)
}
4 changes: 2 additions & 2 deletions internal/catalog/internal/types/failover_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func ValidateFailoverPolicy(res *pbresource.Resource) error {
Wrapped: err,
}
}
if portNameErr := validatePortName(portName); portNameErr != nil {
if portNameErr := ValidatePortName(portName); portNameErr != nil {
merr = multierror.Append(merr, resource.ErrInvalidMapKey{
Map: "port_configs",
Key: portName,
Expand Down Expand Up @@ -245,7 +245,7 @@ func validateFailoverPolicyDestination(dest *pbcatalog.FailoverDestination, port
// assumed and will be reconciled.
if dest.Port != "" {
if ported {
if portNameErr := validatePortName(dest.Port); portNameErr != nil {
if portNameErr := ValidatePortName(dest.Port); portNameErr != nil {
merr = multierror.Append(merr, wrapErr(resource.ErrInvalidField{
Name: "port",
Wrapped: portNameErr,
Expand Down
2 changes: 1 addition & 1 deletion internal/catalog/internal/types/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func ValidateService(res *pbresource.Resource) error {
}

// validate the target port
if nameErr := validatePortName(port.TargetPort); nameErr != nil {
if nameErr := ValidatePortName(port.TargetPort); nameErr != nil {
err = multierror.Append(err, resource.ErrInvalidListElement{
Name: "ports",
Index: idx,
Expand Down
2 changes: 1 addition & 1 deletion internal/catalog/internal/types/service_endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func validateEndpoint(endpoint *pbcatalog.Endpoint, res *pbresource.Resource) er
// Validate the endpoints ports
for portName, port := range endpoint.Ports {
// Port names must be DNS labels
if portNameErr := validatePortName(portName); portNameErr != nil {
if portNameErr := ValidatePortName(portName); portNameErr != nil {
err = multierror.Append(err, resource.ErrInvalidMapKey{
Map: "ports",
Key: portName,
Expand Down
8 changes: 4 additions & 4 deletions internal/catalog/internal/types/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func isValidDNSLabel(label string) bool {
return dnsLabelMatcher.Match([]byte(label))
}

func isValidUnixSocketPath(host string) bool {
func IsValidUnixSocketPath(host string) bool {
if len(host) > maxUnixSocketPathLen || !strings.HasPrefix(host, "unix://") || strings.Contains(host, "\000") {
return false
}
Expand All @@ -71,7 +71,7 @@ func validateWorkloadHost(host string) error {
}

// Check if the host represents an IP address, unix socket path or a DNS name
if !isValidIPAddress(host) && !isValidUnixSocketPath(host) && !isValidDNSName(host) {
if !isValidIPAddress(host) && !IsValidUnixSocketPath(host) && !isValidDNSName(host) {
return errInvalidWorkloadHostFormat{Host: host}
}

Expand Down Expand Up @@ -139,7 +139,7 @@ func validateIPAddress(ip string) error {
return nil
}

func validatePortName(name string) error {
func ValidatePortName(name string) error {
if name == "" {
return resource.ErrEmpty
}
Expand Down Expand Up @@ -184,7 +184,7 @@ func validateWorkloadAddress(addr *pbcatalog.WorkloadAddress, ports map[string]*

// Ensure that unix sockets reference exactly 1 port. They may also indirectly reference 1 port
// by the workload having only a single port and omitting any explicit port assignment.
if isValidUnixSocketPath(addr.Host) &&
if IsValidUnixSocketPath(addr.Host) &&
(len(addr.Ports) > 1 || (len(addr.Ports) == 0 && len(ports) > 1)) {
err = multierror.Append(err, errUnixSocketMultiport)
}
Expand Down
8 changes: 4 additions & 4 deletions internal/catalog/internal/types/validators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func TestIsValidUnixSocketPath(t *testing.T) {

for name, tcase := range cases {
t.Run(name, func(t *testing.T) {
require.Equal(t, tcase.valid, isValidUnixSocketPath(tcase.name))
require.Equal(t, tcase.valid, IsValidUnixSocketPath(tcase.name))
})
}
}
Expand Down Expand Up @@ -361,15 +361,15 @@ func TestValidatePortName(t *testing.T) {
// test for the isValidDNSLabel function.

t.Run("empty", func(t *testing.T) {
require.Equal(t, resource.ErrEmpty, validatePortName(""))
require.Equal(t, resource.ErrEmpty, ValidatePortName(""))
})

t.Run("invalid", func(t *testing.T) {
require.Equal(t, errNotDNSLabel, validatePortName("foo.com"))
require.Equal(t, errNotDNSLabel, ValidatePortName("foo.com"))
})

t.Run("ok", func(t *testing.T) {
require.NoError(t, validatePortName("http"))
require.NoError(t, ValidatePortName("http"))
})
}

Expand Down
2 changes: 1 addition & 1 deletion internal/catalog/internal/types/workload.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func ValidateWorkload(res *pbresource.Resource) error {

// Validate the Workload Ports
for portName, port := range workload.Ports {
if portNameErr := validatePortName(portName); portNameErr != nil {
if portNameErr := ValidatePortName(portName); portNameErr != nil {
err = multierror.Append(err, resource.ErrInvalidMapKey{
Map: "ports",
Key: portName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ func (suite *controllerTestSuite) TestController() {
}).Write(suite.T(), suite.client)

testutil.RunStep(suite.T(), "add explicit destinations and check that new proxy state is generated", func(t *testing.T) {
webProxyStateTemplate = suite.client.WaitForNewVersion(t, webProxyStateTemplateID, webProxyStateTemplate.Version)
webProxyStateTemplate = suite.client.WaitForNewVersion(suite.T(), webProxyStateTemplateID, webProxyStateTemplate.Version)

requireExplicitDestinationsFound(t, "api", webProxyStateTemplate)
})
Expand Down Expand Up @@ -613,7 +613,7 @@ func (suite *controllerTestSuite) TestController() {
})

// We should get a new web proxy template resource because this destination should be removed.
webProxyStateTemplate = suite.client.WaitForNewVersion(t, webProxyStateTemplateID, webProxyStateTemplate.Version)
webProxyStateTemplate = suite.client.WaitForNewVersion(suite.T(), webProxyStateTemplateID, webProxyStateTemplate.Version)

requireExplicitDestinationsNotFound(t, "api", webProxyStateTemplate)
})
Expand Down
81 changes: 77 additions & 4 deletions internal/mesh/internal/types/destinations.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package types

import (
"net"

"github.com/hashicorp/go-multierror"
"google.golang.org/protobuf/proto"

Expand Down Expand Up @@ -73,18 +75,24 @@ func ValidateDestinations(res *pbresource.Resource) error {

var merr error

// Validate the workload selector
if selErr := catalog.ValidateSelector(destinations.Workloads, false); selErr != nil {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "workloads",
Wrapped: selErr,
})
}

if destinations.GetPqDestinations() != nil {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "pq_destinations",
Wrapped: resource.ErrUnsupported,
})
}

for i, dest := range destinations.Destinations {
wrapDestErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "upstreams",
Name: "destinations",
Index: i,
Wrapped: err,
}
Expand All @@ -101,8 +109,73 @@ func ValidateDestinations(res *pbresource.Resource) error {
merr = multierror.Append(merr, refErr)
}

// TODO(v2): validate port name using catalog validator
// TODO(v2): validate ListenAddr
if portErr := catalog.ValidatePortName(dest.DestinationPort); portErr != nil {
merr = multierror.Append(merr, wrapDestErr(resource.ErrInvalidField{
Name: "destination_port",
Wrapped: portErr,
}))
}

if dest.GetDatacenter() != "" {
merr = multierror.Append(merr, wrapDestErr(resource.ErrInvalidField{
Name: "datacenter",
Wrapped: resource.ErrUnsupported,
}))
}

if listenAddrErr := validateListenAddr(dest); listenAddrErr != nil {
merr = multierror.Append(merr, wrapDestErr(listenAddrErr))
}
}

return merr
}

func validateListenAddr(dest *pbmesh.Destination) error {
var merr error

if dest.GetListenAddr() == nil {
return multierror.Append(merr, resource.ErrInvalidFields{
Names: []string{"ip_port", "unix"},
Wrapped: resource.ErrMissingOneOf,
})
}

switch listenAddr := dest.GetListenAddr().(type) {
case *pbmesh.Destination_IpPort:
if ipPortErr := validateIPPort(listenAddr.IpPort); ipPortErr != nil {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "ip_port",
Wrapped: ipPortErr,
})
}
case *pbmesh.Destination_Unix:
if !catalog.IsValidUnixSocketPath(listenAddr.Unix.GetPath()) {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "unix",
Wrapped: resource.ErrInvalidField{
Name: "path",
Wrapped: errInvalidUnixSocketPath,
},
})
}
}

return merr
}

func validateIPPort(ipPort *pbmesh.IPPortAddress) error {
var merr error

if listenPortErr := validatePort(ipPort.GetPort(), "port"); listenPortErr != nil {
merr = multierror.Append(merr, listenPortErr)
}

if net.ParseIP(ipPort.GetIp()) == nil {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "ip",
Wrapped: errInvalidIP,
})
}

return merr
Expand Down
Loading

0 comments on commit 41a986c

Please sign in to comment.