Skip to content

Commit

Permalink
Add caching to Lookup dispatcher
Browse files Browse the repository at this point in the history
Also adds cached dispatcher tests to the overall consistency test suite, and an more complicated lookup example
  • Loading branch information
josephschorr committed Oct 26, 2021
1 parent bc40650 commit c67714b
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 49 deletions.
72 changes: 65 additions & 7 deletions internal/dispatch/caching/caching.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,27 @@ type cachingDispatcher struct {
d dispatch.Dispatcher
c *ristretto.Cache

checkTotalCounter prometheus.Counter
checkFromCacheCounter prometheus.Counter
checkTotalCounter prometheus.Counter

checkFromCacheCounter prometheus.Counter
lookupTotalCounter prometheus.Counter
lookupFromCacheCounter prometheus.Counter
}

type checkResultEntry struct {
result *v1.DispatchCheckResponse
computedWithDepthRemaining uint32
}

var checkResultEntryCost = int64(unsafe.Sizeof(checkResultEntry{}))
type lookupResultEntry struct {
result *v1.DispatchLookupResponse
computedWithDepthRemaining uint32
}

var (
checkResultEntryCost = int64(unsafe.Sizeof(checkResultEntry{}))
lookupResultEntryCost = int64(unsafe.Sizeof(lookupResultEntry{}))
)

// NewCachingDispatcher creates a new dispatch.Dispatcher which delegates dispatch requests
// and caches the responses when possible and desirable.
Expand Down Expand Up @@ -66,6 +77,17 @@ func NewCachingDispatcher(
Name: "check_from_cache_total",
})

lookupTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Subsystem: prometheusSubsystem,
Name: "lookup_total",
})
lookupFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Subsystem: prometheusSubsystem,
Name: "lookup_from_cache_total",
})

