From 23fc4ecc5937311b94b1702e039ea07eb77df93a Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Tue, 9 Apr 2024 10:59:10 -0700 Subject: [PATCH 01/14] move db txn to context. add db session helper --- client/db.go | 3 - db/collection_update.go | 2 +- db/db.go | 13 +- db/db_test.go | 2 +- db/index_test.go | 2 +- db/request.go | 2 +- db/session.go | 63 +++ db/session_test.go | 53 +++ db/store.go | 285 ++++++++++++ db/subscriptions.go | 2 +- db/txn_db.go | 422 ------------------ http/handler.go | 2 - http/handler_ccip.go | 4 +- http/handler_collection.go | 4 +- http/handler_lens.go | 16 +- http/handler_store.go | 48 +- http/middleware.go | 63 +-- net/peer_collection.go | 11 +- net/peer_replicator.go | 12 +- net/server.go | 11 +- tests/bench/query/planner/utils.go | 14 +- .../events/simple/with_create_txn_test.go | 7 +- tests/integration/lens.go | 6 +- tests/integration/utils2.go | 36 +- 24 files changed, 510 insertions(+), 573 deletions(-) create mode 100644 db/session.go create mode 100644 db/session_test.go create mode 100644 db/store.go delete mode 100644 db/txn_db.go diff --git a/client/db.go b/client/db.go index a5d855f137..cedd63d492 100644 --- a/client/db.go +++ b/client/db.go @@ -42,9 +42,6 @@ type DB interface { // can safely operate on it concurrently. NewConcurrentTxn(context.Context, bool) (datastore.Txn, error) - // WithTxn returns a new [client.Store] that respects the given transaction. - WithTxn(datastore.Txn) Store - // Root returns the underlying root store, within which all data managed by DefraDB is held. Root() datastore.RootStore diff --git a/db/collection_update.go b/db/collection_update.go index dcc3ba6cba..1a6371b94a 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -439,7 +439,7 @@ func (c *collection) makeSelectionPlan( ctx, identity, c.db.acp, - c.db.WithTxn(txn), + &store{c.db}, txn, ) diff --git a/db/db.go b/db/db.go index 239b26f9a7..5c3269fb59 100644 --- a/db/db.go +++ b/db/db.go @@ -89,7 +89,7 @@ func newDB( ctx context.Context, rootstore datastore.RootStore, options ...Option, -) (*implicitTxnDB, error) { +) (*store, error) { multistore := datastore.MultiStoreFrom(rootstore) parser, err := graphql.NewParser() @@ -119,7 +119,7 @@ func newDB( return nil, err } - return &implicitTxnDB{db}, nil + return &store{db}, nil } // NewTxn creates a new transaction. @@ -134,15 +134,6 @@ func (db *db) NewConcurrentTxn(ctx context.Context, readonly bool) (datastore.Tx return datastore.NewConcurrentTxnFrom(ctx, db.rootstore, txnId, readonly) } -// WithTxn returns a new [client.Store] that respects the given transaction. -func (db *db) WithTxn(txn datastore.Txn) client.Store { - return &explicitTxnDB{ - db: db, - txn: txn, - lensRegistry: db.lensRegistry.WithTxn(txn), - } -} - // Root returns the root datastore. func (db *db) Root() datastore.RootStore { return db.rootstore diff --git a/db/db_test.go b/db/db_test.go index 237a1f21ed..89e5aa9c6b 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -19,7 +19,7 @@ import ( badgerds "github.com/sourcenetwork/defradb/datastore/badger/v4" ) -func newMemoryDB(ctx context.Context) (*implicitTxnDB, error) { +func newMemoryDB(ctx context.Context) (*store, error) { opts := badgerds.Options{Options: badger.DefaultOptions("").WithInMemory(true)} rootstore, err := badgerds.NewDatastore("", &opts) if err != nil { diff --git a/db/index_test.go b/db/index_test.go index 44c2e45f52..56f90fa35d 100644 --- a/db/index_test.go +++ b/db/index_test.go @@ -53,7 +53,7 @@ const ( type indexTestFixture struct { ctx context.Context - db *implicitTxnDB + db *store txn datastore.Txn users client.Collection t *testing.T diff --git a/db/request.go b/db/request.go index 2905ee4de2..21474da089 100644 --- a/db/request.go +++ b/db/request.go @@ -59,7 +59,7 @@ func (db *db) execRequest( ctx, identity, db.acp, - db.WithTxn(txn), + &store{db}, txn, ) diff --git a/db/session.go b/db/session.go new file mode 100644 index 0000000000..333deca6a4 --- /dev/null +++ b/db/session.go @@ -0,0 +1,63 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package db + +import ( + "context" + + "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/datastore" +) + +type contextKey string + +const ( + txnContextKey = contextKey("txn") +) + +// Session wraps a context to make it easier to pass request scoped +// parameters such as transactions. +type Session struct { + context.Context +} + +// NewSession returns a session that wraps the given context. +func NewSession(ctx context.Context) *Session { + return &Session{ctx} +} + +// WithTxn returns a new session with the transaction value set. +func (s *Session) WithTxn(txn datastore.Txn) *Session { + return &Session{context.WithValue(s, txnContextKey, txn)} +} + +// explicitTxn is a transaction that is managed outside of the session. +type explicitTxn struct { + datastore.Txn +} + +func (t *explicitTxn) Commit(ctx context.Context) error { + return nil // do nothing +} + +func (t *explicitTxn) Discard(ctx context.Context) { + // do nothing +} + +// getContextTxn returns the explicit transaction from +// the context or creates a new implicit one. +func getContextTxn(ctx context.Context, db client.DB, readOnly bool) (datastore.Txn, error) { + txn, ok := ctx.Value(txnContextKey).(datastore.Txn) + if ok { + return &explicitTxn{txn}, nil + } + return db.NewTxn(ctx, readOnly) +} diff --git a/db/session_test.go b/db/session_test.go new file mode 100644 index 0000000000..0808ff0620 --- /dev/null +++ b/db/session_test.go @@ -0,0 +1,53 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSessionWithTxn(t *testing.T) { + ctx := context.Background() + + db, err := newMemoryDB(ctx) + require.NoError(t, err) + + txn, err := db.NewTxn(ctx, true) + require.NoError(t, err) + + session := NewSession(ctx).WithTxn(txn) + + // get txn from session + out, err := getContextTxn(session, db, true) + require.NoError(t, err) + + // txn should be explicit + _, ok := out.(*explicitTxn) + assert.True(t, ok) +} + +func TestGetContextTxn(t *testing.T) { + ctx := context.Background() + + db, err := newMemoryDB(ctx) + require.NoError(t, err) + + txn, err := getContextTxn(ctx, db, true) + require.NoError(t, err) + + // txn should not be explicit + _, ok := txn.(*explicitTxn) + assert.False(t, ok) +} diff --git a/db/store.go b/db/store.go new file mode 100644 index 0000000000..4f279a2fac --- /dev/null +++ b/db/store.go @@ -0,0 +1,285 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package db + +import ( + "context" + + "github.com/lens-vm/lens/host-go/config/model" + + "github.com/sourcenetwork/immutable" + + "github.com/sourcenetwork/defradb/client" +) + +var _ client.Store = (*store)(nil) + +type store struct { + *db +} + +// ExecRequest executes a request against the database. +func (s *store) ExecRequest( + ctx context.Context, + identity immutable.Option[string], + request string, +) *client.RequestResult { + txn, err := getContextTxn(ctx, s, false) + if err != nil { + res := &client.RequestResult{} + res.GQL.Errors = []error{err} + return res + } + defer txn.Discard(ctx) + + res := s.db.execRequest(ctx, identity, request, txn) + if len(res.GQL.Errors) > 0 { + return res + } + + if err := txn.Commit(ctx); err != nil { + res.GQL.Errors = []error{err} + return res + } + + return res +} + +// GetCollectionByName returns an existing collection within the database. +func (s *store) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { + txn, err := getContextTxn(ctx, s, true) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + return s.db.getCollectionByName(ctx, txn, name) +} + +// GetCollections gets all the currently defined collections. +func (s *store) GetCollections( + ctx context.Context, + options client.CollectionFetchOptions, +) ([]client.Collection, error) { + txn, err := getContextTxn(ctx, s, true) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + return s.db.getCollections(ctx, txn, options) +} + +// GetSchemaByVersionID returns the schema description for the schema version of the +// ID provided. +// +// Will return an error if it is not found. +func (s *store) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { + txn, err := getContextTxn(ctx, s, true) + if err != nil { + return client.SchemaDescription{}, err + } + defer txn.Discard(ctx) + + return s.db.getSchemaByVersionID(ctx, txn, versionID) +} + +// GetSchemas returns all schema versions that currently exist within +// this [Store]. +func (s *store) GetSchemas( + ctx context.Context, + options client.SchemaFetchOptions, +) ([]client.SchemaDescription, error) { + txn, err := getContextTxn(ctx, s, true) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + return s.db.getSchemas(ctx, txn, options) +} + +// GetAllIndexes gets all the indexes in the database. +func (s *store) GetAllIndexes( + ctx context.Context, +) (map[client.CollectionName][]client.IndexDescription, error) { + txn, err := getContextTxn(ctx, s, true) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + return s.db.getAllIndexDescriptions(ctx, txn) +} + +// AddSchema takes the provided GQL schema in SDL format, and applies it to the database, +// creating the necessary collections, request types, etc. +// +// All schema types provided must not exist prior to calling this, and they may not reference existing +// types previously defined. +func (s *store) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { + txn, err := getContextTxn(ctx, s, false) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + cols, err := s.db.addSchema(ctx, txn, schemaString) + if err != nil { + return nil, err + } + + if err := txn.Commit(ctx); err != nil { + return nil, err + } + return cols, nil +} + +// PatchSchema takes the given JSON patch string and applies it to the set of CollectionDescriptions +// present in the database. +// +// It will also update the GQL types used by the query system. It will error and not apply any of the +// requested, valid updates should the net result of the patch result in an invalid state. The +// individual operations defined in the patch do not need to result in a valid state, only the net result +// of the full patch. +// +// The collections (including the schema version ID) will only be updated if any changes have actually +// been made, if the net result of the patch matches the current persisted description then no changes +// will be applied. +func (s *store) PatchSchema( + ctx context.Context, + patchString string, + migration immutable.Option[model.Lens], + setAsDefaultVersion bool, +) error { + txn, err := getContextTxn(ctx, s, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = s.db.patchSchema(ctx, txn, patchString, migration, setAsDefaultVersion) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +func (s *store) PatchCollection( + ctx context.Context, + patchString string, +) error { + txn, err := getContextTxn(ctx, s, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = s.db.patchCollection(ctx, txn, patchString) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +func (s *store) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { + txn, err := getContextTxn(ctx, s, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = s.db.setActiveSchemaVersion(ctx, txn, schemaVersionID) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +func (s *store) SetMigration(ctx context.Context, cfg client.LensConfig) error { + txn, err := getContextTxn(ctx, s, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = s.db.setMigration(ctx, txn, cfg) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +func (s *store) AddView( + ctx context.Context, + query string, + sdl string, + transform immutable.Option[model.Lens], +) ([]client.CollectionDefinition, error) { + txn, err := getContextTxn(ctx, s, false) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + defs, err := s.db.addView(ctx, txn, query, sdl, transform) + if err != nil { + return nil, err + } + + err = txn.Commit(ctx) + if err != nil { + return nil, err + } + + return defs, nil +} + +// BasicImport imports a json dataset. +// filepath must be accessible to the node. +func (s *store) BasicImport(ctx context.Context, filepath string) error { + txn, err := getContextTxn(ctx, s, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = s.db.basicImport(ctx, txn, filepath) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +// BasicExport exports the current data or subset of data to file in json format. +func (s *store) BasicExport(ctx context.Context, config *client.BackupConfig) error { + txn, err := getContextTxn(ctx, s, true) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = s.db.basicExport(ctx, txn, config) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +func (s *store) LensRegistry() client.LensRegistry { + return s.db.lensRegistry +} diff --git a/db/subscriptions.go b/db/subscriptions.go index f6f187c54f..5958d567be 100644 --- a/db/subscriptions.go +++ b/db/subscriptions.go @@ -80,7 +80,7 @@ func (db *db) handleEvent( ctx, identity, db.acp, - db.WithTxn(txn), + &store{db}, txn, ) diff --git a/db/txn_db.go b/db/txn_db.go deleted file mode 100644 index e77176b433..0000000000 --- a/db/txn_db.go +++ /dev/null @@ -1,422 +0,0 @@ -// Copyright 2023 Democratized Data Foundation -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package db - -import ( - "context" - - "github.com/lens-vm/lens/host-go/config/model" - - "github.com/sourcenetwork/immutable" - - "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" -) - -var _ client.DB = (*implicitTxnDB)(nil) -var _ client.DB = (*explicitTxnDB)(nil) -var _ client.Store = (*implicitTxnDB)(nil) -var _ client.Store = (*explicitTxnDB)(nil) - -type implicitTxnDB struct { - *db -} - -type explicitTxnDB struct { - *db - txn datastore.Txn - lensRegistry client.LensRegistry -} - -// ExecRequest executes a request against the database. -func (db *implicitTxnDB) ExecRequest( - ctx context.Context, - identity immutable.Option[string], - request string, -) *client.RequestResult { - txn, err := db.NewTxn(ctx, false) - if err != nil { - res := &client.RequestResult{} - res.GQL.Errors = []error{err} - return res - } - defer txn.Discard(ctx) - - res := db.execRequest(ctx, identity, request, txn) - if len(res.GQL.Errors) > 0 { - return res - } - - if err := txn.Commit(ctx); err != nil { - res.GQL.Errors = []error{err} - return res - } - - return res -} - -// ExecRequest executes a transaction request against the database. -func (db *explicitTxnDB) ExecRequest( - ctx context.Context, - identity immutable.Option[string], - request string, -) *client.RequestResult { - return db.execRequest(ctx, identity, request, db.txn) -} - -// GetCollectionByName returns an existing collection within the database. -func (db *implicitTxnDB) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - return db.getCollectionByName(ctx, txn, name) -} - -// GetCollectionByName returns an existing collection within the database. -func (db *explicitTxnDB) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { - col, err := db.getCollectionByName(ctx, db.txn, name) - if err != nil { - return nil, err - } - - return col.WithTxn(db.txn), nil -} - -// GetCollections gets all the currently defined collections. -func (db *implicitTxnDB) GetCollections( - ctx context.Context, - options client.CollectionFetchOptions, -) ([]client.Collection, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - return db.getCollections(ctx, txn, options) -} - -// GetCollections gets all the currently defined collections. -func (db *explicitTxnDB) GetCollections( - ctx context.Context, - options client.CollectionFetchOptions, -) ([]client.Collection, error) { - cols, err := db.getCollections(ctx, db.txn, options) - if err != nil { - return nil, err - } - - for i := range cols { - cols[i] = cols[i].WithTxn(db.txn) - } - - return cols, nil -} - -// GetSchemaByVersionID returns the schema description for the schema version of the -// ID provided. -// -// Will return an error if it is not found. -func (db *implicitTxnDB) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return client.SchemaDescription{}, err - } - defer txn.Discard(ctx) - - return db.getSchemaByVersionID(ctx, txn, versionID) -} - -// GetSchemaByVersionID returns the schema description for the schema version of the -// ID provided. -// -// Will return an error if it is not found. -func (db *explicitTxnDB) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { - return db.getSchemaByVersionID(ctx, db.txn, versionID) -} - -// GetSchemas returns all schema versions that currently exist within -// this [Store]. -func (db *implicitTxnDB) GetSchemas( - ctx context.Context, - options client.SchemaFetchOptions, -) ([]client.SchemaDescription, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - return db.getSchemas(ctx, txn, options) -} - -// GetSchemas returns all schema versions that currently exist within -// this [Store]. -func (db *explicitTxnDB) GetSchemas( - ctx context.Context, - options client.SchemaFetchOptions, -) ([]client.SchemaDescription, error) { - return db.getSchemas(ctx, db.txn, options) -} - -// GetAllIndexes gets all the indexes in the database. -func (db *implicitTxnDB) GetAllIndexes( - ctx context.Context, -) (map[client.CollectionName][]client.IndexDescription, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - return db.getAllIndexDescriptions(ctx, txn) -} - -// GetAllIndexes gets all the indexes in the database. -func (db *explicitTxnDB) GetAllIndexes( - ctx context.Context, -) (map[client.CollectionName][]client.IndexDescription, error) { - return db.getAllIndexDescriptions(ctx, db.txn) -} - -// AddSchema takes the provided GQL schema in SDL format, and applies it to the database, -// creating the necessary collections, request types, etc. -// -// All schema types provided must not exist prior to calling this, and they may not reference existing -// types previously defined. -func (db *implicitTxnDB) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - cols, err := db.addSchema(ctx, txn, schemaString) - if err != nil { - return nil, err - } - - if err := txn.Commit(ctx); err != nil { - return nil, err - } - return cols, nil -} - -// AddSchema takes the provided GQL schema in SDL format, and applies it to the database, -// creating the necessary collections, request types, etc. -// -// All schema types provided must not exist prior to calling this, and they may not reference existing -// types previously defined. -func (db *explicitTxnDB) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { - return db.addSchema(ctx, db.txn, schemaString) -} - -// PatchSchema takes the given JSON patch string and applies it to the set of CollectionDescriptions -// present in the database. -// -// It will also update the GQL types used by the query system. It will error and not apply any of the -// requested, valid updates should the net result of the patch result in an invalid state. The -// individual operations defined in the patch do not need to result in a valid state, only the net result -// of the full patch. -// -// The collections (including the schema version ID) will only be updated if any changes have actually -// been made, if the net result of the patch matches the current persisted description then no changes -// will be applied. -func (db *implicitTxnDB) PatchSchema( - ctx context.Context, - patchString string, - migration immutable.Option[model.Lens], - setAsDefaultVersion bool, -) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.patchSchema(ctx, txn, patchString, migration, setAsDefaultVersion) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -// PatchSchema takes the given JSON patch string and applies it to the set of CollectionDescriptions -// present in the database. -// -// It will also update the GQL types used by the query system. It will error and not apply any of the -// requested, valid updates should the net result of the patch result in an invalid state. The -// individual operations defined in the patch do not need to result in a valid state, only the net result -// of the full patch. -// -// The collections (including the schema version ID) will only be updated if any changes have actually -// been made, if the net result of the patch matches the current persisted description then no changes -// will be applied. -func (db *explicitTxnDB) PatchSchema( - ctx context.Context, - patchString string, - migration immutable.Option[model.Lens], - setAsDefaultVersion bool, -) error { - return db.patchSchema(ctx, db.txn, patchString, migration, setAsDefaultVersion) -} - -func (db *implicitTxnDB) PatchCollection( - ctx context.Context, - patchString string, -) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.patchCollection(ctx, txn, patchString) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -func (db *explicitTxnDB) PatchCollection( - ctx context.Context, - patchString string, -) error { - return db.patchCollection(ctx, db.txn, patchString) -} - -func (db *implicitTxnDB) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.setActiveSchemaVersion(ctx, txn, schemaVersionID) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -func (db *explicitTxnDB) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { - return db.setActiveSchemaVersion(ctx, db.txn, schemaVersionID) -} - -func (db *implicitTxnDB) SetMigration(ctx context.Context, cfg client.LensConfig) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.setMigration(ctx, txn, cfg) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -func (db *explicitTxnDB) SetMigration(ctx context.Context, cfg client.LensConfig) error { - return db.setMigration(ctx, db.txn, cfg) -} - -func (db *implicitTxnDB) AddView( - ctx context.Context, - query string, - sdl string, - transform immutable.Option[model.Lens], -) ([]client.CollectionDefinition, error) { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - defs, err := db.addView(ctx, txn, query, sdl, transform) - if err != nil { - return nil, err - } - - err = txn.Commit(ctx) - if err != nil { - return nil, err - } - - return defs, nil -} - -func (db *explicitTxnDB) AddView( - ctx context.Context, - query string, - sdl string, - transform immutable.Option[model.Lens], -) ([]client.CollectionDefinition, error) { - return db.addView(ctx, db.txn, query, sdl, transform) -} - -// BasicImport imports a json dataset. -// filepath must be accessible to the node. -func (db *implicitTxnDB) BasicImport(ctx context.Context, filepath string) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.basicImport(ctx, txn, filepath) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -// BasicImport imports a json dataset. -// filepath must be accessible to the node. -func (db *explicitTxnDB) BasicImport(ctx context.Context, filepath string) error { - return db.basicImport(ctx, db.txn, filepath) -} - -// BasicExport exports the current data or subset of data to file in json format. -func (db *implicitTxnDB) BasicExport(ctx context.Context, config *client.BackupConfig) error { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.basicExport(ctx, txn, config) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -// BasicExport exports the current data or subset of data to file in json format. -func (db *explicitTxnDB) BasicExport(ctx context.Context, config *client.BackupConfig) error { - return db.basicExport(ctx, db.txn, config) -} - -// LensRegistry returns the LensRegistry in use by this database instance. -// -// It exposes several useful thread-safe migration related functions. -func (db *explicitTxnDB) LensRegistry() client.LensRegistry { - return db.lensRegistry -} diff --git a/http/handler.go b/http/handler.go index 7cd278593b..e6d83dbdd3 100644 --- a/http/handler.go +++ b/http/handler.go @@ -54,7 +54,6 @@ func NewApiRouter() (*Router, error) { }) router.AddRouteGroup(func(r *Router) { - r.AddMiddleware(LensMiddleware) lens_handler.bindRoutes(r) }) @@ -82,7 +81,6 @@ func NewHandler(db client.DB) (*Handler, error) { r.Use( ApiMiddleware(db, txs), TransactionMiddleware, - StoreMiddleware, ) r.Handle("/*", router) }) diff --git a/http/handler_ccip.go b/http/handler_ccip.go index 36151c5cc3..d89103c78a 100644 --- a/http/handler_ccip.go +++ b/http/handler_ccip.go @@ -35,7 +35,7 @@ type CCIPResponse struct { // ExecCCIP handles GraphQL over Cross Chain Interoperability Protocol requests. func (c *ccipHandler) ExecCCIP(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) var ccipReq CCIPRequest switch req.Method { @@ -61,7 +61,7 @@ func (c *ccipHandler) ExecCCIP(rw http.ResponseWriter, req *http.Request) { } identity := getIdentityFromAuthHeader(req) - result := store.ExecRequest(req.Context(), identity, request.Query) + result := db.ExecRequest(req.Context(), identity, request.Query) if result.Pub != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrStreamingNotSupported}) return diff --git a/http/handler_collection.go b/http/handler_collection.go index 1f41442849..05e842d473 100644 --- a/http/handler_collection.go +++ b/http/handler_collection.go @@ -331,8 +331,8 @@ func (s *collectionHandler) CreateIndex(rw http.ResponseWriter, req *http.Reques } func (s *collectionHandler) GetIndexes(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) - indexesMap, err := store.GetAllIndexes(req.Context()) + db := req.Context().Value(dbContextKey).(client.DB) + indexesMap, err := db.GetAllIndexes(req.Context()) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) diff --git a/http/handler_lens.go b/http/handler_lens.go index 532eaacefc..7104116781 100644 --- a/http/handler_lens.go +++ b/http/handler_lens.go @@ -22,9 +22,9 @@ import ( type lensHandler struct{} func (s *lensHandler) ReloadLenses(rw http.ResponseWriter, req *http.Request) { - lens := req.Context().Value(lensContextKey).(client.LensRegistry) + db := req.Context().Value(dbContextKey).(client.DB) - err := lens.ReloadLenses(req.Context()) + err := db.LensRegistry().ReloadLenses(req.Context()) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -33,7 +33,7 @@ func (s *lensHandler) ReloadLenses(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { - lens := req.Context().Value(lensContextKey).(client.LensRegistry) + db := req.Context().Value(dbContextKey).(client.DB) var request setMigrationRequest if err := requestJSON(req, &request); err != nil { @@ -41,7 +41,7 @@ func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { return } - err := lens.SetMigration(req.Context(), request.CollectionID, request.Config) + err := db.LensRegistry().SetMigration(req.Context(), request.CollectionID, request.Config) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -50,7 +50,7 @@ func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { - lens := req.Context().Value(lensContextKey).(client.LensRegistry) + db := req.Context().Value(dbContextKey).(client.DB) var request migrateRequest if err := requestJSON(req, &request); err != nil { @@ -58,7 +58,7 @@ func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { return } - result, err := lens.MigrateUp(req.Context(), enumerable.New(request.Data), request.CollectionID) + result, err := db.LensRegistry().MigrateUp(req.Context(), enumerable.New(request.Data), request.CollectionID) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -75,7 +75,7 @@ func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) MigrateDown(rw http.ResponseWriter, req *http.Request) { - lens := req.Context().Value(lensContextKey).(client.LensRegistry) + db := req.Context().Value(dbContextKey).(client.DB) var request migrateRequest if err := requestJSON(req, &request); err != nil { @@ -83,7 +83,7 @@ func (s *lensHandler) MigrateDown(rw http.ResponseWriter, req *http.Request) { return } - result, err := lens.MigrateDown(req.Context(), enumerable.New(request.Data), request.CollectionID) + result, err := db.LensRegistry().MigrateDown(req.Context(), enumerable.New(request.Data), request.CollectionID) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return diff --git a/http/handler_store.go b/http/handler_store.go index 4c57eda34f..231316ade7 100644 --- a/http/handler_store.go +++ b/http/handler_store.go @@ -27,14 +27,14 @@ import ( type storeHandler struct{} func (s *storeHandler) BasicImport(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) var config client.BackupConfig if err := requestJSON(req, &config); err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - err := store.BasicImport(req.Context(), config.Filepath) + err := db.BasicImport(req.Context(), config.Filepath) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -43,14 +43,14 @@ func (s *storeHandler) BasicImport(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) BasicExport(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) var config client.BackupConfig if err := requestJSON(req, &config); err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - err := store.BasicExport(req.Context(), &config) + err := db.BasicExport(req.Context(), &config) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -59,14 +59,14 @@ func (s *storeHandler) BasicExport(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) AddSchema(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) schema, err := io.ReadAll(req.Body) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - cols, err := store.AddSchema(req.Context(), string(schema)) + cols, err := db.AddSchema(req.Context(), string(schema)) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -75,7 +75,7 @@ func (s *storeHandler) AddSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) var message patchSchemaRequest err := requestJSON(req, &message) @@ -84,7 +84,7 @@ func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { return } - err = store.PatchSchema(req.Context(), message.Patch, message.Migration, message.SetAsDefaultVersion) + err = db.PatchSchema(req.Context(), message.Patch, message.Migration, message.SetAsDefaultVersion) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -93,7 +93,7 @@ func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) var patch string err := requestJSON(req, &patch) @@ -102,7 +102,7 @@ func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request return } - err = store.PatchCollection(req.Context(), patch) + err = db.PatchCollection(req.Context(), patch) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -111,14 +111,14 @@ func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request } func (s *storeHandler) SetActiveSchemaVersion(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) schemaVersionID, err := io.ReadAll(req.Body) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - err = store.SetActiveSchemaVersion(req.Context(), string(schemaVersionID)) + err = db.SetActiveSchemaVersion(req.Context(), string(schemaVersionID)) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -127,7 +127,7 @@ func (s *storeHandler) SetActiveSchemaVersion(rw http.ResponseWriter, req *http. } func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) var message addViewRequest err := requestJSON(req, &message) @@ -136,7 +136,7 @@ func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { return } - defs, err := store.AddView(req.Context(), message.Query, message.SDL, message.Transform) + defs, err := db.AddView(req.Context(), message.Query, message.SDL, message.Transform) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -146,7 +146,7 @@ func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) var cfg client.LensConfig if err := requestJSON(req, &cfg); err != nil { @@ -154,7 +154,7 @@ func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { return } - err := store.SetMigration(req.Context(), cfg) + err := db.SetMigration(req.Context(), cfg) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -163,7 +163,7 @@ func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) options := client.CollectionFetchOptions{} if req.URL.Query().Has("name") { @@ -186,7 +186,7 @@ func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) options.IncludeInactive = immutable.Some(getInactive) } - cols, err := store.GetCollections(req.Context(), options) + cols, err := db.GetCollections(req.Context(), options) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -199,7 +199,7 @@ func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) } func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) options := client.SchemaFetchOptions{} if req.URL.Query().Has("version_id") { @@ -212,7 +212,7 @@ func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { options.Name = immutable.Some(req.URL.Query().Get("name")) } - schema, err := store.GetSchemas(req.Context(), options) + schema, err := db.GetSchemas(req.Context(), options) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -221,9 +221,9 @@ func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) GetAllIndexes(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) - indexes, err := store.GetAllIndexes(req.Context()) + indexes, err := db.GetAllIndexes(req.Context()) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -296,7 +296,7 @@ func (res *GraphQLResponse) UnmarshalJSON(data []byte) error { } func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) var request GraphQLRequest switch { @@ -313,7 +313,7 @@ func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) { } identity := getIdentityFromAuthHeader(req) - result := store.ExecRequest(req.Context(), identity, request.Query) + result := db.ExecRequest(req.Context(), identity, request.Query) if result.Pub == nil { responseJSON(rw, http.StatusOK, GraphQLResponse{result.GQL.Data, result.GQL.Errors}) diff --git a/http/middleware.go b/http/middleware.go index f18ba8bf60..f7b48f0602 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -23,6 +23,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db" ) const TX_HEADER_NAME = "x-defradb-tx" @@ -34,20 +35,6 @@ var ( txsContextKey = contextKey("txs") // dbContextKey is the context key for the client.DB dbContextKey = contextKey("db") - // txContextKey is the context key for the datastore.Txn - // - // This will only be set if a transaction id is specified. - txContextKey = contextKey("tx") - // storeContextKey is the context key for the client.Store - // - // If a transaction exists, all operations will be executed - // in the current transaction context. - storeContextKey = contextKey("store") - // lensContextKey is the context key for the client.LensRegistry - // - // If a transaction exists, all operations will be executed - // in the current transaction context. - lensContextKey = contextKey("lens") // colContextKey is the context key for the client.Collection // // If a transaction exists, all operations will be executed @@ -103,60 +90,26 @@ func TransactionMiddleware(next http.Handler) http.Handler { return } - ctx := context.WithValue(req.Context(), txContextKey, tx) - next.ServeHTTP(rw, req.WithContext(ctx)) - }) -} - -// StoreMiddleware sets the db context for the current request. -func StoreMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) - - var store client.Store - if tx, ok := req.Context().Value(txContextKey).(datastore.Txn); ok { - store = db.WithTxn(tx) - } else { - store = db + // store transaction in session + session := db.NewSession(req.Context()) + if val, ok := tx.(datastore.Txn); ok { + session = session.WithTxn(val) } - - ctx := context.WithValue(req.Context(), storeContextKey, store) - next.ServeHTTP(rw, req.WithContext(ctx)) - }) -} - -// LensMiddleware sets the lens context for the current request. -func LensMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) - - var lens client.LensRegistry - if tx, ok := req.Context().Value(txContextKey).(datastore.Txn); ok { - lens = store.LensRegistry().WithTxn(tx) - } else { - lens = store.LensRegistry() - } - - ctx := context.WithValue(req.Context(), lensContextKey, lens) - next.ServeHTTP(rw, req.WithContext(ctx)) + next.ServeHTTP(rw, req.WithContext(session)) }) } // CollectionMiddleware sets the collection context for the current request. func CollectionMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) - col, err := store.GetCollectionByName(req.Context(), chi.URLParam(req, "name")) + col, err := db.GetCollectionByName(req.Context(), chi.URLParam(req, "name")) if err != nil { rw.WriteHeader(http.StatusNotFound) return } - if tx, ok := req.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } - ctx := context.WithValue(req.Context(), colContextKey, col) next.ServeHTTP(rw, req.WithContext(ctx)) }) diff --git a/net/peer_collection.go b/net/peer_collection.go index 4ef1139a1c..e1ca249700 100644 --- a/net/peer_collection.go +++ b/net/peer_collection.go @@ -19,6 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" + "github.com/sourcenetwork/defradb/db" ) const marker = byte(0xff) @@ -33,8 +34,9 @@ func (p *Peer) AddP2PCollections(ctx context.Context, collectionIDs []string) er // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - storeCol, err := p.db.WithTxn(txn).GetCollections( - p.ctx, + session := db.NewSession(ctx).WithTxn(txn) + storeCol, err := p.db.GetCollections( + session, client.CollectionFetchOptions{ SchemaRoot: immutable.Some(col), }, @@ -112,8 +114,9 @@ func (p *Peer) RemoveP2PCollections(ctx context.Context, collectionIDs []string) // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - storeCol, err := p.db.WithTxn(txn).GetCollections( - p.ctx, + session := db.NewSession(ctx).WithTxn(txn) + storeCol, err := p.db.GetCollections( + session, client.CollectionFetchOptions{ SchemaRoot: immutable.Some(col), }, diff --git a/net/peer_replicator.go b/net/peer_replicator.go index 93f6070f0b..93fdbe190d 100644 --- a/net/peer_replicator.go +++ b/net/peer_replicator.go @@ -21,6 +21,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" + "github.com/sourcenetwork/defradb/db" ) func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { @@ -39,13 +40,14 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { if err := rep.Info.ID.Validate(); err != nil { return err } + session := db.NewSession(ctx).WithTxn(txn) var collections []client.Collection switch { case len(rep.Schemas) > 0: // if specific collections are chosen get them by name for _, name := range rep.Schemas { - col, err := p.db.WithTxn(txn).GetCollectionByName(ctx, name) + col, err := p.db.GetCollectionByName(session, name) if err != nil { return NewErrReplicatorCollections(err) } @@ -60,7 +62,7 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { default: // default to all collections (unless a collection contains a policy). // TODO-ACP: default to all collections after resolving https://github.com/sourcenetwork/defradb/issues/2366 - allCollections, err := p.db.WithTxn(txn).GetCollections(ctx, client.CollectionFetchOptions{}) + allCollections, err := p.db.GetCollections(session, client.CollectionFetchOptions{}) if err != nil { return NewErrReplicatorCollections(err) } @@ -136,12 +138,14 @@ func (p *Peer) DeleteReplicator(ctx context.Context, rep client.Replicator) erro return err } + session := db.NewSession(ctx).WithTxn(txn) + var collections []client.Collection switch { case len(rep.Schemas) > 0: // if specific collections are chosen get them by name for _, name := range rep.Schemas { - col, err := p.db.WithTxn(txn).GetCollectionByName(ctx, name) + col, err := p.db.GetCollectionByName(session, name) if err != nil { return NewErrReplicatorCollections(err) } @@ -156,7 +160,7 @@ func (p *Peer) DeleteReplicator(ctx context.Context, rep client.Replicator) erro default: // default to all collections - collections, err = p.db.WithTxn(txn).GetCollections(ctx, client.CollectionFetchOptions{}) + collections, err = p.db.GetCollections(session, client.CollectionFetchOptions{}) if err != nil { return NewErrReplicatorCollections(err) } diff --git a/net/server.go b/net/server.go index 58a9f16f75..8b0438579b 100644 --- a/net/server.go +++ b/net/server.go @@ -33,6 +33,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore/badger/v4" + "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/errors" pb "github.com/sourcenetwork/defradb/net/pb" ) @@ -250,11 +251,11 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL return nil, err } defer txn.Discard(ctx) - store := s.db.WithTxn(txn) + session := db.NewSession(ctx).WithTxn(txn) // Currently a schema is the best way we have to link a push log request to a collection, // this will change with https://github.com/sourcenetwork/defradb/issues/1085 - col, err := s.getActiveCollection(ctx, store, string(req.Body.SchemaRoot)) + col, err := s.getActiveCollection(session, s.db, string(req.Body.SchemaRoot)) if err != nil { return nil, err } @@ -271,9 +272,9 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL return nil, errors.Wrap("failed to decode block to ipld.Node", err) } - var session sync.WaitGroup + var wg sync.WaitGroup bp := newBlockProcessor(s.peer, txn, col, dsKey, getter) - err = bp.processRemoteBlock(ctx, &session, nd, true) + err = bp.processRemoteBlock(ctx, &wg, nd, true) if err != nil { log.ErrorContextE( ctx, @@ -283,7 +284,7 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL corelog.Any("CID", cid), ) } - session.Wait() + wg.Wait() bp.mergeBlocks(ctx) err = s.syncIndexedDocs(ctx, col.WithTxn(txn), docID) diff --git a/tests/bench/query/planner/utils.go b/tests/bench/query/planner/utils.go index 5bb4472840..0b61e9d81b 100644 --- a/tests/bench/query/planner/utils.go +++ b/tests/bench/query/planner/utils.go @@ -19,6 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/planner" "github.com/sourcenetwork/defradb/request/graphql" @@ -57,11 +58,11 @@ func runMakePlanBench( fixture fixtures.Generator, query string, ) error { - db, _, err := benchutils.SetupDBAndCollections(b, ctx, fixture) + d, _, err := benchutils.SetupDBAndCollections(b, ctx, fixture) if err != nil { return err } - defer db.Close() + defer d.Close() parser, err := buildParser(ctx, fixture) if err != nil { @@ -73,18 +74,19 @@ func runMakePlanBench( if len(errs) > 0 { return errors.Wrap("failed to parse query string", errors.New(fmt.Sprintf("%v", errs))) } - txn, err := db.NewTxn(ctx, false) + txn, err := d.NewTxn(ctx, false) if err != nil { return errors.Wrap("failed to create txn", err) } - b.ResetTimer() + + session := db.NewSession(ctx).WithTxn(txn) for i := 0; i < b.N; i++ { planner := planner.New( - ctx, + session, acpIdentity.NoIdentity, acp.NoACP, - db.WithTxn(txn), + d, txn, ) plan, err := planner.MakePlan(q) diff --git a/tests/integration/events/simple/with_create_txn_test.go b/tests/integration/events/simple/with_create_txn_test.go index 7ff1f838e7..c837cc37ef 100644 --- a/tests/integration/events/simple/with_create_txn_test.go +++ b/tests/integration/events/simple/with_create_txn_test.go @@ -19,6 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/db" testUtils "github.com/sourcenetwork/defradb/tests/integration/events" ) @@ -42,8 +43,10 @@ func TestEventsSimpleWithCreateWithTxnDiscarded(t *testing.T) { func(ctx context.Context, d client.DB) { txn, err := d.NewTxn(ctx, false) assert.Nil(t, err) - r := d.WithTxn(txn).ExecRequest( - ctx, + + session := db.NewSession(ctx).WithTxn(txn) + r := d.ExecRequest( + session, acpIdentity.NoIdentity, `mutation { create_Users(input: {name: "Shahzad"}) { diff --git a/tests/integration/lens.go b/tests/integration/lens.go index 69c49a1cbc..d63f25bd3f 100644 --- a/tests/integration/lens.go +++ b/tests/integration/lens.go @@ -14,6 +14,7 @@ import ( "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/db" ) // ConfigureMigration is a test action which will configure a Lens migration using the @@ -42,9 +43,10 @@ func configureMigration( action ConfigureMigration, ) { for _, node := range getNodes(action.NodeID, s.nodes) { - db := getStore(s, node, action.TransactionID, action.ExpectedError) + txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) + session := db.NewSession(s.ctx).WithTxn(txn) - err := db.SetMigration(s.ctx, action.LensConfig) + err := node.SetMigration(session, action.LensConfig) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) assertExpectedErrorRaised(s.t, s.testCase.Description, action.ExpectedError, expectedErrorRaised) diff --git a/tests/integration/utils2.go b/tests/integration/utils2.go index 18c97e76d1..c97a1b1013 100644 --- a/tests/integration/utils2.go +++ b/tests/integration/utils2.go @@ -32,6 +32,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/datastore" badgerds "github.com/sourcenetwork/defradb/datastore/badger/v4" + "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/net" "github.com/sourcenetwork/defradb/request/graphql" @@ -1080,8 +1081,9 @@ func getCollections( action GetCollections, ) { for _, node := range getNodes(action.NodeID, s.nodes) { - db := getStore(s, node, action.TransactionID, "") - results, err := db.GetCollections(s.ctx, action.FilterOptions) + txn := getTransaction(s, node, action.TransactionID, "") + session := db.NewSession(s.ctx).WithTxn(txn) + results, err := node.GetCollections(session, action.FilterOptions) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) assertExpectedErrorRaised(s.t, s.testCase.Description, action.ExpectedError, expectedErrorRaised) @@ -1249,11 +1251,12 @@ func createDocViaGQL( input, ) - db := getStore(s, node, immutable.None[int](), action.ExpectedError) + txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) identity := acpIdentity.NewIdentity(action.Identity) - result := db.ExecRequest( - s.ctx, + session := db.NewSession(s.ctx).WithTxn(txn) + result := node.ExecRequest( + session, identity, request, ) @@ -1426,10 +1429,10 @@ func updateDocViaGQL( input, ) - db := getStore(s, node, immutable.None[int](), action.ExpectedError) - - result := db.ExecRequest( - s.ctx, + txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) + session := db.NewSession(s.ctx).WithTxn(txn) + result := node.ExecRequest( + session, acpIdentity.NewIdentity(action.Identity), request, ) @@ -1591,14 +1594,14 @@ func withRetry( return nil } -func getStore( +func getTransaction( s *state, db client.DB, transactionSpecifier immutable.Option[int], expectedError string, -) client.Store { +) datastore.Txn { if !transactionSpecifier.HasValue() { - return db + return nil } transactionID := transactionSpecifier.Value() @@ -1619,7 +1622,7 @@ func getStore( s.txns[transactionID] = txn } - return db.WithTxn(s.txns[transactionID]) + return s.txns[transactionID] } // commitTransaction commits the given transaction. @@ -1647,9 +1650,10 @@ func executeRequest( ) { var expectedErrorRaised bool for nodeID, node := range getNodes(action.NodeID, s.nodes) { - db := getStore(s, node, action.TransactionID, action.ExpectedError) - result := db.ExecRequest( - s.ctx, + txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) + session := db.NewSession(s.ctx).WithTxn(txn) + result := node.ExecRequest( + session, acpIdentity.NewIdentity(action.Identity), action.Request, ) From 19b7db3278fa3a7b7bd441f27471c9e42b42afe7 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Tue, 9 Apr 2024 12:23:27 -0700 Subject: [PATCH 02/14] move db session to subpackage. replace collection.WithTxn --- cli/collection.go | 5 - cli/index_create.go | 4 - cli/index_drop.go | 5 - cli/index_list.go | 5 - client/collection.go | 6 -- db/backup.go | 9 +- db/collection.go | 95 +++++-------------- db/collection_delete.go | 18 ++-- db/collection_get.go | 6 +- db/collection_index.go | 43 +++++---- db/collection_update.go | 19 ++-- db/{session.go => context.go} | 33 ++----- db/{session_test.go => context_test.go} | 33 +++---- db/index_test.go | 22 +++-- db/indexed_docs_test.go | 16 +++- db/session/session.go | 29 ++++++ http/middleware.go | 8 +- net/peer_collection.go | 10 +- net/peer_replicator.go | 17 ++-- net/server.go | 8 +- planner/create.go | 2 +- planner/delete.go | 2 +- planner/planner.go | 5 +- planner/update.go | 2 +- tests/bench/query/planner/utils.go | 6 +- .../events/simple/with_create_txn_test.go | 6 +- tests/integration/lens.go | 6 +- tests/integration/utils2.go | 18 ++-- 28 files changed, 194 insertions(+), 244 deletions(-) rename db/{session.go => context.go} (56%) rename db/{session_test.go => context_test.go} (76%) create mode 100644 db/session/session.go diff --git a/cli/collection.go b/cli/collection.go index 23ef9194ae..3697977d32 100644 --- a/cli/collection.go +++ b/cli/collection.go @@ -17,7 +17,6 @@ import ( "github.com/spf13/cobra" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" ) func MakeCollectionCommand() *cobra.Command { @@ -71,10 +70,6 @@ func MakeCollectionCommand() *cobra.Command { } col := cols[0] - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } - ctx := context.WithValue(cmd.Context(), colContextKey, col) cmd.SetContext(ctx) return nil diff --git a/cli/index_create.go b/cli/index_create.go index bfe5ec64c2..0d724da15b 100644 --- a/cli/index_create.go +++ b/cli/index_create.go @@ -14,7 +14,6 @@ import ( "github.com/spf13/cobra" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" ) func MakeIndexCreateCommand() *cobra.Command { @@ -52,9 +51,6 @@ Example: create a named index for 'Users' collection on 'name' field: if err != nil { return err } - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } desc, err = col.CreateIndex(cmd.Context(), desc) if err != nil { return err diff --git a/cli/index_drop.go b/cli/index_drop.go index 96f007268d..5dd069b5da 100644 --- a/cli/index_drop.go +++ b/cli/index_drop.go @@ -12,8 +12,6 @@ package cli import ( "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeIndexDropCommand() *cobra.Command { @@ -34,9 +32,6 @@ Example: drop the index 'UsersByName' for 'Users' collection: if err != nil { return err } - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } return col.DropIndex(cmd.Context(), nameArg) }, } diff --git a/cli/index_list.go b/cli/index_list.go index bf1fd21251..481acb7d37 100644 --- a/cli/index_list.go +++ b/cli/index_list.go @@ -12,8 +12,6 @@ package cli import ( "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeIndexListCommand() *cobra.Command { @@ -38,9 +36,6 @@ Example: show all index for 'Users' collection: if err != nil { return err } - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } indexes, err := col.GetIndexes(cmd.Context()) if err != nil { return err diff --git a/client/collection.go b/client/collection.go index aa219b3a74..bab61607a9 100644 --- a/client/collection.go +++ b/client/collection.go @@ -14,8 +14,6 @@ import ( "context" "github.com/sourcenetwork/immutable" - - "github.com/sourcenetwork/defradb/datastore" ) // Collection represents a defradb collection. @@ -192,10 +190,6 @@ type Collection interface { showDeleted bool, ) (*Document, error) - // WithTxn returns a new instance of the collection, with a transaction - // handle instead of a raw DB handle. - WithTxn(datastore.Txn) Collection - // GetAllDocIDs returns all the document IDs that exist in the collection. GetAllDocIDs(ctx context.Context, identity immutable.Option[string]) (<-chan DocIDResult, error) diff --git a/db/backup.go b/db/backup.go index 2d3b824be1..17110bec05 100644 --- a/db/backup.go +++ b/db/backup.go @@ -92,7 +92,7 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to backup - err = col.WithTxn(txn).Create(ctx, acpIdentity.NoIdentity, doc) + err = col.Create(ctx, acpIdentity.NoIdentity, doc) if err != nil { return NewErrDocCreate(err) } @@ -104,7 +104,7 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin return NewErrDocUpdate(err) } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to backup - err = col.WithTxn(txn).Update(ctx, acpIdentity.NoIdentity, doc) + err = col.Update(ctx, acpIdentity.NoIdentity, doc) if err != nil { return NewErrDocUpdate(err) } @@ -191,9 +191,8 @@ func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client if err != nil { return err } - colTxn := col.WithTxn(txn) // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to export - docIDsCh, err := colTxn.GetAllDocIDs(ctx, acpIdentity.NoIdentity) + docIDsCh, err := col.GetAllDocIDs(ctx, acpIdentity.NoIdentity) if err != nil { return err } @@ -210,7 +209,7 @@ func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client } } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to export - doc, err := colTxn.Get(ctx, acpIdentity.NoIdentity, docResultWithID.ID, false) + doc, err := col.Get(ctx, acpIdentity.NoIdentity, docResultWithID.ID, false) if err != nil { return err } diff --git a/db/collection.go b/db/collection.go index d7364df3b2..f23285fc26 100644 --- a/db/collection.go +++ b/db/collection.go @@ -46,18 +46,8 @@ var _ client.Collection = (*collection)(nil) // collection stores data records at Documents, which are gathered // together under a collection name. This is analogous to SQL Tables. type collection struct { - db *db - - // txn represents any externally provided [datastore.Txn] for which any - // operation on this [collection] instance should be scoped to. - // - // If this has no value, operations requiring a transaction should use an - // implicit internally managed transaction, which only lives for duration - // of the operation in question. - txn immutable.Option[datastore.Txn] - - def client.CollectionDefinition - + db *db + def client.CollectionDefinition indexes []CollectionIndex fetcherFactory func() fetcher.Fetcher } @@ -1240,7 +1230,7 @@ func (c *collection) GetAllDocIDs( ctx context.Context, identity immutable.Option[string], ) (<-chan client.DocIDResult, error) { - txn, err := c.getTxn(ctx, true) + txn, err := getContextTxn(ctx, c.db, true) if err != nil { return nil, err } @@ -1271,7 +1261,7 @@ func (c *collection) getAllDocIDsChan( log.ErrorContextE(ctx, errFailedtoCloseQueryReqAllIDs, err) } close(resCh) - c.discardImplicitTxn(ctx, txn) + txn.Discard(ctx) }() for res := range q.Next() { // check for Done on context first @@ -1351,18 +1341,6 @@ func (c *collection) Definition() client.CollectionDefinition { return c.def } -// WithTxn returns a new instance of the collection, with a transaction -// handle instead of a raw DB handle. -func (c *collection) WithTxn(txn datastore.Txn) client.Collection { - return &collection{ - db: c.db, - txn: immutable.Some(txn), - def: c.def, - indexes: c.indexes, - fetcherFactory: c.fetcherFactory, - } -} - // Create a new document. // Will verify the DocID/CID to ensure that the new document is correctly formatted. func (c *collection) Create( @@ -1370,18 +1348,18 @@ func (c *collection) Create( identity immutable.Option[string], doc *client.Document, ) error { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.create(ctx, identity, txn, doc) if err != nil { return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } // CreateMany creates a collection of documents at once. @@ -1391,11 +1369,11 @@ func (c *collection) CreateMany( identity immutable.Option[string], docs []*client.Document, ) error { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) for _, doc := range docs { err = c.create(ctx, identity, txn, doc) @@ -1403,7 +1381,7 @@ func (c *collection) CreateMany( return err } } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) getDocIDAndPrimaryKeyFromDoc( @@ -1476,11 +1454,11 @@ func (c *collection) Update( identity immutable.Option[string], doc *client.Document, ) error { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(doc.ID()) exists, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) @@ -1499,7 +1477,7 @@ func (c *collection) Update( return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } // Contract: DB Exists check is already performed, and a doc with the given ID exists. @@ -1541,11 +1519,11 @@ func (c *collection) Save( identity immutable.Option[string], doc *client.Document, ) error { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) // Check if document already exists with primary DS key. primaryKey := c.getPrimaryKeyFromDocID(doc.ID()) @@ -1567,7 +1545,7 @@ func (c *collection) Save( return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } // save saves the document state. save MUST not be called outside the `c.create` @@ -1823,11 +1801,11 @@ func (c *collection) Delete( identity immutable.Option[string], docID client.DocID, ) (bool, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return false, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) @@ -1835,7 +1813,7 @@ func (c *collection) Delete( if err != nil { return false, err } - return true, c.commitImplicitTxn(ctx, txn) + return true, txn.Commit(ctx) } // Exists checks if a given document exists with supplied DocID. @@ -1844,18 +1822,18 @@ func (c *collection) Exists( identity immutable.Option[string], docID client.DocID, ) (bool, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return false, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) exists, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) if err != nil && !errors.Is(err, ds.ErrNotFound) { return false, err } - return exists && !isDeleted, c.commitImplicitTxn(ctx, txn) + return exists && !isDeleted, txn.Commit(ctx) } // check if a document exists with the given primary key @@ -1916,35 +1894,6 @@ func (c *collection) saveCompositeToMerkleCRDT( return merkleCRDT.Save(ctx, links) } -// getTxn gets or creates a new transaction from the underlying db. -// If the collection already has a txn, return the existing one. -// Otherwise, create a new implicit transaction. -func (c *collection) getTxn(ctx context.Context, readonly bool) (datastore.Txn, error) { - if c.txn.HasValue() { - return c.txn.Value(), nil - } - return c.db.NewTxn(ctx, readonly) -} - -// discardImplicitTxn is a proxy function used by the collection to execute the Discard() -// transaction function only if its an implicit transaction. -// -// Implicit transactions are transactions that are created *during* an operation execution as a side effect. -// -// Explicit transactions are provided to the collection object via the "WithTxn(...)" function. -func (c *collection) discardImplicitTxn(ctx context.Context, txn datastore.Txn) { - if !c.txn.HasValue() { - txn.Discard(ctx) - } -} - -func (c *collection) commitImplicitTxn(ctx context.Context, txn datastore.Txn) error { - if !c.txn.HasValue() { - return txn.Commit(ctx) - } - return nil -} - func (c *collection) getPrimaryKeyFromDocID(docID client.DocID) core.PrimaryDataStoreKey { return core.PrimaryDataStoreKey{ CollectionRootID: c.Description().RootID, diff --git a/db/collection_delete.go b/db/collection_delete.go index 984cd27a21..155e171e63 100644 --- a/db/collection_delete.go +++ b/db/collection_delete.go @@ -54,12 +54,12 @@ func (c *collection) DeleteWithDocID( identity immutable.Option[string], docID client.DocID, ) (*client.DeleteResult, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) dsKey := c.getPrimaryKeyFromDocID(docID) res, err := c.deleteWithKey(ctx, identity, txn, dsKey) @@ -67,7 +67,7 @@ func (c *collection) DeleteWithDocID( return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } // DeleteWithDocIDs is the same as DeleteWithDocID but accepts multiple DocIDs as a slice. @@ -76,19 +76,19 @@ func (c *collection) DeleteWithDocIDs( identity immutable.Option[string], docIDs []client.DocID, ) (*client.DeleteResult, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) res, err := c.deleteWithIDs(ctx, identity, txn, docIDs, client.Deleted) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } // DeleteWithFilter deletes using a filter to target documents for delete. @@ -97,19 +97,19 @@ func (c *collection) DeleteWithFilter( identity immutable.Option[string], filter any, ) (*client.DeleteResult, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) res, err := c.deleteWithFilter(ctx, identity, txn, filter, client.Deleted) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } func (c *collection) deleteWithKey( diff --git a/db/collection_get.go b/db/collection_get.go index 16d5bd4711..9bcfe54755 100644 --- a/db/collection_get.go +++ b/db/collection_get.go @@ -29,11 +29,11 @@ func (c *collection) Get( showDeleted bool, ) (*client.Document, error) { // create txn - txn, err := c.getTxn(ctx, true) + txn, err := getContextTxn(ctx, c.db, true) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) found, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) @@ -53,7 +53,7 @@ func (c *collection) Get( return nil, client.ErrDocumentNotFoundOrNotAuthorized } - return doc, c.commitImplicitTxn(ctx, txn) + return doc, txn.Commit(ctx) } func (c *collection) get( diff --git a/db/collection_index.go b/db/collection_index.go index 1a7af8cc25..8195730b5b 100644 --- a/db/collection_index.go +++ b/db/collection_index.go @@ -27,6 +27,7 @@ import ( "github.com/sourcenetwork/defradb/db/base" "github.com/sourcenetwork/defradb/db/description" "github.com/sourcenetwork/defradb/db/fetcher" + "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/request/graphql/schema" ) @@ -41,8 +42,8 @@ func (db *db) createCollectionIndex( if err != nil { return client.IndexDescription{}, NewErrCanNotReadCollection(collectionName, err) } - col = col.WithTxn(txn) - return col.CreateIndex(ctx, desc) + sess := session.New(ctx).WithTxn(txn) + return col.CreateIndex(sess, desc) } func (db *db) dropCollectionIndex( @@ -54,8 +55,8 @@ func (db *db) dropCollectionIndex( if err != nil { return NewErrCanNotReadCollection(collectionName, err) } - col = col.WithTxn(txn) - return col.DropIndex(ctx, indexName) + sess := session.New(ctx).WithTxn(txn) + return col.DropIndex(sess, indexName) } // getAllIndexDescriptions returns all the index descriptions in the database. @@ -112,26 +113,26 @@ func (db *db) fetchCollectionIndexDescriptions( } func (c *collection) CreateDocIndex(ctx context.Context, doc *client.Document) error { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.indexNewDoc(ctx, txn, doc) if err != nil { return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) UpdateDocIndex(ctx context.Context, oldDoc, newDoc *client.Document) error { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.deleteIndexedDoc(ctx, txn, oldDoc) if err != nil { @@ -142,22 +143,22 @@ func (c *collection) UpdateDocIndex(ctx context.Context, oldDoc, newDoc *client. return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) DeleteDocIndex(ctx context.Context, doc *client.Document) error { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.deleteIndexedDoc(ctx, txn, doc) if err != nil { return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) indexNewDoc(ctx context.Context, txn datastore.Txn, doc *client.Document) error { @@ -242,17 +243,17 @@ func (c *collection) CreateIndex( ctx context.Context, desc client.IndexDescription, ) (client.IndexDescription, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return client.IndexDescription{}, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) index, err := c.createIndex(ctx, txn, desc) if err != nil { return client.IndexDescription{}, err } - return index.Description(), c.commitImplicitTxn(ctx, txn) + return index.Description(), txn.Commit(ctx) } func (c *collection) createIndex( @@ -398,17 +399,17 @@ func (c *collection) indexExistingDocs( // // All index artifacts for existing documents related the index will be removed. func (c *collection) DropIndex(ctx context.Context, indexName string) error { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.dropIndex(ctx, txn, indexName) if err != nil { return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) dropIndex(ctx context.Context, txn datastore.Txn, indexName string) error { @@ -486,11 +487,11 @@ func (c *collection) loadIndexes(ctx context.Context, txn datastore.Txn) error { // GetIndexes returns all indexes for the collection. func (c *collection) GetIndexes(ctx context.Context) ([]client.IndexDescription, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.loadIndexes(ctx, txn) if err != nil { diff --git a/db/collection_update.go b/db/collection_update.go index 1a6371b94a..7504a9cc78 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -57,16 +57,16 @@ func (c *collection) UpdateWithFilter( filter any, updater string, ) (*client.UpdateResult, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) res, err := c.updateWithFilter(ctx, identity, txn, filter, updater) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } // UpdateWithDocID updates using a DocID to target a single document for update. @@ -78,17 +78,17 @@ func (c *collection) UpdateWithDocID( docID client.DocID, updater string, ) (*client.UpdateResult, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) res, err := c.updateWithDocID(ctx, identity, txn, docID, updater) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } // UpdateWithDocIDs is the same as UpdateWithDocID but accepts multiple DocIDs as a slice. @@ -100,17 +100,17 @@ func (c *collection) UpdateWithDocIDs( docIDs []client.DocID, updater string, ) (*client.UpdateResult, error) { - txn, err := c.getTxn(ctx, false) + txn, err := getContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) res, err := c.updateWithIDs(ctx, identity, txn, docIDs, updater) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } func (c *collection) updateWithDocID( @@ -333,7 +333,6 @@ func (c *collection) patchPrimaryDoc( if err != nil { return err } - primaryCol = primaryCol.WithTxn(txn) primarySchema := primaryCol.Schema() primaryField, ok := primaryCol.Description().GetFieldByRelation( diff --git a/db/session.go b/db/context.go similarity index 56% rename from db/session.go rename to db/context.go index 333deca6a4..f49bfbb7a1 100644 --- a/db/session.go +++ b/db/context.go @@ -13,32 +13,10 @@ package db import ( "context" - "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db/session" ) -type contextKey string - -const ( - txnContextKey = contextKey("txn") -) - -// Session wraps a context to make it easier to pass request scoped -// parameters such as transactions. -type Session struct { - context.Context -} - -// NewSession returns a session that wraps the given context. -func NewSession(ctx context.Context) *Session { - return &Session{ctx} -} - -// WithTxn returns a new session with the transaction value set. -func (s *Session) WithTxn(txn datastore.Txn) *Session { - return &Session{context.WithValue(s, txnContextKey, txn)} -} - // explicitTxn is a transaction that is managed outside of the session. type explicitTxn struct { datastore.Txn @@ -52,10 +30,15 @@ func (t *explicitTxn) Discard(ctx context.Context) { // do nothing } +// transactionDB is a db that can create transactions. +type transactionDB interface { + NewTxn(context.Context, bool) (datastore.Txn, error) +} + // getContextTxn returns the explicit transaction from // the context or creates a new implicit one. -func getContextTxn(ctx context.Context, db client.DB, readOnly bool) (datastore.Txn, error) { - txn, ok := ctx.Value(txnContextKey).(datastore.Txn) +func getContextTxn(ctx context.Context, db transactionDB, readOnly bool) (datastore.Txn, error) { + txn, ok := ctx.Value(session.TxnContextKey).(datastore.Txn) if ok { return &explicitTxn{txn}, nil } diff --git a/db/session_test.go b/db/context_test.go similarity index 76% rename from db/session_test.go rename to db/context_test.go index 0808ff0620..57f6621ae5 100644 --- a/db/session_test.go +++ b/db/context_test.go @@ -14,40 +14,41 @@ import ( "context" "testing" + "github.com/sourcenetwork/defradb/db/session" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestSessionWithTxn(t *testing.T) { +func TestGetContextImplicitTxn(t *testing.T) { ctx := context.Background() db, err := newMemoryDB(ctx) require.NoError(t, err) - txn, err := db.NewTxn(ctx, true) - require.NoError(t, err) - - session := NewSession(ctx).WithTxn(txn) - - // get txn from session - out, err := getContextTxn(session, db, true) + txn, err := getContextTxn(ctx, db, true) require.NoError(t, err) - // txn should be explicit - _, ok := out.(*explicitTxn) - assert.True(t, ok) + // txn should be implicit + _, ok := txn.(*explicitTxn) + assert.False(t, ok) } -func TestGetContextTxn(t *testing.T) { +func TestGetContextExplicitTxn(t *testing.T) { ctx := context.Background() db, err := newMemoryDB(ctx) require.NoError(t, err) - txn, err := getContextTxn(ctx, db, true) + txn, err := db.NewTxn(ctx, true) require.NoError(t, err) - // txn should not be explicit - _, ok := txn.(*explicitTxn) - assert.False(t, ok) + // create a session with a transaction + sess := session.New(ctx).WithTxn(txn) + + out, err := getContextTxn(sess, db, true) + require.NoError(t, err) + + // txn should be explicit + _, ok := out.(*explicitTxn) + assert.True(t, ok) } diff --git a/db/index_test.go b/db/index_test.go index 56f90fa35d..fe738e9c3f 100644 --- a/db/index_test.go +++ b/db/index_test.go @@ -29,6 +29,7 @@ import ( "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/datastore/mocks" + "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/request/graphql/schema" ) @@ -784,7 +785,8 @@ func TestCollectionGetIndexes_ShouldCloseQueryIterator(t *testing.T) { mockedTxn.MockSystemstore.EXPECT().Query(mock.Anything, mock.Anything). Return(queryResults, nil) - _, err := f.users.WithTxn(mockedTxn).GetIndexes(f.ctx) + sess := session.New(f.ctx).WithTxn(mockedTxn) + _, err := f.users.GetIndexes(sess) assert.NoError(t, err) } @@ -840,7 +842,8 @@ func TestCollectionGetIndexes_IfSystemStoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - _, err := f.users.WithTxn(mockedTxn).GetIndexes(f.ctx) + sess := session.New(f.ctx).WithTxn(mockedTxn) + _, err := f.users.GetIndexes(sess) require.ErrorIs(t, err, testCase.ExpectedError) } } @@ -902,7 +905,8 @@ func TestCollectionGetIndexes_IfStoredIndexWithUnsupportedType_ReturnError(t *te mockedTxn.MockSystemstore.EXPECT().Query(mock.Anything, mock.Anything). Return(mocks.NewQueryResultsWithValues(t, indexDescData), nil) - _, err = collection.WithTxn(mockedTxn).GetIndexes(f.ctx) + sess := session.New(f.ctx).WithTxn(mockedTxn) + _, err = collection.GetIndexes(sess) require.ErrorIs(t, err, NewErrUnsupportedIndexFieldType(unsupportedKind)) } @@ -1093,17 +1097,18 @@ func TestDropIndex_IfFailsToDeleteFromStorage_ReturnError(t *testing.T) { mockedTxn.MockDatastore.EXPECT().Query(mock.Anything, mock.Anything).Maybe(). Return(mocks.NewQueryResultsWithValues(t), nil) - err := f.users.WithTxn(mockedTxn).DropIndex(f.ctx, testUsersColIndexName) + sess := session.New(f.ctx).WithTxn(mockedTxn) + err := f.users.DropIndex(sess, testUsersColIndexName) require.ErrorIs(t, err, testErr) } func TestDropIndex_ShouldUpdateCollectionsDescription(t *testing.T) { f := newIndexTestFixture(t) defer f.db.Close() - col := f.users.WithTxn(f.txn) - _, err := col.CreateIndex(f.ctx, getUsersIndexDescOnName()) + sess := session.New(f.ctx).WithTxn(f.txn) + _, err := f.users.CreateIndex(sess, getUsersIndexDescOnName()) require.NoError(t, err) - indOnAge, err := col.CreateIndex(f.ctx, getUsersIndexDescOnAge()) + indOnAge, err := f.users.CreateIndex(sess, getUsersIndexDescOnAge()) require.NoError(t, err) f.commitTxn() @@ -1144,7 +1149,8 @@ func TestDropIndex_IfSystemStoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - err := f.users.WithTxn(mockedTxn).DropIndex(f.ctx, testUsersColIndexName) + sess := session.New(f.ctx).WithTxn(mockedTxn) + err := f.users.DropIndex(sess, testUsersColIndexName) require.ErrorIs(t, err, testErr) } diff --git a/db/indexed_docs_test.go b/db/indexed_docs_test.go index c11eb2617f..84000a5672 100644 --- a/db/indexed_docs_test.go +++ b/db/indexed_docs_test.go @@ -31,6 +31,7 @@ import ( "github.com/sourcenetwork/defradb/datastore/mocks" "github.com/sourcenetwork/defradb/db/fetcher" fetcherMocks "github.com/sourcenetwork/defradb/db/fetcher/mocks" + "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/planner/mapper" ) @@ -322,7 +323,8 @@ func TestNonUnique_IfFailsToStoredIndexedDoc_Error(t *testing.T) { dataStoreOn.Put(mock.Anything, key.ToDS(), mock.Anything).Return(errors.New("error")) dataStoreOn.Put(mock.Anything, mock.Anything, mock.Anything).Return(nil) - err := f.users.WithTxn(mockTxn).Create(f.ctx, acpIdentity.NoIdentity, doc) + sess := session.New(f.ctx).WithTxn(mockTxn) + err := f.users.Create(sess, acpIdentity.NoIdentity, doc) require.ErrorIs(f.t, err, NewErrFailedToStoreIndexedField("name", nil)) } @@ -360,7 +362,8 @@ func TestNonUnique_IfSystemStorageHasInvalidIndexDescription_Error(t *testing.T) systemStoreOn.Query(mock.Anything, mock.Anything). Return(mocks.NewQueryResultsWithValues(t, []byte("invalid")), nil) - err := f.users.WithTxn(mockTxn).Create(f.ctx, acpIdentity.NoIdentity, doc) + sess := session.New(f.ctx).WithTxn(mockTxn) + err := f.users.Create(sess, acpIdentity.NoIdentity, doc) assert.ErrorIs(t, err, datastore.NewErrInvalidStoredValue(nil)) } @@ -378,7 +381,8 @@ func TestNonUnique_IfSystemStorageFailsToReadIndexDesc_Error(t *testing.T) { systemStoreOn.Query(mock.Anything, mock.Anything). Return(nil, testErr) - err := f.users.WithTxn(mockTxn).Create(f.ctx, acpIdentity.NoIdentity, doc) + sess := session.New(f.ctx).WithTxn(mockTxn) + err := f.users.Create(sess, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } @@ -806,7 +810,8 @@ func TestNonUniqueUpdate_IfFailsToReadIndexDescription_ReturnError(t *testing.T) usersCol.(*collection).fetcherFactory = func() fetcher.Fetcher { return fetcherMocks.NewStubbedFetcher(t) } - err = usersCol.WithTxn(mockedTxn).Update(f.ctx, acpIdentity.NoIdentity, doc) + sess := session.New(f.ctx).WithTxn(mockedTxn) + err = usersCol.Update(sess, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } @@ -1048,7 +1053,8 @@ func TestNonUniqueUpdate_IfDatastoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Datastore().Unset() mockedTxn.EXPECT().Datastore().Return(mockedTxn.MockDatastore).Maybe() - err = f.users.WithTxn(mockedTxn).Update(f.ctx, acpIdentity.NoIdentity, doc) + sess := session.New(f.ctx).WithTxn(mockedTxn) + err = f.users.Update(sess, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } } diff --git a/db/session/session.go b/db/session/session.go new file mode 100644 index 0000000000..74a375bbd1 --- /dev/null +++ b/db/session/session.go @@ -0,0 +1,29 @@ +package session + +import ( + "context" + + "github.com/sourcenetwork/defradb/datastore" +) + +type contextKey string + +const ( + TxnContextKey = contextKey("txn") +) + +// Session wraps a context to make it easier to pass request scoped +// parameters such as transactions. +type Session struct { + context.Context +} + +// New returns a new session that wraps the given context. +func New(ctx context.Context) *Session { + return &Session{ctx} +} + +// WithTxn returns a new session with the transaction value set. +func (s *Session) WithTxn(txn datastore.Txn) *Session { + return &Session{context.WithValue(s, TxnContextKey, txn)} +} diff --git a/http/middleware.go b/http/middleware.go index f7b48f0602..2e4d5f1371 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -23,7 +23,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/db" + "github.com/sourcenetwork/defradb/db/session" ) const TX_HEADER_NAME = "x-defradb-tx" @@ -91,11 +91,11 @@ func TransactionMiddleware(next http.Handler) http.Handler { } // store transaction in session - session := db.NewSession(req.Context()) + sess := session.New(req.Context()) if val, ok := tx.(datastore.Txn); ok { - session = session.WithTxn(val) + sess = sess.WithTxn(val) } - next.ServeHTTP(rw, req.WithContext(session)) + next.ServeHTTP(rw, req.WithContext(sess)) }) } diff --git a/net/peer_collection.go b/net/peer_collection.go index e1ca249700..ddd3f9ebf0 100644 --- a/net/peer_collection.go +++ b/net/peer_collection.go @@ -19,7 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/db" + "github.com/sourcenetwork/defradb/db/session" ) const marker = byte(0xff) @@ -34,9 +34,9 @@ func (p *Peer) AddP2PCollections(ctx context.Context, collectionIDs []string) er // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - session := db.NewSession(ctx).WithTxn(txn) + sess := session.New(ctx).WithTxn(txn) storeCol, err := p.db.GetCollections( - session, + sess, client.CollectionFetchOptions{ SchemaRoot: immutable.Some(col), }, @@ -114,9 +114,9 @@ func (p *Peer) RemoveP2PCollections(ctx context.Context, collectionIDs []string) // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - session := db.NewSession(ctx).WithTxn(txn) + sess := session.New(ctx).WithTxn(txn) storeCol, err := p.db.GetCollections( - session, + sess, client.CollectionFetchOptions{ SchemaRoot: immutable.Some(col), }, diff --git a/net/peer_replicator.go b/net/peer_replicator.go index 93fdbe190d..9e2c10703f 100644 --- a/net/peer_replicator.go +++ b/net/peer_replicator.go @@ -21,7 +21,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/db" + "github.com/sourcenetwork/defradb/db/session" ) func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { @@ -40,14 +40,14 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { if err := rep.Info.ID.Validate(); err != nil { return err } - session := db.NewSession(ctx).WithTxn(txn) + sess := session.New(ctx).WithTxn(txn) var collections []client.Collection switch { case len(rep.Schemas) > 0: // if specific collections are chosen get them by name for _, name := range rep.Schemas { - col, err := p.db.GetCollectionByName(session, name) + col, err := p.db.GetCollectionByName(sess, name) if err != nil { return NewErrReplicatorCollections(err) } @@ -62,7 +62,7 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { default: // default to all collections (unless a collection contains a policy). // TODO-ACP: default to all collections after resolving https://github.com/sourcenetwork/defradb/issues/2366 - allCollections, err := p.db.GetCollections(session, client.CollectionFetchOptions{}) + allCollections, err := p.db.GetCollections(sess, client.CollectionFetchOptions{}) if err != nil { return NewErrReplicatorCollections(err) } @@ -111,7 +111,7 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { // push all collection documents to the replicator peer for _, col := range added { // TODO-ACP: Support ACP <> P2P - https://github.com/sourcenetwork/defradb/issues/2366 - keysCh, err := col.WithTxn(txn).GetAllDocIDs(ctx, acpIdentity.NoIdentity) + keysCh, err := col.GetAllDocIDs(sess, acpIdentity.NoIdentity) if err != nil { return NewErrReplicatorDocID(err, col.Name().Value(), rep.Info.ID) } @@ -137,15 +137,14 @@ func (p *Peer) DeleteReplicator(ctx context.Context, rep client.Replicator) erro if err := rep.Info.ID.Validate(); err != nil { return err } - - session := db.NewSession(ctx).WithTxn(txn) + sess := session.New(ctx).WithTxn(txn) var collections []client.Collection switch { case len(rep.Schemas) > 0: // if specific collections are chosen get them by name for _, name := range rep.Schemas { - col, err := p.db.GetCollectionByName(session, name) + col, err := p.db.GetCollectionByName(sess, name) if err != nil { return NewErrReplicatorCollections(err) } @@ -160,7 +159,7 @@ func (p *Peer) DeleteReplicator(ctx context.Context, rep client.Replicator) erro default: // default to all collections - collections, err = p.db.GetCollections(session, client.CollectionFetchOptions{}) + collections, err = p.db.GetCollections(sess, client.CollectionFetchOptions{}) if err != nil { return NewErrReplicatorCollections(err) } diff --git a/net/server.go b/net/server.go index 8b0438579b..1bf9e4a710 100644 --- a/net/server.go +++ b/net/server.go @@ -33,7 +33,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore/badger/v4" - "github.com/sourcenetwork/defradb/db" + "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/errors" pb "github.com/sourcenetwork/defradb/net/pb" ) @@ -252,10 +252,10 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL } defer txn.Discard(ctx) - session := db.NewSession(ctx).WithTxn(txn) + sess := session.New(ctx).WithTxn(txn) // Currently a schema is the best way we have to link a push log request to a collection, // this will change with https://github.com/sourcenetwork/defradb/issues/1085 - col, err := s.getActiveCollection(session, s.db, string(req.Body.SchemaRoot)) + col, err := s.getActiveCollection(sess, s.db, string(req.Body.SchemaRoot)) if err != nil { return nil, err } @@ -287,7 +287,7 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL wg.Wait() bp.mergeBlocks(ctx) - err = s.syncIndexedDocs(ctx, col.WithTxn(txn), docID) + err = s.syncIndexedDocs(sess, col, docID) if err != nil { return nil, err } diff --git a/planner/create.go b/planner/create.go index 3333ae999e..bedb1be5d5 100644 --- a/planner/create.go +++ b/planner/create.go @@ -78,7 +78,7 @@ func (n *createNode) Next() (bool, error) { return false, nil } - if err := n.collection.WithTxn(n.p.txn).Create( + if err := n.collection.Create( n.p.ctx, n.p.identity, n.doc, diff --git a/planner/delete.go b/planner/delete.go index 74bb14d202..87cf0994ac 100644 --- a/planner/delete.go +++ b/planner/delete.go @@ -140,7 +140,7 @@ func (p *Planner) DeleteDocs(parsed *mapper.Mutation) (planNode, error) { p: p, filter: parsed.Filter, docIDs: parsed.DocIDs.Value(), - collection: col.WithTxn(p.txn), + collection: col, source: slctNode, docMapper: docMapper{parsed.DocumentMapping}, }, nil diff --git a/planner/planner.go b/planner/planner.go index eca0168671..0b8cde8aea 100644 --- a/planner/planner.go +++ b/planner/planner.go @@ -21,6 +21,7 @@ import ( "github.com/sourcenetwork/defradb/connor" "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/planner/filter" "github.com/sourcenetwork/defradb/planner/mapper" ) @@ -100,12 +101,14 @@ func New( db client.Store, txn datastore.Txn, ) *Planner { + // all db calls will use this transaction + sess := session.New(ctx).WithTxn(txn) return &Planner{ txn: txn, identity: identity, acp: acp, db: db, - ctx: ctx, + ctx: sess, } } diff --git a/planner/update.go b/planner/update.go index b86c616dbb..458094d4e0 100644 --- a/planner/update.go +++ b/planner/update.go @@ -169,7 +169,7 @@ func (p *Planner) UpdateDocs(parsed *mapper.Mutation) (planNode, error) { if err != nil { return nil, err } - update.collection = col.WithTxn(p.txn) + update.collection = col // create the results Select node resultsNode, err := p.Select(&parsed.Select) diff --git a/tests/bench/query/planner/utils.go b/tests/bench/query/planner/utils.go index 0b61e9d81b..ed2f2bd7e6 100644 --- a/tests/bench/query/planner/utils.go +++ b/tests/bench/query/planner/utils.go @@ -19,7 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/db" + "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/planner" "github.com/sourcenetwork/defradb/request/graphql" @@ -80,10 +80,10 @@ func runMakePlanBench( } b.ResetTimer() - session := db.NewSession(ctx).WithTxn(txn) + sess := session.New(ctx).WithTxn(txn) for i := 0; i < b.N; i++ { planner := planner.New( - session, + sess, acpIdentity.NoIdentity, acp.NoACP, d, diff --git a/tests/integration/events/simple/with_create_txn_test.go b/tests/integration/events/simple/with_create_txn_test.go index c837cc37ef..a9feba5c95 100644 --- a/tests/integration/events/simple/with_create_txn_test.go +++ b/tests/integration/events/simple/with_create_txn_test.go @@ -19,7 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/db" + "github.com/sourcenetwork/defradb/db/session" testUtils "github.com/sourcenetwork/defradb/tests/integration/events" ) @@ -44,9 +44,9 @@ func TestEventsSimpleWithCreateWithTxnDiscarded(t *testing.T) { txn, err := d.NewTxn(ctx, false) assert.Nil(t, err) - session := db.NewSession(ctx).WithTxn(txn) + sess := session.New(ctx).WithTxn(txn) r := d.ExecRequest( - session, + sess, acpIdentity.NoIdentity, `mutation { create_Users(input: {name: "Shahzad"}) { diff --git a/tests/integration/lens.go b/tests/integration/lens.go index d63f25bd3f..8760f957a6 100644 --- a/tests/integration/lens.go +++ b/tests/integration/lens.go @@ -14,7 +14,7 @@ import ( "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/db" + "github.com/sourcenetwork/defradb/db/session" ) // ConfigureMigration is a test action which will configure a Lens migration using the @@ -44,9 +44,9 @@ func configureMigration( ) { for _, node := range getNodes(action.NodeID, s.nodes) { txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) - session := db.NewSession(s.ctx).WithTxn(txn) + sess := session.New(s.ctx).WithTxn(txn) - err := node.SetMigration(session, action.LensConfig) + err := node.SetMigration(sess, action.LensConfig) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) assertExpectedErrorRaised(s.t, s.testCase.Description, action.ExpectedError, expectedErrorRaised) diff --git a/tests/integration/utils2.go b/tests/integration/utils2.go index c97a1b1013..07a49631f9 100644 --- a/tests/integration/utils2.go +++ b/tests/integration/utils2.go @@ -32,7 +32,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/datastore" badgerds "github.com/sourcenetwork/defradb/datastore/badger/v4" - "github.com/sourcenetwork/defradb/db" + "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/net" "github.com/sourcenetwork/defradb/request/graphql" @@ -1082,8 +1082,8 @@ func getCollections( ) { for _, node := range getNodes(action.NodeID, s.nodes) { txn := getTransaction(s, node, action.TransactionID, "") - session := db.NewSession(s.ctx).WithTxn(txn) - results, err := node.GetCollections(session, action.FilterOptions) + sess := session.New(s.ctx).WithTxn(txn) + results, err := node.GetCollections(sess, action.FilterOptions) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) assertExpectedErrorRaised(s.t, s.testCase.Description, action.ExpectedError, expectedErrorRaised) @@ -1254,9 +1254,9 @@ func createDocViaGQL( txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) identity := acpIdentity.NewIdentity(action.Identity) - session := db.NewSession(s.ctx).WithTxn(txn) + sess := session.New(s.ctx).WithTxn(txn) result := node.ExecRequest( - session, + sess, identity, request, ) @@ -1430,9 +1430,9 @@ func updateDocViaGQL( ) txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) - session := db.NewSession(s.ctx).WithTxn(txn) + sess := session.New(s.ctx).WithTxn(txn) result := node.ExecRequest( - session, + sess, acpIdentity.NewIdentity(action.Identity), request, ) @@ -1651,9 +1651,9 @@ func executeRequest( var expectedErrorRaised bool for nodeID, node := range getNodes(action.NodeID, s.nodes) { txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) - session := db.NewSession(s.ctx).WithTxn(txn) + sess := session.New(s.ctx).WithTxn(txn) result := node.ExecRequest( - session, + sess, acpIdentity.NewIdentity(action.Identity), action.Request, ) From 447d9569be9117d3dacd6409dd87eacdf8821bd0 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Tue, 9 Apr 2024 17:31:43 -0700 Subject: [PATCH 03/14] remove lens registry WithTxn --- cli/schema_migration_down.go | 9 +-------- cli/schema_migration_reload.go | 9 +-------- cli/schema_migration_up.go | 9 +-------- client/lens.go | 8 -------- db/backup.go | 11 +++++++---- db/context_test.go | 1 + db/session/session.go | 10 ++++++++++ net/server.go | 19 +++++++++---------- planner/planner.go | 1 - tests/bench/query/planner/utils.go | 4 +--- 10 files changed, 31 insertions(+), 50 deletions(-) diff --git a/cli/schema_migration_down.go b/cli/schema_migration_down.go index 1d7622257c..a49f359694 100644 --- a/cli/schema_migration_down.go +++ b/cli/schema_migration_down.go @@ -17,8 +17,6 @@ import ( "github.com/sourcenetwork/immutable/enumerable" "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeSchemaMigrationDownCommand() *cobra.Command { @@ -67,12 +65,7 @@ Example: migrate from stdin if err := json.Unmarshal(srcData, &src); err != nil { return err } - lens := store.LensRegistry() - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - lens = lens.WithTxn(tx) - } - - out, err := lens.MigrateDown(cmd.Context(), enumerable.New(src), collectionID) + out, err := store.LensRegistry().MigrateDown(cmd.Context(), enumerable.New(src), collectionID) if err != nil { return err } diff --git a/cli/schema_migration_reload.go b/cli/schema_migration_reload.go index 4266b3ec3f..8ffb5542f1 100644 --- a/cli/schema_migration_reload.go +++ b/cli/schema_migration_reload.go @@ -12,8 +12,6 @@ package cli import ( "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeSchemaMigrationReloadCommand() *cobra.Command { @@ -23,12 +21,7 @@ func MakeSchemaMigrationReloadCommand() *cobra.Command { Long: `Reload the schema migrations within DefraDB`, RunE: func(cmd *cobra.Command, args []string) error { store := mustGetContextStore(cmd) - - lens := store.LensRegistry() - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - lens = lens.WithTxn(tx) - } - return lens.ReloadLenses(cmd.Context()) + return store.LensRegistry().ReloadLenses(cmd.Context()) }, } return cmd diff --git a/cli/schema_migration_up.go b/cli/schema_migration_up.go index 577b87d4c7..4473c45911 100644 --- a/cli/schema_migration_up.go +++ b/cli/schema_migration_up.go @@ -17,8 +17,6 @@ import ( "github.com/sourcenetwork/immutable/enumerable" "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeSchemaMigrationUpCommand() *cobra.Command { @@ -67,12 +65,7 @@ Example: migrate from stdin if err := json.Unmarshal(srcData, &src); err != nil { return err } - lens := store.LensRegistry() - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - lens = lens.WithTxn(tx) - } - - out, err := lens.MigrateUp(cmd.Context(), enumerable.New(src), collectionID) + out, err := store.LensRegistry().MigrateUp(cmd.Context(), enumerable.New(src), collectionID) if err != nil { return err } diff --git a/client/lens.go b/client/lens.go index 1a6b423991..3f5befc604 100644 --- a/client/lens.go +++ b/client/lens.go @@ -15,8 +15,6 @@ import ( "github.com/lens-vm/lens/host-go/config/model" "github.com/sourcenetwork/immutable/enumerable" - - "github.com/sourcenetwork/defradb/datastore" ) // LensConfig represents the configuration of a Lens migration in Defra. @@ -43,12 +41,6 @@ type LensConfig struct { // LensRegistry exposes several useful thread-safe migration related functions which may // be used to manage migrations. type LensRegistry interface { - // WithTxn returns a new LensRegistry scoped to the given transaction. - // - // WARNING: Currently this does not provide snapshot isolation, if other transactions are committed - // after this has been created, the results of those commits will be visible within this scope. - WithTxn(datastore.Txn) LensRegistry - // SetMigration caches the migration for the given collection ID. It does not persist the migration in long // term storage, for that one should call [Store.SetMigration(ctx, cfg)]. // diff --git a/db/backup.go b/db/backup.go index 17110bec05..9bdd0220a6 100644 --- a/db/backup.go +++ b/db/backup.go @@ -21,6 +21,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db/session" ) func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath string) (err error) { @@ -91,8 +92,9 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin return NewErrDocFromMap(err) } + sess := session.New(ctx).WithTxn(txn) // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to backup - err = col.Create(ctx, acpIdentity.NoIdentity, doc) + err = col.Create(sess, acpIdentity.NoIdentity, doc) if err != nil { return NewErrDocCreate(err) } @@ -104,7 +106,7 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin return NewErrDocUpdate(err) } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to backup - err = col.Update(ctx, acpIdentity.NoIdentity, doc) + err = col.Update(sess, acpIdentity.NoIdentity, doc) if err != nil { return NewErrDocUpdate(err) } @@ -191,8 +193,9 @@ func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client if err != nil { return err } + sess := session.New(ctx).WithTxn(txn) // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to export - docIDsCh, err := col.GetAllDocIDs(ctx, acpIdentity.NoIdentity) + docIDsCh, err := col.GetAllDocIDs(sess, acpIdentity.NoIdentity) if err != nil { return err } @@ -209,7 +212,7 @@ func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client } } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to export - doc, err := col.Get(ctx, acpIdentity.NoIdentity, docResultWithID.ID, false) + doc, err := col.Get(sess, acpIdentity.NoIdentity, docResultWithID.ID, false) if err != nil { return err } diff --git a/db/context_test.go b/db/context_test.go index 57f6621ae5..c584ce2e2a 100644 --- a/db/context_test.go +++ b/db/context_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/sourcenetwork/defradb/db/session" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/db/session/session.go b/db/session/session.go index 74a375bbd1..7d00a66c7b 100644 --- a/db/session/session.go +++ b/db/session/session.go @@ -1,3 +1,13 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + package session import ( diff --git a/net/server.go b/net/server.go index 1bf9e4a710..57c8a23ebe 100644 --- a/net/server.go +++ b/net/server.go @@ -32,6 +32,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" + "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/datastore/badger/v4" "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/errors" @@ -287,7 +288,7 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL wg.Wait() bp.mergeBlocks(ctx) - err = s.syncIndexedDocs(sess, col, docID) + err = s.syncIndexedDocs(ctx, col, docID, txn) if err != nil { return nil, err } @@ -350,15 +351,13 @@ func (s *server) syncIndexedDocs( ctx context.Context, col client.Collection, docID client.DocID, + txn datastore.Txn, ) error { - preTxnCol, err := s.db.GetCollectionByName(ctx, col.Name().Value()) - if err != nil { - return err - } + sess := session.New(ctx).WithTxn(txn) //TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2365 // Resolve while handling acp <> secondary indexes. - oldDoc, err := preTxnCol.Get(ctx, acpIdentity.NoIdentity, docID, false) + oldDoc, err := col.Get(ctx, acpIdentity.NoIdentity, docID, false) isNewDoc := errors.Is(err, client.ErrDocumentNotFoundOrNotAuthorized) if !isNewDoc && err != nil { return err @@ -366,18 +365,18 @@ func (s *server) syncIndexedDocs( //TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2365 // Resolve while handling acp <> secondary indexes. - doc, err := col.Get(ctx, acpIdentity.NoIdentity, docID, false) + doc, err := col.Get(sess, acpIdentity.NoIdentity, docID, false) isDeletedDoc := errors.Is(err, client.ErrDocumentNotFoundOrNotAuthorized) if !isDeletedDoc && err != nil { return err } if isDeletedDoc { - return preTxnCol.DeleteDocIndex(ctx, oldDoc) + return col.DeleteDocIndex(ctx, oldDoc) } else if isNewDoc { - return col.CreateDocIndex(ctx, doc) + return col.CreateDocIndex(sess, doc) } else { - return col.UpdateDocIndex(ctx, oldDoc, doc) + return col.UpdateDocIndex(sess, oldDoc, doc) } } diff --git a/planner/planner.go b/planner/planner.go index 0b8cde8aea..faeac5a554 100644 --- a/planner/planner.go +++ b/planner/planner.go @@ -101,7 +101,6 @@ func New( db client.Store, txn datastore.Txn, ) *Planner { - // all db calls will use this transaction sess := session.New(ctx).WithTxn(txn) return &Planner{ txn: txn, diff --git a/tests/bench/query/planner/utils.go b/tests/bench/query/planner/utils.go index ed2f2bd7e6..caba91836d 100644 --- a/tests/bench/query/planner/utils.go +++ b/tests/bench/query/planner/utils.go @@ -19,7 +19,6 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/planner" "github.com/sourcenetwork/defradb/request/graphql" @@ -80,10 +79,9 @@ func runMakePlanBench( } b.ResetTimer() - sess := session.New(ctx).WithTxn(txn) for i := 0; i < b.N; i++ { planner := planner.New( - sess, + ctx, acpIdentity.NoIdentity, acp.NoACP, d, From 3fa4b97069938f361dc0de3cfe6cb062eafef5e7 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Wed, 10 Apr 2024 10:31:50 -0700 Subject: [PATCH 04/14] ensure internal db methods have correct transaction context --- db/backup.go | 11 ++-- db/collection.go | 28 +++++++--- db/collection_delete.go | 9 ++- db/collection_get.go | 4 +- db/collection_index.go | 33 +++++++---- db/collection_update.go | 14 ++++- db/context.go | 34 ++++++++--- db/context_test.go | 32 +++++------ db/index_test.go | 27 +++++---- db/indexed_docs_test.go | 21 ++++--- db/{session => }/session.go | 14 ++--- db/session_test.go | 24 ++++++++ db/store.go | 56 ++++++++++++++----- db/subscriptions.go | 2 +- http/middleware.go | 4 +- net/peer_collection.go | 6 +- net/peer_replicator.go | 10 +++- net/server.go | 8 ++- planner/planner.go | 4 +- .../events/simple/with_create_txn_test.go | 4 +- tests/integration/lens.go | 4 +- tests/integration/utils2.go | 10 ++-- 22 files changed, 227 insertions(+), 132 deletions(-) rename db/{session => }/session.go (75%) create mode 100644 db/session_test.go diff --git a/db/backup.go b/db/backup.go index 9bdd0220a6..17110bec05 100644 --- a/db/backup.go +++ b/db/backup.go @@ -21,7 +21,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/db/session" ) func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath string) (err error) { @@ -92,9 +91,8 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin return NewErrDocFromMap(err) } - sess := session.New(ctx).WithTxn(txn) // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to backup - err = col.Create(sess, acpIdentity.NoIdentity, doc) + err = col.Create(ctx, acpIdentity.NoIdentity, doc) if err != nil { return NewErrDocCreate(err) } @@ -106,7 +104,7 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin return NewErrDocUpdate(err) } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to backup - err = col.Update(sess, acpIdentity.NoIdentity, doc) + err = col.Update(ctx, acpIdentity.NoIdentity, doc) if err != nil { return NewErrDocUpdate(err) } @@ -193,9 +191,8 @@ func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client if err != nil { return err } - sess := session.New(ctx).WithTxn(txn) // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to export - docIDsCh, err := col.GetAllDocIDs(sess, acpIdentity.NoIdentity) + docIDsCh, err := col.GetAllDocIDs(ctx, acpIdentity.NoIdentity) if err != nil { return err } @@ -212,7 +209,7 @@ func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client } } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to export - doc, err := col.Get(sess, acpIdentity.NoIdentity, docResultWithID.ID, false) + doc, err := col.Get(ctx, acpIdentity.NoIdentity, docResultWithID.ID, false) if err != nil { return err } diff --git a/db/collection.go b/db/collection.go index f23285fc26..e733fb469b 100644 --- a/db/collection.go +++ b/db/collection.go @@ -1230,11 +1230,11 @@ func (c *collection) GetAllDocIDs( ctx context.Context, identity immutable.Option[string], ) (<-chan client.DocIDResult, error) { - txn, err := getContextTxn(ctx, c.db, true) + ctx, err := ensureContextTxn(ctx, c.db, true) if err != nil { return nil, err } - + txn := mustGetContextTxn(ctx) return c.getAllDocIDsChan(ctx, identity, txn) } @@ -1348,10 +1348,12 @@ func (c *collection) Create( identity immutable.Option[string], doc *client.Document, ) error { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.create(ctx, identity, txn, doc) @@ -1369,10 +1371,12 @@ func (c *collection) CreateMany( identity immutable.Option[string], docs []*client.Document, ) error { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) for _, doc := range docs { @@ -1454,10 +1458,12 @@ func (c *collection) Update( identity immutable.Option[string], doc *client.Document, ) error { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(doc.ID()) @@ -1519,10 +1525,12 @@ func (c *collection) Save( identity immutable.Option[string], doc *client.Document, ) error { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) // Check if document already exists with primary DS key. @@ -1801,10 +1809,12 @@ func (c *collection) Delete( identity immutable.Option[string], docID client.DocID, ) (bool, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return false, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) @@ -1822,10 +1832,12 @@ func (c *collection) Exists( identity immutable.Option[string], docID client.DocID, ) (bool, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return false, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) diff --git a/db/collection_delete.go b/db/collection_delete.go index 155e171e63..fdb9005e7e 100644 --- a/db/collection_delete.go +++ b/db/collection_delete.go @@ -54,11 +54,12 @@ func (c *collection) DeleteWithDocID( identity immutable.Option[string], docID client.DocID, ) (*client.DeleteResult, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) dsKey := c.getPrimaryKeyFromDocID(docID) @@ -76,11 +77,12 @@ func (c *collection) DeleteWithDocIDs( identity immutable.Option[string], docIDs []client.DocID, ) (*client.DeleteResult, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) res, err := c.deleteWithIDs(ctx, identity, txn, docIDs, client.Deleted) @@ -97,11 +99,12 @@ func (c *collection) DeleteWithFilter( identity immutable.Option[string], filter any, ) (*client.DeleteResult, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) res, err := c.deleteWithFilter(ctx, identity, txn, filter, client.Deleted) diff --git a/db/collection_get.go b/db/collection_get.go index 9bcfe54755..b694d962fe 100644 --- a/db/collection_get.go +++ b/db/collection_get.go @@ -29,10 +29,12 @@ func (c *collection) Get( showDeleted bool, ) (*client.Document, error) { // create txn - txn, err := getContextTxn(ctx, c.db, true) + ctx, err := ensureContextTxn(ctx, c.db, true) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) diff --git a/db/collection_index.go b/db/collection_index.go index 8195730b5b..52de557356 100644 --- a/db/collection_index.go +++ b/db/collection_index.go @@ -27,7 +27,6 @@ import ( "github.com/sourcenetwork/defradb/db/base" "github.com/sourcenetwork/defradb/db/description" "github.com/sourcenetwork/defradb/db/fetcher" - "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/request/graphql/schema" ) @@ -42,8 +41,8 @@ func (db *db) createCollectionIndex( if err != nil { return client.IndexDescription{}, NewErrCanNotReadCollection(collectionName, err) } - sess := session.New(ctx).WithTxn(txn) - return col.CreateIndex(sess, desc) + ctx = setContextTxn(ctx, txn) + return col.CreateIndex(ctx, desc) } func (db *db) dropCollectionIndex( @@ -55,8 +54,8 @@ func (db *db) dropCollectionIndex( if err != nil { return NewErrCanNotReadCollection(collectionName, err) } - sess := session.New(ctx).WithTxn(txn) - return col.DropIndex(sess, indexName) + ctx = setContextTxn(ctx, txn) + return col.DropIndex(ctx, indexName) } // getAllIndexDescriptions returns all the index descriptions in the database. @@ -113,10 +112,12 @@ func (db *db) fetchCollectionIndexDescriptions( } func (c *collection) CreateDocIndex(ctx context.Context, doc *client.Document) error { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.indexNewDoc(ctx, txn, doc) @@ -128,10 +129,12 @@ func (c *collection) CreateDocIndex(ctx context.Context, doc *client.Document) e } func (c *collection) UpdateDocIndex(ctx context.Context, oldDoc, newDoc *client.Document) error { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.deleteIndexedDoc(ctx, txn, oldDoc) @@ -147,10 +150,12 @@ func (c *collection) UpdateDocIndex(ctx context.Context, oldDoc, newDoc *client. } func (c *collection) DeleteDocIndex(ctx context.Context, doc *client.Document) error { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.deleteIndexedDoc(ctx, txn, doc) @@ -243,10 +248,12 @@ func (c *collection) CreateIndex( ctx context.Context, desc client.IndexDescription, ) (client.IndexDescription, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return client.IndexDescription{}, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) index, err := c.createIndex(ctx, txn, desc) @@ -399,10 +406,12 @@ func (c *collection) indexExistingDocs( // // All index artifacts for existing documents related the index will be removed. func (c *collection) DropIndex(ctx context.Context, indexName string) error { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.dropIndex(ctx, txn, indexName) @@ -487,10 +496,12 @@ func (c *collection) loadIndexes(ctx context.Context, txn datastore.Txn) error { // GetIndexes returns all indexes for the collection. func (c *collection) GetIndexes(ctx context.Context) ([]client.IndexDescription, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.loadIndexes(ctx, txn) diff --git a/db/collection_update.go b/db/collection_update.go index 7504a9cc78..96b51d5bf3 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -57,11 +57,13 @@ func (c *collection) UpdateWithFilter( filter any, updater string, ) (*client.UpdateResult, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) + res, err := c.updateWithFilter(ctx, identity, txn, filter, updater) if err != nil { return nil, err @@ -78,11 +80,14 @@ func (c *collection) UpdateWithDocID( docID client.DocID, updater string, ) (*client.UpdateResult, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) + res, err := c.updateWithDocID(ctx, identity, txn, docID, updater) if err != nil { return nil, err @@ -100,11 +105,14 @@ func (c *collection) UpdateWithDocIDs( docIDs []client.DocID, updater string, ) (*client.UpdateResult, error) { - txn, err := getContextTxn(ctx, c.db, false) + ctx, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) + res, err := c.updateWithIDs(ctx, identity, txn, docIDs, updater) if err != nil { return nil, err diff --git a/db/context.go b/db/context.go index f49bfbb7a1..5c3c6e1d54 100644 --- a/db/context.go +++ b/db/context.go @@ -14,10 +14,11 @@ import ( "context" "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/db/session" ) -// explicitTxn is a transaction that is managed outside of the session. +type txnContextKey struct{} + +// explicitTxn is a transaction that is managed outside of a db operation. type explicitTxn struct { datastore.Txn } @@ -35,12 +36,29 @@ type transactionDB interface { NewTxn(context.Context, bool) (datastore.Txn, error) } -// getContextTxn returns the explicit transaction from -// the context or creates a new implicit one. -func getContextTxn(ctx context.Context, db transactionDB, readOnly bool) (datastore.Txn, error) { - txn, ok := ctx.Value(session.TxnContextKey).(datastore.Txn) +// ensureContextTxn ensures that the returned context has a transaction. +// +// If a transactions exists on the context it will be made explicit, +// otherwise a new implicit transaction will be created. +func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (context.Context, error) { + txn, ok := ctx.Value(txnContextKey{}).(datastore.Txn) if ok { - return &explicitTxn{txn}, nil + return setContextTxn(ctx, &explicitTxn{txn}), nil } - return db.NewTxn(ctx, readOnly) + txn, err := db.NewTxn(ctx, readOnly) + if err != nil { + return nil, err + } + return setContextTxn(ctx, txn), nil +} + +// mustGetContextTxn returns the transaction from the context if it exists, +// otherwise it panics. +func mustGetContextTxn(ctx context.Context) datastore.Txn { + return ctx.Value(txnContextKey{}).(datastore.Txn) +} + +// setContextTxn returns a new context with the txn value set. +func setContextTxn(ctx context.Context, txn datastore.Txn) context.Context { + return context.WithValue(ctx, txnContextKey{}, txn) } diff --git a/db/context_test.go b/db/context_test.go index c584ce2e2a..9a72f9b91a 100644 --- a/db/context_test.go +++ b/db/context_test.go @@ -14,42 +14,38 @@ import ( "context" "testing" - "github.com/sourcenetwork/defradb/db/session" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestGetContextImplicitTxn(t *testing.T) { +func TestEnsureContextTxnExplicit(t *testing.T) { ctx := context.Background() db, err := newMemoryDB(ctx) require.NoError(t, err) - txn, err := getContextTxn(ctx, db, true) + txn, err := db.NewTxn(ctx, true) require.NoError(t, err) - // txn should be implicit - _, ok := txn.(*explicitTxn) - assert.False(t, ok) + // set an explicit transaction + ctx = setContextTxn(ctx, txn) + + ctx, err = ensureContextTxn(ctx, db, true) + require.NoError(t, err) + + _, ok := mustGetContextTxn(ctx).(*explicitTxn) + assert.True(t, ok) } -func TestGetContextExplicitTxn(t *testing.T) { +func TestEnsureContextTxnImplicit(t *testing.T) { ctx := context.Background() db, err := newMemoryDB(ctx) require.NoError(t, err) - txn, err := db.NewTxn(ctx, true) - require.NoError(t, err) - - // create a session with a transaction - sess := session.New(ctx).WithTxn(txn) - - out, err := getContextTxn(sess, db, true) + ctx, err = ensureContextTxn(ctx, db, true) require.NoError(t, err) - // txn should be explicit - _, ok := out.(*explicitTxn) - assert.True(t, ok) + _, ok := mustGetContextTxn(ctx).(*explicitTxn) + assert.False(t, ok) } diff --git a/db/index_test.go b/db/index_test.go index fe738e9c3f..a6fc9cbd0d 100644 --- a/db/index_test.go +++ b/db/index_test.go @@ -29,7 +29,6 @@ import ( "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/datastore/mocks" - "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/request/graphql/schema" ) @@ -785,8 +784,8 @@ func TestCollectionGetIndexes_ShouldCloseQueryIterator(t *testing.T) { mockedTxn.MockSystemstore.EXPECT().Query(mock.Anything, mock.Anything). Return(queryResults, nil) - sess := session.New(f.ctx).WithTxn(mockedTxn) - _, err := f.users.GetIndexes(sess) + ctx := setContextTxn(f.ctx, mockedTxn) + _, err := f.users.GetIndexes(ctx) assert.NoError(t, err) } @@ -842,8 +841,8 @@ func TestCollectionGetIndexes_IfSystemStoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - sess := session.New(f.ctx).WithTxn(mockedTxn) - _, err := f.users.GetIndexes(sess) + ctx := setContextTxn(f.ctx, mockedTxn) + _, err := f.users.GetIndexes(ctx) require.ErrorIs(t, err, testCase.ExpectedError) } } @@ -905,8 +904,8 @@ func TestCollectionGetIndexes_IfStoredIndexWithUnsupportedType_ReturnError(t *te mockedTxn.MockSystemstore.EXPECT().Query(mock.Anything, mock.Anything). Return(mocks.NewQueryResultsWithValues(t, indexDescData), nil) - sess := session.New(f.ctx).WithTxn(mockedTxn) - _, err = collection.GetIndexes(sess) + ctx := setContextTxn(f.ctx, mockedTxn) + _, err = collection.GetIndexes(ctx) require.ErrorIs(t, err, NewErrUnsupportedIndexFieldType(unsupportedKind)) } @@ -1097,18 +1096,18 @@ func TestDropIndex_IfFailsToDeleteFromStorage_ReturnError(t *testing.T) { mockedTxn.MockDatastore.EXPECT().Query(mock.Anything, mock.Anything).Maybe(). Return(mocks.NewQueryResultsWithValues(t), nil) - sess := session.New(f.ctx).WithTxn(mockedTxn) - err := f.users.DropIndex(sess, testUsersColIndexName) + ctx := setContextTxn(f.ctx, mockedTxn) + err := f.users.DropIndex(ctx, testUsersColIndexName) require.ErrorIs(t, err, testErr) } func TestDropIndex_ShouldUpdateCollectionsDescription(t *testing.T) { f := newIndexTestFixture(t) defer f.db.Close() - sess := session.New(f.ctx).WithTxn(f.txn) - _, err := f.users.CreateIndex(sess, getUsersIndexDescOnName()) + ctx := setContextTxn(f.ctx, f.txn) + _, err := f.users.CreateIndex(ctx, getUsersIndexDescOnName()) require.NoError(t, err) - indOnAge, err := f.users.CreateIndex(sess, getUsersIndexDescOnAge()) + indOnAge, err := f.users.CreateIndex(ctx, getUsersIndexDescOnAge()) require.NoError(t, err) f.commitTxn() @@ -1149,8 +1148,8 @@ func TestDropIndex_IfSystemStoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - sess := session.New(f.ctx).WithTxn(mockedTxn) - err := f.users.DropIndex(sess, testUsersColIndexName) + ctx := setContextTxn(f.ctx, mockedTxn) + err := f.users.DropIndex(ctx, testUsersColIndexName) require.ErrorIs(t, err, testErr) } diff --git a/db/indexed_docs_test.go b/db/indexed_docs_test.go index 84000a5672..4d353e4ea7 100644 --- a/db/indexed_docs_test.go +++ b/db/indexed_docs_test.go @@ -31,7 +31,6 @@ import ( "github.com/sourcenetwork/defradb/datastore/mocks" "github.com/sourcenetwork/defradb/db/fetcher" fetcherMocks "github.com/sourcenetwork/defradb/db/fetcher/mocks" - "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/planner/mapper" ) @@ -323,8 +322,8 @@ func TestNonUnique_IfFailsToStoredIndexedDoc_Error(t *testing.T) { dataStoreOn.Put(mock.Anything, key.ToDS(), mock.Anything).Return(errors.New("error")) dataStoreOn.Put(mock.Anything, mock.Anything, mock.Anything).Return(nil) - sess := session.New(f.ctx).WithTxn(mockTxn) - err := f.users.Create(sess, acpIdentity.NoIdentity, doc) + ctx := setContextTxn(f.ctx, mockTxn) + err := f.users.Create(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(f.t, err, NewErrFailedToStoreIndexedField("name", nil)) } @@ -362,8 +361,8 @@ func TestNonUnique_IfSystemStorageHasInvalidIndexDescription_Error(t *testing.T) systemStoreOn.Query(mock.Anything, mock.Anything). Return(mocks.NewQueryResultsWithValues(t, []byte("invalid")), nil) - sess := session.New(f.ctx).WithTxn(mockTxn) - err := f.users.Create(sess, acpIdentity.NoIdentity, doc) + ctx := setContextTxn(f.ctx, mockTxn) + err := f.users.Create(ctx, acpIdentity.NoIdentity, doc) assert.ErrorIs(t, err, datastore.NewErrInvalidStoredValue(nil)) } @@ -381,8 +380,8 @@ func TestNonUnique_IfSystemStorageFailsToReadIndexDesc_Error(t *testing.T) { systemStoreOn.Query(mock.Anything, mock.Anything). Return(nil, testErr) - sess := session.New(f.ctx).WithTxn(mockTxn) - err := f.users.Create(sess, acpIdentity.NoIdentity, doc) + ctx := setContextTxn(f.ctx, mockTxn) + err := f.users.Create(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } @@ -810,8 +809,8 @@ func TestNonUniqueUpdate_IfFailsToReadIndexDescription_ReturnError(t *testing.T) usersCol.(*collection).fetcherFactory = func() fetcher.Fetcher { return fetcherMocks.NewStubbedFetcher(t) } - sess := session.New(f.ctx).WithTxn(mockedTxn) - err = usersCol.Update(sess, acpIdentity.NoIdentity, doc) + ctx := setContextTxn(f.ctx, mockedTxn) + err = usersCol.Update(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } @@ -1053,8 +1052,8 @@ func TestNonUniqueUpdate_IfDatastoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Datastore().Unset() mockedTxn.EXPECT().Datastore().Return(mockedTxn.MockDatastore).Maybe() - sess := session.New(f.ctx).WithTxn(mockedTxn) - err = f.users.Update(sess, acpIdentity.NoIdentity, doc) + ctx := setContextTxn(f.ctx, mockedTxn) + err = f.users.Update(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } } diff --git a/db/session/session.go b/db/session.go similarity index 75% rename from db/session/session.go rename to db/session.go index 7d00a66c7b..192205b48b 100644 --- a/db/session/session.go +++ b/db/session.go @@ -8,7 +8,7 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -package session +package db import ( "context" @@ -16,24 +16,18 @@ import ( "github.com/sourcenetwork/defradb/datastore" ) -type contextKey string - -const ( - TxnContextKey = contextKey("txn") -) - // Session wraps a context to make it easier to pass request scoped // parameters such as transactions. type Session struct { context.Context } -// New returns a new session that wraps the given context. -func New(ctx context.Context) *Session { +// NewSession returns a new session that wraps the given context. +func NewSession(ctx context.Context) *Session { return &Session{ctx} } // WithTxn returns a new session with the transaction value set. func (s *Session) WithTxn(txn datastore.Txn) *Session { - return &Session{context.WithValue(s, TxnContextKey, txn)} + return &Session{setContextTxn(s, txn)} } diff --git a/db/session_test.go b/db/session_test.go new file mode 100644 index 0000000000..3e71091ca7 --- /dev/null +++ b/db/session_test.go @@ -0,0 +1,24 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSessionWithTxn(t *testing.T) { + sess := NewSession(context.Background()).WithTxn(&explicitTxn{}) + _, ok := sess.Value(txnContextKey{}).(*explicitTxn) + assert.True(t, ok) +} diff --git a/db/store.go b/db/store.go index 4f279a2fac..7839eb099a 100644 --- a/db/store.go +++ b/db/store.go @@ -32,12 +32,14 @@ func (s *store) ExecRequest( identity immutable.Option[string], request string, ) *client.RequestResult { - txn, err := getContextTxn(ctx, s, false) + ctx, err := ensureContextTxn(ctx, s, false) if err != nil { res := &client.RequestResult{} res.GQL.Errors = []error{err} return res } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) res := s.db.execRequest(ctx, identity, request, txn) @@ -55,10 +57,12 @@ func (s *store) ExecRequest( // GetCollectionByName returns an existing collection within the database. func (s *store) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { - txn, err := getContextTxn(ctx, s, true) + ctx, err := ensureContextTxn(ctx, s, true) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getCollectionByName(ctx, txn, name) @@ -69,10 +73,12 @@ func (s *store) GetCollections( ctx context.Context, options client.CollectionFetchOptions, ) ([]client.Collection, error) { - txn, err := getContextTxn(ctx, s, true) + ctx, err := ensureContextTxn(ctx, s, true) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getCollections(ctx, txn, options) @@ -83,10 +89,12 @@ func (s *store) GetCollections( // // Will return an error if it is not found. func (s *store) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { - txn, err := getContextTxn(ctx, s, true) + ctx, err := ensureContextTxn(ctx, s, true) if err != nil { return client.SchemaDescription{}, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getSchemaByVersionID(ctx, txn, versionID) @@ -98,10 +106,12 @@ func (s *store) GetSchemas( ctx context.Context, options client.SchemaFetchOptions, ) ([]client.SchemaDescription, error) { - txn, err := getContextTxn(ctx, s, true) + ctx, err := ensureContextTxn(ctx, s, true) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getSchemas(ctx, txn, options) @@ -111,10 +121,12 @@ func (s *store) GetSchemas( func (s *store) GetAllIndexes( ctx context.Context, ) (map[client.CollectionName][]client.IndexDescription, error) { - txn, err := getContextTxn(ctx, s, true) + ctx, err := ensureContextTxn(ctx, s, true) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getAllIndexDescriptions(ctx, txn) @@ -126,10 +138,12 @@ func (s *store) GetAllIndexes( // All schema types provided must not exist prior to calling this, and they may not reference existing // types previously defined. func (s *store) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { - txn, err := getContextTxn(ctx, s, false) + ctx, err := ensureContextTxn(ctx, s, false) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) cols, err := s.db.addSchema(ctx, txn, schemaString) @@ -160,10 +174,12 @@ func (s *store) PatchSchema( migration immutable.Option[model.Lens], setAsDefaultVersion bool, ) error { - txn, err := getContextTxn(ctx, s, false) + ctx, err := ensureContextTxn(ctx, s, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.patchSchema(ctx, txn, patchString, migration, setAsDefaultVersion) @@ -178,10 +194,12 @@ func (s *store) PatchCollection( ctx context.Context, patchString string, ) error { - txn, err := getContextTxn(ctx, s, false) + ctx, err := ensureContextTxn(ctx, s, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.patchCollection(ctx, txn, patchString) @@ -193,10 +211,12 @@ func (s *store) PatchCollection( } func (s *store) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { - txn, err := getContextTxn(ctx, s, false) + ctx, err := ensureContextTxn(ctx, s, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.setActiveSchemaVersion(ctx, txn, schemaVersionID) @@ -208,10 +228,12 @@ func (s *store) SetActiveSchemaVersion(ctx context.Context, schemaVersionID stri } func (s *store) SetMigration(ctx context.Context, cfg client.LensConfig) error { - txn, err := getContextTxn(ctx, s, false) + ctx, err := ensureContextTxn(ctx, s, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.setMigration(ctx, txn, cfg) @@ -228,10 +250,12 @@ func (s *store) AddView( sdl string, transform immutable.Option[model.Lens], ) ([]client.CollectionDefinition, error) { - txn, err := getContextTxn(ctx, s, false) + ctx, err := ensureContextTxn(ctx, s, false) if err != nil { return nil, err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) defs, err := s.db.addView(ctx, txn, query, sdl, transform) @@ -250,10 +274,12 @@ func (s *store) AddView( // BasicImport imports a json dataset. // filepath must be accessible to the node. func (s *store) BasicImport(ctx context.Context, filepath string) error { - txn, err := getContextTxn(ctx, s, false) + ctx, err := ensureContextTxn(ctx, s, false) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.basicImport(ctx, txn, filepath) @@ -266,10 +292,12 @@ func (s *store) BasicImport(ctx context.Context, filepath string) error { // BasicExport exports the current data or subset of data to file in json format. func (s *store) BasicExport(ctx context.Context, config *client.BackupConfig) error { - txn, err := getContextTxn(ctx, s, true) + ctx, err := ensureContextTxn(ctx, s, true) if err != nil { return err } + + txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.basicExport(ctx, txn, config) diff --git a/db/subscriptions.go b/db/subscriptions.go index 5958d567be..b0faa0414b 100644 --- a/db/subscriptions.go +++ b/db/subscriptions.go @@ -62,8 +62,8 @@ func (db *db) handleSubscription( continue } + ctx := setContextTxn(ctx, txn) db.handleEvent(ctx, identity, txn, pub, evt, r) - txn.Discard(ctx) } } diff --git a/http/middleware.go b/http/middleware.go index 2e4d5f1371..76945e0a77 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -23,7 +23,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/db/session" + "github.com/sourcenetwork/defradb/db" ) const TX_HEADER_NAME = "x-defradb-tx" @@ -91,7 +91,7 @@ func TransactionMiddleware(next http.Handler) http.Handler { } // store transaction in session - sess := session.New(req.Context()) + sess := db.NewSession(req.Context()) if val, ok := tx.(datastore.Txn); ok { sess = sess.WithTxn(val) } diff --git a/net/peer_collection.go b/net/peer_collection.go index ddd3f9ebf0..69caf8fd46 100644 --- a/net/peer_collection.go +++ b/net/peer_collection.go @@ -19,7 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/db/session" + "github.com/sourcenetwork/defradb/db" ) const marker = byte(0xff) @@ -34,7 +34,7 @@ func (p *Peer) AddP2PCollections(ctx context.Context, collectionIDs []string) er // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - sess := session.New(ctx).WithTxn(txn) + sess := db.NewSession(ctx).WithTxn(txn) storeCol, err := p.db.GetCollections( sess, client.CollectionFetchOptions{ @@ -114,7 +114,7 @@ func (p *Peer) RemoveP2PCollections(ctx context.Context, collectionIDs []string) // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - sess := session.New(ctx).WithTxn(txn) + sess := db.NewSession(ctx).WithTxn(txn) storeCol, err := p.db.GetCollections( sess, client.CollectionFetchOptions{ diff --git a/net/peer_replicator.go b/net/peer_replicator.go index 9e2c10703f..36c42086a4 100644 --- a/net/peer_replicator.go +++ b/net/peer_replicator.go @@ -21,7 +21,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/db/session" + "github.com/sourcenetwork/defradb/db" ) func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { @@ -40,7 +40,9 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { if err := rep.Info.ID.Validate(); err != nil { return err } - sess := session.New(ctx).WithTxn(txn) + + // use a session for all operations + sess := db.NewSession(ctx).WithTxn(txn) var collections []client.Collection switch { @@ -137,7 +139,9 @@ func (p *Peer) DeleteReplicator(ctx context.Context, rep client.Replicator) erro if err := rep.Info.ID.Validate(); err != nil { return err } - sess := session.New(ctx).WithTxn(txn) + + // use a session for all operations + sess := db.NewSession(ctx).WithTxn(txn) var collections []client.Collection switch { diff --git a/net/server.go b/net/server.go index 57c8a23ebe..0196d3d2e6 100644 --- a/net/server.go +++ b/net/server.go @@ -34,7 +34,7 @@ import ( "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/datastore/badger/v4" - "github.com/sourcenetwork/defradb/db/session" + "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/errors" pb "github.com/sourcenetwork/defradb/net/pb" ) @@ -253,7 +253,9 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL } defer txn.Discard(ctx) - sess := session.New(ctx).WithTxn(txn) + // use a session for all operations + sess := db.NewSession(ctx).WithTxn(txn) + // Currently a schema is the best way we have to link a push log request to a collection, // this will change with https://github.com/sourcenetwork/defradb/issues/1085 col, err := s.getActiveCollection(sess, s.db, string(req.Body.SchemaRoot)) @@ -353,7 +355,7 @@ func (s *server) syncIndexedDocs( docID client.DocID, txn datastore.Txn, ) error { - sess := session.New(ctx).WithTxn(txn) + sess := db.NewSession(ctx).WithTxn(txn) //TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2365 // Resolve while handling acp <> secondary indexes. diff --git a/planner/planner.go b/planner/planner.go index faeac5a554..eca0168671 100644 --- a/planner/planner.go +++ b/planner/planner.go @@ -21,7 +21,6 @@ import ( "github.com/sourcenetwork/defradb/connor" "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore" - "github.com/sourcenetwork/defradb/db/session" "github.com/sourcenetwork/defradb/planner/filter" "github.com/sourcenetwork/defradb/planner/mapper" ) @@ -101,13 +100,12 @@ func New( db client.Store, txn datastore.Txn, ) *Planner { - sess := session.New(ctx).WithTxn(txn) return &Planner{ txn: txn, identity: identity, acp: acp, db: db, - ctx: sess, + ctx: ctx, } } diff --git a/tests/integration/events/simple/with_create_txn_test.go b/tests/integration/events/simple/with_create_txn_test.go index a9feba5c95..81f6c8bf30 100644 --- a/tests/integration/events/simple/with_create_txn_test.go +++ b/tests/integration/events/simple/with_create_txn_test.go @@ -19,7 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/db/session" + "github.com/sourcenetwork/defradb/db" testUtils "github.com/sourcenetwork/defradb/tests/integration/events" ) @@ -44,7 +44,7 @@ func TestEventsSimpleWithCreateWithTxnDiscarded(t *testing.T) { txn, err := d.NewTxn(ctx, false) assert.Nil(t, err) - sess := session.New(ctx).WithTxn(txn) + sess := db.NewSession(ctx).WithTxn(txn) r := d.ExecRequest( sess, acpIdentity.NoIdentity, diff --git a/tests/integration/lens.go b/tests/integration/lens.go index 8760f957a6..9b0836d556 100644 --- a/tests/integration/lens.go +++ b/tests/integration/lens.go @@ -14,7 +14,7 @@ import ( "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/db/session" + "github.com/sourcenetwork/defradb/db" ) // ConfigureMigration is a test action which will configure a Lens migration using the @@ -44,7 +44,7 @@ func configureMigration( ) { for _, node := range getNodes(action.NodeID, s.nodes) { txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) - sess := session.New(s.ctx).WithTxn(txn) + sess := db.NewSession(s.ctx).WithTxn(txn) err := node.SetMigration(sess, action.LensConfig) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) diff --git a/tests/integration/utils2.go b/tests/integration/utils2.go index 07a49631f9..8fb78544cc 100644 --- a/tests/integration/utils2.go +++ b/tests/integration/utils2.go @@ -32,7 +32,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/datastore" badgerds "github.com/sourcenetwork/defradb/datastore/badger/v4" - "github.com/sourcenetwork/defradb/db/session" + "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/net" "github.com/sourcenetwork/defradb/request/graphql" @@ -1082,7 +1082,7 @@ func getCollections( ) { for _, node := range getNodes(action.NodeID, s.nodes) { txn := getTransaction(s, node, action.TransactionID, "") - sess := session.New(s.ctx).WithTxn(txn) + sess := db.NewSession(s.ctx).WithTxn(txn) results, err := node.GetCollections(sess, action.FilterOptions) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) @@ -1254,7 +1254,7 @@ func createDocViaGQL( txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) identity := acpIdentity.NewIdentity(action.Identity) - sess := session.New(s.ctx).WithTxn(txn) + sess := db.NewSession(s.ctx).WithTxn(txn) result := node.ExecRequest( sess, identity, @@ -1430,7 +1430,7 @@ func updateDocViaGQL( ) txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) - sess := session.New(s.ctx).WithTxn(txn) + sess := db.NewSession(s.ctx).WithTxn(txn) result := node.ExecRequest( sess, acpIdentity.NewIdentity(action.Identity), @@ -1651,7 +1651,7 @@ func executeRequest( var expectedErrorRaised bool for nodeID, node := range getNodes(action.NodeID, s.nodes) { txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) - sess := session.New(s.ctx).WithTxn(txn) + sess := db.NewSession(s.ctx).WithTxn(txn) result := node.ExecRequest( sess, acpIdentity.NewIdentity(action.Identity), From 2c6b92fb6c27298c6e2c534aa065c2e95d834a5a Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Wed, 10 Apr 2024 11:04:27 -0700 Subject: [PATCH 05/14] cli and http context transaction fixes --- cli/backup_export.go | 4 +-- cli/backup_import.go | 4 +-- cli/client.go | 5 +--- cli/collection.go | 7 ++--- cli/collection_describe.go | 4 +-- cli/collection_patch.go | 4 +-- cli/index_create.go | 4 +-- cli/index_drop.go | 4 +-- cli/index_list.go | 6 ++-- cli/request.go | 4 +-- cli/schema_add.go | 4 +-- cli/schema_describe.go | 4 +-- cli/schema_migration_down.go | 4 +-- cli/schema_migration_reload.go | 4 +-- cli/schema_migration_set.go | 4 +-- cli/schema_migration_set_registry.go | 4 +-- cli/schema_migration_up.go | 4 +-- cli/schema_patch.go | 4 +-- cli/schema_set_active.go | 4 +-- cli/utils.go | 37 ++----------------------- cli/view_add.go | 4 +-- db/context.go | 9 +++--- db/session_test.go | 2 +- http/client.go | 5 ---- http/client_collection.go | 8 ------ http/client_lens.go | 6 ---- http/http_client.go | 17 ++++-------- tests/clients/cli/wrapper.go | 9 +----- tests/clients/cli/wrapper_cli.go | 20 +++++-------- tests/clients/cli/wrapper_collection.go | 10 +------ tests/clients/cli/wrapper_lens.go | 5 ---- tests/clients/http/wrapper.go | 4 --- 32 files changed, 63 insertions(+), 155 deletions(-) diff --git a/cli/backup_export.go b/cli/backup_export.go index b905bdf9c7..496d336c44 100644 --- a/cli/backup_export.go +++ b/cli/backup_export.go @@ -38,7 +38,7 @@ Example: export data for the 'Users' collection: defradb client export --collection Users user_data.json`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) if !isValidExportFormat(format) { return ErrInvalidExportFormat @@ -56,7 +56,7 @@ Example: export data for the 'Users' collection: Collections: collections, } - return store.BasicExport(cmd.Context(), &data) + return db.BasicExport(cmd.Context(), &data) }, } cmd.Flags().BoolVarP(&pretty, "pretty", "p", false, "Set the output JSON to be pretty printed") diff --git a/cli/backup_import.go b/cli/backup_import.go index 56f1907643..092ddb61aa 100644 --- a/cli/backup_import.go +++ b/cli/backup_import.go @@ -24,8 +24,8 @@ Example: import data to the database: defradb client import user_data.json`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) - return store.BasicImport(cmd.Context(), args[0]) + db := mustGetContextDB(cmd) + return db.BasicImport(cmd.Context(), args[0]) }, } return cmd diff --git a/cli/client.go b/cli/client.go index 532712e8f8..e1df983b69 100644 --- a/cli/client.go +++ b/cli/client.go @@ -28,10 +28,7 @@ Execute queries, add schema types, obtain node info, etc.`, if err := setContextConfig(cmd); err != nil { return err } - if err := setContextTransaction(cmd, txID); err != nil { - return err - } - return setContextStore(cmd) + return setContextTransaction(cmd, txID) }, } cmd.PersistentFlags().Uint64Var(&txID, "tx", 0, "Transaction ID") diff --git a/cli/collection.go b/cli/collection.go index 3697977d32..b4178a48df 100644 --- a/cli/collection.go +++ b/cli/collection.go @@ -40,10 +40,7 @@ func MakeCollectionCommand() *cobra.Command { if err := setContextTransaction(cmd, txID); err != nil { return err } - if err := setContextStore(cmd); err != nil { - return err - } - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) options := client.CollectionFetchOptions{} if versionID != "" { @@ -59,7 +56,7 @@ func MakeCollectionCommand() *cobra.Command { options.IncludeInactive = immutable.Some(getInactive) } - cols, err := store.GetCollections(cmd.Context(), options) + cols, err := db.GetCollections(cmd.Context(), options) if err != nil { return err } diff --git a/cli/collection_describe.go b/cli/collection_describe.go index 5d1a85ea5e..5311cf0b80 100644 --- a/cli/collection_describe.go +++ b/cli/collection_describe.go @@ -40,7 +40,7 @@ Example: view collection by version id. This will also return inactive collectio defradb client collection describe --version bae123 `, RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) options := client.CollectionFetchOptions{} if versionID != "" { @@ -56,7 +56,7 @@ Example: view collection by version id. This will also return inactive collectio options.IncludeInactive = immutable.Some(getInactive) } - cols, err := store.GetCollections( + cols, err := db.GetCollections( cmd.Context(), options, ) diff --git a/cli/collection_patch.go b/cli/collection_patch.go index 49d5a91305..7b7f4252b2 100644 --- a/cli/collection_patch.go +++ b/cli/collection_patch.go @@ -39,7 +39,7 @@ Example: patch from stdin: To learn more about the DefraDB GraphQL Schema Language, refer to https://docs.source.network.`, Args: cobra.RangeArgs(0, 1), RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) var patch string switch { @@ -61,7 +61,7 @@ To learn more about the DefraDB GraphQL Schema Language, refer to https://docs.s return fmt.Errorf("patch cannot be empty") } - return store.PatchCollection(cmd.Context(), patch) + return db.PatchCollection(cmd.Context(), patch) }, } cmd.Flags().StringVarP(&patchFile, "patch-file", "p", "", "File to load a patch from") diff --git a/cli/index_create.go b/cli/index_create.go index 0d724da15b..d4b27c8077 100644 --- a/cli/index_create.go +++ b/cli/index_create.go @@ -36,7 +36,7 @@ Example: create a named index for 'Users' collection on 'name' field: defradb client index create --collection Users --fields name --name UsersByName`, ValidArgs: []string{"collection", "fields", "name"}, RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) var fields []client.IndexedFieldDescription for _, name := range fieldsArg { @@ -47,7 +47,7 @@ Example: create a named index for 'Users' collection on 'name' field: Fields: fields, Unique: uniqueArg, } - col, err := store.GetCollectionByName(cmd.Context(), collectionArg) + col, err := db.GetCollectionByName(cmd.Context(), collectionArg) if err != nil { return err } diff --git a/cli/index_drop.go b/cli/index_drop.go index 5dd069b5da..60b4f52f6d 100644 --- a/cli/index_drop.go +++ b/cli/index_drop.go @@ -26,9 +26,9 @@ Example: drop the index 'UsersByName' for 'Users' collection: defradb client index create --collection Users --name UsersByName`, ValidArgs: []string{"collection", "name"}, RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) - col, err := store.GetCollectionByName(cmd.Context(), collectionArg) + col, err := db.GetCollectionByName(cmd.Context(), collectionArg) if err != nil { return err } diff --git a/cli/index_list.go b/cli/index_list.go index 481acb7d37..89b091d179 100644 --- a/cli/index_list.go +++ b/cli/index_list.go @@ -28,11 +28,11 @@ Example: show all index for 'Users' collection: defradb client index list --collection Users`, ValidArgs: []string{"collection"}, RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) switch { case collectionArg != "": - col, err := store.GetCollectionByName(cmd.Context(), collectionArg) + col, err := db.GetCollectionByName(cmd.Context(), collectionArg) if err != nil { return err } @@ -42,7 +42,7 @@ Example: show all index for 'Users' collection: } return writeJSON(cmd, indexes) default: - indexes, err := store.GetAllIndexes(cmd.Context()) + indexes, err := db.GetAllIndexes(cmd.Context()) if err != nil { return err } diff --git a/cli/request.go b/cli/request.go index c583d51a28..03de7bae4a 100644 --- a/cli/request.go +++ b/cli/request.go @@ -78,8 +78,8 @@ To learn more about the DefraDB GraphQL Query Language, refer to https://docs.so return errors.New("request cannot be empty") } - store := mustGetContextStore(cmd) - result := store.ExecRequest(cmd.Context(), identity, request) + db := mustGetContextDB(cmd) + result := db.ExecRequest(cmd.Context(), identity, request) var errors []string for _, err := range result.GQL.Errors { diff --git a/cli/schema_add.go b/cli/schema_add.go index e81896322d..5277ddd6bd 100644 --- a/cli/schema_add.go +++ b/cli/schema_add.go @@ -41,7 +41,7 @@ Example: add from stdin: Learn more about the DefraDB GraphQL Schema Language on https://docs.source.network.`, RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) var schema string switch { @@ -63,7 +63,7 @@ Learn more about the DefraDB GraphQL Schema Language on https://docs.source.netw return fmt.Errorf("schema cannot be empty") } - cols, err := store.AddSchema(cmd.Context(), schema) + cols, err := db.AddSchema(cmd.Context(), schema) if err != nil { return err } diff --git a/cli/schema_describe.go b/cli/schema_describe.go index c4133baa8c..ddc43db1d7 100644 --- a/cli/schema_describe.go +++ b/cli/schema_describe.go @@ -40,7 +40,7 @@ Example: view a single schema by version id defradb client schema describe --version bae123 `, RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) options := client.SchemaFetchOptions{} if versionID != "" { @@ -53,7 +53,7 @@ Example: view a single schema by version id options.Name = immutable.Some(name) } - schemas, err := store.GetSchemas(cmd.Context(), options) + schemas, err := db.GetSchemas(cmd.Context(), options) if err != nil { return err } diff --git a/cli/schema_migration_down.go b/cli/schema_migration_down.go index a49f359694..b83f85ca74 100644 --- a/cli/schema_migration_down.go +++ b/cli/schema_migration_down.go @@ -39,7 +39,7 @@ Example: migrate from stdin `, Args: cobra.RangeArgs(0, 1), RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) var srcData []byte switch { @@ -65,7 +65,7 @@ Example: migrate from stdin if err := json.Unmarshal(srcData, &src); err != nil { return err } - out, err := store.LensRegistry().MigrateDown(cmd.Context(), enumerable.New(src), collectionID) + out, err := db.LensRegistry().MigrateDown(cmd.Context(), enumerable.New(src), collectionID) if err != nil { return err } diff --git a/cli/schema_migration_reload.go b/cli/schema_migration_reload.go index 8ffb5542f1..a4e9f89934 100644 --- a/cli/schema_migration_reload.go +++ b/cli/schema_migration_reload.go @@ -20,8 +20,8 @@ func MakeSchemaMigrationReloadCommand() *cobra.Command { Short: "Reload the schema migrations within DefraDB", Long: `Reload the schema migrations within DefraDB`, RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) - return store.LensRegistry().ReloadLenses(cmd.Context()) + db := mustGetContextDB(cmd) + return db.LensRegistry().ReloadLenses(cmd.Context()) }, } return cmd diff --git a/cli/schema_migration_set.go b/cli/schema_migration_set.go index f7b32103b9..2a609449d4 100644 --- a/cli/schema_migration_set.go +++ b/cli/schema_migration_set.go @@ -42,7 +42,7 @@ Example: add from stdin: Learn more about the DefraDB GraphQL Schema Language on https://docs.source.network.`, Args: cobra.RangeArgs(2, 3), RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) var lensCfgJson string switch { @@ -81,7 +81,7 @@ Learn more about the DefraDB GraphQL Schema Language on https://docs.source.netw Lens: lensCfg, } - return store.SetMigration(cmd.Context(), migrationCfg) + return db.SetMigration(cmd.Context(), migrationCfg) }, } cmd.Flags().StringVarP(&lensFile, "file", "f", "", "Lens configuration file") diff --git a/cli/schema_migration_set_registry.go b/cli/schema_migration_set_registry.go index cc5098afae..99e1ba0104 100644 --- a/cli/schema_migration_set_registry.go +++ b/cli/schema_migration_set_registry.go @@ -32,7 +32,7 @@ Example: set from an argument string: Learn more about the DefraDB GraphQL Schema Language on https://docs.source.network.`, Args: cobra.ExactArgs(2), RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) decoder := json.NewDecoder(strings.NewReader(args[1])) decoder.DisallowUnknownFields() @@ -47,7 +47,7 @@ Learn more about the DefraDB GraphQL Schema Language on https://docs.source.netw return err } - return store.LensRegistry().SetMigration(cmd.Context(), uint32(collectionID), lensCfg) + return db.LensRegistry().SetMigration(cmd.Context(), uint32(collectionID), lensCfg) }, } return cmd diff --git a/cli/schema_migration_up.go b/cli/schema_migration_up.go index 4473c45911..491068ad28 100644 --- a/cli/schema_migration_up.go +++ b/cli/schema_migration_up.go @@ -39,7 +39,7 @@ Example: migrate from stdin `, Args: cobra.RangeArgs(0, 1), RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) var srcData []byte switch { @@ -65,7 +65,7 @@ Example: migrate from stdin if err := json.Unmarshal(srcData, &src); err != nil { return err } - out, err := store.LensRegistry().MigrateUp(cmd.Context(), enumerable.New(src), collectionID) + out, err := db.LensRegistry().MigrateUp(cmd.Context(), enumerable.New(src), collectionID) if err != nil { return err } diff --git a/cli/schema_patch.go b/cli/schema_patch.go index cf9224d204..1a0f617c8d 100644 --- a/cli/schema_patch.go +++ b/cli/schema_patch.go @@ -44,7 +44,7 @@ Example: patch from stdin: To learn more about the DefraDB GraphQL Schema Language, refer to https://docs.source.network.`, RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) var patch string switch { @@ -90,7 +90,7 @@ To learn more about the DefraDB GraphQL Schema Language, refer to https://docs.s migration = immutable.Some(lensCfg) } - return store.PatchSchema(cmd.Context(), patch, migration, setActive) + return db.PatchSchema(cmd.Context(), patch, migration, setActive) }, } cmd.Flags().BoolVar(&setActive, "set-active", false, diff --git a/cli/schema_set_active.go b/cli/schema_set_active.go index 2b13713461..9560d88276 100644 --- a/cli/schema_set_active.go +++ b/cli/schema_set_active.go @@ -22,8 +22,8 @@ func MakeSchemaSetActiveCommand() *cobra.Command { those without it (if they share the same schema root).`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) - return store.SetActiveSchemaVersion(cmd.Context(), args[0]) + db := mustGetContextDB(cmd) + return db.SetActiveSchemaVersion(cmd.Context(), args[0]) }, } return cmd diff --git a/cli/utils.go b/cli/utils.go index f923021fcf..40259aa21c 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -21,7 +21,7 @@ import ( "github.com/spf13/viper" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/http" ) @@ -32,17 +32,8 @@ var ( cfgContextKey = contextKey("cfg") // rootDirContextKey is the context key for the root directory. rootDirContextKey = contextKey("rootDir") - // txContextKey is the context key for the datastore.Txn - // - // This will only be set if a transaction id is specified. - txContextKey = contextKey("tx") // dbContextKey is the context key for the client.DB dbContextKey = contextKey("db") - // storeContextKey is the context key for the client.Store - // - // If a transaction exists, all operations will be executed - // in the current transaction context. - storeContextKey = contextKey("store") // colContextKey is the context key for the client.Collection // // If a transaction exists, all operations will be executed @@ -57,13 +48,6 @@ func mustGetContextDB(cmd *cobra.Command) client.DB { return cmd.Context().Value(dbContextKey).(client.DB) } -// mustGetContextStore returns the store for the current command context. -// -// If a store is not set in the current context this function panics. -func mustGetContextStore(cmd *cobra.Command) client.Store { - return cmd.Context().Value(storeContextKey).(client.Store) -} - // mustGetContextP2P returns the p2p implementation for the current command context. // // If a p2p implementation is not set in the current context this function panics. @@ -115,24 +99,7 @@ func setContextTransaction(cmd *cobra.Command, txId uint64) error { if err != nil { return err } - ctx := context.WithValue(cmd.Context(), txContextKey, tx) - cmd.SetContext(ctx) - return nil -} - -// setContextStore sets the store for the current command context. -func setContextStore(cmd *cobra.Command) error { - cfg := mustGetContextConfig(cmd) - db, err := http.NewClient(cfg.GetString("api.address")) - if err != nil { - return err - } - ctx := context.WithValue(cmd.Context(), dbContextKey, db) - if tx, ok := ctx.Value(txContextKey).(datastore.Txn); ok { - ctx = context.WithValue(ctx, storeContextKey, db.WithTxn(tx)) - } else { - ctx = context.WithValue(ctx, storeContextKey, db) - } + ctx := context.WithValue(cmd.Context(), db.TxnContextKey{}, tx) cmd.SetContext(ctx) return nil } diff --git a/cli/view_add.go b/cli/view_add.go index 9c7d42b723..7038dae81c 100644 --- a/cli/view_add.go +++ b/cli/view_add.go @@ -34,7 +34,7 @@ Example: add from an argument string: Learn more about the DefraDB GraphQL Schema Language on https://docs.source.network.`, Args: cobra.RangeArgs(2, 4), RunE: func(cmd *cobra.Command, args []string) error { - store := mustGetContextStore(cmd) + db := mustGetContextDB(cmd) query := args[0] sdl := args[1] @@ -69,7 +69,7 @@ Learn more about the DefraDB GraphQL Schema Language on https://docs.source.netw transform = immutable.Some(lensCfg) } - defs, err := store.AddView(cmd.Context(), query, sdl, transform) + defs, err := db.AddView(cmd.Context(), query, sdl, transform) if err != nil { return err } diff --git a/db/context.go b/db/context.go index 5c3c6e1d54..6801b3c5b1 100644 --- a/db/context.go +++ b/db/context.go @@ -16,7 +16,8 @@ import ( "github.com/sourcenetwork/defradb/datastore" ) -type txnContextKey struct{} +// TxnContextKey is the key type for transaction context values. +type TxnContextKey struct{} // explicitTxn is a transaction that is managed outside of a db operation. type explicitTxn struct { @@ -41,7 +42,7 @@ type transactionDB interface { // If a transactions exists on the context it will be made explicit, // otherwise a new implicit transaction will be created. func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (context.Context, error) { - txn, ok := ctx.Value(txnContextKey{}).(datastore.Txn) + txn, ok := ctx.Value(TxnContextKey{}).(datastore.Txn) if ok { return setContextTxn(ctx, &explicitTxn{txn}), nil } @@ -55,10 +56,10 @@ func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (con // mustGetContextTxn returns the transaction from the context if it exists, // otherwise it panics. func mustGetContextTxn(ctx context.Context) datastore.Txn { - return ctx.Value(txnContextKey{}).(datastore.Txn) + return ctx.Value(TxnContextKey{}).(datastore.Txn) } // setContextTxn returns a new context with the txn value set. func setContextTxn(ctx context.Context, txn datastore.Txn) context.Context { - return context.WithValue(ctx, txnContextKey{}, txn) + return context.WithValue(ctx, TxnContextKey{}, txn) } diff --git a/db/session_test.go b/db/session_test.go index 3e71091ca7..de4d82c89a 100644 --- a/db/session_test.go +++ b/db/session_test.go @@ -19,6 +19,6 @@ import ( func TestSessionWithTxn(t *testing.T) { sess := NewSession(context.Background()).WithTxn(&explicitTxn{}) - _, ok := sess.Value(txnContextKey{}).(*explicitTxn) + _, ok := sess.Value(TxnContextKey{}).(*explicitTxn) assert.True(t, ok) } diff --git a/http/client.go b/http/client.go index 69c5f2a503..8837ce2e2d 100644 --- a/http/client.go +++ b/http/client.go @@ -86,11 +86,6 @@ func (c *Client) NewConcurrentTxn(ctx context.Context, readOnly bool) (datastore return &Transaction{txRes.ID, c.http}, nil } -func (c *Client) WithTxn(tx datastore.Txn) client.Store { - client := c.http.withTxn(tx.ID()) - return &Client{client} -} - func (c *Client) BasicImport(ctx context.Context, filepath string) error { methodURL := c.http.baseURL.JoinPath("backup", "import") diff --git a/http/client_collection.go b/http/client_collection.go index c53bc7e7ff..39ede6aafc 100644 --- a/http/client_collection.go +++ b/http/client_collection.go @@ -25,7 +25,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" - "github.com/sourcenetwork/defradb/datastore" ) var _ client.Collection = (*Collection)(nil) @@ -445,13 +444,6 @@ func (c *Collection) Get( return doc, nil } -func (c *Collection) WithTxn(tx datastore.Txn) client.Collection { - return &Collection{ - http: c.http.withTxn(tx.ID()), - def: c.def, - } -} - func (c *Collection) GetAllDocIDs( ctx context.Context, identity immutable.Option[string], diff --git a/http/client_lens.go b/http/client_lens.go index 9021aa31d6..34945a41d6 100644 --- a/http/client_lens.go +++ b/http/client_lens.go @@ -21,7 +21,6 @@ import ( "github.com/sourcenetwork/immutable/enumerable" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" ) var _ client.LensRegistry = (*LensRegistry)(nil) @@ -31,11 +30,6 @@ type LensRegistry struct { http *httpClient } -func (c *LensRegistry) WithTxn(tx datastore.Txn) client.LensRegistry { - http := c.http.withTxn(tx.ID()) - return &LensRegistry{http} -} - type setMigrationRequest struct { CollectionID uint32 Config model.Lens diff --git a/http/http_client.go b/http/http_client.go index 13abb3c6d0..f5d450ff94 100644 --- a/http/http_client.go +++ b/http/http_client.go @@ -17,12 +17,14 @@ import ( "net/http" "net/url" "strings" + + "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db" ) type httpClient struct { client *http.Client baseURL *url.URL - txValue string } func newHttpClient(rawURL string) (*httpClient, error) { @@ -40,20 +42,13 @@ func newHttpClient(rawURL string) (*httpClient, error) { return &client, nil } -func (c *httpClient) withTxn(value uint64) *httpClient { - return &httpClient{ - client: c.client, - baseURL: c.baseURL, - txValue: fmt.Sprintf("%d", value), - } -} - func (c *httpClient) setDefaultHeaders(req *http.Request) { req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") - if c.txValue != "" { - req.Header.Set(TX_HEADER_NAME, c.txValue) + txn, ok := req.Context().Value(db.TxnContextKey{}).(datastore.Txn) + if ok { + req.Header.Set(TX_HEADER_NAME, fmt.Sprintf("%d", txn.ID())) } } diff --git a/tests/clients/cli/wrapper.go b/tests/clients/cli/wrapper.go index d10188d4b2..2ddaf86137 100644 --- a/tests/clients/cli/wrapper.go +++ b/tests/clients/cli/wrapper.go @@ -406,7 +406,7 @@ func (w *Wrapper) ExecRequest( result := &client.RequestResult{} - stdOut, stdErr, err := w.cmd.executeStream(args) + stdOut, stdErr, err := w.cmd.executeStream(ctx, args) if err != nil { result.GQL.Errors = []error{err} return result @@ -515,13 +515,6 @@ func (w *Wrapper) NewConcurrentTxn(ctx context.Context, readOnly bool) (datastor return &Transaction{tx, w.cmd}, nil } -func (w *Wrapper) WithTxn(tx datastore.Txn) client.Store { - return &Wrapper{ - node: w.node, - cmd: w.cmd.withTxn(tx), - } -} - func (w *Wrapper) Root() datastore.RootStore { return w.node.Root() } diff --git a/tests/clients/cli/wrapper_cli.go b/tests/clients/cli/wrapper_cli.go index 2a985dcb18..d4d6ce45ae 100644 --- a/tests/clients/cli/wrapper_cli.go +++ b/tests/clients/cli/wrapper_cli.go @@ -18,11 +18,11 @@ import ( "github.com/sourcenetwork/defradb/cli" "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db" ) type cliWrapper struct { address string - txValue string } func newCliWrapper(address string) *cliWrapper { @@ -31,15 +31,8 @@ func newCliWrapper(address string) *cliWrapper { } } -func (w *cliWrapper) withTxn(tx datastore.Txn) *cliWrapper { - return &cliWrapper{ - address: w.address, - txValue: fmt.Sprintf("%d", tx.ID()), - } -} - -func (w *cliWrapper) execute(_ context.Context, args []string) ([]byte, error) { - stdOut, stdErr, err := w.executeStream(args) +func (w *cliWrapper) execute(ctx context.Context, args []string) ([]byte, error) { + stdOut, stdErr, err := w.executeStream(ctx, args) if err != nil { return nil, err } @@ -57,12 +50,13 @@ func (w *cliWrapper) execute(_ context.Context, args []string) ([]byte, error) { return stdOutData, nil } -func (w *cliWrapper) executeStream(args []string) (io.ReadCloser, io.ReadCloser, error) { +func (w *cliWrapper) executeStream(ctx context.Context, args []string) (io.ReadCloser, io.ReadCloser, error) { stdOutRead, stdOutWrite := io.Pipe() stdErrRead, stdErrWrite := io.Pipe() - if w.txValue != "" { - args = append(args, "--tx", w.txValue) + tx, ok := ctx.Value(db.TxnContextKey{}).(datastore.Txn) + if ok { + args = append(args, "--tx", fmt.Sprintf("%d", tx.ID())) } args = append(args, "--url", w.address) diff --git a/tests/clients/cli/wrapper_collection.go b/tests/clients/cli/wrapper_collection.go index 9bb8fb9938..861606a2d1 100644 --- a/tests/clients/cli/wrapper_collection.go +++ b/tests/clients/cli/wrapper_collection.go @@ -20,7 +20,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/http" ) @@ -448,13 +447,6 @@ func (c *Collection) Get( return doc, nil } -func (c *Collection) WithTxn(tx datastore.Txn) client.Collection { - return &Collection{ - cmd: c.cmd.withTxn(tx), - def: c.def, - } -} - func (c *Collection) GetAllDocIDs( ctx context.Context, identity immutable.Option[string], @@ -466,7 +458,7 @@ func (c *Collection) GetAllDocIDs( args := []string{"client", "collection", "docIDs"} args = append(args, "--name", c.Description().Name.Value()) - stdOut, _, err := c.cmd.executeStream(args) + stdOut, _, err := c.cmd.executeStream(ctx, args) if err != nil { return nil, err } diff --git a/tests/clients/cli/wrapper_lens.go b/tests/clients/cli/wrapper_lens.go index da6011b9eb..a9f3e20bd1 100644 --- a/tests/clients/cli/wrapper_lens.go +++ b/tests/clients/cli/wrapper_lens.go @@ -20,7 +20,6 @@ import ( "github.com/sourcenetwork/immutable/enumerable" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" ) var _ client.LensRegistry = (*LensRegistry)(nil) @@ -29,10 +28,6 @@ type LensRegistry struct { cmd *cliWrapper } -func (w *LensRegistry) WithTxn(tx datastore.Txn) client.LensRegistry { - return &LensRegistry{w.cmd.withTxn(tx)} -} - func (w *LensRegistry) SetMigration(ctx context.Context, collectionID uint32, config model.Lens) error { args := []string{"client", "schema", "migration", "set-registry"} diff --git a/tests/clients/http/wrapper.go b/tests/clients/http/wrapper.go index 415212b99c..51fe7ae66b 100644 --- a/tests/clients/http/wrapper.go +++ b/tests/clients/http/wrapper.go @@ -201,10 +201,6 @@ func (w *Wrapper) NewConcurrentTxn(ctx context.Context, readOnly bool) (datastor return &TxWrapper{server, client}, nil } -func (w *Wrapper) WithTxn(tx datastore.Txn) client.Store { - return w.client.WithTxn(tx) -} - func (w *Wrapper) Root() datastore.RootStore { return w.node.Root() } From e4ecbd8d1e921ea4983a3bfd346a738267c8cff2 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Wed, 10 Apr 2024 11:20:24 -0700 Subject: [PATCH 06/14] fix cli db context --- cli/client.go | 3 +++ cli/collection.go | 3 +++ cli/utils.go | 12 ++++++++++++ 3 files changed, 18 insertions(+) diff --git a/cli/client.go b/cli/client.go index e1df983b69..03ce3cd011 100644 --- a/cli/client.go +++ b/cli/client.go @@ -28,6 +28,9 @@ Execute queries, add schema types, obtain node info, etc.`, if err := setContextConfig(cmd); err != nil { return err } + if err := setContextDB(cmd); err != nil { + return err + } return setContextTransaction(cmd, txID) }, } diff --git a/cli/collection.go b/cli/collection.go index b4178a48df..4e69fd97fd 100644 --- a/cli/collection.go +++ b/cli/collection.go @@ -37,6 +37,9 @@ func MakeCollectionCommand() *cobra.Command { if err := setContextConfig(cmd); err != nil { return err } + if err := setContextDB(cmd); err != nil { + return err + } if err := setContextTransaction(cmd, txID); err != nil { return err } diff --git a/cli/utils.go b/cli/utils.go index 40259aa21c..1f8e97c676 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -76,6 +76,18 @@ func tryGetContextCollection(cmd *cobra.Command) (client.Collection, bool) { return col, ok } +// setContextDB sets the db for the current command context. +func setContextDB(cmd *cobra.Command) error { + cfg := mustGetContextConfig(cmd) + db, err := http.NewClient(cfg.GetString("api.address")) + if err != nil { + return err + } + ctx := context.WithValue(cmd.Context(), dbContextKey, db) + cmd.SetContext(ctx) + return nil +} + // setContextConfig sets teh config for the current command context. func setContextConfig(cmd *cobra.Command) error { rootdir := mustGetContextRootDir(cmd) From caa3afa497c155d13823602cd57e18a3177b2b1a Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Wed, 10 Apr 2024 13:10:33 -0700 Subject: [PATCH 07/14] replace db session with simpler context api --- cli/utils.go | 2 +- db/collection_index.go | 4 +-- db/context.go | 25 +++++++++----- db/context_test.go | 2 +- db/index_test.go | 12 +++---- db/indexed_docs_test.go | 10 +++--- db/session.go | 33 ------------------- db/session_test.go | 24 -------------- db/subscriptions.go | 2 +- http/http_client.go | 3 +- http/middleware.go | 8 ++--- net/peer_collection.go | 8 ++--- net/peer_replicator.go | 18 +++++----- net/server.go | 19 ++++++----- tests/clients/cli/wrapper_cli.go | 3 +- .../events/simple/with_create_txn_test.go | 4 +-- tests/integration/lens.go | 4 +-- tests/integration/utils2.go | 16 ++++----- 18 files changed, 72 insertions(+), 125 deletions(-) delete mode 100644 db/session.go delete mode 100644 db/session_test.go diff --git a/cli/utils.go b/cli/utils.go index 1f8e97c676..d93cce46f2 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -111,7 +111,7 @@ func setContextTransaction(cmd *cobra.Command, txId uint64) error { if err != nil { return err } - ctx := context.WithValue(cmd.Context(), db.TxnContextKey{}, tx) + ctx := db.SetContextTxn(cmd.Context(), tx) cmd.SetContext(ctx) return nil } diff --git a/db/collection_index.go b/db/collection_index.go index 52de557356..0ac1975bda 100644 --- a/db/collection_index.go +++ b/db/collection_index.go @@ -41,7 +41,7 @@ func (db *db) createCollectionIndex( if err != nil { return client.IndexDescription{}, NewErrCanNotReadCollection(collectionName, err) } - ctx = setContextTxn(ctx, txn) + ctx = SetContextTxn(ctx, txn) return col.CreateIndex(ctx, desc) } @@ -54,7 +54,7 @@ func (db *db) dropCollectionIndex( if err != nil { return NewErrCanNotReadCollection(collectionName, err) } - ctx = setContextTxn(ctx, txn) + ctx = SetContextTxn(ctx, txn) return col.DropIndex(ctx, indexName) } diff --git a/db/context.go b/db/context.go index 6801b3c5b1..a95052164d 100644 --- a/db/context.go +++ b/db/context.go @@ -16,8 +16,8 @@ import ( "github.com/sourcenetwork/defradb/datastore" ) -// TxnContextKey is the key type for transaction context values. -type TxnContextKey struct{} +// txnContextKey is the key type for transaction context values. +type txnContextKey struct{} // explicitTxn is a transaction that is managed outside of a db operation. type explicitTxn struct { @@ -42,24 +42,31 @@ type transactionDB interface { // If a transactions exists on the context it will be made explicit, // otherwise a new implicit transaction will be created. func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (context.Context, error) { - txn, ok := ctx.Value(TxnContextKey{}).(datastore.Txn) + txn, ok := TryGetContextTxn(ctx) if ok { - return setContextTxn(ctx, &explicitTxn{txn}), nil + return SetContextTxn(ctx, &explicitTxn{txn}), nil } txn, err := db.NewTxn(ctx, readOnly) if err != nil { return nil, err } - return setContextTxn(ctx, txn), nil + return SetContextTxn(ctx, txn), nil } // mustGetContextTxn returns the transaction from the context if it exists, // otherwise it panics. func mustGetContextTxn(ctx context.Context) datastore.Txn { - return ctx.Value(TxnContextKey{}).(datastore.Txn) + return ctx.Value(txnContextKey{}).(datastore.Txn) } -// setContextTxn returns a new context with the txn value set. -func setContextTxn(ctx context.Context, txn datastore.Txn) context.Context { - return context.WithValue(ctx, TxnContextKey{}, txn) +// TryGetContextTxn returns a transaction and a bool indicating if the +// txn was retrieved from the given context. +func TryGetContextTxn(ctx context.Context) (datastore.Txn, bool) { + txn, ok := ctx.Value(txnContextKey{}).(datastore.Txn) + return txn, ok +} + +// SetContextTxn returns a new context with the txn value set. +func SetContextTxn(ctx context.Context, txn datastore.Txn) context.Context { + return context.WithValue(ctx, txnContextKey{}, txn) } diff --git a/db/context_test.go b/db/context_test.go index 9a72f9b91a..b711b4ae0a 100644 --- a/db/context_test.go +++ b/db/context_test.go @@ -28,7 +28,7 @@ func TestEnsureContextTxnExplicit(t *testing.T) { require.NoError(t, err) // set an explicit transaction - ctx = setContextTxn(ctx, txn) + ctx = SetContextTxn(ctx, txn) ctx, err = ensureContextTxn(ctx, db, true) require.NoError(t, err) diff --git a/db/index_test.go b/db/index_test.go index a6fc9cbd0d..a9cb90a132 100644 --- a/db/index_test.go +++ b/db/index_test.go @@ -784,7 +784,7 @@ func TestCollectionGetIndexes_ShouldCloseQueryIterator(t *testing.T) { mockedTxn.MockSystemstore.EXPECT().Query(mock.Anything, mock.Anything). Return(queryResults, nil) - ctx := setContextTxn(f.ctx, mockedTxn) + ctx := SetContextTxn(f.ctx, mockedTxn) _, err := f.users.GetIndexes(ctx) assert.NoError(t, err) } @@ -841,7 +841,7 @@ func TestCollectionGetIndexes_IfSystemStoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - ctx := setContextTxn(f.ctx, mockedTxn) + ctx := SetContextTxn(f.ctx, mockedTxn) _, err := f.users.GetIndexes(ctx) require.ErrorIs(t, err, testCase.ExpectedError) } @@ -904,7 +904,7 @@ func TestCollectionGetIndexes_IfStoredIndexWithUnsupportedType_ReturnError(t *te mockedTxn.MockSystemstore.EXPECT().Query(mock.Anything, mock.Anything). Return(mocks.NewQueryResultsWithValues(t, indexDescData), nil) - ctx := setContextTxn(f.ctx, mockedTxn) + ctx := SetContextTxn(f.ctx, mockedTxn) _, err = collection.GetIndexes(ctx) require.ErrorIs(t, err, NewErrUnsupportedIndexFieldType(unsupportedKind)) } @@ -1096,7 +1096,7 @@ func TestDropIndex_IfFailsToDeleteFromStorage_ReturnError(t *testing.T) { mockedTxn.MockDatastore.EXPECT().Query(mock.Anything, mock.Anything).Maybe(). Return(mocks.NewQueryResultsWithValues(t), nil) - ctx := setContextTxn(f.ctx, mockedTxn) + ctx := SetContextTxn(f.ctx, mockedTxn) err := f.users.DropIndex(ctx, testUsersColIndexName) require.ErrorIs(t, err, testErr) } @@ -1104,7 +1104,7 @@ func TestDropIndex_IfFailsToDeleteFromStorage_ReturnError(t *testing.T) { func TestDropIndex_ShouldUpdateCollectionsDescription(t *testing.T) { f := newIndexTestFixture(t) defer f.db.Close() - ctx := setContextTxn(f.ctx, f.txn) + ctx := SetContextTxn(f.ctx, f.txn) _, err := f.users.CreateIndex(ctx, getUsersIndexDescOnName()) require.NoError(t, err) indOnAge, err := f.users.CreateIndex(ctx, getUsersIndexDescOnAge()) @@ -1148,7 +1148,7 @@ func TestDropIndex_IfSystemStoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - ctx := setContextTxn(f.ctx, mockedTxn) + ctx := SetContextTxn(f.ctx, mockedTxn) err := f.users.DropIndex(ctx, testUsersColIndexName) require.ErrorIs(t, err, testErr) } diff --git a/db/indexed_docs_test.go b/db/indexed_docs_test.go index 4d353e4ea7..70604fdc1f 100644 --- a/db/indexed_docs_test.go +++ b/db/indexed_docs_test.go @@ -322,7 +322,7 @@ func TestNonUnique_IfFailsToStoredIndexedDoc_Error(t *testing.T) { dataStoreOn.Put(mock.Anything, key.ToDS(), mock.Anything).Return(errors.New("error")) dataStoreOn.Put(mock.Anything, mock.Anything, mock.Anything).Return(nil) - ctx := setContextTxn(f.ctx, mockTxn) + ctx := SetContextTxn(f.ctx, mockTxn) err := f.users.Create(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(f.t, err, NewErrFailedToStoreIndexedField("name", nil)) } @@ -361,7 +361,7 @@ func TestNonUnique_IfSystemStorageHasInvalidIndexDescription_Error(t *testing.T) systemStoreOn.Query(mock.Anything, mock.Anything). Return(mocks.NewQueryResultsWithValues(t, []byte("invalid")), nil) - ctx := setContextTxn(f.ctx, mockTxn) + ctx := SetContextTxn(f.ctx, mockTxn) err := f.users.Create(ctx, acpIdentity.NoIdentity, doc) assert.ErrorIs(t, err, datastore.NewErrInvalidStoredValue(nil)) } @@ -380,7 +380,7 @@ func TestNonUnique_IfSystemStorageFailsToReadIndexDesc_Error(t *testing.T) { systemStoreOn.Query(mock.Anything, mock.Anything). Return(nil, testErr) - ctx := setContextTxn(f.ctx, mockTxn) + ctx := SetContextTxn(f.ctx, mockTxn) err := f.users.Create(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } @@ -809,7 +809,7 @@ func TestNonUniqueUpdate_IfFailsToReadIndexDescription_ReturnError(t *testing.T) usersCol.(*collection).fetcherFactory = func() fetcher.Fetcher { return fetcherMocks.NewStubbedFetcher(t) } - ctx := setContextTxn(f.ctx, mockedTxn) + ctx := SetContextTxn(f.ctx, mockedTxn) err = usersCol.Update(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } @@ -1052,7 +1052,7 @@ func TestNonUniqueUpdate_IfDatastoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Datastore().Unset() mockedTxn.EXPECT().Datastore().Return(mockedTxn.MockDatastore).Maybe() - ctx := setContextTxn(f.ctx, mockedTxn) + ctx := SetContextTxn(f.ctx, mockedTxn) err = f.users.Update(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } diff --git a/db/session.go b/db/session.go deleted file mode 100644 index 192205b48b..0000000000 --- a/db/session.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Democratized Data Foundation -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package db - -import ( - "context" - - "github.com/sourcenetwork/defradb/datastore" -) - -// Session wraps a context to make it easier to pass request scoped -// parameters such as transactions. -type Session struct { - context.Context -} - -// NewSession returns a new session that wraps the given context. -func NewSession(ctx context.Context) *Session { - return &Session{ctx} -} - -// WithTxn returns a new session with the transaction value set. -func (s *Session) WithTxn(txn datastore.Txn) *Session { - return &Session{setContextTxn(s, txn)} -} diff --git a/db/session_test.go b/db/session_test.go deleted file mode 100644 index de4d82c89a..0000000000 --- a/db/session_test.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2024 Democratized Data Foundation -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package db - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestSessionWithTxn(t *testing.T) { - sess := NewSession(context.Background()).WithTxn(&explicitTxn{}) - _, ok := sess.Value(TxnContextKey{}).(*explicitTxn) - assert.True(t, ok) -} diff --git a/db/subscriptions.go b/db/subscriptions.go index b0faa0414b..e7f5d997cc 100644 --- a/db/subscriptions.go +++ b/db/subscriptions.go @@ -62,7 +62,7 @@ func (db *db) handleSubscription( continue } - ctx := setContextTxn(ctx, txn) + ctx := SetContextTxn(ctx, txn) db.handleEvent(ctx, identity, txn, pub, evt, r) txn.Discard(ctx) } diff --git a/http/http_client.go b/http/http_client.go index f5d450ff94..5bcda30dcd 100644 --- a/http/http_client.go +++ b/http/http_client.go @@ -18,7 +18,6 @@ import ( "net/url" "strings" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/db" ) @@ -46,7 +45,7 @@ func (c *httpClient) setDefaultHeaders(req *http.Request) { req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") - txn, ok := req.Context().Value(db.TxnContextKey{}).(datastore.Txn) + txn, ok := db.TryGetContextTxn(req.Context()) if ok { req.Header.Set(TX_HEADER_NAME, fmt.Sprintf("%d", txn.ID())) } diff --git a/http/middleware.go b/http/middleware.go index 76945e0a77..674921fd73 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -89,13 +89,11 @@ func TransactionMiddleware(next http.Handler) http.Handler { next.ServeHTTP(rw, req) return } - - // store transaction in session - sess := db.NewSession(req.Context()) + ctx := req.Context() if val, ok := tx.(datastore.Txn); ok { - sess = sess.WithTxn(val) + ctx = db.SetContextTxn(ctx, val) } - next.ServeHTTP(rw, req.WithContext(sess)) + next.ServeHTTP(rw, req.WithContext(ctx)) }) } diff --git a/net/peer_collection.go b/net/peer_collection.go index 69caf8fd46..d8d27b361d 100644 --- a/net/peer_collection.go +++ b/net/peer_collection.go @@ -34,9 +34,9 @@ func (p *Peer) AddP2PCollections(ctx context.Context, collectionIDs []string) er // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - sess := db.NewSession(ctx).WithTxn(txn) + ctx = db.SetContextTxn(ctx, txn) storeCol, err := p.db.GetCollections( - sess, + ctx, client.CollectionFetchOptions{ SchemaRoot: immutable.Some(col), }, @@ -114,9 +114,9 @@ func (p *Peer) RemoveP2PCollections(ctx context.Context, collectionIDs []string) // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - sess := db.NewSession(ctx).WithTxn(txn) + ctx = db.SetContextTxn(ctx, txn) storeCol, err := p.db.GetCollections( - sess, + ctx, client.CollectionFetchOptions{ SchemaRoot: immutable.Some(col), }, diff --git a/net/peer_replicator.go b/net/peer_replicator.go index 36c42086a4..1dd3c47cf4 100644 --- a/net/peer_replicator.go +++ b/net/peer_replicator.go @@ -41,15 +41,15 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { return err } - // use a session for all operations - sess := db.NewSession(ctx).WithTxn(txn) + // set transaction for all operations + ctx = db.SetContextTxn(ctx, txn) var collections []client.Collection switch { case len(rep.Schemas) > 0: // if specific collections are chosen get them by name for _, name := range rep.Schemas { - col, err := p.db.GetCollectionByName(sess, name) + col, err := p.db.GetCollectionByName(ctx, name) if err != nil { return NewErrReplicatorCollections(err) } @@ -64,7 +64,7 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { default: // default to all collections (unless a collection contains a policy). // TODO-ACP: default to all collections after resolving https://github.com/sourcenetwork/defradb/issues/2366 - allCollections, err := p.db.GetCollections(sess, client.CollectionFetchOptions{}) + allCollections, err := p.db.GetCollections(ctx, client.CollectionFetchOptions{}) if err != nil { return NewErrReplicatorCollections(err) } @@ -113,7 +113,7 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { // push all collection documents to the replicator peer for _, col := range added { // TODO-ACP: Support ACP <> P2P - https://github.com/sourcenetwork/defradb/issues/2366 - keysCh, err := col.GetAllDocIDs(sess, acpIdentity.NoIdentity) + keysCh, err := col.GetAllDocIDs(ctx, acpIdentity.NoIdentity) if err != nil { return NewErrReplicatorDocID(err, col.Name().Value(), rep.Info.ID) } @@ -140,15 +140,15 @@ func (p *Peer) DeleteReplicator(ctx context.Context, rep client.Replicator) erro return err } - // use a session for all operations - sess := db.NewSession(ctx).WithTxn(txn) + // set transaction for all operations + ctx = db.SetContextTxn(ctx, txn) var collections []client.Collection switch { case len(rep.Schemas) > 0: // if specific collections are chosen get them by name for _, name := range rep.Schemas { - col, err := p.db.GetCollectionByName(sess, name) + col, err := p.db.GetCollectionByName(ctx, name) if err != nil { return NewErrReplicatorCollections(err) } @@ -163,7 +163,7 @@ func (p *Peer) DeleteReplicator(ctx context.Context, rep client.Replicator) erro default: // default to all collections - collections, err = p.db.GetCollections(sess, client.CollectionFetchOptions{}) + collections, err = p.db.GetCollections(ctx, client.CollectionFetchOptions{}) if err != nil { return NewErrReplicatorCollections(err) } diff --git a/net/server.go b/net/server.go index 0196d3d2e6..535bb16315 100644 --- a/net/server.go +++ b/net/server.go @@ -253,12 +253,12 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL } defer txn.Discard(ctx) - // use a session for all operations - sess := db.NewSession(ctx).WithTxn(txn) + // use a transaction for all operations + ctx = db.SetContextTxn(ctx, txn) // Currently a schema is the best way we have to link a push log request to a collection, // this will change with https://github.com/sourcenetwork/defradb/issues/1085 - col, err := s.getActiveCollection(sess, s.db, string(req.Body.SchemaRoot)) + col, err := s.getActiveCollection(ctx, s.db, string(req.Body.SchemaRoot)) if err != nil { return nil, err } @@ -355,11 +355,12 @@ func (s *server) syncIndexedDocs( docID client.DocID, txn datastore.Txn, ) error { - sess := db.NewSession(ctx).WithTxn(txn) + // remove transaction from old context + oldCtx := db.SetContextTxn(ctx, nil) //TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2365 // Resolve while handling acp <> secondary indexes. - oldDoc, err := col.Get(ctx, acpIdentity.NoIdentity, docID, false) + oldDoc, err := col.Get(oldCtx, acpIdentity.NoIdentity, docID, false) isNewDoc := errors.Is(err, client.ErrDocumentNotFoundOrNotAuthorized) if !isNewDoc && err != nil { return err @@ -367,18 +368,18 @@ func (s *server) syncIndexedDocs( //TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2365 // Resolve while handling acp <> secondary indexes. - doc, err := col.Get(sess, acpIdentity.NoIdentity, docID, false) + doc, err := col.Get(ctx, acpIdentity.NoIdentity, docID, false) isDeletedDoc := errors.Is(err, client.ErrDocumentNotFoundOrNotAuthorized) if !isDeletedDoc && err != nil { return err } if isDeletedDoc { - return col.DeleteDocIndex(ctx, oldDoc) + return col.DeleteDocIndex(oldCtx, oldDoc) } else if isNewDoc { - return col.CreateDocIndex(sess, doc) + return col.CreateDocIndex(ctx, doc) } else { - return col.UpdateDocIndex(sess, oldDoc, doc) + return col.UpdateDocIndex(ctx, oldDoc, doc) } } diff --git a/tests/clients/cli/wrapper_cli.go b/tests/clients/cli/wrapper_cli.go index d4d6ce45ae..9076605857 100644 --- a/tests/clients/cli/wrapper_cli.go +++ b/tests/clients/cli/wrapper_cli.go @@ -17,7 +17,6 @@ import ( "strings" "github.com/sourcenetwork/defradb/cli" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/db" ) @@ -54,7 +53,7 @@ func (w *cliWrapper) executeStream(ctx context.Context, args []string) (io.ReadC stdOutRead, stdOutWrite := io.Pipe() stdErrRead, stdErrWrite := io.Pipe() - tx, ok := ctx.Value(db.TxnContextKey{}).(datastore.Txn) + tx, ok := db.TryGetContextTxn(ctx) if ok { args = append(args, "--tx", fmt.Sprintf("%d", tx.ID())) } diff --git a/tests/integration/events/simple/with_create_txn_test.go b/tests/integration/events/simple/with_create_txn_test.go index 81f6c8bf30..f90fc96a88 100644 --- a/tests/integration/events/simple/with_create_txn_test.go +++ b/tests/integration/events/simple/with_create_txn_test.go @@ -44,9 +44,9 @@ func TestEventsSimpleWithCreateWithTxnDiscarded(t *testing.T) { txn, err := d.NewTxn(ctx, false) assert.Nil(t, err) - sess := db.NewSession(ctx).WithTxn(txn) + ctx = db.SetContextTxn(ctx, txn) r := d.ExecRequest( - sess, + ctx, acpIdentity.NoIdentity, `mutation { create_Users(input: {name: "Shahzad"}) { diff --git a/tests/integration/lens.go b/tests/integration/lens.go index 9b0836d556..541b708a33 100644 --- a/tests/integration/lens.go +++ b/tests/integration/lens.go @@ -44,9 +44,9 @@ func configureMigration( ) { for _, node := range getNodes(action.NodeID, s.nodes) { txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) - sess := db.NewSession(s.ctx).WithTxn(txn) + ctx := db.SetContextTxn(s.ctx, txn) - err := node.SetMigration(sess, action.LensConfig) + err := node.SetMigration(ctx, action.LensConfig) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) assertExpectedErrorRaised(s.t, s.testCase.Description, action.ExpectedError, expectedErrorRaised) diff --git a/tests/integration/utils2.go b/tests/integration/utils2.go index 8fb78544cc..deb38acde3 100644 --- a/tests/integration/utils2.go +++ b/tests/integration/utils2.go @@ -1082,8 +1082,8 @@ func getCollections( ) { for _, node := range getNodes(action.NodeID, s.nodes) { txn := getTransaction(s, node, action.TransactionID, "") - sess := db.NewSession(s.ctx).WithTxn(txn) - results, err := node.GetCollections(sess, action.FilterOptions) + ctx := db.SetContextTxn(s.ctx, txn) + results, err := node.GetCollections(ctx, action.FilterOptions) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) assertExpectedErrorRaised(s.t, s.testCase.Description, action.ExpectedError, expectedErrorRaised) @@ -1254,9 +1254,9 @@ func createDocViaGQL( txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) identity := acpIdentity.NewIdentity(action.Identity) - sess := db.NewSession(s.ctx).WithTxn(txn) + ctx := db.SetContextTxn(s.ctx, txn) result := node.ExecRequest( - sess, + ctx, identity, request, ) @@ -1430,9 +1430,9 @@ func updateDocViaGQL( ) txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) - sess := db.NewSession(s.ctx).WithTxn(txn) + ctx := db.SetContextTxn(s.ctx, txn) result := node.ExecRequest( - sess, + ctx, acpIdentity.NewIdentity(action.Identity), request, ) @@ -1651,9 +1651,9 @@ func executeRequest( var expectedErrorRaised bool for nodeID, node := range getNodes(action.NodeID, s.nodes) { txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) - sess := db.NewSession(s.ctx).WithTxn(txn) + ctx := db.SetContextTxn(s.ctx, txn) result := node.ExecRequest( - sess, + ctx, acpIdentity.NewIdentity(action.Identity), action.Request, ) From 3673d0500440260ec7c0363653a5fd063c633839 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Thu, 11 Apr 2024 09:44:34 -0700 Subject: [PATCH 08/14] return transaction from ensureContextTxn --- db/collection.go | 27 ++++++-------------- db/collection_delete.go | 12 +++------ db/collection_get.go | 4 +-- db/collection_index.go | 24 +++++------------- db/collection_update.go | 11 +++----- db/context.go | 14 +++-------- db/context_test.go | 14 ++++++++--- db/store.go | 56 +++++++++++------------------------------ 8 files changed, 48 insertions(+), 114 deletions(-) diff --git a/db/collection.go b/db/collection.go index e733fb469b..1afa1c775a 100644 --- a/db/collection.go +++ b/db/collection.go @@ -1230,11 +1230,10 @@ func (c *collection) GetAllDocIDs( ctx context.Context, identity immutable.Option[string], ) (<-chan client.DocIDResult, error) { - ctx, err := ensureContextTxn(ctx, c.db, true) + ctx, txn, err := ensureContextTxn(ctx, c.db, true) if err != nil { return nil, err } - txn := mustGetContextTxn(ctx) return c.getAllDocIDsChan(ctx, identity, txn) } @@ -1348,12 +1347,10 @@ func (c *collection) Create( identity immutable.Option[string], doc *client.Document, ) error { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.create(ctx, identity, txn, doc) @@ -1371,12 +1368,10 @@ func (c *collection) CreateMany( identity immutable.Option[string], docs []*client.Document, ) error { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) for _, doc := range docs { @@ -1458,12 +1453,10 @@ func (c *collection) Update( identity immutable.Option[string], doc *client.Document, ) error { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(doc.ID()) @@ -1525,12 +1518,10 @@ func (c *collection) Save( identity immutable.Option[string], doc *client.Document, ) error { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) // Check if document already exists with primary DS key. @@ -1809,12 +1800,10 @@ func (c *collection) Delete( identity immutable.Option[string], docID client.DocID, ) (bool, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return false, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) @@ -1832,12 +1821,10 @@ func (c *collection) Exists( identity immutable.Option[string], docID client.DocID, ) (bool, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return false, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) diff --git a/db/collection_delete.go b/db/collection_delete.go index fdb9005e7e..8d5bf3f2bb 100644 --- a/db/collection_delete.go +++ b/db/collection_delete.go @@ -54,12 +54,10 @@ func (c *collection) DeleteWithDocID( identity immutable.Option[string], docID client.DocID, ) (*client.DeleteResult, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) dsKey := c.getPrimaryKeyFromDocID(docID) @@ -77,12 +75,10 @@ func (c *collection) DeleteWithDocIDs( identity immutable.Option[string], docIDs []client.DocID, ) (*client.DeleteResult, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) res, err := c.deleteWithIDs(ctx, identity, txn, docIDs, client.Deleted) @@ -99,12 +95,10 @@ func (c *collection) DeleteWithFilter( identity immutable.Option[string], filter any, ) (*client.DeleteResult, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) res, err := c.deleteWithFilter(ctx, identity, txn, filter, client.Deleted) diff --git a/db/collection_get.go b/db/collection_get.go index b694d962fe..8ae0dcae75 100644 --- a/db/collection_get.go +++ b/db/collection_get.go @@ -29,12 +29,10 @@ func (c *collection) Get( showDeleted bool, ) (*client.Document, error) { // create txn - ctx, err := ensureContextTxn(ctx, c.db, true) + ctx, txn, err := ensureContextTxn(ctx, c.db, true) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) diff --git a/db/collection_index.go b/db/collection_index.go index 0ac1975bda..3e33c94709 100644 --- a/db/collection_index.go +++ b/db/collection_index.go @@ -112,12 +112,10 @@ func (db *db) fetchCollectionIndexDescriptions( } func (c *collection) CreateDocIndex(ctx context.Context, doc *client.Document) error { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.indexNewDoc(ctx, txn, doc) @@ -129,12 +127,10 @@ func (c *collection) CreateDocIndex(ctx context.Context, doc *client.Document) e } func (c *collection) UpdateDocIndex(ctx context.Context, oldDoc, newDoc *client.Document) error { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.deleteIndexedDoc(ctx, txn, oldDoc) @@ -150,12 +146,10 @@ func (c *collection) UpdateDocIndex(ctx context.Context, oldDoc, newDoc *client. } func (c *collection) DeleteDocIndex(ctx context.Context, doc *client.Document) error { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.deleteIndexedDoc(ctx, txn, doc) @@ -248,12 +242,10 @@ func (c *collection) CreateIndex( ctx context.Context, desc client.IndexDescription, ) (client.IndexDescription, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return client.IndexDescription{}, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) index, err := c.createIndex(ctx, txn, desc) @@ -406,12 +398,10 @@ func (c *collection) indexExistingDocs( // // All index artifacts for existing documents related the index will be removed. func (c *collection) DropIndex(ctx context.Context, indexName string) error { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.dropIndex(ctx, txn, indexName) @@ -496,12 +486,10 @@ func (c *collection) loadIndexes(ctx context.Context, txn datastore.Txn) error { // GetIndexes returns all indexes for the collection. func (c *collection) GetIndexes(ctx context.Context) ([]client.IndexDescription, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = c.loadIndexes(ctx, txn) diff --git a/db/collection_update.go b/db/collection_update.go index 96b51d5bf3..6c836412e0 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -57,11 +57,10 @@ func (c *collection) UpdateWithFilter( filter any, updater string, ) (*client.UpdateResult, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) res, err := c.updateWithFilter(ctx, identity, txn, filter, updater) @@ -80,12 +79,10 @@ func (c *collection) UpdateWithDocID( docID client.DocID, updater string, ) (*client.UpdateResult, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) res, err := c.updateWithDocID(ctx, identity, txn, docID, updater) @@ -105,12 +102,10 @@ func (c *collection) UpdateWithDocIDs( docIDs []client.DocID, updater string, ) (*client.UpdateResult, error) { - ctx, err := ensureContextTxn(ctx, c.db, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) res, err := c.updateWithIDs(ctx, identity, txn, docIDs, updater) diff --git a/db/context.go b/db/context.go index a95052164d..49df02fca3 100644 --- a/db/context.go +++ b/db/context.go @@ -41,22 +41,16 @@ type transactionDB interface { // // If a transactions exists on the context it will be made explicit, // otherwise a new implicit transaction will be created. -func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (context.Context, error) { +func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (context.Context, datastore.Txn, error) { txn, ok := TryGetContextTxn(ctx) if ok { - return SetContextTxn(ctx, &explicitTxn{txn}), nil + return SetContextTxn(ctx, &explicitTxn{txn}), txn, nil } txn, err := db.NewTxn(ctx, readOnly) if err != nil { - return nil, err + return nil, txn, err } - return SetContextTxn(ctx, txn), nil -} - -// mustGetContextTxn returns the transaction from the context if it exists, -// otherwise it panics. -func mustGetContextTxn(ctx context.Context) datastore.Txn { - return ctx.Value(txnContextKey{}).(datastore.Txn) + return SetContextTxn(ctx, txn), txn, nil } // TryGetContextTxn returns a transaction and a bool indicating if the diff --git a/db/context_test.go b/db/context_test.go index b711b4ae0a..c8b1a322e5 100644 --- a/db/context_test.go +++ b/db/context_test.go @@ -30,10 +30,13 @@ func TestEnsureContextTxnExplicit(t *testing.T) { // set an explicit transaction ctx = SetContextTxn(ctx, txn) - ctx, err = ensureContextTxn(ctx, db, true) + ctx, txn, err = ensureContextTxn(ctx, db, true) require.NoError(t, err) - _, ok := mustGetContextTxn(ctx).(*explicitTxn) + _, ok := txn.(*explicitTxn) + assert.True(t, ok) + + _, ok = ctx.Value(txnContextKey{}).(*explicitTxn) assert.True(t, ok) } @@ -43,9 +46,12 @@ func TestEnsureContextTxnImplicit(t *testing.T) { db, err := newMemoryDB(ctx) require.NoError(t, err) - ctx, err = ensureContextTxn(ctx, db, true) + ctx, txn, err := ensureContextTxn(ctx, db, true) require.NoError(t, err) - _, ok := mustGetContextTxn(ctx).(*explicitTxn) + _, ok := txn.(*explicitTxn) + assert.False(t, ok) + + _, ok = ctx.Value(txnContextKey{}).(*explicitTxn) assert.False(t, ok) } diff --git a/db/store.go b/db/store.go index 7839eb099a..5b33c8607b 100644 --- a/db/store.go +++ b/db/store.go @@ -32,14 +32,12 @@ func (s *store) ExecRequest( identity immutable.Option[string], request string, ) *client.RequestResult { - ctx, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, s, false) if err != nil { res := &client.RequestResult{} res.GQL.Errors = []error{err} return res } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) res := s.db.execRequest(ctx, identity, request, txn) @@ -57,12 +55,10 @@ func (s *store) ExecRequest( // GetCollectionByName returns an existing collection within the database. func (s *store) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { - ctx, err := ensureContextTxn(ctx, s, true) + ctx, txn, err := ensureContextTxn(ctx, s, true) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getCollectionByName(ctx, txn, name) @@ -73,12 +69,10 @@ func (s *store) GetCollections( ctx context.Context, options client.CollectionFetchOptions, ) ([]client.Collection, error) { - ctx, err := ensureContextTxn(ctx, s, true) + ctx, txn, err := ensureContextTxn(ctx, s, true) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getCollections(ctx, txn, options) @@ -89,12 +83,10 @@ func (s *store) GetCollections( // // Will return an error if it is not found. func (s *store) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { - ctx, err := ensureContextTxn(ctx, s, true) + ctx, txn, err := ensureContextTxn(ctx, s, true) if err != nil { return client.SchemaDescription{}, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getSchemaByVersionID(ctx, txn, versionID) @@ -106,12 +98,10 @@ func (s *store) GetSchemas( ctx context.Context, options client.SchemaFetchOptions, ) ([]client.SchemaDescription, error) { - ctx, err := ensureContextTxn(ctx, s, true) + ctx, txn, err := ensureContextTxn(ctx, s, true) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getSchemas(ctx, txn, options) @@ -121,12 +111,10 @@ func (s *store) GetSchemas( func (s *store) GetAllIndexes( ctx context.Context, ) (map[client.CollectionName][]client.IndexDescription, error) { - ctx, err := ensureContextTxn(ctx, s, true) + ctx, txn, err := ensureContextTxn(ctx, s, true) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) return s.db.getAllIndexDescriptions(ctx, txn) @@ -138,12 +126,10 @@ func (s *store) GetAllIndexes( // All schema types provided must not exist prior to calling this, and they may not reference existing // types previously defined. func (s *store) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { - ctx, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, s, false) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) cols, err := s.db.addSchema(ctx, txn, schemaString) @@ -174,12 +160,10 @@ func (s *store) PatchSchema( migration immutable.Option[model.Lens], setAsDefaultVersion bool, ) error { - ctx, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, s, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.patchSchema(ctx, txn, patchString, migration, setAsDefaultVersion) @@ -194,12 +178,10 @@ func (s *store) PatchCollection( ctx context.Context, patchString string, ) error { - ctx, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, s, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.patchCollection(ctx, txn, patchString) @@ -211,12 +193,10 @@ func (s *store) PatchCollection( } func (s *store) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { - ctx, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, s, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.setActiveSchemaVersion(ctx, txn, schemaVersionID) @@ -228,12 +208,10 @@ func (s *store) SetActiveSchemaVersion(ctx context.Context, schemaVersionID stri } func (s *store) SetMigration(ctx context.Context, cfg client.LensConfig) error { - ctx, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, s, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.setMigration(ctx, txn, cfg) @@ -250,12 +228,10 @@ func (s *store) AddView( sdl string, transform immutable.Option[model.Lens], ) ([]client.CollectionDefinition, error) { - ctx, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, s, false) if err != nil { return nil, err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) defs, err := s.db.addView(ctx, txn, query, sdl, transform) @@ -274,12 +250,10 @@ func (s *store) AddView( // BasicImport imports a json dataset. // filepath must be accessible to the node. func (s *store) BasicImport(ctx context.Context, filepath string) error { - ctx, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, s, false) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.basicImport(ctx, txn, filepath) @@ -292,12 +266,10 @@ func (s *store) BasicImport(ctx context.Context, filepath string) error { // BasicExport exports the current data or subset of data to file in json format. func (s *store) BasicExport(ctx context.Context, config *client.BackupConfig) error { - ctx, err := ensureContextTxn(ctx, s, true) + ctx, txn, err := ensureContextTxn(ctx, s, true) if err != nil { return err } - - txn := mustGetContextTxn(ctx) defer txn.Discard(ctx) err = s.db.basicExport(ctx, txn, config) From fbed9d0279a0571ad29009ffe49e35a4f46c21c4 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Thu, 11 Apr 2024 09:52:56 -0700 Subject: [PATCH 09/14] merge db.store funcs into db --- db/collection_update.go | 2 +- db/db.go | 4 +- db/db_test.go | 2 +- db/index_test.go | 2 +- db/request.go | 2 +- db/store.go | 94 ++++++++++++++++++----------------------- db/subscriptions.go | 2 +- 7 files changed, 49 insertions(+), 59 deletions(-) diff --git a/db/collection_update.go b/db/collection_update.go index 6c836412e0..e9ab2e7fa1 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -441,7 +441,7 @@ func (c *collection) makeSelectionPlan( ctx, identity, c.db.acp, - &store{c.db}, + c.db, txn, ) diff --git a/db/db.go b/db/db.go index 5c3269fb59..e7a6fa8d09 100644 --- a/db/db.go +++ b/db/db.go @@ -89,7 +89,7 @@ func newDB( ctx context.Context, rootstore datastore.RootStore, options ...Option, -) (*store, error) { +) (*db, error) { multistore := datastore.MultiStoreFrom(rootstore) parser, err := graphql.NewParser() @@ -119,7 +119,7 @@ func newDB( return nil, err } - return &store{db}, nil + return db, nil } // NewTxn creates a new transaction. diff --git a/db/db_test.go b/db/db_test.go index 89e5aa9c6b..118adb285b 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -19,7 +19,7 @@ import ( badgerds "github.com/sourcenetwork/defradb/datastore/badger/v4" ) -func newMemoryDB(ctx context.Context) (*store, error) { +func newMemoryDB(ctx context.Context) (*db, error) { opts := badgerds.Options{Options: badger.DefaultOptions("").WithInMemory(true)} rootstore, err := badgerds.NewDatastore("", &opts) if err != nil { diff --git a/db/index_test.go b/db/index_test.go index a9cb90a132..aeda2bdd6d 100644 --- a/db/index_test.go +++ b/db/index_test.go @@ -53,7 +53,7 @@ const ( type indexTestFixture struct { ctx context.Context - db *store + db *db txn datastore.Txn users client.Collection t *testing.T diff --git a/db/request.go b/db/request.go index 21474da089..69b300f482 100644 --- a/db/request.go +++ b/db/request.go @@ -59,7 +59,7 @@ func (db *db) execRequest( ctx, identity, db.acp, - &store{db}, + db, txn, ) diff --git a/db/store.go b/db/store.go index 5b33c8607b..aff11f851d 100644 --- a/db/store.go +++ b/db/store.go @@ -20,19 +20,13 @@ import ( "github.com/sourcenetwork/defradb/client" ) -var _ client.Store = (*store)(nil) - -type store struct { - *db -} - // ExecRequest executes a request against the database. -func (s *store) ExecRequest( +func (db *db) ExecRequest( ctx context.Context, identity immutable.Option[string], request string, ) *client.RequestResult { - ctx, txn, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, db, false) if err != nil { res := &client.RequestResult{} res.GQL.Errors = []error{err} @@ -40,7 +34,7 @@ func (s *store) ExecRequest( } defer txn.Discard(ctx) - res := s.db.execRequest(ctx, identity, request, txn) + res := db.execRequest(ctx, identity, request, txn) if len(res.GQL.Errors) > 0 { return res } @@ -54,70 +48,70 @@ func (s *store) ExecRequest( } // GetCollectionByName returns an existing collection within the database. -func (s *store) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { - ctx, txn, err := ensureContextTxn(ctx, s, true) +func (db *db) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { + ctx, txn, err := ensureContextTxn(ctx, db, true) if err != nil { return nil, err } defer txn.Discard(ctx) - return s.db.getCollectionByName(ctx, txn, name) + return db.getCollectionByName(ctx, txn, name) } // GetCollections gets all the currently defined collections. -func (s *store) GetCollections( +func (db *db) GetCollections( ctx context.Context, options client.CollectionFetchOptions, ) ([]client.Collection, error) { - ctx, txn, err := ensureContextTxn(ctx, s, true) + ctx, txn, err := ensureContextTxn(ctx, db, true) if err != nil { return nil, err } defer txn.Discard(ctx) - return s.db.getCollections(ctx, txn, options) + return db.getCollections(ctx, txn, options) } // GetSchemaByVersionID returns the schema description for the schema version of the // ID provided. // // Will return an error if it is not found. -func (s *store) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { - ctx, txn, err := ensureContextTxn(ctx, s, true) +func (db *db) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { + ctx, txn, err := ensureContextTxn(ctx, db, true) if err != nil { return client.SchemaDescription{}, err } defer txn.Discard(ctx) - return s.db.getSchemaByVersionID(ctx, txn, versionID) + return db.getSchemaByVersionID(ctx, txn, versionID) } // GetSchemas returns all schema versions that currently exist within // this [Store]. -func (s *store) GetSchemas( +func (db *db) GetSchemas( ctx context.Context, options client.SchemaFetchOptions, ) ([]client.SchemaDescription, error) { - ctx, txn, err := ensureContextTxn(ctx, s, true) + ctx, txn, err := ensureContextTxn(ctx, db, true) if err != nil { return nil, err } defer txn.Discard(ctx) - return s.db.getSchemas(ctx, txn, options) + return db.getSchemas(ctx, txn, options) } // GetAllIndexes gets all the indexes in the database. -func (s *store) GetAllIndexes( +func (db *db) GetAllIndexes( ctx context.Context, ) (map[client.CollectionName][]client.IndexDescription, error) { - ctx, txn, err := ensureContextTxn(ctx, s, true) + ctx, txn, err := ensureContextTxn(ctx, db, true) if err != nil { return nil, err } defer txn.Discard(ctx) - return s.db.getAllIndexDescriptions(ctx, txn) + return db.getAllIndexDescriptions(ctx, txn) } // AddSchema takes the provided GQL schema in SDL format, and applies it to the database, @@ -125,14 +119,14 @@ func (s *store) GetAllIndexes( // // All schema types provided must not exist prior to calling this, and they may not reference existing // types previously defined. -func (s *store) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { - ctx, txn, err := ensureContextTxn(ctx, s, false) +func (db *db) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { + ctx, txn, err := ensureContextTxn(ctx, db, false) if err != nil { return nil, err } defer txn.Discard(ctx) - cols, err := s.db.addSchema(ctx, txn, schemaString) + cols, err := db.addSchema(ctx, txn, schemaString) if err != nil { return nil, err } @@ -154,19 +148,19 @@ func (s *store) AddSchema(ctx context.Context, schemaString string) ([]client.Co // The collections (including the schema version ID) will only be updated if any changes have actually // been made, if the net result of the patch matches the current persisted description then no changes // will be applied. -func (s *store) PatchSchema( +func (db *db) PatchSchema( ctx context.Context, patchString string, migration immutable.Option[model.Lens], setAsDefaultVersion bool, ) error { - ctx, txn, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, db, false) if err != nil { return err } defer txn.Discard(ctx) - err = s.db.patchSchema(ctx, txn, patchString, migration, setAsDefaultVersion) + err = db.patchSchema(ctx, txn, patchString, migration, setAsDefaultVersion) if err != nil { return err } @@ -174,17 +168,17 @@ func (s *store) PatchSchema( return txn.Commit(ctx) } -func (s *store) PatchCollection( +func (db *db) PatchCollection( ctx context.Context, patchString string, ) error { - ctx, txn, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, db, false) if err != nil { return err } defer txn.Discard(ctx) - err = s.db.patchCollection(ctx, txn, patchString) + err = db.patchCollection(ctx, txn, patchString) if err != nil { return err } @@ -192,14 +186,14 @@ func (s *store) PatchCollection( return txn.Commit(ctx) } -func (s *store) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { - ctx, txn, err := ensureContextTxn(ctx, s, false) +func (db *db) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { + ctx, txn, err := ensureContextTxn(ctx, db, false) if err != nil { return err } defer txn.Discard(ctx) - err = s.db.setActiveSchemaVersion(ctx, txn, schemaVersionID) + err = db.setActiveSchemaVersion(ctx, txn, schemaVersionID) if err != nil { return err } @@ -207,14 +201,14 @@ func (s *store) SetActiveSchemaVersion(ctx context.Context, schemaVersionID stri return txn.Commit(ctx) } -func (s *store) SetMigration(ctx context.Context, cfg client.LensConfig) error { - ctx, txn, err := ensureContextTxn(ctx, s, false) +func (db *db) SetMigration(ctx context.Context, cfg client.LensConfig) error { + ctx, txn, err := ensureContextTxn(ctx, db, false) if err != nil { return err } defer txn.Discard(ctx) - err = s.db.setMigration(ctx, txn, cfg) + err = db.setMigration(ctx, txn, cfg) if err != nil { return err } @@ -222,19 +216,19 @@ func (s *store) SetMigration(ctx context.Context, cfg client.LensConfig) error { return txn.Commit(ctx) } -func (s *store) AddView( +func (db *db) AddView( ctx context.Context, query string, sdl string, transform immutable.Option[model.Lens], ) ([]client.CollectionDefinition, error) { - ctx, txn, err := ensureContextTxn(ctx, s, false) + ctx, txn, err := ensureContextTxn(ctx, db, false) if err != nil { return nil, err } defer txn.Discard(ctx) - defs, err := s.db.addView(ctx, txn, query, sdl, transform) + defs, err := db.addView(ctx, txn, query, sdl, transform) if err != nil { return nil, err } @@ -249,14 +243,14 @@ func (s *store) AddView( // BasicImport imports a json dataset. // filepath must be accessible to the node. -func (s *store) BasicImport(ctx context.Context, filepath string) error { - ctx, txn, err := ensureContextTxn(ctx, s, false) +func (db *db) BasicImport(ctx context.Context, filepath string) error { + ctx, txn, err := ensureContextTxn(ctx, db, false) if err != nil { return err } defer txn.Discard(ctx) - err = s.db.basicImport(ctx, txn, filepath) + err = db.basicImport(ctx, txn, filepath) if err != nil { return err } @@ -265,21 +259,17 @@ func (s *store) BasicImport(ctx context.Context, filepath string) error { } // BasicExport exports the current data or subset of data to file in json format. -func (s *store) BasicExport(ctx context.Context, config *client.BackupConfig) error { - ctx, txn, err := ensureContextTxn(ctx, s, true) +func (db *db) BasicExport(ctx context.Context, config *client.BackupConfig) error { + ctx, txn, err := ensureContextTxn(ctx, db, true) if err != nil { return err } defer txn.Discard(ctx) - err = s.db.basicExport(ctx, txn, config) + err = db.basicExport(ctx, txn, config) if err != nil { return err } return txn.Commit(ctx) } - -func (s *store) LensRegistry() client.LensRegistry { - return s.db.lensRegistry -} diff --git a/db/subscriptions.go b/db/subscriptions.go index e7f5d997cc..e649769c18 100644 --- a/db/subscriptions.go +++ b/db/subscriptions.go @@ -80,7 +80,7 @@ func (db *db) handleEvent( ctx, identity, db.acp, - &store{db}, + db, txn, ) From 73219a989a63f9b2aa103a05eea452ea4bd8ae3c Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Thu, 11 Apr 2024 10:04:11 -0700 Subject: [PATCH 10/14] preserve client.Store type in cli --- cli/backup_export.go | 4 ++-- cli/backup_import.go | 4 ++-- cli/collection.go | 4 ++-- cli/collection_describe.go | 4 ++-- cli/collection_patch.go | 4 ++-- cli/index_create.go | 4 ++-- cli/index_drop.go | 4 ++-- cli/index_list.go | 6 +++--- cli/request.go | 4 ++-- cli/schema_describe.go | 4 ++-- cli/schema_migration_down.go | 4 ++-- cli/schema_migration_reload.go | 4 ++-- cli/schema_migration_set.go | 4 ++-- cli/schema_migration_set_registry.go | 4 ++-- cli/schema_migration_up.go | 4 ++-- cli/schema_patch.go | 4 ++-- cli/schema_set_active.go | 4 ++-- cli/utils.go | 7 +++++++ cli/view_add.go | 4 ++-- 19 files changed, 44 insertions(+), 37 deletions(-) diff --git a/cli/backup_export.go b/cli/backup_export.go index 496d336c44..b905bdf9c7 100644 --- a/cli/backup_export.go +++ b/cli/backup_export.go @@ -38,7 +38,7 @@ Example: export data for the 'Users' collection: defradb client export --collection Users user_data.json`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) if !isValidExportFormat(format) { return ErrInvalidExportFormat @@ -56,7 +56,7 @@ Example: export data for the 'Users' collection: Collections: collections, } - return db.BasicExport(cmd.Context(), &data) + return store.BasicExport(cmd.Context(), &data) }, } cmd.Flags().BoolVarP(&pretty, "pretty", "p", false, "Set the output JSON to be pretty printed") diff --git a/cli/backup_import.go b/cli/backup_import.go index 092ddb61aa..56f1907643 100644 --- a/cli/backup_import.go +++ b/cli/backup_import.go @@ -24,8 +24,8 @@ Example: import data to the database: defradb client import user_data.json`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) - return db.BasicImport(cmd.Context(), args[0]) + store := mustGetContextStore(cmd) + return store.BasicImport(cmd.Context(), args[0]) }, } return cmd diff --git a/cli/collection.go b/cli/collection.go index 4e69fd97fd..2cdd9b33bd 100644 --- a/cli/collection.go +++ b/cli/collection.go @@ -43,7 +43,7 @@ func MakeCollectionCommand() *cobra.Command { if err := setContextTransaction(cmd, txID); err != nil { return err } - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) options := client.CollectionFetchOptions{} if versionID != "" { @@ -59,7 +59,7 @@ func MakeCollectionCommand() *cobra.Command { options.IncludeInactive = immutable.Some(getInactive) } - cols, err := db.GetCollections(cmd.Context(), options) + cols, err := store.GetCollections(cmd.Context(), options) if err != nil { return err } diff --git a/cli/collection_describe.go b/cli/collection_describe.go index 5311cf0b80..5d1a85ea5e 100644 --- a/cli/collection_describe.go +++ b/cli/collection_describe.go @@ -40,7 +40,7 @@ Example: view collection by version id. This will also return inactive collectio defradb client collection describe --version bae123 `, RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) options := client.CollectionFetchOptions{} if versionID != "" { @@ -56,7 +56,7 @@ Example: view collection by version id. This will also return inactive collectio options.IncludeInactive = immutable.Some(getInactive) } - cols, err := db.GetCollections( + cols, err := store.GetCollections( cmd.Context(), options, ) diff --git a/cli/collection_patch.go b/cli/collection_patch.go index 7b7f4252b2..49d5a91305 100644 --- a/cli/collection_patch.go +++ b/cli/collection_patch.go @@ -39,7 +39,7 @@ Example: patch from stdin: To learn more about the DefraDB GraphQL Schema Language, refer to https://docs.source.network.`, Args: cobra.RangeArgs(0, 1), RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) var patch string switch { @@ -61,7 +61,7 @@ To learn more about the DefraDB GraphQL Schema Language, refer to https://docs.s return fmt.Errorf("patch cannot be empty") } - return db.PatchCollection(cmd.Context(), patch) + return store.PatchCollection(cmd.Context(), patch) }, } cmd.Flags().StringVarP(&patchFile, "patch-file", "p", "", "File to load a patch from") diff --git a/cli/index_create.go b/cli/index_create.go index d4b27c8077..0d724da15b 100644 --- a/cli/index_create.go +++ b/cli/index_create.go @@ -36,7 +36,7 @@ Example: create a named index for 'Users' collection on 'name' field: defradb client index create --collection Users --fields name --name UsersByName`, ValidArgs: []string{"collection", "fields", "name"}, RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) var fields []client.IndexedFieldDescription for _, name := range fieldsArg { @@ -47,7 +47,7 @@ Example: create a named index for 'Users' collection on 'name' field: Fields: fields, Unique: uniqueArg, } - col, err := db.GetCollectionByName(cmd.Context(), collectionArg) + col, err := store.GetCollectionByName(cmd.Context(), collectionArg) if err != nil { return err } diff --git a/cli/index_drop.go b/cli/index_drop.go index 60b4f52f6d..5dd069b5da 100644 --- a/cli/index_drop.go +++ b/cli/index_drop.go @@ -26,9 +26,9 @@ Example: drop the index 'UsersByName' for 'Users' collection: defradb client index create --collection Users --name UsersByName`, ValidArgs: []string{"collection", "name"}, RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) - col, err := db.GetCollectionByName(cmd.Context(), collectionArg) + col, err := store.GetCollectionByName(cmd.Context(), collectionArg) if err != nil { return err } diff --git a/cli/index_list.go b/cli/index_list.go index 89b091d179..481acb7d37 100644 --- a/cli/index_list.go +++ b/cli/index_list.go @@ -28,11 +28,11 @@ Example: show all index for 'Users' collection: defradb client index list --collection Users`, ValidArgs: []string{"collection"}, RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) switch { case collectionArg != "": - col, err := db.GetCollectionByName(cmd.Context(), collectionArg) + col, err := store.GetCollectionByName(cmd.Context(), collectionArg) if err != nil { return err } @@ -42,7 +42,7 @@ Example: show all index for 'Users' collection: } return writeJSON(cmd, indexes) default: - indexes, err := db.GetAllIndexes(cmd.Context()) + indexes, err := store.GetAllIndexes(cmd.Context()) if err != nil { return err } diff --git a/cli/request.go b/cli/request.go index 03de7bae4a..c583d51a28 100644 --- a/cli/request.go +++ b/cli/request.go @@ -78,8 +78,8 @@ To learn more about the DefraDB GraphQL Query Language, refer to https://docs.so return errors.New("request cannot be empty") } - db := mustGetContextDB(cmd) - result := db.ExecRequest(cmd.Context(), identity, request) + store := mustGetContextStore(cmd) + result := store.ExecRequest(cmd.Context(), identity, request) var errors []string for _, err := range result.GQL.Errors { diff --git a/cli/schema_describe.go b/cli/schema_describe.go index ddc43db1d7..c4133baa8c 100644 --- a/cli/schema_describe.go +++ b/cli/schema_describe.go @@ -40,7 +40,7 @@ Example: view a single schema by version id defradb client schema describe --version bae123 `, RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) options := client.SchemaFetchOptions{} if versionID != "" { @@ -53,7 +53,7 @@ Example: view a single schema by version id options.Name = immutable.Some(name) } - schemas, err := db.GetSchemas(cmd.Context(), options) + schemas, err := store.GetSchemas(cmd.Context(), options) if err != nil { return err } diff --git a/cli/schema_migration_down.go b/cli/schema_migration_down.go index b83f85ca74..a49f359694 100644 --- a/cli/schema_migration_down.go +++ b/cli/schema_migration_down.go @@ -39,7 +39,7 @@ Example: migrate from stdin `, Args: cobra.RangeArgs(0, 1), RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) var srcData []byte switch { @@ -65,7 +65,7 @@ Example: migrate from stdin if err := json.Unmarshal(srcData, &src); err != nil { return err } - out, err := db.LensRegistry().MigrateDown(cmd.Context(), enumerable.New(src), collectionID) + out, err := store.LensRegistry().MigrateDown(cmd.Context(), enumerable.New(src), collectionID) if err != nil { return err } diff --git a/cli/schema_migration_reload.go b/cli/schema_migration_reload.go index a4e9f89934..8ffb5542f1 100644 --- a/cli/schema_migration_reload.go +++ b/cli/schema_migration_reload.go @@ -20,8 +20,8 @@ func MakeSchemaMigrationReloadCommand() *cobra.Command { Short: "Reload the schema migrations within DefraDB", Long: `Reload the schema migrations within DefraDB`, RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) - return db.LensRegistry().ReloadLenses(cmd.Context()) + store := mustGetContextStore(cmd) + return store.LensRegistry().ReloadLenses(cmd.Context()) }, } return cmd diff --git a/cli/schema_migration_set.go b/cli/schema_migration_set.go index 2a609449d4..f7b32103b9 100644 --- a/cli/schema_migration_set.go +++ b/cli/schema_migration_set.go @@ -42,7 +42,7 @@ Example: add from stdin: Learn more about the DefraDB GraphQL Schema Language on https://docs.source.network.`, Args: cobra.RangeArgs(2, 3), RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) var lensCfgJson string switch { @@ -81,7 +81,7 @@ Learn more about the DefraDB GraphQL Schema Language on https://docs.source.netw Lens: lensCfg, } - return db.SetMigration(cmd.Context(), migrationCfg) + return store.SetMigration(cmd.Context(), migrationCfg) }, } cmd.Flags().StringVarP(&lensFile, "file", "f", "", "Lens configuration file") diff --git a/cli/schema_migration_set_registry.go b/cli/schema_migration_set_registry.go index 99e1ba0104..cc5098afae 100644 --- a/cli/schema_migration_set_registry.go +++ b/cli/schema_migration_set_registry.go @@ -32,7 +32,7 @@ Example: set from an argument string: Learn more about the DefraDB GraphQL Schema Language on https://docs.source.network.`, Args: cobra.ExactArgs(2), RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) decoder := json.NewDecoder(strings.NewReader(args[1])) decoder.DisallowUnknownFields() @@ -47,7 +47,7 @@ Learn more about the DefraDB GraphQL Schema Language on https://docs.source.netw return err } - return db.LensRegistry().SetMigration(cmd.Context(), uint32(collectionID), lensCfg) + return store.LensRegistry().SetMigration(cmd.Context(), uint32(collectionID), lensCfg) }, } return cmd diff --git a/cli/schema_migration_up.go b/cli/schema_migration_up.go index 491068ad28..4473c45911 100644 --- a/cli/schema_migration_up.go +++ b/cli/schema_migration_up.go @@ -39,7 +39,7 @@ Example: migrate from stdin `, Args: cobra.RangeArgs(0, 1), RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) var srcData []byte switch { @@ -65,7 +65,7 @@ Example: migrate from stdin if err := json.Unmarshal(srcData, &src); err != nil { return err } - out, err := db.LensRegistry().MigrateUp(cmd.Context(), enumerable.New(src), collectionID) + out, err := store.LensRegistry().MigrateUp(cmd.Context(), enumerable.New(src), collectionID) if err != nil { return err } diff --git a/cli/schema_patch.go b/cli/schema_patch.go index 1a0f617c8d..cf9224d204 100644 --- a/cli/schema_patch.go +++ b/cli/schema_patch.go @@ -44,7 +44,7 @@ Example: patch from stdin: To learn more about the DefraDB GraphQL Schema Language, refer to https://docs.source.network.`, RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) var patch string switch { @@ -90,7 +90,7 @@ To learn more about the DefraDB GraphQL Schema Language, refer to https://docs.s migration = immutable.Some(lensCfg) } - return db.PatchSchema(cmd.Context(), patch, migration, setActive) + return store.PatchSchema(cmd.Context(), patch, migration, setActive) }, } cmd.Flags().BoolVar(&setActive, "set-active", false, diff --git a/cli/schema_set_active.go b/cli/schema_set_active.go index 9560d88276..2b13713461 100644 --- a/cli/schema_set_active.go +++ b/cli/schema_set_active.go @@ -22,8 +22,8 @@ func MakeSchemaSetActiveCommand() *cobra.Command { those without it (if they share the same schema root).`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) - return db.SetActiveSchemaVersion(cmd.Context(), args[0]) + store := mustGetContextStore(cmd) + return store.SetActiveSchemaVersion(cmd.Context(), args[0]) }, } return cmd diff --git a/cli/utils.go b/cli/utils.go index d93cce46f2..1df10a3409 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -48,6 +48,13 @@ func mustGetContextDB(cmd *cobra.Command) client.DB { return cmd.Context().Value(dbContextKey).(client.DB) } +// mustGetContextStore returns the store for the current command context. +// +// If a store is not set in the current context this function panics. +func mustGetContextStore(cmd *cobra.Command) client.Store { + return cmd.Context().Value(dbContextKey).(client.Store) +} + // mustGetContextP2P returns the p2p implementation for the current command context. // // If a p2p implementation is not set in the current context this function panics. diff --git a/cli/view_add.go b/cli/view_add.go index 7038dae81c..9c7d42b723 100644 --- a/cli/view_add.go +++ b/cli/view_add.go @@ -34,7 +34,7 @@ Example: add from an argument string: Learn more about the DefraDB GraphQL Schema Language on https://docs.source.network.`, Args: cobra.RangeArgs(2, 4), RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) query := args[0] sdl := args[1] @@ -69,7 +69,7 @@ Learn more about the DefraDB GraphQL Schema Language on https://docs.source.netw transform = immutable.Some(lensCfg) } - defs, err := db.AddView(cmd.Context(), query, sdl, transform) + defs, err := store.AddView(cmd.Context(), query, sdl, transform) if err != nil { return err } From f65a7677ea04ea7923032394f41515c90e31592f Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Thu, 11 Apr 2024 10:10:11 -0700 Subject: [PATCH 11/14] preserve client.Store type in http --- http/handler_ccip.go | 4 ++-- http/handler_collection.go | 4 ++-- http/handler_lens.go | 16 ++++++------- http/handler_store.go | 48 +++++++++++++++++++------------------- 4 files changed, 36 insertions(+), 36 deletions(-) diff --git a/http/handler_ccip.go b/http/handler_ccip.go index d89103c78a..dfe8a66083 100644 --- a/http/handler_ccip.go +++ b/http/handler_ccip.go @@ -35,7 +35,7 @@ type CCIPResponse struct { // ExecCCIP handles GraphQL over Cross Chain Interoperability Protocol requests. func (c *ccipHandler) ExecCCIP(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var ccipReq CCIPRequest switch req.Method { @@ -61,7 +61,7 @@ func (c *ccipHandler) ExecCCIP(rw http.ResponseWriter, req *http.Request) { } identity := getIdentityFromAuthHeader(req) - result := db.ExecRequest(req.Context(), identity, request.Query) + result := store.ExecRequest(req.Context(), identity, request.Query) if result.Pub != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{ErrStreamingNotSupported}) return diff --git a/http/handler_collection.go b/http/handler_collection.go index 05e842d473..8b7f0cf64c 100644 --- a/http/handler_collection.go +++ b/http/handler_collection.go @@ -331,8 +331,8 @@ func (s *collectionHandler) CreateIndex(rw http.ResponseWriter, req *http.Reques } func (s *collectionHandler) GetIndexes(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) - indexesMap, err := db.GetAllIndexes(req.Context()) + store := req.Context().Value(dbContextKey).(client.Store) + indexesMap, err := store.GetAllIndexes(req.Context()) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) diff --git a/http/handler_lens.go b/http/handler_lens.go index 7104116781..94ef9c2abe 100644 --- a/http/handler_lens.go +++ b/http/handler_lens.go @@ -22,9 +22,9 @@ import ( type lensHandler struct{} func (s *lensHandler) ReloadLenses(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) - err := db.LensRegistry().ReloadLenses(req.Context()) + err := store.LensRegistry().ReloadLenses(req.Context()) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -33,7 +33,7 @@ func (s *lensHandler) ReloadLenses(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var request setMigrationRequest if err := requestJSON(req, &request); err != nil { @@ -41,7 +41,7 @@ func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { return } - err := db.LensRegistry().SetMigration(req.Context(), request.CollectionID, request.Config) + err := store.LensRegistry().SetMigration(req.Context(), request.CollectionID, request.Config) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -50,7 +50,7 @@ func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var request migrateRequest if err := requestJSON(req, &request); err != nil { @@ -58,7 +58,7 @@ func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { return } - result, err := db.LensRegistry().MigrateUp(req.Context(), enumerable.New(request.Data), request.CollectionID) + result, err := store.LensRegistry().MigrateUp(req.Context(), enumerable.New(request.Data), request.CollectionID) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -75,7 +75,7 @@ func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) MigrateDown(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var request migrateRequest if err := requestJSON(req, &request); err != nil { @@ -83,7 +83,7 @@ func (s *lensHandler) MigrateDown(rw http.ResponseWriter, req *http.Request) { return } - result, err := db.LensRegistry().MigrateDown(req.Context(), enumerable.New(request.Data), request.CollectionID) + result, err := store.LensRegistry().MigrateDown(req.Context(), enumerable.New(request.Data), request.CollectionID) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return diff --git a/http/handler_store.go b/http/handler_store.go index 231316ade7..c71e108818 100644 --- a/http/handler_store.go +++ b/http/handler_store.go @@ -27,14 +27,14 @@ import ( type storeHandler struct{} func (s *storeHandler) BasicImport(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var config client.BackupConfig if err := requestJSON(req, &config); err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - err := db.BasicImport(req.Context(), config.Filepath) + err := store.BasicImport(req.Context(), config.Filepath) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -43,14 +43,14 @@ func (s *storeHandler) BasicImport(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) BasicExport(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var config client.BackupConfig if err := requestJSON(req, &config); err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - err := db.BasicExport(req.Context(), &config) + err := store.BasicExport(req.Context(), &config) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -59,14 +59,14 @@ func (s *storeHandler) BasicExport(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) AddSchema(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) schema, err := io.ReadAll(req.Body) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - cols, err := db.AddSchema(req.Context(), string(schema)) + cols, err := store.AddSchema(req.Context(), string(schema)) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -75,7 +75,7 @@ func (s *storeHandler) AddSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var message patchSchemaRequest err := requestJSON(req, &message) @@ -84,7 +84,7 @@ func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { return } - err = db.PatchSchema(req.Context(), message.Patch, message.Migration, message.SetAsDefaultVersion) + err = store.PatchSchema(req.Context(), message.Patch, message.Migration, message.SetAsDefaultVersion) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -93,7 +93,7 @@ func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var patch string err := requestJSON(req, &patch) @@ -102,7 +102,7 @@ func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request return } - err = db.PatchCollection(req.Context(), patch) + err = store.PatchCollection(req.Context(), patch) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -111,14 +111,14 @@ func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request } func (s *storeHandler) SetActiveSchemaVersion(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) schemaVersionID, err := io.ReadAll(req.Body) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - err = db.SetActiveSchemaVersion(req.Context(), string(schemaVersionID)) + err = store.SetActiveSchemaVersion(req.Context(), string(schemaVersionID)) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -127,7 +127,7 @@ func (s *storeHandler) SetActiveSchemaVersion(rw http.ResponseWriter, req *http. } func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var message addViewRequest err := requestJSON(req, &message) @@ -136,7 +136,7 @@ func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { return } - defs, err := db.AddView(req.Context(), message.Query, message.SDL, message.Transform) + defs, err := store.AddView(req.Context(), message.Query, message.SDL, message.Transform) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -146,7 +146,7 @@ func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var cfg client.LensConfig if err := requestJSON(req, &cfg); err != nil { @@ -154,7 +154,7 @@ func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { return } - err := db.SetMigration(req.Context(), cfg) + err := store.SetMigration(req.Context(), cfg) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -163,7 +163,7 @@ func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) options := client.CollectionFetchOptions{} if req.URL.Query().Has("name") { @@ -186,7 +186,7 @@ func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) options.IncludeInactive = immutable.Some(getInactive) } - cols, err := db.GetCollections(req.Context(), options) + cols, err := store.GetCollections(req.Context(), options) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -199,7 +199,7 @@ func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) } func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) options := client.SchemaFetchOptions{} if req.URL.Query().Has("version_id") { @@ -212,7 +212,7 @@ func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { options.Name = immutable.Some(req.URL.Query().Get("name")) } - schema, err := db.GetSchemas(req.Context(), options) + schema, err := store.GetSchemas(req.Context(), options) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -221,9 +221,9 @@ func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) GetAllIndexes(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) - indexes, err := db.GetAllIndexes(req.Context()) + indexes, err := store.GetAllIndexes(req.Context()) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -296,7 +296,7 @@ func (res *GraphQLResponse) UnmarshalJSON(data []byte) error { } func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) + store := req.Context().Value(dbContextKey).(client.Store) var request GraphQLRequest switch { @@ -313,7 +313,7 @@ func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) { } identity := getIdentityFromAuthHeader(req) - result := db.ExecRequest(req.Context(), identity, request.Query) + result := store.ExecRequest(req.Context(), identity, request.Query) if result.Pub == nil { responseJSON(rw, http.StatusOK, GraphQLResponse{result.GQL.Data, result.GQL.Errors}) From 5962357565b5cbac69e47c5ae726acbdcadf3021 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Thu, 11 Apr 2024 10:19:35 -0700 Subject: [PATCH 12/14] review fixes --- cli/client.go | 4 ++-- cli/collection.go | 4 ++-- cli/schema_add.go | 4 ++-- net/server.go | 4 +--- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/cli/client.go b/cli/client.go index 03ce3cd011..475f83a80a 100644 --- a/cli/client.go +++ b/cli/client.go @@ -28,10 +28,10 @@ Execute queries, add schema types, obtain node info, etc.`, if err := setContextConfig(cmd); err != nil { return err } - if err := setContextDB(cmd); err != nil { + if err := setContextTransaction(cmd, txID); err != nil { return err } - return setContextTransaction(cmd, txID) + return setContextDB(cmd) }, } cmd.PersistentFlags().Uint64Var(&txID, "tx", 0, "Transaction ID") diff --git a/cli/collection.go b/cli/collection.go index 2cdd9b33bd..5b682e5366 100644 --- a/cli/collection.go +++ b/cli/collection.go @@ -37,10 +37,10 @@ func MakeCollectionCommand() *cobra.Command { if err := setContextConfig(cmd); err != nil { return err } - if err := setContextDB(cmd); err != nil { + if err := setContextTransaction(cmd, txID); err != nil { return err } - if err := setContextTransaction(cmd, txID); err != nil { + if err := setContextDB(cmd); err != nil { return err } store := mustGetContextStore(cmd) diff --git a/cli/schema_add.go b/cli/schema_add.go index 5277ddd6bd..e81896322d 100644 --- a/cli/schema_add.go +++ b/cli/schema_add.go @@ -41,7 +41,7 @@ Example: add from stdin: Learn more about the DefraDB GraphQL Schema Language on https://docs.source.network.`, RunE: func(cmd *cobra.Command, args []string) error { - db := mustGetContextDB(cmd) + store := mustGetContextStore(cmd) var schema string switch { @@ -63,7 +63,7 @@ Learn more about the DefraDB GraphQL Schema Language on https://docs.source.netw return fmt.Errorf("schema cannot be empty") } - cols, err := db.AddSchema(cmd.Context(), schema) + cols, err := store.AddSchema(cmd.Context(), schema) if err != nil { return err } diff --git a/net/server.go b/net/server.go index 535bb16315..73496559cf 100644 --- a/net/server.go +++ b/net/server.go @@ -32,7 +32,6 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/datastore/badger/v4" "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/errors" @@ -290,7 +289,7 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL wg.Wait() bp.mergeBlocks(ctx) - err = s.syncIndexedDocs(ctx, col, docID, txn) + err = s.syncIndexedDocs(ctx, col, docID) if err != nil { return nil, err } @@ -353,7 +352,6 @@ func (s *server) syncIndexedDocs( ctx context.Context, col client.Collection, docID client.DocID, - txn datastore.Txn, ) error { // remove transaction from old context oldCtx := db.SetContextTxn(ctx, nil) From 743ddcf9794578c8440c1436cd80098f5d6e1732 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Thu, 11 Apr 2024 10:33:20 -0700 Subject: [PATCH 13/14] fix ensureContextTxn --- db/context.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/db/context.go b/db/context.go index 49df02fca3..4917ec11fb 100644 --- a/db/context.go +++ b/db/context.go @@ -44,7 +44,7 @@ type transactionDB interface { func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (context.Context, datastore.Txn, error) { txn, ok := TryGetContextTxn(ctx) if ok { - return SetContextTxn(ctx, &explicitTxn{txn}), txn, nil + return SetContextTxn(ctx, &explicitTxn{txn}), &explicitTxn{txn}, nil } txn, err := db.NewTxn(ctx, readOnly) if err != nil { From 83e0d7afcc95e1cb10dae2bc45c30ebba10c2879 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Fri, 12 Apr 2024 09:37:28 -0700 Subject: [PATCH 14/14] update SetContextTxn docs --- db/context.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/db/context.go b/db/context.go index 4917ec11fb..d39472ea5a 100644 --- a/db/context.go +++ b/db/context.go @@ -61,6 +61,8 @@ func TryGetContextTxn(ctx context.Context) (datastore.Txn, bool) { } // SetContextTxn returns a new context with the txn value set. +// +// This will overwrite any previously set transaction value. func SetContextTxn(ctx context.Context, txn datastore.Txn) context.Context { return context.WithValue(ctx, txnContextKey{}, txn) }