Skip to content

Commit

Permalink
Merge pull request #167 from xmidt-org/normify
Browse files Browse the repository at this point in the history
Add the ability to validate and update a message using a collection o…
  • Loading branch information
schmidtw authored Mar 13, 2024
2 parents 727d830 + 7f19f1b commit e53d23c
Show file tree
Hide file tree
Showing 4 changed files with 649 additions and 0 deletions.
11 changes: 11 additions & 0 deletions messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,17 @@ func (msg *Message) SetIncludeSpans(value bool) *Message {
return msg
}

// TrimmedPartnerIDs returns a copy of the PartnerIDs field with all empty strings removed.
func (msg *Message) TrimmedPartnerIDs() []string {
trimmed := make([]string, 0, len(msg.PartnerIDs))
for _, id := range msg.PartnerIDs {
if id != "" {
trimmed = append(trimmed, id)
}
}
return trimmed
}

// SimpleRequestResponse represents a WRP message of type SimpleRequestResponseMessageType.
//
// https://github.com/xmidt-org/wrp-c/wiki/Web-Routing-Protocol#simple-request-response-definition
Expand Down
31 changes: 31 additions & 0 deletions messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,3 +765,34 @@ func TestUnknown(t *testing.T) {
})
}
}

func TestMessage_TrimmedPartnerIDs(t *testing.T) {
tests := []struct {
description string
partners []string
want []string
}{
{
description: "empty partner list",
partners: []string{},
want: []string{},
}, {
description: "normal partner list",
partners: []string{"foo", "bar", "baz"},
want: []string{"foo", "bar", "baz"},
}, {
description: "partner list with empty strings",
partners: []string{"", "foo", "", "bar", "", "baz", ""},
want: []string{"foo", "bar", "baz"},
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
assert := assert.New(t)
msg := &Message{
PartnerIDs: tc.partners,
}
assert.Equal(tc.want, msg.TrimmedPartnerIDs())
})
}
}
299 changes: 299 additions & 0 deletions normify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package wrp

import (
"errors"
"strconv"
"time"

"github.com/google/uuid"
)

var (
ErrInvalidMessageType = errors.New("invalid message type")
ErrInvalidPartnerID = errors.New("invalid partner ID")
ErrInvalidSource = errors.New("invalid source locator")
ErrInvalidDest = errors.New("invalid destination locator")
ErrInvalidString = errors.New("invalid UTF-8 string")
)

// Normifier applies a series of normalizing options to a WRP message.
type Normifier struct {
opts []NormifierOption
}

// NormifierOption is a functional option for normalizing a WRP message.
type NormifierOption interface {
// normify applies the option to the given message.
normify(*Message) error
}

// optionFunc is an adapter to allow the use of ordinary functions as
// normalizing options.
type optionFunc func(*Message) error

var _ NormifierOption = optionFunc(nil)

func (f optionFunc) normify(m *Message) error {
return f(m)
}

// New creates a new Correctifier with the given options.
func NewNormifier(opts ...NormifierOption) *Normifier {
return &Normifier{
opts: opts,
}
}

// Normify applies the normalizing and validating options to the message. It
// returns an error if any of the options fail.
func (n *Normifier) Normify(m *Message) error {
for _, opt := range n.opts {
if opt != nil {
if err := opt.normify(m); err != nil {
return err
}
}
}
return nil
}

// errorOption returns an option that always returns the given error.
func errorOption(err error) NormifierOption {
return optionFunc(func(*Message) error {
return err
})
}

// options returns a new option that applies all of the given options in order.
func options(opts ...NormifierOption) NormifierOption {
return optionFunc(func(m *Message) error {
for _, opt := range opts {
if opt != nil {
if err := opt.normify(m); err != nil {
return err
}
}
}
return nil
})
}

// -- Normalizers --------------------------------------------------------------

// ReplaceAnySelfLocator replaces any `self:` based locator with the scheme and
// authority of the given locator. If the given locator is not valid, the
// option returns an error.
func ReplaceAnySelfLocator(me string) NormifierOption {
return options(
ReplaceSourceSelfLocator(me),
ReplaceDestinationSelfLocator(me),
)
}

// ReplaceSourceSelfLocator replaces a `self:` based source locator with the
// scheme and authority of the given locator. If the given locator is not valid,
// the option returns an error.
func ReplaceSourceSelfLocator(me string) NormifierOption {
full, err := ParseLocator(me)
if err != nil {
return errorOption(err)
}

return optionFunc(func(m *Message) error {
src, err := ParseLocator(m.Source)
if err != nil {
return err
}

if src.Scheme == "self" {
src.Scheme = full.Scheme
src.Authority = full.Authority
m.Source = src.String()
}

return nil
})
}

