Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

caveat in MemDB datastore #807

Merged
merged 18 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions internal/datastore/memdb/caveat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package memdb

import (
"errors"
"fmt"
"time"

"github.com/authzed/spicedb/pkg/datastore"
core "github.com/authzed/spicedb/pkg/proto/core/v1"

"github.com/hashicorp/go-memdb"
)

const tableCaveats = "caveats"

type caveat struct {
id datastore.CaveatID
name string
expression []byte
}

func (c *caveat) CoreCaveat() *core.Caveat {
vroldanbet marked this conversation as resolved.
Show resolved Hide resolved
return &core.Caveat{
Name: c.name,
Expression: c.expression,
}
}

func (r *memdbReader) ReadCaveatByName(name string) (*core.Caveat, error) {
r.lockOrPanic()
defer r.Unlock()

tx, err := r.txSource()
if err != nil {
return nil, err
}
return r.readCaveatByName(tx, name)
}

func (r *memdbReader) ReadCaveatByID(ID datastore.CaveatID) (*core.Caveat, error) {
r.lockOrPanic()
defer r.Unlock()

tx, err := r.txSource()
if err != nil {
return nil, err
}
return r.readCaveatByID(tx, ID)
}

func (r *memdbReader) readCaveatByID(tx *memdb.Txn, ID datastore.CaveatID) (*core.Caveat, error) {
found, err := tx.First(tableCaveats, indexID, ID)
if err != nil {
return nil, err
}
if found == nil {
return nil, fmt.Errorf("caveat with id %d not found: %w", ID, datastore.ErrCaveatNotFound)
}
c := found.(*caveat)
return c.CoreCaveat(), nil
}

func (r *memdbReader) readCaveatByName(tx *memdb.Txn, name string) (*core.Caveat, error) {
found, err := tx.First(tableCaveats, indexName, name)
if err != nil {
return nil, err
}
if found == nil {
return nil, fmt.Errorf("caveat with name %s not found: %w", name, datastore.ErrCaveatNotFound)
}
c := found.(*caveat)
return c.CoreCaveat(), nil
}

func (rwt *memdbReadWriteTx) WriteCaveats(caveats []*core.Caveat) ([]datastore.CaveatID, error) {
rwt.lockOrPanic()
defer rwt.Unlock()
tx, err := rwt.txSource()
if err != nil {
return nil, err
}
return rwt.writeCaveat(tx, caveats)
}

func (rwt *memdbReadWriteTx) writeCaveat(tx *memdb.Txn, caveats []*core.Caveat) ([]datastore.CaveatID, error) {
ids := make([]datastore.CaveatID, 0, len(caveats))
vroldanbet marked this conversation as resolved.
Show resolved Hide resolved
for _, coreCaveat := range caveats {
id := datastore.CaveatID(time.Now().UnixNano())
c := caveat{
id: id,
name: coreCaveat.Name,
expression: coreCaveat.Expression,
}
// TODO(vroldanbet) why does go-memdb not honor unique index name?
found, err := rwt.readCaveatByName(tx, coreCaveat.Name)
vroldanbet marked this conversation as resolved.
Show resolved Hide resolved
if err != nil && !errors.Is(err, datastore.ErrCaveatNotFound) {
return nil, err
}
if found != nil {
return nil, fmt.Errorf("duplicated caveat with name %s", coreCaveat.Name)
}
if err = tx.Insert(tableCaveats, &c); err != nil {
return nil, err
}
ids = append(ids, id)
}
return ids, nil
}

func (rwt *memdbReadWriteTx) DeleteCaveats(caveats []*core.Caveat) error {
rwt.lockOrPanic()
defer rwt.Unlock()
tx, err := rwt.txSource()
if err != nil {
return err
}
return tx.Delete(tableCaveats, caveats)
}
142 changes: 142 additions & 0 deletions internal/datastore/memdb/caveat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package memdb

import (
"context"
"math"
"testing"
"time"

"github.com/authzed/spicedb/internal/datastore/common"
"github.com/authzed/spicedb/internal/testfixtures"
"github.com/authzed/spicedb/pkg/caveats"
"github.com/authzed/spicedb/pkg/datastore"
core "github.com/authzed/spicedb/pkg/proto/core/v1"
"github.com/authzed/spicedb/pkg/tuple"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/structpb"
)

