Skip to content

Commit

Permalink
Fix #392:
Browse files Browse the repository at this point in the history
- Move util/metautils to root-level package metadata
- Rename NiceMD to MD, which is a wrapper for grpc/metadata.MD
  • Loading branch information
rahulkhairwar committed Dec 5, 2021
1 parent dd1540e commit 3267a84
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 171 deletions.
10 changes: 5 additions & 5 deletions interceptors/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ import (
"testing"
"time"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/metadata"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"golang.org/x/oauth2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/oauth"
"google.golang.org/grpc/metadata"
grpcMetadata "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/util/metautils"
)

var authedMarker struct{}
Expand Down Expand Up @@ -66,9 +67,8 @@ func (s *assertingPingService) PingList(ping *testpb.PingListRequest, stream tes
}

func ctxWithToken(ctx context.Context, scheme string, token string) context.Context {
md := metadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token))
nCtx := metautils.NiceMD(md).ToOutgoing(ctx)
return nCtx
md := grpcMetadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token))
return metadata.MD(md).ToOutgoing(ctx)
}

func TestAuthTestSuite(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions interceptors/auth/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (
"context"
"strings"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/metadata"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/util/metautils"
)

var (
Expand All @@ -23,7 +23,7 @@ var (
// case-insensitive format (see rfc2617, sec 1.2). If no such authorization is found, or the token
// is of wrong scheme, an error with gRPC status `Unauthenticated` is returned.
func AuthFromMD(ctx context.Context, expectedScheme string) (string, error) {
val := metautils.ExtractIncoming(ctx).Get(headerAuthorize)
val := metadata.ExtractIncoming(ctx).Get(headerAuthorize)
if val == "" {
return "", status.Errorf(codes.Unauthenticated, "Request unauthenticated with "+expectedScheme)
}
Expand Down
24 changes: 12 additions & 12 deletions interceptors/auth/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,62 +7,62 @@ import (
"context"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/metadata"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
grpcMetadata "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/util/metautils"
)

func TestAuthFromMD(t *testing.T) {
for _, run := range []struct {
md metadata.MD
md grpcMetadata.MD
value string
errCode codes.Code
msg string
}{
{
md: metadata.Pairs("authorization", "bearer some_token"),
md: grpcMetadata.Pairs("authorization", "bearer some_token"),
value: "some_token",
msg: "must extract simple bearer tokens without case checking",
},
{
md: metadata.Pairs("authorization", "Bearer some_token"),
md: grpcMetadata.Pairs("authorization", "Bearer some_token"),
value: "some_token",
msg: "must extract simple bearer tokens with case checking",
},
{
md: metadata.Pairs("authorization", "Bearer some multi string bearer"),
md: grpcMetadata.Pairs("authorization", "Bearer some multi string bearer"),
value: "some multi string bearer",
msg: "must handle string based bearers",
},
{
md: metadata.Pairs("authorization", "Basic login:passwd"),
md: grpcMetadata.Pairs("authorization", "Basic login:passwd"),
value: "",
errCode: codes.Unauthenticated,
msg: "must check authentication type",
},
{
md: metadata.Pairs("authorization", "Basic login:passwd", "authorization", "bearer some_token"),
md: grpcMetadata.Pairs("authorization", "Basic login:passwd", "authorization", "bearer some_token"),
value: "",
errCode: codes.Unauthenticated,
msg: "must not allow multiple authentication methods",
},
{
md: metadata.Pairs("authorization", ""),
md: grpcMetadata.Pairs("authorization", ""),
value: "",
errCode: codes.Unauthenticated,
msg: "authorization string must not be empty",
},
{
md: metadata.Pairs("authorization", "Bearer"),
md: grpcMetadata.Pairs("authorization", "Bearer"),
value: "",
errCode: codes.Unauthenticated,
msg: "bearer token must not be empty",
},
} {
ctx := metautils.NiceMD(run.md).ToIncoming(context.TODO())
ctx := metadata.MD(run.md).ToIncoming(context.TODO())
out, err := AuthFromMD(ctx, "bearer")
if run.errCode != codes.OK {
assert.Equal(t, run.errCode, status.Code(err), run.msg)
Expand Down
12 changes: 6 additions & 6 deletions interceptors/retry/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ import (
"sync"
"time"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/metadata"

"golang.org/x/net/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
grpcMetadata "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/util/metautils"
)

const (
Expand Down Expand Up @@ -170,11 +170,11 @@ func (s *serverStreamingRetryingStream) CloseSend() error {
return s.getStream().CloseSend()
}

func (s *serverStreamingRetryingStream) Header() (metadata.MD, error) {
func (s *serverStreamingRetryingStream) Header() (grpcMetadata.MD, error) {
return s.getStream().Header()
}

func (s *serverStreamingRetryingStream) Trailer() metadata.MD {
func (s *serverStreamingRetryingStream) Trailer() grpcMetadata.MD {
return s.getStream().Trailer()
}

Expand Down Expand Up @@ -296,7 +296,7 @@ func perCallContext(parentCtx context.Context, callOpts *options, attempt uint)
ctx, cancel = context.WithTimeout(ctx, callOpts.perCallTimeout)
}
if attempt > 0 && callOpts.includeHeader {
mdClone := metautils.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, fmt.Sprintf("%d", attempt))
mdClone := metadata.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, fmt.Sprintf("%d", attempt))
ctx = mdClone.ToOutgoing(ctx)
}
return ctx, cancel
Expand Down
18 changes: 9 additions & 9 deletions interceptors/skip/interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
grpcMetadata "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
Expand Down Expand Up @@ -42,7 +42,7 @@ type skipPingService struct {
}

func checkMetadata(ctx context.Context, grpcType interceptors.GRPCType, service string, method string) error {
m, _ := metadata.FromIncomingContext(ctx)
m, _ := grpcMetadata.FromIncomingContext(ctx)
if typeFromMetadata := m.Get(keyGRPCType)[0]; typeFromMetadata != string(grpcType) {
return status.Errorf(codes.Internal, fmt.Sprintf("expected grpc type %s, got: %s", grpcType, typeFromMetadata))
}
Expand Down Expand Up @@ -82,7 +82,7 @@ func (s *skipPingService) PingList(_ *testpb.PingListRequest, stream testpb.Test
}

func filter(ctx context.Context, gRPCType interceptors.GRPCType, service string, method string) bool {
m, _ := metadata.FromIncomingContext(ctx)
m, _ := grpcMetadata.FromIncomingContext(ctx)
// Set parameters into metadata
m.Set(keyGRPCType, string(gRPCType))
m.Set(keyService, service)
Expand Down Expand Up @@ -144,14 +144,14 @@ func (s *SkipSuite) TestPing() {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var m metadata.MD
var m grpcMetadata.MD
if tc.skip {
m = metadata.New(map[string]string{
m = grpcMetadata.New(map[string]string{
"skip": "true",
})
}

resp, err := s.Client.Ping(metadata.NewOutgoingContext(s.SimpleCtx(), m), testpb.GoodPing)
resp, err := s.Client.Ping(grpcMetadata.NewOutgoingContext(s.SimpleCtx(), m), testpb.GoodPing)
require.NoError(t, err)

var value string
Expand Down Expand Up @@ -182,14 +182,14 @@ func (s *SkipSuite) TestPingList() {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var m metadata.MD
var m grpcMetadata.MD
if tc.skip {
m = metadata.New(map[string]string{
m = grpcMetadata.New(map[string]string{
"skip": "true",
})
}

stream, err := s.Client.PingList(metadata.NewOutgoingContext(s.SimpleCtx(), m), testpb.GoodPingList)
stream, err := s.Client.PingList(grpcMetadata.NewOutgoingContext(s.SimpleCtx(), m), testpb.GoodPingList)
require.NoError(t, err)

for {
Expand Down
10 changes: 5 additions & 5 deletions util/metautils/doc.go → metadata/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
// Licensed under the Apache License 2.0.

/*
Package `metautils` provides convenience functions for dealing with gRPC metadata.MD objects inside
Package `metadata` provides convenience functions for dealing with gRPC metadata.MD objects inside
Context handlers.
While the upstream grpc-go package contains decent functionality (see https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md)
they are hard to use.
The majority of functions center around the NiceMD, which is a convenience wrapper around metadata.MD. For example
The majority of functions center around the MD, which is a convenience wrapper around metadata.MD. For example
the following code allows you to easily extract incoming metadata (server handler) and put it into a new client context
metadata.
nmd := metautils.ExtractIncoming(serverCtx).Clone(":authorization", ":custom")
clientCtx := nmd.Set("x-client-header", "2").Set("x-another", "3").ToOutgoing(ctx)
md := metadata.ExtractIncoming(serverCtx).Clone(":authorization", ":custom")
clientCtx := md.Set("x-client-header", "2").Set("x-another", "3").ToOutgoing(ctx)
*/

package metautils
package metadata
58 changes: 29 additions & 29 deletions util/metautils/nicemd.go → metadata/metadata.go
Original file line number Diff line number Diff line change
@@ -1,48 +1,48 @@
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.

package metautils
package metadata

import (
"context"
"strings"

"google.golang.org/grpc/metadata"
grpcMetadata "google.golang.org/grpc/metadata"
)

// NiceMD is a convenience wrapper defining extra functions on the metadata.
type NiceMD metadata.MD
// MD is a convenience wrapper defining extra functions on the metadata.
type MD grpcMetadata.MD

// ExtractIncoming extracts an inbound metadata from the server-side context.
//
// This function always returns a NiceMD wrapper of the metadata.MD, in case the context doesn't have metadata it returns
// a new empty NiceMD.
func ExtractIncoming(ctx context.Context) NiceMD {
md, ok := metadata.FromIncomingContext(ctx)
// This function always returns a MD wrapper of the grpcMetadata.MD, in case the context doesn't have metadata it returns
// a new empty MD.
func ExtractIncoming(ctx context.Context) MD {
md, ok := grpcMetadata.FromIncomingContext(ctx)
if !ok {
return NiceMD(metadata.Pairs())
return MD(grpcMetadata.Pairs())
}
return NiceMD(md)
return MD(md)
}

// ExtractOutgoing extracts an outbound metadata from the client-side context.
//
// This function always returns a NiceMD wrapper of the metadata.MD, in case the context doesn't have metadata it returns
// a new empty NiceMD.
func ExtractOutgoing(ctx context.Context) NiceMD {
md, ok := metadata.FromOutgoingContext(ctx)
// This function always returns a MD wrapper of the grpcMetadata.MD, in case the context doesn't have metadata it returns
// a new empty MD.
func ExtractOutgoing(ctx context.Context) MD {
md, ok := grpcMetadata.FromOutgoingContext(ctx)
if !ok {
return NiceMD(metadata.Pairs())
return MD(grpcMetadata.Pairs())
}
return NiceMD(md)
return MD(md)
}

// Clone performs a *deep* copy of the metadata.MD.
// Clone performs a *deep* copy of the grpcMetadata.MD.
//
// You can specify the lower-case copiedKeys to only copy certain whitelisted keys. If no keys are explicitly whitelisted
// all keys get copied.
func (m NiceMD) Clone(copiedKeys ...string) NiceMD {
newMd := NiceMD(metadata.Pairs())
func (m MD) Clone(copiedKeys ...string) MD {
newMd := MD(grpcMetadata.Pairs())
for k, vv := range m {
found := false
if len(copiedKeys) == 0 {
Expand All @@ -64,16 +64,16 @@ func (m NiceMD) Clone(copiedKeys ...string) NiceMD {
return newMd
}

// ToOutgoing sets the given NiceMD as a client-side context for dispatching.
func (m NiceMD) ToOutgoing(ctx context.Context) context.Context {
return metadata.NewOutgoingContext(ctx, metadata.MD(m))
// ToOutgoing sets the given MD as a client-side context for dispatching.
func (m MD) ToOutgoing(ctx context.Context) context.Context {
return grpcMetadata.NewOutgoingContext(ctx, grpcMetadata.MD(m))
}

// ToIncoming sets the given NiceMD as a server-side context for dispatching.
// ToIncoming sets the given MD as a server-side context for dispatching.
//
// This is mostly useful in ServerInterceptors.
func (m NiceMD) ToIncoming(ctx context.Context) context.Context {
return metadata.NewIncomingContext(ctx, metadata.MD(m))
func (m MD) ToIncoming(ctx context.Context) context.Context {
return grpcMetadata.NewIncomingContext(ctx, grpcMetadata.MD(m))
}

// Get retrieves a single value from the metadata.
Expand All @@ -82,7 +82,7 @@ func (m NiceMD) ToIncoming(ctx context.Context) context.Context {
// an empty string is returned.
//
// The function is binary-key safe.
func (m NiceMD) Get(key string) string {
func (m MD) Get(key string) string {
k, _ := encodeKeyValue(key, "")
vv, ok := m[k]
if !ok {
Expand All @@ -97,7 +97,7 @@ func (m NiceMD) Get(key string) string {
//
// The function is binary-key safe.

func (m NiceMD) Del(key string) NiceMD {
func (m MD) Del(key string) MD {
k, _ := encodeKeyValue(key, "")
delete(m, k)
return m
Expand All @@ -108,7 +108,7 @@ func (m NiceMD) Del(key string) NiceMD {
// It works analogously to http.Header.Set, overwriting all previous metadata values.
//
// The function is binary-key safe.
func (m NiceMD) Set(key string, value string) NiceMD {
func (m MD) Set(key string, value string) MD {
k, v := encodeKeyValue(key, value)
m[k] = []string{v}
return m
Expand All @@ -119,7 +119,7 @@ func (m NiceMD) Set(key string, value string) NiceMD {
// It works analogously to http.Header.Add, as it appends to any existing values associated with key.
//
// The function is binary-key safe.
func (m NiceMD) Add(key string, value string) NiceMD {
func (m MD) Add(key string, value string) MD {
k, v := encodeKeyValue(key, value)
m[k] = append(m[k], v)
return m
Expand Down
Loading

0 comments on commit 3267a84

Please sign in to comment.