Skip to content

Commit

Permalink
Merge pull request #479 from basvanbeek/master
Browse files Browse the repository at this point in the history
Improved gRPC transport with some minor breaking changes.

gRPC Server Transport
=====================
Existing customers of the gRPC Server transport will have to get rid of the service scoped context as this has become kind of an anti-pattern.

This changes the gRPC server construction signature from:
  NewServer(ctx, endpoint, decoder, encoder, ...serverOptions)

to:
  NewServer(endpoint, decoder, encoder, ...serverOptions)

Server customers using `ResponseFunc` will need to update their signatures from:
  ResponseFunc func(context.Context, *metadata.MD)

to:
  ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context

PLEASE NOTE! Next to this signature update it might be that certain logic needs to move to the ServerRequestFunc if the logic was reading metadata from the originating request. This was incorrect behavior anyway as this is what ServerRequestFuncs are intended to handle. Use context to pass details if needing to manipulate response data based on incoming request metadata.


gRPC Client Transport
=====================
Customers of the gRPC Client transport will benefit from the newly added ClientResponseFunc, which will allow them to pick up metadata from both Headers and Trailers returned by a gRPC Server.

Consumers referencing RequestFunc will need to rename to ClientRequestFunc due to RequestFunc having been split up for Client and Server side.

ServerBefore() and ServerAfter() now append the referenced functions so next to variadic function parameters you can also call ServerBefore() and ServerAfter() multiple times without overwriting previous invocations. This is in line with the same change that already occurred server side.

Context
=======
Similar to the HTTP Server and Client Transport, the gRPC Server and Client Transport now feature passing of context.Context for all Before and After Funcs. This can be helpful if needing to pass data along to chained middlewares or when response header logic might need to signal changes needed in the regular response payload which now can be done by passing details in context and picking them up in the response encoder step.
  • Loading branch information