func TestWriteReadCaveat(t *testing.T) {
req := require.New(t)

ds, err := NewMemdbDatastore(0, 1*time.Hour, 1*time.Hour)
req.NoError(err)
coreCaveat := createCoreCaveat(t)
ctx := context.Background()

// Fails to write dupes in the same transaction
_, _, err = writeCaveat(ctx, ds, coreCaveat, coreCaveat)
req.Error(err)

// Succeeds writing a caveat
rev, ID, err := writeCaveat(ctx, ds, coreCaveat)
req.NoError(err)

// fails to write caveat with the same name in different tx
vroldanbet marked this conversation as resolved.
Show resolved Hide resolved
_, _, err = writeCaveat(ctx, ds, coreCaveat)
req.Error(err)

// the caveat can be looked up by name
cr, ok := ds.SnapshotReader(rev).(datastore.CaveatReader)
req.True(ok, "expected a CaveatStorer value")
cv, err := cr.ReadCaveatByName(coreCaveat.Name)
req.NoError(err)
req.Equal(coreCaveat, cv)
req.NoError(err)

// the caveat can be looked up by ID
cv, err = cr.ReadCaveatByID(ID)
req.NoError(err)
req.Equal(coreCaveat, cv)
req.NoError(err)

// returns an error if caveat name or ID does not exist
_, err = cr.ReadCaveatByName("doesnotexist")
req.ErrorIs(err, datastore.ErrCaveatNotFound)
_, err = cr.ReadCaveatByID(math.MaxUint64)
req.ErrorIs(err, datastore.ErrCaveatNotFound)
}

func TestWriteCaveatedTuple(t *testing.T) {
req := require.New(t)
ctx := context.Background()

ds, err := NewMemdbDatastore(0, 1*time.Hour, 1*time.Hour)
req.NoError(err)
sds, _ := testfixtures.StandardDatastoreWithSchema(ds, req)

// store caveat, write caveated tuple and read back same value
coreCaveat := createCoreCaveat(t)
_, cavID, err := writeCaveat(ctx, ds, coreCaveat)
req.NoError(err)
vroldanbet marked this conversation as resolved.
Show resolved Hide resolved
tpl := createTestCaveatedTuple(t, "document:companyplan#parent@folder:company#...", cavID)
rev, err := common.WriteTuples(ctx, sds, core.RelationTupleUpdate_CREATE, tpl)
req.NoError(err)
iter, err := ds.SnapshotReader(rev).QueryRelationships(ctx, datastore.RelationshipsFilter{
ResourceType: tpl.ResourceAndRelation.Namespace,
})
req.NoError(err)
defer iter.Close()
readTpl := iter.Next()
req.Equal(tpl, readTpl)

// caveated tuple can reference non-existing caveat - controller layer is responsible for validation
tpl = createTestCaveatedTuple(t, "document:rando#parent@folder:company#...", math.MaxUint64)
_, err = common.WriteTuples(ctx, sds, core.RelationTupleUpdate_CREATE, tpl)
req.NoError(err)
}

func createTestCaveatedTuple(t *testing.T, tplString string, id datastore.CaveatID) *core.RelationTuple {
tpl := tuple.MustParse(tplString)
st, err := structpb.NewStruct(map[string]interface{}{"a": 1, "b": "test"})
require.NoError(t, err)
tpl.Caveat = &core.ContextualizedCaveat{
CaveatId: uint64(id),
Context: st,
}
return tpl
}

func writeCaveat(ctx context.Context, ds datastore.Datastore, coreCaveat ...*core.Caveat) (datastore.Revision, datastore.CaveatID, error) {
var IDs []datastore.CaveatID
rev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, tx datastore.ReadWriteTransaction) error {
cs, ok := tx.(datastore.CaveatStorer)
if !ok {
panic("expected a CaveatStorer value")
}
var err error
IDs, err = cs.WriteCaveats(coreCaveat)
return err
})
if err != nil {
return datastore.NoRevision, 0, err
}
return rev, IDs[0], err
}

func createCoreCaveat(t *testing.T) *core.Caveat {
t.Helper()
c := createCompiledCaveat(t)
cBytes, err := c.Serialize()
require.NoError(t, err)
coreCaveat := &core.Caveat{
Name: c.Name(),
Expression: cBytes,
}
require.NoError(t, err)
return coreCaveat
}

