Skip to content

Commit

Permalink
Move create entry "similarity" check inside of DataStore (#2475)
Browse files Browse the repository at this point in the history
The entry similarity check happens only on the API level, but not within the
same transaction that the entry creation takes place in. This causes a small
window where two concurrent requests for similar entries can result in both
being created.

This PR moves the similarity check down to the SQL DataStore layer inside of
the transaction that creates the entry. It also introduces a new DataStore
function CreateOrReturnRegistrationEntry that returns the existing entry, that
is useful for the API, but otherwise allows us to not have to update the
existing CreateRegistrationEntry callers at this time. Both
CreateRegistrationEntry and CreateOrReturnRegistrationEntry go through the same
code paths and differ only in how they treat the case when a similar entry
exists. CreateRegistrationEntry fails with AlreadyExists, while
CreateOrReturnRegistrationEntry returns the existing entry along with a bool
indicating that it is an existing entry.

This PR does NOT address the issue that UpdateRegistrationEntry can end up
creating two "similar" entries.

Signed-off-by: Andrew Harding <aharding@vmware.com>
  • Loading branch information
azdagron authored Aug 30, 2021
1 parent f8d3d83 commit c01d0b6
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 62 deletions.
6 changes: 6 additions & 0 deletions pkg/common/telemetry/server/datastore/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ func (w metricsWrapper) CreateRegistrationEntry(ctx context.Context, entry *comm
return w.ds.CreateRegistrationEntry(ctx, entry)
}

func (w metricsWrapper) CreateOrReturnRegistrationEntry(ctx context.Context, entry *common.RegistrationEntry) (_ *common.RegistrationEntry, _ bool, err error) {
callCounter := StartCreateRegistrationCall(w.m)
defer callCounter.Done(&err)
return w.ds.CreateOrReturnRegistrationEntry(ctx, entry)
}

func (w metricsWrapper) DeleteAttestedNode(ctx context.Context, spiffeID string) (_ *common.AttestedNode, err error) {
callCounter := StartDeleteNodeCall(w.m)
defer callCounter.Done(&err)
Expand Down
8 changes: 8 additions & 0 deletions pkg/common/telemetry/server/datastore/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ func TestWithMetrics(t *testing.T) {
key: "datastore.registration_entry.create",
methodName: "CreateRegistrationEntry",
},
{
key: "datastore.registration_entry.create",
methodName: "CreateOrReturnRegistrationEntry",
},
{
key: "datastore.node.delete",
methodName: "DeleteAttestedNode",
Expand Down Expand Up @@ -263,6 +267,10 @@ func (ds *fakeDataStore) CreateRegistrationEntry(context.Context, *common.Regist
return &common.RegistrationEntry{}, ds.err
}

func (ds *fakeDataStore) CreateOrReturnRegistrationEntry(context.Context, *common.RegistrationEntry) (*common.RegistrationEntry, bool, error) {
return &common.RegistrationEntry{}, true, ds.err
}

func (ds *fakeDataStore) DeleteAttestedNode(context.Context, string) (*common.AttestedNode, error) {
return &common.AttestedNode{}, ds.err
}
Expand Down
6 changes: 2 additions & 4 deletions pkg/server/api/bundle/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,9 +781,6 @@ func TestAppendBundle(t *testing.T) {
}

func TestBatchDeleteFederatedBundle(t *testing.T) {
test := setupServiceTest(t)
defer test.Cleanup()

td1 := spiffeid.RequireTrustDomainFromString("td1.org")
td2 := spiffeid.RequireTrustDomainFromString("td2.org")
td3 := spiffeid.RequireTrustDomainFromString("td3.org")
Expand Down Expand Up @@ -1108,7 +1105,8 @@ func TestBatchDeleteFederatedBundle(t *testing.T) {
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
test.logHook.Reset()
test := setupServiceTest(t)
defer test.Cleanup()

// Create all test bundles
for _, td := range dsBundles {
Expand Down
43 changes: 6 additions & 37 deletions pkg/server/api/entry/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,25 +207,14 @@ func (s *Service) createEntry(ctx context.Context, e *types.Entry, outputMask *t

log = log.WithField(telemetry.SPIFFEID, cEntry.SpiffeId)

existingEntry, err := s.getExistingEntry(ctx, cEntry)
if err != nil {
return &entryv1.BatchCreateEntryResponse_Result{
Status: api.MakeStatus(log, codes.Internal, "failed to list entries", err),
}
}

resultStatus := api.OK()
regEntry := existingEntry

if existingEntry == nil {
// Create entry
regEntry, err = s.ds.CreateRegistrationEntry(ctx, cEntry)
if err != nil {
return &entryv1.BatchCreateEntryResponse_Result{
Status: api.MakeStatus(log, codes.Internal, "failed to create entry", err),
}
regEntry, existing, err := s.ds.CreateOrReturnRegistrationEntry(ctx, cEntry)
switch {
case err != nil:
return &entryv1.BatchCreateEntryResponse_Result{
Status: api.MakeStatus(log, codes.Internal, "failed to create entry", err),
}
} else {
case existing:
resultStatus = api.CreateStatus(codes.AlreadyExists, "similar entry already exists")
}

Expand Down Expand Up @@ -391,26 +380,6 @@ func applyMask(e *types.Entry, mask *types.EntryMask) {
}
}

func (s *Service) getExistingEntry(ctx context.Context, e *common.RegistrationEntry) (*common.RegistrationEntry, error) {
resp, err := s.ds.ListRegistrationEntries(ctx, &datastore.ListRegistrationEntriesRequest{
BySpiffeID: e.SpiffeId,
ByParentID: e.ParentId,
BySelectors: &datastore.BySelectors{
Match: datastore.Exact,
Selectors: e.Selectors,
},
})

if err != nil {
return nil, err
}

if len(resp.Entries) > 0 {
return resp.Entries[0], nil
}
return nil, nil
}

func (s *Service) updateEntry(ctx context.Context, e *types.Entry, inputMask *types.EntryMask, outputMask *types.EntryMask) *entryv1.BatchUpdateEntryResponse_Result {
log := rpccontext.Logger(ctx)
log = log.WithField(telemetry.RegistrationID, e.Id)
Expand Down
16 changes: 9 additions & 7 deletions pkg/server/api/entry/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,7 @@ func TestBatchCreateEntry(t *testing.T) {
reqEntries []*types.Entry

// fake ds configurations
noCustomCreate bool
dsError error
dsResults map[string]*common.RegistrationEntry
expectDsEntries map[string]*common.RegistrationEntry
Expand Down Expand Up @@ -1820,6 +1821,7 @@ func TestBatchCreateEntry(t *testing.T) {
},
},
},
noCustomCreate: true,
},
{
name: "invalid entry",
Expand Down Expand Up @@ -1971,7 +1973,7 @@ func TestBatchCreateEntry(t *testing.T) {
defaultEntryID := createTestEntries(t, ds, defaultEntry)[defaultEntry.SpiffeId].EntryId

// Setup fake
ds.customCreate = true
ds.customCreate = !tt.noCustomCreate
ds.t = t
ds.expectEntries = tt.expectDsEntries
ds.results = tt.dsResults
Expand Down Expand Up @@ -3612,31 +3614,31 @@ func newFakeDS(t *testing.T) *fakeDS {
}
}

func (f *fakeDS) CreateRegistrationEntry(ctx context.Context, entry *common.RegistrationEntry) (*common.RegistrationEntry, error) {
func (f *fakeDS) CreateOrReturnRegistrationEntry(ctx context.Context, entry *common.RegistrationEntry) (*common.RegistrationEntry, bool, error) {
if !f.customCreate {
return f.DataStore.CreateRegistrationEntry(ctx, entry)
return f.DataStore.CreateOrReturnRegistrationEntry(ctx, entry)
}

if f.err != nil {
return nil, f.err
return nil, false, f.err
}
entryID := entry.EntryId

expect, ok := f.expectEntries[entryID]
assert.True(f.t, ok, "no expect entry found")
assert.True(f.t, ok, "no expect entry found for entry %q", entryID)

// Validate we get expected entry
spiretest.AssertProtoEqual(f.t, expect, entry)

// Return expect when no custom result configured
if len(f.results) == 0 {
return expect, nil
return expect, false, nil
}

res, ok := f.results[entryID]
assert.True(f.t, ok, "no result found")

return res, nil
return res, false, nil
}

type entryFetcher struct {
Expand Down
1 change: 1 addition & 0 deletions pkg/server/datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type DataStore interface {
// Entries
CountRegistrationEntries(context.Context) (int32, error)
CreateRegistrationEntry(context.Context, *common.RegistrationEntry) (*common.RegistrationEntry, error)
CreateOrReturnRegistrationEntry(context.Context, *common.RegistrationEntry) (*common.RegistrationEntry, bool, error)
DeleteRegistrationEntry(ctx context.Context, entryID string) (*common.RegistrationEntry, error)
FetchRegistrationEntry(ctx context.Context, entryID string) (*common.RegistrationEntry, error)
ListRegistrationEntries(context.Context, *ListRegistrationEntriesRequest) (*ListRegistrationEntriesResponse, error)
Expand Down
57 changes: 51 additions & 6 deletions pkg/server/datastore/sqlstore/sqlstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,18 +297,40 @@ func (ds *Plugin) ListNodeSelectors(ctx context.Context,
// CreateRegistrationEntry stores the given registration entry
func (ds *Plugin) CreateRegistrationEntry(ctx context.Context,
entry *common.RegistrationEntry) (registrationEntry *common.RegistrationEntry, err error) {
out, _, err := ds.createOrReturnRegistrationEntry(ctx, entry)
return out, err
}

// CreateOrReturnRegistrationEntry stores the given registration entry. If an
// entry already exists with the same (parentID, spiffeID, selector) tuple,
// that entry is returned instead.
func (ds *Plugin) CreateOrReturnRegistrationEntry(ctx context.Context,
entry *common.RegistrationEntry) (registrationEntry *common.RegistrationEntry, existing bool, err error) {
return ds.createOrReturnRegistrationEntry(ctx, entry)
}

func (ds *Plugin) createOrReturnRegistrationEntry(ctx context.Context,
entry *common.RegistrationEntry) (registrationEntry *common.RegistrationEntry, existing bool, err error) {
// TODO: Validations should be done in the ProtoBuf level [https://github.com/spiffe/spire/issues/44]
if err = validateRegistrationEntry(entry); err != nil {
return nil, err
return nil, false, err
}

if err = ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) {
registrationEntry, err = lookupSimilarEntry(ctx, ds.db, tx, entry)
if err != nil {
return err
}
if registrationEntry != nil {
existing = true
return nil
}
registrationEntry, err = createRegistrationEntry(tx, entry)
return err
}); err != nil {
return nil, err
return nil, false, err
}
return registrationEntry, nil
return registrationEntry, existing, nil
}

// FetchRegistrationEntry fetches an existing registration by entry ID
Expand Down Expand Up @@ -1996,7 +2018,7 @@ func listRegistrationEntries(ctx context.Context, db *sqlDB, log logrus.FieldLog
// query returns rows that are completely filtered out. If that happens,
// keep querying until a page gets at least one result.
for {
resp, err := listRegistrationEntriesOnce(ctx, db, req)
resp, err := listRegistrationEntriesOnce(ctx, db.raw, db.databaseType, db.supportsCTE, req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -2057,8 +2079,12 @@ func filterEntriesBySelectorSet(entries []*common.RegistrationEntry, selectors [
return filtered
}

func listRegistrationEntriesOnce(ctx context.Context, db *sqlDB, req *datastore.ListRegistrationEntriesRequest) (*datastore.ListRegistrationEntriesResponse, error) {
query, args, err := buildListRegistrationEntriesQuery(db.databaseType, db.supportsCTE, req)
type queryContext interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}

func listRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseType string, supportsCTE bool, req *datastore.ListRegistrationEntriesRequest) (*datastore.ListRegistrationEntriesResponse, error) {
query, args, err := buildListRegistrationEntriesQuery(databaseType, supportsCTE, req)
if err != nil {
return nil, sqlError.Wrap(err)
}
Expand Down Expand Up @@ -3405,3 +3431,22 @@ func nullableUnixTimeToDBTime(unixTime int64) *time.Time {
dbTime := time.Unix(unixTime, 0)
return &dbTime
}

func lookupSimilarEntry(ctx context.Context, db *sqlDB, tx *gorm.DB, entry *common.RegistrationEntry) (*common.RegistrationEntry, error) {
resp, err := listRegistrationEntriesOnce(ctx, tx.CommonDB().(queryContext), db.databaseType, db.supportsCTE, &datastore.ListRegistrationEntriesRequest{
BySpiffeID: entry.SpiffeId,
ByParentID: entry.ParentId,
BySelectors: &datastore.BySelectors{
Match: datastore.Exact,
Selectors: entry.Selectors,
},
})
switch {
case err != nil:
return nil, err
case len(resp.Entries) > 0:
return resp.Entries[0], nil
default:
return nil, nil
}
}
28 changes: 20 additions & 8 deletions pkg/server/datastore/sqlstore/sqlstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ func (s *PluginSuite) SetupTest() {
}

func (s *PluginSuite) TearDownTest() {
s.ds.closeDB()
if s.ds != nil {
s.ds.closeDB()
}
}

func (s *PluginSuite) newPlugin() *Plugin {
Expand Down Expand Up @@ -1432,8 +1434,8 @@ func (s *PluginSuite) testListRegistrationEntries(dataConsistency datastore.Data
bazbarCB2.FederatesWith = []string{"spiffe://federated2.test"}
bazbarCD12 := makeEntry("baz", "bar", "C", "D")
bazbarCD12.FederatesWith = []string{"spiffe://federated1.test", "spiffe://federated2.test"}
bazbarAD3 := makeEntry("baz", "bar", "A", "D")
bazbarAD3.FederatesWith = []string{"spiffe://federated3.test"}
bazbarAE3 := makeEntry("baz", "bar", "A", "E")
bazbarAE3.FederatesWith = []string{"spiffe://federated3.test"}

bazbuzAB12 := makeEntry("baz", "buz", "A", "B")
bazbuzAB12.FederatesWith = []string{"spiffe://federated1.test", "spiffe://federated2.test"}
Expand Down Expand Up @@ -1717,7 +1719,7 @@ func (s *PluginSuite) testListRegistrationEntries(dataConsistency datastore.Data
},
{
test: "by parentID and federatesWith one match any",
entries: []*common.RegistrationEntry{foobarAB1, foobarAD12, foobarCB2, foobarCD12, zizzazX, bazbarAB1, bazbarAD12, bazbarCB2, bazbarCD12, bazbarAD3},
entries: []*common.RegistrationEntry{foobarAB1, foobarAD12, foobarCB2, foobarCD12, zizzazX, bazbarAB1, bazbarAD12, bazbarCB2, bazbarCD12, bazbarAE3},
byParentID: makeID("baz"),
byFederatesWith: byFederatesWith(datastore.MatchAny, "spiffe://federated1.test"),
expectEntriesOut: []*common.RegistrationEntry{bazbarAB1, bazbarAD12, bazbarCD12},
Expand All @@ -1726,7 +1728,7 @@ func (s *PluginSuite) testListRegistrationEntries(dataConsistency datastore.Data
},
{
test: "by parentID and federatesWith many match any",
entries: []*common.RegistrationEntry{foobarAB1, foobarAD12, foobarCB2, foobarCD12, zizzazX, bazbarAB1, bazbarAD12, bazbarCB2, bazbarCD12, bazbarAD3},
entries: []*common.RegistrationEntry{foobarAB1, foobarAD12, foobarCB2, foobarCD12, zizzazX, bazbarAB1, bazbarAD12, bazbarCB2, bazbarCD12, bazbarAE3},
byParentID: makeID("baz"),
byFederatesWith: byFederatesWith(datastore.MatchAny, "spiffe://federated1.test", "spiffe://federated2.test"),
expectEntriesOut: []*common.RegistrationEntry{bazbarAB1, bazbarAD12, bazbarCB2, bazbarCD12},
Expand All @@ -1735,7 +1737,7 @@ func (s *PluginSuite) testListRegistrationEntries(dataConsistency datastore.Data
},
{
test: "by parentID and federatesWith one superset",
entries: []*common.RegistrationEntry{foobarAB1, foobarAD12, foobarCB2, foobarCD12, zizzazX, bazbarAB1, bazbarAD12, bazbarCB2, bazbarCD12, bazbarAD3},
entries: []*common.RegistrationEntry{foobarAB1, foobarAD12, foobarCB2, foobarCD12, zizzazX, bazbarAB1, bazbarAD12, bazbarCB2, bazbarCD12, bazbarAE3},
byParentID: makeID("baz"),
byFederatesWith: byFederatesWith(datastore.Superset, "spiffe://federated1.test"),
expectEntriesOut: []*common.RegistrationEntry{bazbarAB1, bazbarAD12, bazbarCD12},
Expand All @@ -1744,7 +1746,7 @@ func (s *PluginSuite) testListRegistrationEntries(dataConsistency datastore.Data
},
{
test: "by parentID and federatesWith many superset",
entries: []*common.RegistrationEntry{foobarAB1, foobarAD12, foobarCB2, foobarCD12, zizzazX, bazbarAB1, bazbarAD12, bazbarCB2, bazbarCD12, bazbarAD3},
entries: []*common.RegistrationEntry{foobarAB1, foobarAD12, foobarCB2, foobarCD12, zizzazX, bazbarAB1, bazbarAD12, bazbarCB2, bazbarCD12, bazbarAE3},
byParentID: makeID("baz"),
byFederatesWith: byFederatesWith(datastore.Superset, "spiffe://federated1.test", "spiffe://federated2.test"),
expectEntriesOut: []*common.RegistrationEntry{bazbarAD12, bazbarCD12},
Expand Down Expand Up @@ -2176,6 +2178,8 @@ func (s *PluginSuite) TestUpdateRegistrationEntryWithMask() {
// Needed for the FederatesWith field to work
s.createBundle("spiffe://dom1.org")
s.createBundle("spiffe://dom2.org")

var id string
for _, testcase := range []struct {
name string
mask *common.RegistrationEntryMask
Expand Down Expand Up @@ -2295,8 +2299,11 @@ func (s *PluginSuite) TestUpdateRegistrationEntryWithMask() {
} {
tt := testcase
s.Run(tt.name, func() {
if id != "" {
s.deleteRegistrationEntry(id)
}
registrationEntry := s.createRegistrationEntry(oldEntry)
id := registrationEntry.EntryId
id = registrationEntry.EntryId

updateEntry := &common.RegistrationEntry{}
tt.update(updateEntry)
Expand Down Expand Up @@ -3352,6 +3359,11 @@ func (s *PluginSuite) createRegistrationEntry(entry *common.RegistrationEntry) *
return registrationEntry
}

func (s *PluginSuite) deleteRegistrationEntry(entryID string) {
_, err := s.ds.DeleteRegistrationEntry(ctx, entryID)
s.Require().NoError(err)
}

func (s *PluginSuite) fetchRegistrationEntry(entryID string) *common.RegistrationEntry {
registrationEntry, err := s.ds.FetchRegistrationEntry(ctx, entryID)
s.Require().NoError(err)
Expand Down
7 changes: 7 additions & 0 deletions test/fakes/fakedatastore/fakedatastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ func (s *DataStore) CreateRegistrationEntry(ctx context.Context, entry *common.R
return s.ds.CreateRegistrationEntry(ctx, entry)
}

func (s *DataStore) CreateOrReturnRegistrationEntry(ctx context.Context, entry *common.RegistrationEntry) (*common.RegistrationEntry, bool, error) {
if err := s.getNextError(); err != nil {
return nil, false, err
}
return s.ds.CreateOrReturnRegistrationEntry(ctx, entry)
}

func (s *DataStore) FetchRegistrationEntry(ctx context.Context, entryID string) (*common.RegistrationEntry, error) {
if err := s.getNextError(); err != nil {
return nil, err
Expand Down

0 comments on commit c01d0b6

Please sign in to comment.