From 7501c6a8a6c867c6c99794c9dc4a8f0f12a42e3f Mon Sep 17 00:00:00 2001 From: Andrew Sisley Date: Wed, 6 Mar 2024 17:06:45 -0500 Subject: [PATCH] Make returned collections respect transaction --- client/db.go | 6 ++++++ db/txn_db.go | 18 ++++++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/client/db.go b/client/db.go index 7fe5470ecb..7b0cc8060f 100644 --- a/client/db.go +++ b/client/db.go @@ -187,6 +187,9 @@ type Store interface { // GetCollectionByName attempts to retrieve a collection matching the given name. // // If no matching collection is found an error will be returned. + // + // If a transaction was explicitly provided to this [Store] via [DB].[WithTxn], any function calls + // made via the returned [Collection] will respect that transaction. GetCollectionByName(context.Context, CollectionName) (Collection, error) // GetCollections returns all collections and their descriptions matching the given options @@ -194,6 +197,9 @@ type Store interface { // // Inactive collections are not returned by default unless a specific schema version ID // is provided. + // + // If a transaction was explicitly provided to this [Store] via [DB].[WithTxn], any function calls + // made via the returned [Collection]s will respect that transaction. GetCollections(context.Context, CollectionFetchOptions) ([]Collection, error) // GetSchemaByVersionID returns the schema description for the schema version of the diff --git a/db/txn_db.go b/db/txn_db.go index 96afcda0f6..455b062b93 100644 --- a/db/txn_db.go +++ b/db/txn_db.go @@ -79,7 +79,12 @@ func (db *implicitTxnDB) GetCollectionByName(ctx context.Context, name string) ( // GetCollectionByName returns an existing collection within the database. func (db *explicitTxnDB) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { - return db.getCollectionByName(ctx, db.txn, name) + col, err := db.getCollectionByName(ctx, db.txn, name) + if err != nil { + return nil, err + } + + return col.WithTxn(db.txn), nil } // GetCollections gets all the currently defined collections. @@ -101,7 +106,16 @@ func (db *explicitTxnDB) GetCollections( ctx context.Context, options client.CollectionFetchOptions, ) ([]client.Collection, error) { - return db.getCollections(ctx, db.txn, options) + cols, err := db.getCollections(ctx, db.txn, options) + if err != nil { + return nil, err + } + + for i := range cols { + cols[i] = cols[i].WithTxn(db.txn) + } + + return cols, nil } // GetSchemaByVersionID returns the schema description for the schema version of the