func createCompiledCaveat(t *testing.T) *caveats.CompiledCaveat {
t.Helper()
env, err := caveats.EnvForVariables(map[string]caveats.VariableType{
"a": caveats.IntType,
"b": caveats.IntType,
})
require.NoError(t, err)
c, err := caveats.CompileCaveatWithName(env, "a == b", uuid.New().String())
require.NoError(t, err)
return c
}
12 changes: 10 additions & 2 deletions internal/datastore/memdb/memdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,23 @@ func (mdb *memdbDatastore) ReadWriteTx(
for _, change := range tx.Changes() {
if change.Table == tableRelationship {
if change.After != nil {
rt, err := change.After.(*relationship).RelationTuple()
if err != nil {
return datastore.NoRevision, err
}
newChanges.Changes = append(newChanges.Changes, &corev1.RelationTupleUpdate{
Operation: corev1.RelationTupleUpdate_TOUCH,
Tuple: change.After.(*relationship).RelationTuple(),
Tuple: rt,
})
}
if change.After == nil && change.Before != nil {
rt, err := change.Before.(*relationship).RelationTuple()
if err != nil {
return datastore.NoRevision, err
}
newChanges.Changes = append(newChanges.Changes, &corev1.RelationTupleUpdate{
Operation: corev1.RelationTupleUpdate_DELETE,
Tuple: change.Before.(*relationship).RelationTuple(),
Tuple: rt,
})
}
}
Expand Down
10 changes: 8 additions & 2 deletions internal/datastore/memdb/readonly.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ type memdbTupleIterator struct {
it memdb.ResultIterator
limit *uint64
count uint64
err error
}

func (mti *memdbTupleIterator) Next() *core.RelationTuple {
Expand All @@ -293,11 +294,16 @@ func (mti *memdbTupleIterator) Next() *core.RelationTuple {
}
mti.count++

return foundRaw.(*relationship).RelationTuple()
rt, err := foundRaw.(*relationship).RelationTuple()
if err != nil {
mti.err = err
return nil
}
return rt
}

func (mti *memdbTupleIterator) Err() error {
return nil
return mti.err
}

func (mti *memdbTupleIterator) Close() {
Expand Down
24 changes: 22 additions & 2 deletions internal/datastore/memdb/readwrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func (rwt *memdbReadWriteTx) write(tx *memdb.Txn, mutations ...*core.RelationTup
mutation.Tuple.Subject.Namespace,
mutation.Tuple.Subject.ObjectId,
mutation.Tuple.Subject.Relation,
rwt.toCaveatReference(mutation),
}

found, err := tx.First(
Expand All @@ -66,7 +67,11 @@ func (rwt *memdbReadWriteTx) write(tx *memdb.Txn, mutations ...*core.RelationTup
switch mutation.Operation {
case core.RelationTupleUpdate_CREATE:
if existing != nil {
return common.NewCreateRelationshipExistsError(existing.RelationTuple())
rt, err := existing.RelationTuple()
if err != nil {
return err
}
return common.NewCreateRelationshipExistsError(rt)
}
fallthrough
case core.RelationTupleUpdate_TOUCH:
Expand All @@ -87,6 +92,17 @@ func (rwt *memdbReadWriteTx) write(tx *memdb.Txn, mutations ...*core.RelationTup
return nil
}

func (rwt *memdbReadWriteTx) toCaveatReference(mutation *core.RelationTupleUpdate) *contextualizedCaveat {
var cr *contextualizedCaveat
if mutation.Tuple.Caveat != nil {
vroldanbet marked this conversation as resolved.
Show resolved Hide resolved
cr = &contextualizedCaveat{
caveatID: datastore.CaveatID(mutation.Tuple.Caveat.CaveatId),
context: mutation.Tuple.Caveat.Context.AsMap(),
}
}
return cr
}

func (rwt *memdbReadWriteTx) DeleteRelationships(filter *v1.RelationshipFilter) error {
rwt.lockOrPanic()
defer rwt.Unlock()
Expand All @@ -111,7 +127,11 @@ func (rwt *memdbReadWriteTx) deleteWithLock(tx *memdb.Txn, filter *v1.Relationsh
// Collect the tuples into a slice of mutations for the changelog
var mutations []*core.RelationTupleUpdate
for row := filteredIter.Next(); row != nil; row = filteredIter.Next() {
mutations = append(mutations, tuple.Delete(row.(*relationship).RelationTuple()))
rt, err := row.(*relationship).RelationTuple()
if err != nil {
return err
}
mutations = append(mutations, tuple.Delete(rt))
}

return rwt.write(tx, mutations...)
Expand Down
Loading