Skip to content

Commit

Permalink
fix: validate proto messages before converting them to anypb.Any (#4499)
Browse files Browse the repository at this point in the history
* validate proto message before converting to any

Signed-off-by: Huabing Zhao <zhaohuabing@gmail.com>
  • Loading branch information
zhaohuabing authored Oct 29, 2024
1 parent b877bac commit 05817fc
Show file tree
Hide file tree
Showing 19 changed files with 184 additions and 112 deletions.
36 changes: 18 additions & 18 deletions internal/utils/protocov/protocov.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,30 @@ import (
"google.golang.org/protobuf/types/known/anypb"
)

const (
APIPrefix = "type.googleapis.com/"
)

var marshalOpts = proto.MarshalOptions{}
// Deprecated: error should not be ignored, use ToAnyWithValidation instead.
func ToAny(msg proto.Message) *anypb.Any {
res, err := ToAnyWithValidation(msg)
if err != nil {
return nil
}
return res
}

func ToAnyWithError(msg proto.Message) (*anypb.Any, error) {
func ToAnyWithValidation(msg proto.Message) (*anypb.Any, error) {
if msg == nil {
return nil, errors.New("empty message received")
}
b, err := marshalOpts.Marshal(msg)
if err != nil {
return nil, err

// If the message has a ValidateAll method, call it before marshaling.
if validator, ok := msg.(interface{ ValidateAll() error }); ok {
if err := validator.ValidateAll(); err != nil {
return nil, err
}
}
return &anypb.Any{
TypeUrl: APIPrefix + string(msg.ProtoReflect().Descriptor().FullName()),
Value: b,
}, nil
}

func ToAny(msg proto.Message) *anypb.Any {
res, err := ToAnyWithError(msg)
any, err := anypb.New(msg)
if err != nil {
return nil
return nil, err
}
return res
return any, nil
}
33 changes: 23 additions & 10 deletions internal/xds/translator/accesslog.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/envoyproxy/go-control-plane/pkg/wellknown"
otlpcommonv1 "go.opentelemetry.io/proto/otlp/common/v1"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"

egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
Expand Down Expand Up @@ -90,9 +89,9 @@ var (
}
)

func buildXdsAccessLog(al *ir.AccessLog, accessLogType ir.ProxyAccessLogType) []*accesslog.AccessLog {
func buildXdsAccessLog(al *ir.AccessLog, accessLogType ir.ProxyAccessLogType) ([]*accesslog.AccessLog, error) {
if al == nil {
return nil
return nil, nil
}

totalLen := len(al.Text) + len(al.JSON) + len(al.OpenTelemetry)
Expand Down Expand Up @@ -133,8 +132,10 @@ func buildXdsAccessLog(al *ir.AccessLog, accessLogType ir.ProxyAccessLogType) []
filelog.GetLogFormat().Formatters = formatters
}

// TODO: find a better way to handle this
accesslogAny, _ := anypb.New(filelog)
accesslogAny, err := protocov.ToAnyWithValidation(filelog)
if err != nil {
return nil, err
}
accessLogs = append(accessLogs, &accesslog.AccessLog{
Name: wellknown.FileAccessLog,
ConfigType: &accesslog.AccessLog_TypedConfig{
Expand Down Expand Up @@ -185,7 +186,10 @@ func buildXdsAccessLog(al *ir.AccessLog, accessLogType ir.ProxyAccessLogType) []
filelog.GetLogFormat().Formatters = formatters
}

accesslogAny, _ := anypb.New(filelog)
accesslogAny, err := protocov.ToAnyWithValidation(filelog)
if err != nil {
return nil, err
}
accessLogs = append(accessLogs, &accesslog.AccessLog{
Name: wellknown.FileAccessLog,
ConfigType: &accesslog.AccessLog_TypedConfig{
Expand Down Expand Up @@ -228,7 +232,10 @@ func buildXdsAccessLog(al *ir.AccessLog, accessLogType ir.ProxyAccessLogType) []
alCfg.AdditionalResponseTrailersToLog = als.HTTP.ResponseTrailers
}

accesslogAny, _ := anypb.New(alCfg)
accesslogAny, err := protocov.ToAnyWithValidation(alCfg)
if err != nil {
return nil, err
}
accessLogs = append(accessLogs, &accesslog.AccessLog{
Name: wellknown.HTTPGRPCAccessLog,
ConfigType: &accesslog.AccessLog_TypedConfig{
Expand All @@ -241,7 +248,10 @@ func buildXdsAccessLog(al *ir.AccessLog, accessLogType ir.ProxyAccessLogType) []
CommonConfig: cc,
}

accesslogAny, _ := anypb.New(alCfg)
accesslogAny, err := protocov.ToAnyWithValidation(alCfg)
if err != nil {
return nil, err
}
accessLogs = append(accessLogs, &accesslog.AccessLog{
Name: tcpGRPCAccessLog,
ConfigType: &accesslog.AccessLog_TypedConfig{
Expand Down Expand Up @@ -297,7 +307,10 @@ func buildXdsAccessLog(al *ir.AccessLog, accessLogType ir.ProxyAccessLogType) []
al.Formatters = formatters
}

accesslogAny, _ := anypb.New(al)
accesslogAny, err := protocov.ToAnyWithValidation(al)
if err != nil {
return nil, err
}
accessLogs = append(accessLogs, &accesslog.AccessLog{
Name: otelAccessLog,
ConfigType: &accesslog.AccessLog_TypedConfig{
Expand All @@ -307,7 +320,7 @@ func buildXdsAccessLog(al *ir.AccessLog, accessLogType ir.ProxyAccessLogType) []
})
}

return accessLogs
return accessLogs, nil
}

func celAccessLogFilter(expr string) *accesslog.AccessLogFilter {
Expand Down
26 changes: 11 additions & 15 deletions internal/xds/translator/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
"github.com/envoyproxy/gateway/internal/ir"
"github.com/envoyproxy/gateway/internal/utils/protocov"
"github.com/envoyproxy/gateway/internal/xds/types"
)

Expand Down Expand Up @@ -75,7 +76,7 @@ func (*rbac) patchHCM(
// buildHCMRBACFilter returns a RBAC filter from the provided IR listener.
func buildHCMRBACFilter() (*hcmv3.HttpFilter, error) {
rbacProto := &rbacv3.RBAC{}
rbacAny, err := anypb.New(rbacProto)
rbacAny, err := protocov.ToAnyWithValidation(rbacProto)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -133,7 +134,7 @@ func (*rbac) patchRoute(route *routev3.Route, irRoute *ir.HTTPRoute) error {
return err
}

if cfgAny, err = anypb.New(rbacPerRoute); err != nil {
if cfgAny, err = protocov.ToAnyWithValidation(rbacPerRoute); err != nil {
return err
}

Expand All @@ -159,15 +160,15 @@ func buildRBACPerRoute(authorization *ir.Authorization) (*rbacv3.RBACPerRoute, e
Name: "ALLOW",
Action: rbacconfigv3.RBAC_ALLOW,
}
if allowAction, err = anypb.New(allow); err != nil {
if allowAction, err = protocov.ToAnyWithValidation(allow); err != nil {
return nil, err
}

deny := &rbacconfigv3.Action{
Name: "DENY",
Action: rbacconfigv3.RBAC_DENY,
}
if denyAction, err = anypb.New(deny); err != nil {
if denyAction, err = protocov.ToAnyWithValidation(deny); err != nil {
return nil, err
}

Expand Down Expand Up @@ -287,11 +288,6 @@ func buildRBACPerRoute(authorization *ir.Authorization) (*rbacv3.RBACPerRoute, e
rbac.Rbac.Matcher.MatcherType = nil
}

// We need to validate the RBACPerRoute message before converting it to an Any.
if err = rbac.ValidateAll(); err != nil {
return nil, err
}

return rbac, nil
}

Expand All @@ -316,11 +312,11 @@ func buildIPPredicate(clientCIDRs []*ir.CIDRMatch) (*matcherv3.Matcher_MatcherLi
})
}

if ipMatcher, err = anypb.New(ipRangeMatcher); err != nil {
if ipMatcher, err = protocov.ToAnyWithValidation(ipRangeMatcher); err != nil {
return nil, err
}

if sourceIPInput, err = anypb.New(&networkinput.SourceIPInput{}); err != nil {
if sourceIPInput, err = protocov.ToAnyWithValidation(&networkinput.SourceIPInput{}); err != nil {
return nil, err
}

Expand Down Expand Up @@ -389,11 +385,11 @@ func buildJWTPredicate(jwt egv1a1.JWTPrincipal) ([]*matcherv3.Matcher_MatcherLis
},
}

if inputPb, err = anypb.New(input); err != nil {
if inputPb, err = protocov.ToAnyWithValidation(input); err != nil {
return nil, err
}

if matcherPb, err = anypb.New(scopeMatcher); err != nil {
if matcherPb, err = protocov.ToAnyWithValidation(scopeMatcher); err != nil {
return nil, err
}

Expand Down Expand Up @@ -454,7 +450,7 @@ func buildJWTPredicate(jwt egv1a1.JWTPrincipal) ([]*matcherv3.Matcher_MatcherLis
Path: path,
}

if inputPb, err = anypb.New(input); err != nil {
if inputPb, err = protocov.ToAnyWithValidation(input); err != nil {
return nil, err
}

Expand Down Expand Up @@ -492,7 +488,7 @@ func buildJWTPredicate(jwt egv1a1.JWTPrincipal) ([]*matcherv3.Matcher_MatcherLis
}
}

if matcherPb, err = anypb.New(&metadatav3.Metadata{
if matcherPb, err = protocov.ToAnyWithValidation(&metadatav3.Metadata{
Value: valueMatcher,
}); err != nil {
return nil, err
Expand Down
5 changes: 3 additions & 2 deletions internal/xds/translator/basicauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
"github.com/envoyproxy/gateway/internal/ir"
"github.com/envoyproxy/gateway/internal/utils/protocov"
"github.com/envoyproxy/gateway/internal/xds/types"
)

Expand Down Expand Up @@ -84,7 +85,7 @@ func buildHCMBasicAuthFilter(basicAuth *ir.BasicAuth) (*hcmv3.HttpFilter, error)
if err = basicAuthProto.ValidateAll(); err != nil {
return nil, err
}
if basicAuthAny, err = anypb.New(basicAuthProto); err != nil {
if basicAuthAny, err = protocov.ToAnyWithValidation(basicAuthProto); err != nil {
return nil, err
}

Expand Down Expand Up @@ -134,7 +135,7 @@ func (*basicAuth) patchRoute(route *routev3.Route, irRoute *ir.HTTPRoute) error
return err
}

if basicAuthAny, err = anypb.New(basicAuthProto); err != nil {
if basicAuthAny, err = protocov.ToAnyWithValidation(basicAuthProto); err != nil {
return err
}

Expand Down
9 changes: 5 additions & 4 deletions internal/xds/translator/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
"github.com/envoyproxy/gateway/internal/ir"
"github.com/envoyproxy/gateway/internal/utils/protocov"
)

const (
Expand Down Expand Up @@ -509,7 +510,7 @@ func buildTypedExtensionProtocolOptions(args *xdsClusterArgs) map[string]*anypb.
if args.http1Settings != nil {
http1opts.EnableTrailers = args.http1Settings.EnableTrailers
if args.http1Settings.PreserveHeaderCase {
preservecaseAny, _ := anypb.New(&preservecasev3.PreserveCaseFormatterConfig{})
preservecaseAny, _ := protocov.ToAnyWithValidation(&preservecasev3.PreserveCaseFormatterConfig{})
http1opts.HeaderKeyFormat = &corev3.Http1ProtocolOptions_HeaderKeyFormat{
HeaderFormat: &corev3.Http1ProtocolOptions_HeaderKeyFormat_StatefulFormatter{
StatefulFormatter: &corev3.TypedExtensionConfig{
Expand Down Expand Up @@ -562,7 +563,7 @@ func buildTypedExtensionProtocolOptions(args *xdsClusterArgs) map[string]*anypb.
}
}

anyProtocolOptions, _ := anypb.New(&protocolOptions)
anyProtocolOptions, _ := protocov.ToAnyWithValidation(&protocolOptions)

extensionOptions := map[string]*anypb.Any{
extensionOptionsKey: anyProtocolOptions,
Expand Down Expand Up @@ -593,7 +594,7 @@ func buildProxyProtocolSocket(proxyProtocol *ir.ProxyProtocol, tSocket *corev3.T
// If existing transport socket does not exist wrap around raw buffer
if tSocket == nil {
rawCtx := &rawbufferv3.RawBuffer{}
rawCtxAny, err := anypb.New(rawCtx)
rawCtxAny, err := protocov.ToAnyWithValidation(rawCtx)
if err != nil {
return nil
}
Expand All @@ -608,7 +609,7 @@ func buildProxyProtocolSocket(proxyProtocol *ir.ProxyProtocol, tSocket *corev3.T
ppCtx.TransportSocket = tSocket
}

ppCtxAny, err := anypb.New(ppCtx)
ppCtxAny, err := protocov.ToAnyWithValidation(ppCtx)
if err != nil {
return nil
}
Expand Down
11 changes: 6 additions & 5 deletions internal/xds/translator/custom_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
"github.com/envoyproxy/gateway/internal/ir"
"github.com/envoyproxy/gateway/internal/utils/protocov"
"github.com/envoyproxy/gateway/internal/xds/types"
)

Expand Down Expand Up @@ -85,7 +86,7 @@ func (c *customResponse) buildHCMCustomResponseFilter(ro *ir.ResponseOverride) (
return nil, err
}

any, err := anypb.New(proto)
any, err := protocov.ToAnyWithValidation(proto)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -237,7 +238,7 @@ func (c *customResponse) buildHTTPAttributeCELInput() (*cncfv3.TypedExtensionCon
err error
)

if pb, err = anypb.New(&matcherv3.HttpAttributesCelMatchInput{}); err != nil {
if pb, err = protocov.ToAnyWithValidation(&matcherv3.HttpAttributesCelMatchInput{}); err != nil {
return nil, err
}

Expand All @@ -253,7 +254,7 @@ func (c *customResponse) buildStatusCodeInput() (*cncfv3.TypedExtensionConfig, e
err error
)

if pb, err = anypb.New(&envoymatcherv3.HttpResponseStatusCodeMatchInput{}); err != nil {
if pb, err = protocov.ToAnyWithValidation(&envoymatcherv3.HttpResponseStatusCodeMatchInput{}); err != nil {
return nil, err
}

Expand Down Expand Up @@ -364,7 +365,7 @@ func (c *customResponse) buildStatusCodeCELMatcher(codeRange ir.StatusCodeRange)
return nil, err
}

if pb, err = anypb.New(matcher); err != nil {
if pb, err = protocov.ToAnyWithValidation(matcher); err != nil {
return nil, err
}

Expand Down Expand Up @@ -403,7 +404,7 @@ func (c *customResponse) buildAction(r ir.ResponseOverrideRule) (*matcherv3.Matc
return nil, err
}

if pb, err = anypb.New(response); err != nil {
if pb, err = protocov.ToAnyWithValidation(response); err != nil {
return nil, err
}

Expand Down
5 changes: 3 additions & 2 deletions internal/xds/translator/fault.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
"github.com/envoyproxy/gateway/internal/ir"
"github.com/envoyproxy/gateway/internal/utils/protocov"
"github.com/envoyproxy/gateway/internal/xds/types"
)

Expand Down Expand Up @@ -71,7 +72,7 @@ func buildHCMFaultFilter() (*hcmv3.HttpFilter, error) {
return nil, err
}

faultAny, err := anypb.New(faultProto)
faultAny, err := protocov.ToAnyWithValidation(faultProto)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -165,7 +166,7 @@ func (*fault) patchRoute(route *routev3.Route, irRoute *ir.HTTPRoute) error {
return nil
}

routeCfgAny, err := anypb.New(routeCfgProto)
routeCfgAny, err := protocov.ToAnyWithValidation(routeCfgProto)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 05817fc

Please sign in to comment.