basvanbeek authored Mar 5, 2017
2 parents 46c895b + 8864ce8 commit 322a9e0
Show file tree
Hide file tree
Showing 19 changed files with 654 additions and 63 deletions.
8 changes: 4 additions & 4 deletions auth/jwt/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ func FromHTTPContext() http.RequestFunc {

// ToGRPCContext moves JWT token from grpc metadata to context. Particularly
// userful for servers.
func ToGRPCContext() grpc.RequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
func ToGRPCContext() grpc.ServerRequestFunc {
return func(ctx context.Context, md metadata.MD) context.Context {
// capital "Key" is illegal in HTTP/2.
authHeader, ok := (*md)["authorization"]
authHeader, ok := md["authorization"]
if !ok {
return ctx
}
Expand All @@ -63,7 +63,7 @@ func ToGRPCContext() grpc.RequestFunc {

// FromGRPCContext moves JWT token from context to grpc metadata. Particularly
// useful for clients.
func FromGRPCContext() grpc.RequestFunc {
func FromGRPCContext() grpc.ClientRequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
token, ok := ctx.Value(JWTTokenContextKey).(string)
if ok {
Expand Down
6 changes: 3 additions & 3 deletions auth/jwt/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,23 @@ func TestToGRPCContext(t *testing.T) {
reqFunc := ToGRPCContext()

// No Authorization header is passed
ctx := reqFunc(context.Background(), &md)
ctx := reqFunc(context.Background(), md)
token := ctx.Value(JWTTokenContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
}

// Invalid Authorization header is passed
md["authorization"] = []string{fmt.Sprintf("%s", signedKey)}
ctx = reqFunc(context.Background(), &md)
ctx = reqFunc(context.Background(), md)
token = ctx.Value(JWTTokenContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
}

// Authorization header is correct
md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)}
ctx = reqFunc(context.Background(), &md)
ctx = reqFunc(context.Background(), md)
token, ok := ctx.Value(JWTTokenContextKey).(string)
if !ok {
t.Fatal("JWT Token not passed to context correctly")
Expand Down
2 changes: 1 addition & 1 deletion examples/addsvc/cmd/addsvc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func main() {
return
}

srv := addsvc.MakeGRPCServer(ctx, endpoints, tracer, logger)
srv := addsvc.MakeGRPCServer(endpoints, tracer, logger)
s := grpc.NewServer()
pb.RegisterAddServer(s, srv)

Expand Down
4 changes: 1 addition & 3 deletions examples/addsvc/transport_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,18 @@ import (
)

// MakeGRPCServer makes a set of endpoints available as a gRPC AddServer.
func MakeGRPCServer(ctx context.Context, endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer {
func MakeGRPCServer(endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer {
options := []grpctransport.ServerOption{
grpctransport.ServerErrorLogger(logger),
}
return &grpcServer{
sum: grpctransport.NewServer(
ctx,
endpoints.SumEndpoint,
DecodeGRPCSumRequest,
EncodeGRPCSumResponse,
append(options, grpctransport.ServerBefore(opentracing.FromGRPCRequest(tracer, "Sum", logger)))...,
),
concat: grpctransport.NewServer(
ctx,
endpoints.ConcatEndpoint,
DecodeGRPCConcatRequest,
EncodeGRPCConcatResponse,
Expand Down
6 changes: 3 additions & 3 deletions tracing/opentracing/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func ToGRPCRequest(tracer opentracing.Tracer, logger log.Logger) func(ctx contex
// `operationName` accordingly. If no trace could be found in `req`, the Span
// will be a trace root. The Span is incorporated in the returned Context and
// can be retrieved with opentracing.SpanFromContext(ctx).
func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md *metadata.MD) context.Context {
return func(ctx context.Context, md *metadata.MD) context.Context {
func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md metadata.MD) context.Context {
return func(ctx context.Context, md metadata.MD) context.Context {
var span opentracing.Span
wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{md})
wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{&md})
if err != nil && err != opentracing.ErrSpanContextNotFound {
logger.Log("err", err)
}
Expand Down
2 changes: 1 addition & 1 deletion tracing/opentracing/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestTraceGRPCRequestRoundtrip(t *testing.T) {

// Use FromGRPCRequest to verify that we can join with the trace given MD.
fromGRPCFunc := kitot.FromGRPCRequest(tracer, "joined", logger)
joinCtx := fromGRPCFunc(afterCtx, &md)
joinCtx := fromGRPCFunc(afterCtx, md)
joinedSpan := opentracing.SpanFromContext(joinCtx).(*mocktracer.MockSpan)

joinedContext := joinedSpan.Context().(mocktracer.MockSpanContext)
Expand Down
50 changes: 50 additions & 0 deletions transport/grpc/_grpc_test/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package test

import (
"context"

"google.golang.org/grpc"

"github.com/go-kit/kit/endpoint"
grpctransport "github.com/go-kit/kit/transport/grpc"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)

type clientBinding struct {
test endpoint.Endpoint
}

func (c *clientBinding) Test(ctx context.Context, a string, b int64) (context.Context, string, error) {
response, err := c.test(ctx, TestRequest{A: a, B: b})
if err != nil {
return nil, "", err
}
r := response.(*TestResponse)
return r.Ctx, r.V, nil
}

func NewClient(cc *grpc.ClientConn) Service {
return &clientBinding{
test: grpctransport.NewClient(
cc,
"pb.Test",
"Test",
encodeRequest,
decodeResponse,
&pb.TestResponse{},
grpctransport.ClientBefore(
injectCorrelationID,
),
grpctransport.ClientBefore(
displayClientRequestHeaders,
),
grpctransport.ClientAfter(
displayClientResponseHeaders,
displayClientResponseTrailers,
),
grpctransport.ClientAfter(
extractConsumedCorrelationID,
),
).Endpoint(),
}
}
141 changes: 141 additions & 0 deletions transport/grpc/_grpc_test/context_metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package test

import (
"context"
"fmt"

"google.golang.org/grpc/metadata"
)

type metaContext string

const (
correlationID metaContext = "correlation-id"
responseHDR metaContext = "my-response-header"
responseTRLR metaContext = "my-response-trailer"
correlationIDTRLR metaContext = "correlation-id-consumed"
)

/* client before functions */

func injectCorrelationID(ctx context.Context, md *metadata.MD) context.Context {
if hdr, ok := ctx.Value(correlationID).(string); ok {
fmt.Printf("\tClient found correlationID %q in context, set metadata header\n", hdr)
(*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr)
}
return ctx
}

func displayClientRequestHeaders(ctx context.Context, md *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tClient >> Request Headers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

/* server before functions */

func extractCorrelationID(ctx context.Context, md metadata.MD) context.Context {
if hdr, ok := md[string(correlationID)]; ok {
cID := hdr[len(hdr)-1]
ctx = context.WithValue(ctx, correlationID, cID)
fmt.Printf("\tServer received correlationID %q in metadata header, set context\n", cID)
}
return ctx
}

func displayServerRequestHeaders(ctx context.Context, md metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tServer << Request Headers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

/* server after functions */

func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
*md = metadata.Join(*md, metadata.Pairs(string(responseHDR), "has-a-value"))
return ctx
}

func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tServer >> Response Headers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
*md = metadata.Join(*md, metadata.Pairs(string(responseTRLR), "has-a-value-too"))
return ctx
}

func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
if hdr, ok := ctx.Value(correlationID).(string); ok {
fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr)
*md = metadata.Join(*md, metadata.Pairs(string(correlationIDTRLR), hdr))
}
return ctx
}

func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tServer >> Response Trailers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

/* client after functions */

func displayClientResponseHeaders(ctx context.Context, md metadata.MD, _ metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tClient << Response Headers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

func displayClientResponseTrailers(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tClient << Response Trailers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

func extractConsumedCorrelationID(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context {
if hdr, ok := md[string(correlationIDTRLR)]; ok {
fmt.Printf("\tClient received consumed correlationID %q in metadata trailer, set context\n", hdr[len(hdr)-1])
ctx = context.WithValue(ctx, correlationIDTRLR, hdr[len(hdr)-1])
}
return ctx
}

/* CorrelationID context handlers */

func SetCorrelationID(ctx context.Context, v string) context.Context {
return context.WithValue(ctx, correlationID, v)
}

func GetConsumedCorrelationID(ctx context.Context) string {
if trlr, ok := ctx.Value(correlationIDTRLR).(string); ok {
return trlr
}
return ""
}
3 changes: 3 additions & 0 deletions transport/grpc/_grpc_test/pb/generate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package pb

//go:generate protoc test.proto --go_out=plugins=grpc:.
Loading

0 comments on commit 322a9e0

Please sign in to comment.