Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to validate and update a message using a collection o… #167

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,17 @@ func (msg *Message) SetIncludeSpans(value bool) *Message {
return msg
}

// TrimmedPartnerIDs returns a copy of the PartnerIDs field with all empty strings removed.
func (msg *Message) TrimmedPartnerIDs() []string {
trimmed := make([]string, 0, len(msg.PartnerIDs))
for _, id := range msg.PartnerIDs {
if id != "" {
trimmed = append(trimmed, id)
}
}
return trimmed
}

// SimpleRequestResponse represents a WRP message of type SimpleRequestResponseMessageType.
//
// https://github.com/xmidt-org/wrp-c/wiki/Web-Routing-Protocol#simple-request-response-definition
Expand Down
31 changes: 31 additions & 0 deletions messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,3 +765,34 @@ func TestUnknown(t *testing.T) {
})
}
}

func TestMessage_TrimmedPartnerIDs(t *testing.T) {
tests := []struct {
description string
partners []string
want []string
}{
{
description: "empty partner list",
partners: []string{},
want: []string{},
}, {
description: "normal partner list",
partners: []string{"foo", "bar", "baz"},
want: []string{"foo", "bar", "baz"},
}, {
description: "partner list with empty strings",
partners: []string{"", "foo", "", "bar", "", "baz", ""},
want: []string{"foo", "bar", "baz"},
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
assert := assert.New(t)
msg := &Message{
PartnerIDs: tc.partners,
}
assert.Equal(tc.want, msg.TrimmedPartnerIDs())
})
}
}
299 changes: 299 additions & 0 deletions normify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package wrp

import (
"errors"
"strconv"
"time"

"github.com/google/uuid"
)

var (
ErrInvalidMessageType = errors.New("invalid message type")
ErrInvalidPartnerID = errors.New("invalid partner ID")
ErrInvalidSource = errors.New("invalid source locator")
ErrInvalidDest = errors.New("invalid destination locator")
ErrInvalidString = errors.New("invalid UTF-8 string")
)

// Normifier applies a series of normalizing options to a WRP message.
type Normifier struct {
opts []NormifierOption
}

// NormifierOption is a functional option for normalizing a WRP message.
type NormifierOption interface {
// normify applies the option to the given message.
normify(*Message) error
}

// optionFunc is an adapter to allow the use of ordinary functions as
// normalizing options.
type optionFunc func(*Message) error

var _ NormifierOption = optionFunc(nil)

func (f optionFunc) normify(m *Message) error {
return f(m)
}

// New creates a new Correctifier with the given options.
func NewNormifier(opts ...NormifierOption) *Normifier {
return &Normifier{
opts: opts,
}
}

// Normify applies the normalizing and validating options to the message. It
// returns an error if any of the options fail.
func (n *Normifier) Normify(m *Message) error {
for _, opt := range n.opts {
if opt != nil {
if err := opt.normify(m); err != nil {
return err
}
}
}
return nil
}

// errorOption returns an option that always returns the given error.
func errorOption(err error) NormifierOption {
return optionFunc(func(*Message) error {
return err
})
}

// options returns a new option that applies all of the given options in order.
func options(opts ...NormifierOption) NormifierOption {
return optionFunc(func(m *Message) error {
for _, opt := range opts {
if opt != nil {
if err := opt.normify(m); err != nil {
return err
}
}
}
return nil
})
}

// -- Normalizers --------------------------------------------------------------

// ReplaceAnySelfLocator replaces any `self:` based locator with the scheme and
// authority of the given locator. If the given locator is not valid, the
// option returns an error.
func ReplaceAnySelfLocator(me string) NormifierOption {
return options(
ReplaceSourceSelfLocator(me),
ReplaceDestinationSelfLocator(me),
)
}

// ReplaceSourceSelfLocator replaces a `self:` based source locator with the
// scheme and authority of the given locator. If the given locator is not valid,
// the option returns an error.
func ReplaceSourceSelfLocator(me string) NormifierOption {
full, err := ParseLocator(me)
if err != nil {
return errorOption(err)
}

return optionFunc(func(m *Message) error {
src, err := ParseLocator(m.Source)
if err != nil {
return err
}

if src.Scheme == "self" {
src.Scheme = full.Scheme
src.Authority = full.Authority
m.Source = src.String()
}

return nil
})
}

// ReplaceDestinationSelfLocator replaces the destination of the message with the
// given locator if the destination is a `self:` based locator. If the given
// locator is not valid, the option returns an error.
func ReplaceDestinationSelfLocator(me string) NormifierOption {
full, err := ParseLocator(me)
if err != nil {
return errorOption(err)
}

return optionFunc(func(m *Message) error {
dst, err := ParseLocator(m.Destination)
if err != nil {
return err
}

if dst.Scheme == "self" {
dst.Scheme = full.Scheme
dst.Authority = full.Authority
m.Destination = dst.String()
}

return nil
})
}

