Skip to content

Commit

Permalink
Add skip interceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
XSAM committed Nov 24, 2020
1 parent 354e77b commit cfb8ee6
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 0 deletions.
6 changes: 6 additions & 0 deletions interceptors/skip/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/*
`skip` allow users to skip interceptors in certain condition.
Users can use grpc type, service name, method name and metadata to determine whether to skip the interceptor.
*/
package skip
30 changes: 30 additions & 0 deletions interceptors/skip/examples_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package skip_test

import (
"context"

"google.golang.org/grpc"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/skip"
)

// Simple example of skipping auth interceptor in the reflection method.
func Example_initialization() {
_ = grpc.NewServer(
grpc.UnaryInterceptor(skip.UnaryServerInterceptor(auth.UnaryServerInterceptor(exampleAuthFunc), ReflectionFilter)),
grpc.StreamInterceptor(skip.StreamServerInterceptor(auth.StreamServerInterceptor(exampleAuthFunc), ReflectionFilter)),
)
}

func exampleAuthFunc(ctx context.Context) (context.Context, error) {
return ctx, nil
}

func ReflectionFilter(ctx context.Context, gRPCType interceptors.GRPCType, service string, method string) bool {
if service == "grpc.reflection.v1alpha.ServerReflection" {
return true
}
return false
}
43 changes: 43 additions & 0 deletions interceptors/skip/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package skip

import (
"context"

"google.golang.org/grpc"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
)

type Filter func(ctx context.Context, gRPCType interceptors.GRPCType, service string, method string) bool

// UnaryServerInterceptor returns a new unary server interceptor that determines whether to skip the input interceptor.
func UnaryServerInterceptor(in grpc.UnaryServerInterceptor, filter Filter) grpc.UnaryServerInterceptor {
if filter == nil {
return in
}

return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
service, method := interceptors.SplitMethodName(info.FullMethod)
if filter(ctx, interceptors.Unary, service, method) {
// Skip interceptor
return handler(ctx, req)
}
return in(ctx, req, info, handler)
}
}

// StreamServerInterceptor returns a new streaming server interceptor that determines whether to skip the input interceptor.
func StreamServerInterceptor(in grpc.StreamServerInterceptor, filter Filter) grpc.StreamServerInterceptor {
if filter == nil {
return in
}

return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
service, method := interceptors.SplitMethodName(info.FullMethod)
if filter(ss.Context(), interceptors.StreamRPCType(info), service, method) {
// Skip interceptor
return handler(srv, ss)
}
return in(srv, ss, info, handler)
}
}
198 changes: 198 additions & 0 deletions interceptors/skip/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package skip_test

import (
"context"
"fmt"
"io"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/grpctesting"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/grpctesting/testpb"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/skip"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/tags"
)

var (
goodPing = &testpb.PingRequest{Value: "something", SleepTimeMs: 9999}
)

const (
keyGRPCType = "skip.grpc_type"
keyService = "skip.service"
keyMethod = "skip.method"
)

func skipped(ctx context.Context) bool {
if len(tags.Extract(ctx).Values()) > 0 {
return false
}
return true
}

type skipPingService struct {
testpb.TestServiceServer
}

func checkMetadata(ctx context.Context, grpcType interceptors.GRPCType, service string, method string) error {
m, _ := metadata.FromIncomingContext(ctx)
if typeFromMetadata := m.Get(keyGRPCType)[0]; typeFromMetadata != string(grpcType) {
return status.Errorf(codes.Internal, fmt.Sprintf("expected grpc type %s, got: %s", grpcType, typeFromMetadata))
}
if serviceFromMetadata := m.Get(keyService)[0]; serviceFromMetadata != service {
return status.Errorf(codes.Internal, fmt.Sprintf("expected service %s, got: %s", service, serviceFromMetadata))
}
if methodFromMetadata := m.Get(keyMethod)[0]; methodFromMetadata != method {
return status.Errorf(codes.Internal, fmt.Sprintf("expected method %s, got: %s", method, methodFromMetadata))
}
return nil
}

func (s *skipPingService) Ping(ctx context.Context, _ *testpb.PingRequest) (*testpb.PingResponse, error) {
err := checkMetadata(ctx, interceptors.Unary, "grpc_middleware.testpb.TestService", "Ping")
if err != nil {
return nil, err
}

if skipped(ctx) {
return &testpb.PingResponse{Value: "skipped"}, nil
}

return &testpb.PingResponse{}, nil
}

func (s *skipPingService) PingList(_ *testpb.PingRequest, stream testpb.TestService_PingListServer) error {
err := checkMetadata(stream.Context(), interceptors.ServerStream, "grpc_middleware.testpb.TestService", "PingList")
if err != nil {
return err
}

var out testpb.PingResponse
if skipped(stream.Context()) {
out.Value = "skipped"
}
return stream.Send(&out)
}

func filter(ctx context.Context, gRPCType interceptors.GRPCType, service string, method string) bool {
m, _ := metadata.FromIncomingContext(ctx)
// Set parameters into metadata
m.Set(keyGRPCType, string(gRPCType))
m.Set(keyService, service)
m.Set(keyMethod, method)

if v := m.Get("skip"); len(v) > 0 && v[0] == "true" {

return true
}
return false
}

func TestSkipSuite(t *testing.T) {
s := &SkipSuite{
InterceptorTestSuite: &grpctesting.InterceptorTestSuite{
TestService: &skipPingService{&grpctesting.TestPingService{T: t}},
ServerOpts: []grpc.ServerOption{
grpc.UnaryInterceptor(skip.UnaryServerInterceptor(tags.UnaryServerInterceptor(), filter)),
grpc.StreamInterceptor(skip.StreamServerInterceptor(tags.StreamServerInterceptor(), filter)),
},
},
}
suite.Run(t, s)
}

type SkipSuite struct {
*grpctesting.InterceptorTestSuite
}

func (s *SkipSuite) TestPing() {
t := s.T()

testCases := []struct {
name string
skip bool
}{
{
name: "skip tags interceptor",
skip: true,
},
{
name: "do not skip",
skip: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var m metadata.MD
if tc.skip {
m = metadata.New(map[string]string{
"skip": "true",
})
}

resp, err := s.Client.Ping(metadata.NewOutgoingContext(s.SimpleCtx(), m), goodPing)
require.NoError(t, err)

var value string
if tc.skip {
value = "skipped"
}
assert.Equal(t, value, resp.Value)
})
}
}

func (s *SkipSuite) TestPingList() {
t := s.T()

testCases := []struct {
name string
skip bool
}{
{
name: "skip tags interceptor",
skip: true,
},
{
name: "do not skip",
skip: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var m metadata.MD
if tc.skip {
m = metadata.New(map[string]string{
"skip": "true",
})
}

stream, err := s.Client.PingList(metadata.NewOutgoingContext(s.SimpleCtx(), m), goodPing)
require.NoError(t, err)

for {
resp, err := stream.Recv()
if err == io.EOF {
break
}
require.NoError(s.T(), err)

var value string
if tc.skip {
value = "skipped"
}
assert.Equal(t, value, resp.Value)
}
})
}
}

0 comments on commit cfb8ee6

Please sign in to comment.