if prometheusSubsystem != "" {
err = prometheus.Register(checkTotalCounter)
if err != nil {
Expand All @@ -77,6 +99,16 @@ func NewCachingDispatcher(
return nil, fmt.Errorf(errCachingInitialization, err)
}

err = prometheus.Register(lookupTotalCounter)
if err != nil {
return nil, fmt.Errorf(errCachingInitialization, err)
}

err = prometheus.Register(lookupFromCacheCounter)
if err != nil {
return nil, fmt.Errorf(errCachingInitialization, err)
}

// Export some ristretto metrics
err = registerMetricsFunc("cache_hits_total", prometheusSubsystem, cache.Metrics.Hits)
if err != nil {
Expand All @@ -99,7 +131,7 @@ func NewCachingDispatcher(
}
}

return &cachingDispatcher{delegate, cache, checkTotalCounter, checkFromCacheCounter}, nil
return &cachingDispatcher{delegate, cache, checkTotalCounter, checkFromCacheCounter, lookupTotalCounter, lookupFromCacheCounter}, nil
}

func registerMetricsFunc(name string, subsystem string, metricsFunc func() uint64) error {
Expand All @@ -115,7 +147,7 @@ func registerMetricsFunc(name string, subsystem string, metricsFunc func() uint6
// DispatchCheck implements dispatch.Check interface
func (cd *cachingDispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) {
cd.checkTotalCounter.Inc()
requestKey := requestToKey(req)
requestKey := checkRequestToKey(req)

if cachedResultRaw, found := cd.c.Get(requestKey); found {
cachedResult := cachedResultRaw.(checkResultEntry)
Expand Down Expand Up @@ -146,7 +178,29 @@ func (cd *cachingDispatcher) DispatchExpand(ctx context.Context, req *v1.Dispatc

// DispatchLookup implements dispatch.Lookup interface and does not do any caching yet.
func (cd *cachingDispatcher) DispatchLookup(ctx context.Context, req *v1.DispatchLookupRequest) (*v1.DispatchLookupResponse, error) {
return cd.d.DispatchLookup(ctx, req)
cd.lookupTotalCounter.Inc()
if req.Metadata.DepthRemaining > 0 {
requestKey := lookupRequestToKey(req)
if cachedResultRaw, found := cd.c.Get(requestKey); found {
cachedResult := cachedResultRaw.(lookupResultEntry)
cd.lookupFromCacheCounter.Inc()
return cachedResult.result, nil
}
}

computed, err := cd.d.DispatchLookup(ctx, req)

// We only want to cache the result if there was no error
if err == nil {
requestKey := lookupRequestToKey(req)
toCache := lookupResultEntry{computed, req.Metadata.DepthRemaining}
toCache.result.Metadata.DispatchCount = 0
cd.c.Set(requestKey, toCache, lookupResultEntryCost)
}

// Return both the computed and err in ALL cases: computed contains resolved metadata even
// if there was an error.
return computed, err
}

func (cd *cachingDispatcher) Close() error {
Expand All @@ -158,6 +212,10 @@ func (cd *cachingDispatcher) Close() error {
return nil
}

func requestToKey(req *v1.DispatchCheckRequest) string {
func checkRequestToKey(req *v1.DispatchCheckRequest) string {
return fmt.Sprintf("%s@%s@%s", tuple.StringONR(req.ObjectAndRelation), tuple.StringONR(req.Subject), req.Metadata.AtRevision)
}

func lookupRequestToKey(req *v1.DispatchLookupRequest) string {
return fmt.Sprintf("lookup//%s#%s@%s@%s", req.ObjectRelation.Namespace, req.ObjectRelation.Relation, tuple.StringONR(req.Subject), req.Metadata.AtRevision)
}
20 changes: 20 additions & 0 deletions internal/dispatch/graph/lookup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,26 @@ func TestSimpleLookup(t *testing.T) {

require.NoError(err)
require.ElementsMatch(tc.resolvedObjects, lookupResult.ResolvedOnrs)

// We have to sleep a while to let the cache converge:
// https://github.com/dgraph-io/ristretto/blob/01b9f37dd0fd453225e042d6f3a27cd14f252cd0/cache_test.go#L17
time.Sleep(10 * time.Millisecond)

// Run again with the cache available.
lookupResult, err = dispatch.DispatchLookup(context.Background(), &v1.DispatchLookupRequest{
ObjectRelation: tc.start,
Subject: tc.target,
Metadata: &v1.ResolverMeta{
AtRevision: revision.String(),
DepthRemaining: 50,
},
Limit: 10,
DirectStack: nil,
TtuStack: nil,
})

require.NoError(err)
require.ElementsMatch(tc.resolvedObjects, lookupResult.ResolvedOnrs)
})
}
}
Expand Down
94 changes: 53 additions & 41 deletions internal/services/consistency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/authzed/spicedb/internal/datastore/memdb"
"github.com/authzed/spicedb/internal/dispatch"
"github.com/authzed/spicedb/internal/dispatch/caching"
"github.com/authzed/spicedb/internal/dispatch/graph"
"github.com/authzed/spicedb/internal/namespace"
v1 "github.com/authzed/spicedb/internal/proto/dispatch/v1"
Expand Down Expand Up @@ -56,47 +57,58 @@ func TestConsistency(t *testing.T) {
t.Run(fmt.Sprintf("fuzz%d", delta/time.Millisecond), func(t *testing.T) {
for _, filePath := range consistencyTestFiles {
t.Run(path.Base(filePath), func(t *testing.T) {
lrequire := require.New(t)

unvalidated, err := memdb.NewMemdbDatastore(0, delta, memdb.DisableGC, 0)
lrequire.NoError(err)

ds := testfixtures.NewValidatingDatastore(unvalidated)
// defer ds.Close()

fullyResolved, revision, err := validationfile.PopulateFromFiles(ds, []string{filePath})
lrequire.NoError(err)

ns, err := namespace.NewCachingNamespaceManager(ds, 1*time.Second, nil)
lrequire.NoError(err)

dispatch := graph.NewLocalOnlyDispatcher(ns, ds)

// Validate the type system for each namespace.
for _, nsDef := range fullyResolved.NamespaceDefinitions {
_, ts, _, err := ns.ReadNamespaceAndTypes(context.Background(), nsDef.Name)
lrequire.NoError(err)

err = ts.Validate(context.Background())
lrequire.NoError(err)
}

// Build the list of tuples per namespace.
tuplesPerNamespace := slicemultimap.New()
for _, tpl := range fullyResolved.Tuples {
tuplesPerNamespace.Put(tpl.ObjectAndRelation.Namespace, tpl)
}

// Run the consistency tests for each service.
v1permclient, _ := v1svc.RunForTesting(t, ds, ns, dispatch, 50)
testers := []serviceTester{
v0ServiceTester{v0svc.NewACLServer(ds, ns, dispatch, 50)},
v1ServiceTester{v1permclient},
}

for _, tester := range testers {
t.Run(tester.Name(), func(t *testing.T) {
runConsistencyTests(t, tester, dispatch, fullyResolved, tuplesPerNamespace, revision)
for _, dispatcherKind := range []string{"local", "caching"} {
t.Run(dispatcherKind, func(t *testing.T) {
lrequire := require.New(t)

unvalidated, err := memdb.NewMemdbDatastore(0, delta, memdb.DisableGC, 0)
lrequire.NoError(err)

ds := testfixtures.NewValidatingDatastore(unvalidated)

fullyResolved, revision, err := validationfile.PopulateFromFiles(ds, []string{filePath})
lrequire.NoError(err)

ns, err := namespace.NewCachingNamespaceManager(ds, 1*time.Second, nil)
lrequire.NoError(err)

localOnlyDispatcher := graph.NewLocalOnlyDispatcher(ns, ds)

// Validate the type system for each namespace.
for _, nsDef := range fullyResolved.NamespaceDefinitions {
_, ts, _, err := ns.ReadNamespaceAndTypes(context.Background(), nsDef.Name)
lrequire.NoError(err)

err = ts.Validate(context.Background())
lrequire.NoError(err)
}

// Build the list of tuples per namespace.
tuplesPerNamespace := slicemultimap.New()
for _, tpl := range fullyResolved.Tuples {
tuplesPerNamespace.Put(tpl.ObjectAndRelation.Namespace, tpl)
}

// Run the consistency tests for each service.
dispatcher := localOnlyDispatcher
if dispatcherKind == "caching" {
cachingDispatcher, err := caching.NewCachingDispatcher(localOnlyDispatcher, nil, "")
lrequire.NoError(err)
defer cachingDispatcher.Close()
dispatcher = cachingDispatcher
}

v1permclient, _ := v1svc.RunForTesting(t, ds, ns, dispatcher, 50)
testers := []serviceTester{
v0ServiceTester{v0svc.NewACLServer(ds, ns, dispatcher, 50)},
v1ServiceTester{v1permclient},
}

for _, tester := range testers {
t.Run(tester.Name(), func(t *testing.T) {
runConsistencyTests(t, tester, dispatcher, fullyResolved, tuplesPerNamespace, revision)
})
}
})
}
})
Expand Down
34 changes: 34 additions & 0 deletions internal/services/testconfigs/lookupsametypes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
---
schema: >-
definition test/user {}
definition test/resource {
relation viewer: test/user | test/usergroup#member | test/usergroup#manager
relation editor: test/user | test/usergroup#member | test/usergroup#manager
relation creator: test/user | test/usergroup#member | test/usergroup#manager
relation owner: test/user | test/usergroup#member | test/usergroup#manager
permission view = viewer + editor + creator + owner
}
definition test/usergroup {
relation direct_member: test/user | test/usergroup#member | test/usergroup#manager | test/usergroup#contributor
relation contributor: test/user | test/usergroup#member | test/usergroup#contributor | test/usergroup#manager
relation manager: test/user | test/usergroup#member | test/usergroup#manager
permission member = direct_member + contributor + manager
}
relationships: |
test/usergroup:productname#manager@test/user:an_eng_manager#...
test/usergroup:productname#direct_member@test/user:an_engineer#...
test/usergroup:applications#manager@test/user:an_eng_director#...
test/usergroup:engineering#manager@test/user:cto#...
test/usergroup:csuite#manager@test/user:ceo#...
test/usergroup:csuite#direct_member@test/user:cto#...
test/usergroup:other#direct_member@test/user:denied#...
test/usergroup:engineering#direct_member@test/usergroup:applications#member
test/usergroup:applications#direct_member@test/usergroup:productname#member
test/usergroup:engineering#direct_member@test/usergroup:csuite#member
test/resource:promserver#creator@test/user:an_engineer#...
test/resource:promserver#viewer@test/usergroup:engineering#member
test/resource:jira#viewer@test/usergroup:engineering#member
test/resource:promserver#viewer@test/user:an_external_test/user#...
2 changes: 1 addition & 1 deletion internal/services/v1/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
func RunForTesting(t *testing.T, ds datastore.Datastore, nsm namespace.Manager, dispatch dispatch.Dispatcher, defaultDepth uint32) (v1.PermissionsServiceClient, v1.SchemaServiceClient) {
lis := bufconn.Listen(1024 * 1024)
s := tf.NewTestServer()
v1.RegisterPermissionsServiceServer(s, NewPermissionsServer(ds, nsm, dispatch, 50))
v1.RegisterPermissionsServiceServer(s, NewPermissionsServer(ds, nsm, dispatch, defaultDepth))
v1.RegisterSchemaServiceServer(s, NewSchemaServer(ds))

go func() {
Expand Down

0 comments on commit c67714b

Please sign in to comment.