// EnsureTransactionUUID ensures that the message has a transaction UUID. If
// the message does not have a transaction UUID, a new one is generated and
// added to the message.
func EnsureTransactionUUID() NormifierOption {
return optionFunc(func(m *Message) error {
if m.TransactionUUID == "" {
id, err := uuid.NewRandom()
if err != nil {
return err
}

Check warning on line 155 in normify.go

View check run for this annotation

Codecov / codecov/patch

normify.go#L154-L155

Added lines #L154 - L155 were not covered by tests

m.TransactionUUID = id.String()
}
return nil
})
}

// EnsurePartnerID ensures that the message includes the given partner ID in
// the list. If not present, the partner ID is added to the list.
func EnsurePartnerID(partnerID string) NormifierOption {
return optionFunc(func(m *Message) error {
if m.PartnerIDs == nil {
m.PartnerIDs = make([]string, 0, 1)
}
for _, id := range m.PartnerIDs {
if id == partnerID {
return nil
}
}
m.PartnerIDs = append(m.PartnerIDs, partnerID)
return nil
})
}

// SetPartnerID ensures that the message has only the given partner ID. This
// will always set the partner ID, replacing any existing partner IDs.
func SetPartnerID(partnerID string) NormifierOption {
return optionFunc(func(m *Message) error {
m.PartnerIDs = []string{partnerID}
return nil
})
}

// SetSessionID ensures that the message has the given session ID. This will
// always set the session ID, replacing any existing session ID
func SetSessionID(sessionID string) NormifierOption {
return optionFunc(func(m *Message) error {
m.SessionID = sessionID
return nil
})
}

// EnsureMetadataString ensures that the message has the given string metadata.
// This will always set the metadata.
func EnsureMetadataString(key, value string) NormifierOption {
return optionFunc(func(m *Message) error {
if m.Metadata == nil {
m.Metadata = make(map[string]string)
}
m.Metadata[key] = value
return nil
})
}

// EnsureMetadataTime ensures that the message has the given time metadata.
// This will always set the metadata. The time is formatted using RFC3339.
func EnsureMetadataTime(key string, t time.Time) NormifierOption {
return EnsureMetadataString(key, t.Format(time.RFC3339))
}

// EnsureMetadataInt64 ensures that the message has the given integer metadata.
// This will always set the metadata. The integer is converted to a string
// using base 10.
func EnsureMetadataInt64(key string, i int64) NormifierOption {
return EnsureMetadataString(key, strconv.FormatInt(i, 10))
}

// -- Validators ---------------------------------------------------------------

// ValidateSource ensures that the source locator is valid.
func ValidateSource() NormifierOption {
return optionFunc(func(m *Message) error {
if _, err := ParseLocator(m.Source); err != nil {
return errors.Join(err, ErrInvalidSource)
}
return nil
})
}

// ValidateDestination ensures that the destination locator is valid.
func ValidateDestination() NormifierOption {
return optionFunc(func(m *Message) error {
if _, err := ParseLocator(m.Destination); err != nil {
return errors.Join(err, ErrInvalidDest)
}
return nil
})
}

// ValidateMessageType ensures that the message type is valid.
func ValidateMessageType() NormifierOption {
return optionFunc(func(m *Message) error {
if m.Type <= Invalid1MessageType || m.Type >= LastMessageType {
return ErrInvalidMessageType
}
return nil
})
}

// ValidateOnlyUTF8Strings ensures that all string fields in the message are
// valid UTF-8.
func ValidateOnlyUTF8Strings() NormifierOption {
return optionFunc(func(m *Message) error {
if err := UTF8(m); err != nil {
return errors.Join(err, ErrInvalidString)
}
return nil
})
}

// ValidateIsPartner ensures that the message has the given partner ID.
func ValidateIsPartner(partner string) NormifierOption {
return optionFunc(func(m *Message) error {
list := m.TrimmedPartnerIDs()
if len(list) != 1 || list[0] != partner {
return ErrInvalidPartnerID
}

return nil
})
}

// ValidateHasPartner ensures that the message has one of the given partner
// IDs.
func ValidateHasPartner(partners ...string) NormifierOption {
trimmed := make([]string, 0, len(partners))
for _, p := range partners {
if p != "" {
trimmed = append(trimmed, p)
}
}

return optionFunc(func(m *Message) error {
list := m.TrimmedPartnerIDs()
for _, p := range trimmed {
for _, id := range list {
if id == p {
return nil
}
}
}
return ErrInvalidPartnerID
})
}
Loading
Loading