Skip to content

Commit

Permalink
metadata: Use strings.EqualFold for ValueFromIncomingContext (#6743)
Browse files Browse the repository at this point in the history
  • Loading branch information
evanj authored Oct 30, 2023
1 parent 8cb9846 commit d7ea67b
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 12 deletions.
18 changes: 10 additions & 8 deletions metadata/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,16 @@ func Join(mds ...MD) MD {
type mdIncomingKey struct{}
type mdOutgoingKey struct{}

// NewIncomingContext creates a new context with incoming md attached.
// NewIncomingContext creates a new context with incoming md attached. md must
// not be modified after calling this function.
func NewIncomingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdIncomingKey{}, md)
}

// NewOutgoingContext creates a new context with outgoing md attached. If used
// in conjunction with AppendToOutgoingContext, NewOutgoingContext will
// overwrite any previously-appended metadata.
// overwrite any previously-appended metadata. md must not be modified after
// calling this function.
func NewOutgoingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md})
}
Expand Down Expand Up @@ -203,7 +205,8 @@ func FromIncomingContext(ctx context.Context) (MD, bool) {
}

// ValueFromIncomingContext returns the metadata value corresponding to the metadata
// key from the incoming metadata if it exists. Key must be lower-case.
// key from the incoming metadata if it exists. Keys are matched in a case insensitive
// manner.
//
// # Experimental
//
Expand All @@ -219,17 +222,16 @@ func ValueFromIncomingContext(ctx context.Context, key string) []string {
return copyOf(v)
}
for k, v := range md {
// We need to manually convert all keys to lower case, because MD is a
// map, and there's no guarantee that the MD attached to the context is
// created using our helper functions.
if strings.ToLower(k) == key {
// Case insenitive comparison: MD is a map, and there's no guarantee
// that the MD attached to the context is created using our helper
// functions.
if strings.EqualFold(k, key) {
return copyOf(v)
}
}
return nil
}

// the returned slice must not be modified in place
func copyOf(v []string) []string {
vals := make([]string, len(v))
copy(vals, v)
Expand Down
61 changes: 57 additions & 4 deletions metadata/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,48 @@ func (s) TestDelete(t *testing.T) {
}
}

func (s) TestFromIncomingContext(t *testing.T) {
md := Pairs(
"X-My-Header-1", "42",
)
// Verify that we lowercase if callers directly modify md
md["X-INCORRECT-UPPERCASE"] = []string{"foo"}
ctx := NewIncomingContext(context.Background(), md)

result, found := FromIncomingContext(ctx)
if !found {
t.Fatal("FromIncomingContext must return metadata")
}
expected := MD{
"x-my-header-1": []string{"42"},
"x-incorrect-uppercase": []string{"foo"},
}
if !reflect.DeepEqual(result, expected) {
t.Errorf("FromIncomingContext returned %#v, expected %#v", result, expected)
}

// ensure modifying result does not modify the value in the context
result["new_key"] = []string{"foo"}
result["x-my-header-1"][0] = "mutated"

result2, found := FromIncomingContext(ctx)
if !found {
t.Fatal("FromIncomingContext must return metadata")
}
if !reflect.DeepEqual(result2, expected) {
t.Errorf("FromIncomingContext after modifications returned %#v, expected %#v", result2, expected)
}
}

func (s) TestValueFromIncomingContext(t *testing.T) {
md := Pairs(
"X-My-Header-1", "42",
"X-My-Header-2", "43-1",
"X-My-Header-2", "43-2",
"x-my-header-3", "44",
)
// Verify that we lowercase if callers directly modify md
md["X-INCORRECT-UPPERCASE"] = []string{"foo"}
ctx := NewIncomingContext(context.Background(), md)

for _, test := range []struct {
Expand All @@ -227,6 +262,10 @@ func (s) TestValueFromIncomingContext(t *testing.T) {
key: "x-unknown",
want: nil,
},
{
key: "x-incorrect-uppercase",
want: []string{"foo"},
},
} {
v := ValueFromIncomingContext(ctx, test.key)
if !reflect.DeepEqual(v, test.want) {
Expand Down Expand Up @@ -348,8 +387,22 @@ func BenchmarkFromIncomingContext(b *testing.B) {
func BenchmarkValueFromIncomingContext(b *testing.B) {
md := Pairs("X-My-Header-1", "42")
ctx := NewIncomingContext(context.Background(), md)
b.ResetTimer()
for n := 0; n < b.N; n++ {
ValueFromIncomingContext(ctx, "x-my-header-1")
}

b.Run("key-found", func(b *testing.B) {
for n := 0; n < b.N; n++ {
result := ValueFromIncomingContext(ctx, "x-my-header-1")
if len(result) != 1 {
b.Fatal("ensures not optimized away")
}
}
})

b.Run("key-not-found", func(b *testing.B) {
for n := 0; n < b.N; n++ {
result := ValueFromIncomingContext(ctx, "key-not-found")
if len(result) != 0 {
b.Fatal("ensures not optimized away")
}
}
})
}

0 comments on commit d7ea67b

Please sign in to comment.