Skip to content

Commit

Permalink
Merge pull request #391 from ecordell/universal-consistency
Browse files Browse the repository at this point in the history
add universal consistency middleware
  • Loading branch information
ecordell authored Feb 3, 2022
2 parents a906b7c + de46c9b commit b536bd9
Show file tree
Hide file tree
Showing 31 changed files with 616 additions and 679 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
with:
go-version: "^1.17"
- name: "Test"
run: "go test -race ./..."
run: "go test -race -timeout 20m ./..."

integration:
name: "Integration"
Expand Down
2 changes: 1 addition & 1 deletion internal/dispatch/combined/combined.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func UpstreamAddr(addr string) Option {
}
}

// UpstreamAddr sets the optional cluster dispatching upstream certificate
// UpstreamCAPath sets the optional cluster dispatching upstream certificate
// authority.
func UpstreamCAPath(path string) Option {
return func(state *optionState) {
Expand Down
112 changes: 104 additions & 8 deletions internal/middleware/consistency/consistency.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"errors"
"fmt"
"strings"

v0 "github.com/authzed/authzed-go/proto/authzed/api/v0"
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/rs/zerolog/log"
"github.com/shopspring/decimal"
Expand All @@ -13,14 +15,20 @@ import (
"google.golang.org/grpc/status"

"github.com/authzed/spicedb/internal/datastore"
datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
"github.com/authzed/spicedb/internal/services/serviceerrors"
"github.com/authzed/spicedb/pkg/zedtoken"
"github.com/authzed/spicedb/pkg/zookie"
)

type hasConsistency interface {
GetConsistency() *v1.Consistency
}

type hasAtRevision interface {
GetAtRevision() *v0.Zookie
}

type ctxKeyType struct{}

var revisionKey ctxKeyType = struct{}{}
Expand Down Expand Up @@ -61,18 +69,41 @@ func MustRevisionFromContext(ctx context.Context) (decimal.Decimal, *v1.ZedToken
// AddRevisionToContext adds a revision to the given context, based on the consistency block found
// in the given request (if applicable).
func AddRevisionToContext(ctx context.Context, req interface{}, ds datastore.Datastore) error {
reqWithConsistency, ok := req.(hasConsistency)
if !ok {
switch req := req.(type) {
case hasConsistency:
return addRevisionToContextFromConsistency(ctx, req, ds)
case hasAtRevision:
return addRevisionToContextFromAtRevision(ctx, req, ds)
default:
return addHeadRevision(ctx, ds)
}
}

// addHeadRevision sets the value of the revision in the context to the current head revision in the datastore
func addHeadRevision(ctx context.Context, ds datastore.Datastore) error {
handle := ctx.Value(revisionKey)
if handle == nil {
return nil
}

revision, err := ds.HeadRevision(ctx)
if err != nil {
return rewriteDatastoreError(ctx, err)
}
handle.(*revisionHandle).revision = revision
return nil
}

// addRevisionToContextFromConsistency adds a revision to the given context, based on the consistency block found
// in the given request (if applicable).
func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency, ds datastore.Datastore) error {
handle := ctx.Value(revisionKey)
if handle == nil {
return nil
}

var revision decimal.Decimal
consistency := reqWithConsistency.GetConsistency()
consistency := req.GetConsistency()

switch {
case consistency == nil || consistency.GetMinimizeLatency():
Expand Down Expand Up @@ -122,10 +153,48 @@ func AddRevisionToContext(ctx context.Context, req interface{}, ds datastore.Dat
return nil
}

// addRevisionToContextFromAtRevision adds a revision to the given context, based on the AtRevision field (v0 api only)
func addRevisionToContextFromAtRevision(ctx context.Context, req hasAtRevision, ds datastore.Datastore) error {
handle := ctx.Value(revisionKey)
if handle == nil {
return nil
}

// Read should attempt to use the exact revision requested
if req, ok := req.(*v0.ReadRequest); ok && req.AtRevision != nil {
decoded, err := zookie.DecodeRevision(req.AtRevision)
if err != nil {
return status.Errorf(codes.InvalidArgument, "bad request revision: %s", err)
}

handle.(*revisionHandle).revision = decoded
return nil
}

// all other requests pick a revision
revision, err := pickBestRevisionV0(ctx, req.GetAtRevision(), ds)
if err != nil {
return status.Errorf(codes.InvalidArgument, err.Error())
}
handle.(*revisionHandle).revision = revision
return nil
}

var bypassServiceWhitelist = map[string]struct{}{
"/grpc.reflection.v1alpha.ServerReflection/": {},
"/grpc.health.v1.Health/": {},
}

// UnaryServerInterceptor returns a new unary server interceptor that performs per-request exchange of
// the specified consistency configuration for the revision at which to perform the request.
func UnaryServerInterceptor(ds datastore.Datastore) grpc.UnaryServerInterceptor {
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
for bypass := range bypassServiceWhitelist {
if strings.HasPrefix(info.FullMethod, bypass) {
return handler(ctx, req)
}
}
ds := datastoremw.MustFromContext(ctx)
newCtx := ContextWithHandle(ctx)
if err := AddRevisionToContext(newCtx, req, ds); err != nil {
return nil, err
Expand All @@ -137,16 +206,20 @@ func UnaryServerInterceptor(ds datastore.Datastore) grpc.UnaryServerInterceptor

// StreamServerInterceptor returns a new stream server interceptor that performs per-request exchange of
// the specified consistency configuration for the revision at which to perform the request.
func StreamServerInterceptor(ds datastore.Datastore) grpc.StreamServerInterceptor {
func StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &recvWrapper{stream, ds, ContextWithHandle(stream.Context())}
for bypass := range bypassServiceWhitelist {
if strings.HasPrefix(info.FullMethod, bypass) {
return handler(srv, stream)
}
}
wrapper := &recvWrapper{stream, ContextWithHandle(stream.Context())}
return handler(srv, wrapper)
}
}

type recvWrapper struct {
grpc.ServerStream
ds datastore.Datastore
ctx context.Context
}

Expand All @@ -158,8 +231,9 @@ func (s *recvWrapper) RecvMsg(m interface{}) error {
if err := s.ServerStream.RecvMsg(m); err != nil {
return err
}
ds := datastoremw.MustFromContext(s.ctx)

if err := AddRevisionToContext(s.ctx, m, s.ds); err != nil {
if err := AddRevisionToContext(s.ctx, m, ds); err != nil {
return err
}

Expand Down Expand Up @@ -188,6 +262,28 @@ func pickBestRevision(ctx context.Context, requested *v1.ZedToken, ds datastore.
return databaseRev, nil
}

func pickBestRevisionV0(ctx context.Context, requested *v0.Zookie, ds datastore.Datastore) (decimal.Decimal, error) {
// Calculate a revision as we see fit
databaseRev, err := ds.OptimizedRevision(ctx)
if err != nil {
return decimal.Zero, err
}

if requested != nil {
requestedRev, err := zookie.DecodeRevision(requested)
if err != nil {
return decimal.Zero, errInvalidZedToken
}

if requestedRev.GreaterThan(databaseRev) {
return requestedRev, nil
}
return databaseRev, nil
}

return databaseRev, nil
}

func rewriteDatastoreError(ctx context.Context, err error) error {
switch {
case errors.As(err, &datastore.ErrPreconditionFailed{}):
Expand Down
58 changes: 56 additions & 2 deletions internal/middleware/consistency/consistency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"testing"

v0 "github.com/authzed/authzed-go/proto/authzed/api/v0"
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing"
pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
Expand All @@ -16,7 +17,9 @@ import (
"google.golang.org/grpc"

"github.com/authzed/spicedb/internal/datastore/memdb"
datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
"github.com/authzed/spicedb/pkg/zedtoken"
"github.com/authzed/spicedb/pkg/zookie"
)

func TestAddRevisionToContextNoneSupplied(t *testing.T) {
Expand Down Expand Up @@ -135,6 +138,51 @@ func TestAddRevisionToContextAtInvalidExactSnapshot(t *testing.T) {
require.Error(err)
}

func TestAddRevisionToContextV0AtRevision(t *testing.T) {
require := require.New(t)

ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC, 0)
require.NoError(err)

databaseRev, err := ds.HeadRevision(context.Background())
require.NoError(err)

updated := ContextWithHandle(context.Background())
err = AddRevisionToContext(updated, &v0.ReadRequest{AtRevision: zookie.NewFromRevision(databaseRev)}, ds)
require.NoError(err)
require.Equal(databaseRev.BigInt(), RevisionFromContext(updated).BigInt())
}

func TestAddRevisionToContextV0NoAtRevision(t *testing.T) {
require := require.New(t)

ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC, 0)
require.NoError(err)

databaseRev, err := ds.HeadRevision(context.Background())
require.NoError(err)

updated := ContextWithHandle(context.Background())
err = AddRevisionToContext(updated, &v0.ReadRequest{}, ds)
require.NoError(err)
require.Equal(databaseRev.BigInt(), RevisionFromContext(updated).BigInt())
}

func TestAddRevisionToContextAPIAlwaysFullyConsistent(t *testing.T) {
require := require.New(t)

ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC, 0)
require.NoError(err)

databaseRev, err := ds.HeadRevision(context.Background())
require.NoError(err)

updated := ContextWithHandle(context.Background())
err = AddRevisionToContext(updated, &v1.WriteSchemaRequest{}, ds)
require.NoError(err)
require.Equal(databaseRev.BigInt(), RevisionFromContext(updated).BigInt())
}

func TestConsistencyTestSuite(t *testing.T) {
require := require.New(t)

Expand All @@ -144,8 +192,14 @@ func TestConsistencyTestSuite(t *testing.T) {
s := &ConsistencyTestSuite{
InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(StreamServerInterceptor(ds)),
grpc.UnaryInterceptor(UnaryServerInterceptor(ds)),
grpc.ChainStreamInterceptor(
datastoremw.StreamServerInterceptor(ds),
StreamServerInterceptor(),
),
grpc.ChainUnaryInterceptor(
datastoremw.UnaryServerInterceptor(ds),
UnaryServerInterceptor(),
),
},
},
}
Expand Down
Loading

0 comments on commit b536bd9

Please sign in to comment.