diff --git a/cmd/spicedb/main.go b/cmd/spicedb/main.go index db18aa26bf..2820529566 100644 --- a/cmd/spicedb/main.go +++ b/cmd/spicedb/main.go @@ -6,9 +6,19 @@ import ( "net/http/pprof" "time" + "github.com/cespare/xxhash" "github.com/jzelinskie/cobrautil" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" + "github.com/sercand/kuberesolver/v3" + "google.golang.org/grpc/balancer" + + consistentbalancer "github.com/authzed/spicedb/pkg/balancer" +) + +const ( + hashringReplicationFactor = 20 + backendsPerKey = 1 ) var defaultPreRunE = cobrautil.CommandStack( @@ -31,6 +41,11 @@ func metricsHandler() http.Handler { func main() { rand.Seed(time.Now().UnixNano()) + // enable kubernetes grpc resolver + kuberesolver.RegisterInCluster() + // enable consistent hashring grpc load balancer + balancer.Register(consistentbalancer.NewConsistentHashringBuilder(xxhash.Sum64, hashringReplicationFactor, backendsPerKey)) + rootCmd := newRootCmd() registerVersionCmd(rootCmd) registerServeCmd(rootCmd) diff --git a/cmd/spicedb/serve.go b/cmd/spicedb/serve.go index a13a6184d9..850775948f 100644 --- a/cmd/spicedb/serve.go +++ b/cmd/spicedb/serve.go @@ -2,6 +2,7 @@ package main import ( "context" + "crypto/tls" "fmt" "os" "os/signal" @@ -9,6 +10,7 @@ import ( "time" "github.com/alecthomas/units" + "github.com/authzed/grpcutil" "github.com/fatih/color" grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpczerolog "github.com/grpc-ecosystem/go-grpc-middleware/providers/zerolog/v2" @@ -20,6 +22,7 @@ import ( "github.com/spf13/cobra" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "github.com/authzed/spicedb/internal/auth" "github.com/authzed/spicedb/internal/dashboard" @@ -36,12 +39,14 @@ import ( "github.com/authzed/spicedb/internal/gateway" "github.com/authzed/spicedb/internal/middleware/servicespecific" "github.com/authzed/spicedb/internal/namespace" + v1 "github.com/authzed/spicedb/internal/proto/dispatch/v1" "github.com/authzed/spicedb/internal/services" clusterdispatch "github.com/authzed/spicedb/internal/services/dispatch" v1alpha1svc "github.com/authzed/spicedb/internal/services/v1alpha1" logmw "github.com/authzed/spicedb/pkg/middleware/logging" "github.com/authzed/spicedb/pkg/middleware/requestid" "github.com/authzed/spicedb/pkg/validationfile" + "github.com/authzed/spicedb/pkg/x509util" ) func registerServeCmd(rootCmd *cobra.Command) { @@ -106,6 +111,7 @@ func registerServeCmd(rootCmd *cobra.Command) { // Flags for configuring dispatch behavior serveCmd.Flags().Uint32("dispatch-max-depth", 50, "maximum recursion depth for nested calls") cobrautil.RegisterGrpcServerFlags(serveCmd.Flags(), "dispatch-cluster", "dispatch", ":50053", false) + serveCmd.Flags().String("dispatch-upstream-addr", "", "upstream grpc address to dispatch to") serveCmd.Flags().String("dispatch-cluster-dns-name", "", "DNS SRV record name to resolve for cluster dispatch") serveCmd.Flags().String("dispatch-cluster-service-name", "grpc", "DNS SRV record service name to resolve for cluster dispatch") serveCmd.Flags().String("dispatch-peer-resolver-addr", "", "address used to connect to the peer endpoint resolver") @@ -275,8 +281,9 @@ func serveRun(cmd *cobra.Command, args []string) { } redispatch := graph.NewLocalOnlyDispatcher(nsm, ds) - redispatchClientCtx, redispatchClientCancel := context.WithCancel(context.Background()) + // servok redispatch configuration + redispatchClientCtx, redispatchClientCancel := context.WithCancel(context.Background()) redispatchTarget := cobrautil.MustGetStringExpanded(cmd, "dispatch-cluster-dns-name") redispatchServiceName := cobrautil.MustGetStringExpanded(cmd, "dispatch-cluster-service-name") if redispatchTarget != "" { @@ -322,6 +329,31 @@ func serveRun(cmd *cobra.Command, args []string) { redispatch = remote.NewClusterDispatcher(client) } + // grpc consistent loadbalancer redispatch configuration + dispatchAddr := cobrautil.MustGetStringExpanded(cmd, "dispatch-upstream-addr") + if len(dispatchAddr) > 0 { + log.Info().Str("upstream", dispatchAddr).Msg("configuring grpc consistent load balancer for redispatch") + + peerPSK := cobrautil.MustGetStringExpanded(cmd, "grpc-preshared-key") + peerCertPath := cobrautil.MustGetStringExpanded(cmd, "dispatch-cluster-tls-cert-path") + pool, err := x509util.CustomCertPool(peerCertPath) + if err != nil { + log.Fatal().Str("certpath", peerCertPath).Err(err).Msg("error loading certs for dispatch") + } + creds := credentials.NewTLS(&tls.Config{RootCAs: pool}) + + conn, err := grpc.Dial(dispatchAddr, + grpc.WithTransportCredentials(creds), + grpcutil.WithBearerToken(peerPSK), + grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor()), + grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"consistent-hashring"}`), + ) + if err != nil { + log.Fatal().Str("endpoint", dispatchAddr).Err(err).Msg("error constructing client for endpoint") + } + redispatch = remote.NewClusterDispatcher(v1.NewDispatchServiceClient(conn)) + } + cachingRedispatch, err := caching.NewCachingDispatcher(redispatch, nil, "dispatch_client") if err != nil { log.Fatal().Err(err).Msg("failed to initialize redispatcher cache") diff --git a/go.mod b/go.mod index 1a1616ca59..95c522413e 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( github.com/prometheus/procfs v0.7.3 // indirect github.com/rs/zerolog v1.26.0 github.com/scylladb/go-set v1.0.2 + github.com/sercand/kuberesolver/v3 v3.1.0 github.com/shopspring/decimal v1.3.1 github.com/spf13/cobra v1.2.1 github.com/spf13/viper v1.9.0 // indirect diff --git a/go.sum b/go.sum index c9105196b0..978b348535 100644 --- a/go.sum +++ b/go.sum @@ -577,6 +577,8 @@ github.com/scylladb/go-set v1.0.2 h1:SkvlMCKhP0wyyct6j+0IHJkBkSZL+TDzZ4E7f7BCcRE github.com/scylladb/go-set v1.0.2/go.mod h1:DkpGd78rljTxKAnTDPFqXSGxvETQnJyuSOQwsHycqfs= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/seccomp/libseccomp-golang v0.9.1/go.mod h1:GbW5+tmTXfcxTToHLXlScSlAvWlF4P2Ca7zGrPiEpWo= +github.com/sercand/kuberesolver/v3 v3.1.0 h1:Q6mbvkxvWH7LiwQkTfsHvFtx4aOtkCIXZ8Sxdm5wq7Y= +github.com/sercand/kuberesolver/v3 v3.1.0/go.mod h1:OSHRdFT97s/dOQaqdb1FXP/xG84i/aalrrsMphNh12Q= github.com/shabbyrobe/gocovmerge v0.0.0-20180507124511-f6ea450bfb63 h1:J6qvD6rbmOil46orKqJaRPG+zTpoGlBTUdyv8ki63L0= github.com/shabbyrobe/gocovmerge v0.0.0-20180507124511-f6ea450bfb63/go.mod h1:n+VKSARF5y/tS9XFSP7vWDfS+GUC5vs/YT7M5XDTUEM= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= diff --git a/internal/dispatch/remote/cluster.go b/internal/dispatch/remote/cluster.go index 1f7de3aec8..22ce1f7cb3 100644 --- a/internal/dispatch/remote/cluster.go +++ b/internal/dispatch/remote/cluster.go @@ -2,13 +2,20 @@ package remote import ( "context" + "fmt" "google.golang.org/grpc" + "google.golang.org/protobuf/proto" "github.com/authzed/spicedb/internal/dispatch" v1 "github.com/authzed/spicedb/internal/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/balancer" ) +const errComputingBackend = "unable to compute backend for request: %w" + +var protoMarshal = proto.MarshalOptions{Deterministic: true} + type clusterClient interface { DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest, opts ...grpc.CallOption) (*v1.DispatchCheckResponse, error) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest, opts ...grpc.CallOption) (*v1.DispatchExpandResponse, error) @@ -31,6 +38,11 @@ func (cr *clusterDispatcher) DispatchCheck(ctx context.Context, req *v1.Dispatch return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, err } + requestKey, err := protoMarshal.Marshal(req) + if err != nil { + return nil, fmt.Errorf(errComputingBackend, err) + } + ctx = context.WithValue(ctx, balancer.CtxKey, requestKey) resp, err := cr.clusterClient.DispatchCheck(ctx, req) if err != nil { return &v1.DispatchCheckResponse{Metadata: requestFailureMetadata}, err @@ -45,6 +57,11 @@ func (cr *clusterDispatcher) DispatchExpand(ctx context.Context, req *v1.Dispatc return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err } + requestKey, err := protoMarshal.Marshal(req) + if err != nil { + return nil, fmt.Errorf(errComputingBackend, err) + } + ctx = context.WithValue(ctx, balancer.CtxKey, requestKey) resp, err := cr.clusterClient.DispatchExpand(ctx, req) if err != nil { return &v1.DispatchExpandResponse{Metadata: requestFailureMetadata}, err @@ -59,6 +76,11 @@ func (cr *clusterDispatcher) DispatchLookup(ctx context.Context, req *v1.Dispatc return &v1.DispatchLookupResponse{Metadata: emptyMetadata}, err } + requestKey, err := protoMarshal.Marshal(req) + if err != nil { + return nil, fmt.Errorf(errComputingBackend, err) + } + ctx = context.WithValue(ctx, balancer.CtxKey, requestKey) resp, err := cr.clusterClient.DispatchLookup(ctx, req) if err != nil { return &v1.DispatchLookupResponse{Metadata: requestFailureMetadata}, err diff --git a/pkg/balancer/balancer.go b/pkg/balancer/balancer.go new file mode 100644 index 0000000000..63debc0d47 --- /dev/null +++ b/pkg/balancer/balancer.go @@ -0,0 +1,95 @@ +package balancer + +import ( + "math/rand" + "time" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/grpclog" + + "github.com/authzed/spicedb/pkg/consistent" +) + +type ctxKey string + +const ( + // BalancerName is the name of consistent-hashring balancer. + BalancerName = "consistent-hashring" + + // CtxKey is the key for the grpc request's context.Context which points to + // the key to hash for the request. The value it points to must be []byte + CtxKey ctxKey = "requestKey" +) + +var ( + logger = grpclog.Component("consistenthashring") + r = rand.New(rand.NewSource(time.Now().UnixNano())) +) + +// NewConsistentHashringBuilder creates a new ConsistentBalancerBuilder that +// will create a balancer with the given config. +// Before making a connection, register it with grpc with: +// `balancer.Register(consistent.NewConsistentHashringBuilder(hasher, factor, spread))` +func NewConsistentHashringBuilder(hasher consistent.HasherFunc, replicationFactor, spread uint8) balancer.Builder { + return base.NewBalancerBuilder( + BalancerName, + &consistentHashringPickerBuilder{hasher: hasher, replicationFactor: replicationFactor, spread: spread}, + base.Config{HealthCheck: true}, + ) +} + +type subConnMember struct { + balancer.SubConn + key string +} + +// Key implements consistent.Member +// This value is what will be hashed for placement on the consistent hash ring. +func (s subConnMember) Key() string { + return s.key +} + +var _ consistent.Member = &subConnMember{} + +type consistentHashringPickerBuilder struct { + hasher consistent.HasherFunc + replicationFactor, spread uint8 +} + +func (b *consistentHashringPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { + logger.Infof("consistentHashringPicker: Build called with info: %v", info) + if len(info.ReadySCs) == 0 { + return base.NewErrPicker(balancer.ErrNoSubConnAvailable) + } + hashring := consistent.NewHashring(b.hasher, b.replicationFactor) + for sc, scInfo := range info.ReadySCs { + if err := hashring.Add(subConnMember{ + SubConn: sc, + key: scInfo.Address.Addr + scInfo.Address.ServerName, + }); err != nil { + return base.NewErrPicker(err) + } + } + return &consistentHashringPicker{ + hashring: hashring, + spread: b.spread, + } +} + +type consistentHashringPicker struct { + hashring *consistent.Hashring + spread uint8 +} + +func (p *consistentHashringPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + key := info.Ctx.Value(CtxKey).([]byte) + members, err := p.hashring.FindN(key, p.spread) + if err != nil { + return balancer.PickResult{}, err + } + chosen := members[r.Intn(int(p.spread))].(subConnMember) + return balancer.PickResult{ + SubConn: chosen.SubConn, + }, nil +}