Skip to content

Commit

Permalink
updates to reflect change in ServerRequestFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
basvanbeek committed Mar 5, 2017
1 parent 4079d84 commit 8864ce8
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
8 changes: 4 additions & 4 deletions auth/jwt/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ func FromHTTPContext() http.RequestFunc {

// ToGRPCContext moves JWT token from grpc metadata to context. Particularly
// userful for servers.
func ToGRPCContext() grpc.RequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
func ToGRPCContext() grpc.ServerRequestFunc {
return func(ctx context.Context, md metadata.MD) context.Context {
// capital "Key" is illegal in HTTP/2.
authHeader, ok := (*md)["authorization"]
authHeader, ok := md["authorization"]
if !ok {
return ctx
}
Expand All @@ -63,7 +63,7 @@ func ToGRPCContext() grpc.RequestFunc {

// FromGRPCContext moves JWT token from context to grpc metadata. Particularly
// useful for clients.
func FromGRPCContext() grpc.RequestFunc {
func FromGRPCContext() grpc.ClientRequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
token, ok := ctx.Value(JWTTokenContextKey).(string)
if ok {
Expand Down
6 changes: 3 additions & 3 deletions auth/jwt/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,23 @@ func TestToGRPCContext(t *testing.T) {
reqFunc := ToGRPCContext()

// No Authorization header is passed
ctx := reqFunc(context.Background(), &md)
ctx := reqFunc(context.Background(), md)
token := ctx.Value(JWTTokenContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
}

// Invalid Authorization header is passed
md["authorization"] = []string{fmt.Sprintf("%s", signedKey)}
ctx = reqFunc(context.Background(), &md)
ctx = reqFunc(context.Background(), md)
token = ctx.Value(JWTTokenContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
}

// Authorization header is correct
md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)}
ctx = reqFunc(context.Background(), &md)
ctx = reqFunc(context.Background(), md)
token, ok := ctx.Value(JWTTokenContextKey).(string)
if !ok {
t.Fatal("JWT Token not passed to context correctly")
Expand Down
6 changes: 3 additions & 3 deletions tracing/opentracing/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func ToGRPCRequest(tracer opentracing.Tracer, logger log.Logger) func(ctx contex
// `operationName` accordingly. If no trace could be found in `req`, the Span
// will be a trace root. The Span is incorporated in the returned Context and
// can be retrieved with opentracing.SpanFromContext(ctx).
func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md *metadata.MD) context.Context {
return func(ctx context.Context, md *metadata.MD) context.Context {
func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md metadata.MD) context.Context {
return func(ctx context.Context, md metadata.MD) context.Context {
var span opentracing.Span
wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{md})
wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{&md})
if err != nil && err != opentracing.ErrSpanContextNotFound {
logger.Log("err", err)
}
Expand Down
2 changes: 1 addition & 1 deletion tracing/opentracing/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestTraceGRPCRequestRoundtrip(t *testing.T) {

// Use FromGRPCRequest to verify that we can join with the trace given MD.
fromGRPCFunc := kitot.FromGRPCRequest(tracer, "joined", logger)
joinCtx := fromGRPCFunc(afterCtx, &md)
joinCtx := fromGRPCFunc(afterCtx, md)
joinedSpan := opentracing.SpanFromContext(joinCtx).(*mocktracer.MockSpan)

joinedContext := joinedSpan.Context().(mocktracer.MockSpanContext)
Expand Down

0 comments on commit 8864ce8

Please sign in to comment.