// ReplaceDestinationSelfLocator replaces the destination of the message with the
// given locator if the destination is a `self:` based locator. If the given
// locator is not valid, the option returns an error.
func ReplaceDestinationSelfLocator(me string) NormifierOption {
full, err := ParseLocator(me)
if err != nil {
return errorOption(err)
}

return optionFunc(func(m *Message) error {
dst, err := ParseLocator(m.Destination)
if err != nil {
return err
}

if dst.Scheme == "self" {
dst.Scheme = full.Scheme
dst.Authority = full.Authority
m.Destination = dst.String()
}

return nil
})
}

// EnsureTransactionUUID ensures that the message has a transaction UUID. If
// the message does not have a transaction UUID, a new one is generated and
// added to the message.
func EnsureTransactionUUID() NormifierOption {
return optionFunc(func(m *Message) error {
if m.TransactionUUID == "" {
id, err := uuid.NewRandom()
if err != nil {
return err
}

m.TransactionUUID = id.String()
}
return nil
})
}

// EnsurePartnerID ensures that the message includes the given partner ID in
// the list. If not present, the partner ID is added to the list.
func EnsurePartnerID(partnerID string) NormifierOption {
return optionFunc(func(m *Message) error {
if m.PartnerIDs == nil {
m.PartnerIDs = make([]string, 0, 1)
}
for _, id := range m.PartnerIDs {
if id == partnerID {
return nil
}
}
m.PartnerIDs = append(m.PartnerIDs, partnerID)
return nil
})
}

// SetPartnerID ensures that the message has only the given partner ID. This
// will always set the partner ID, replacing any existing partner IDs.
func SetPartnerID(partnerID string) NormifierOption {
return optionFunc(func(m *Message) error {
m.PartnerIDs = []string{partnerID}
return nil
})
}

// SetSessionID ensures that the message has the given session ID. This will
// always set the session ID, replacing any existing session ID
func SetSessionID(sessionID string) NormifierOption {
return optionFunc(func(m *Message) error {
m.SessionID = sessionID
return nil
})
}

// EnsureMetadataString ensures that the message has the given string metadata.
// This will always set the metadata.
func EnsureMetadataString(key, value string) NormifierOption {
return optionFunc(func(m *Message) error {
if m.Metadata == nil {
m.Metadata = make(map[string]string)
}
m.Metadata[key] = value
return nil
})
}

// EnsureMetadataTime ensures that the message has the given time metadata.
// This will always set the metadata. The time is formatted using RFC3339.
func EnsureMetadataTime(key string, t time.Time) NormifierOption {
return EnsureMetadataString(key, t.Format(time.RFC3339))
}

// EnsureMetadataInt64 ensures that the message has the given integer metadata.
// This will always set the metadata. The integer is converted to a string
// using base 10.
func EnsureMetadataInt64(key string, i int64) NormifierOption {
return EnsureMetadataString(key, strconv.FormatInt(i, 10))
}

// -- Validators ---------------------------------------------------------------

// ValidateSource ensures that the source locator is valid.
func ValidateSource() NormifierOption {
return optionFunc(func(m *Message) error {
if _, err := ParseLocator(m.Source); err != nil {
return errors.Join(err, ErrInvalidSource)
}
return nil
})
}

// ValidateDestination ensures that the destination locator is valid.
func ValidateDestination() NormifierOption {
return optionFunc(func(m *Message) error {
if _, err := ParseLocator(m.Destination); err != nil {
return errors.Join(err, ErrInvalidDest)
}
return nil
})
}

// ValidateMessageType ensures that the message type is valid.
func ValidateMessageType() NormifierOption {
return optionFunc(func(m *Message) error {
if m.Type <= Invalid1MessageType || m.Type >= LastMessageType {
return ErrInvalidMessageType
}
return nil
})
}

// ValidateOnlyUTF8Strings ensures that all string fields in the message are
// valid UTF-8.
func ValidateOnlyUTF8Strings() NormifierOption {
return optionFunc(func(m *Message) error {
if err := UTF8(m); err != nil {
return errors.Join(err, ErrInvalidString)
}
return nil
})
}

// ValidateIsPartner ensures that the message has the given partner ID.
func ValidateIsPartner(partner string) NormifierOption {
return optionFunc(func(m *Message) error {
list := m.TrimmedPartnerIDs()
if len(list) != 1 || list[0] != partner {
return ErrInvalidPartnerID
}

return nil
})
}

// ValidateHasPartner ensures that the message has one of the given partner
// IDs.
func ValidateHasPartner(partners ...string) NormifierOption {
trimmed := make([]string, 0, len(partners))
for _, p := range partners {
if p != "" {
trimmed = append(trimmed, p)
}
}

return optionFunc(func(m *Message) error {
list := m.TrimmedPartnerIDs()
for _, p := range trimmed {
for _, id := range list {
if id == p {
return nil
}
}
}
return ErrInvalidPartnerID
})
}
Loading

0 comments on commit e53d23c

Please sign in to comment.