Skip to content

Commit

Permalink
feat(firestore): add GetCommitTime TransactionOption
Browse files Browse the repository at this point in the history
  • Loading branch information
galenwarren committed Oct 31, 2022
1 parent 5d7d4ec commit b850276
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
33 changes: 30 additions & 3 deletions firestore/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package firestore
import (
"context"
"errors"
"time"

"cloud.google.com/go/internal/trace"
gax "github.com/googleapis/gax-go/v2"
Expand All @@ -40,6 +41,7 @@ type Transaction struct {
// A TransactionOption is an option passed to Client.Transaction.
type TransactionOption interface {
config(t *Transaction)
handleCommitResponse(r *pb.CommitResponse)
}

// MaxAttempts is a TransactionOption that configures the maximum number of times to
Expand All @@ -48,7 +50,8 @@ func MaxAttempts(n int) maxAttempts { return maxAttempts(n) }

type maxAttempts int

func (m maxAttempts) config(t *Transaction) { t.maxAttempts = int(m) }
func (m maxAttempts) config(t *Transaction) { t.maxAttempts = int(m) }
func (m maxAttempts) handleCommitResponse(r *pb.CommitResponse) {}

// DefaultTransactionMaxAttempts is the default number of times to attempt a transaction.
const DefaultTransactionMaxAttempts = 5
Expand All @@ -59,7 +62,23 @@ var ReadOnly = ro{}

type ro struct{}

func (ro) config(t *Transaction) { t.readOnly = true }
func (ro) config(t *Transaction) { t.readOnly = true }
func (ro) handleCommitResponse(r *pb.CommitResponse) {}

// GetCommitTime is a TransactionOption that allows the caller to indicate where the commit
// time for the transaction should be stored, upon successful commit.
func GetCommitTime(t *time.Time) commitTime {
return commitTime{Time: t}
}

type commitTime struct {
*time.Time
}

func (c commitTime) config(t *Transaction) {}
func (c commitTime) handleCommitResponse(r *pb.CommitResponse) {
*c.Time = r.CommitTime.AsTime()
}

var (
// Defined here for testing.
Expand Down Expand Up @@ -114,6 +133,7 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr
}
}
var backoff gax.Backoff
var commitResponse *pb.CommitResponse
// TODO(jba): use other than the standard backoff parameters?
// TODO(jba): get backoff time from gRPC trailer metadata? See
// extractRetryDelay in https://code.googlesource.com/gocloud/+/master/spanner/retry.go.
Expand Down Expand Up @@ -141,13 +161,20 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr
return err
}
t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/firestore.Client.Commit")
_, err = t.c.c.Commit(t.ctx, &pb.CommitRequest{
commitResponse, err = t.c.c.Commit(t.ctx, &pb.CommitRequest{
Database: t.c.path(),
Writes: t.writes,
Transaction: t.id,
})
trace.EndSpan(t.ctx, err)

// on success, handle the commit response
if err == nil {
for _, opt := range opts {
opt.handleCommitResponse(commitResponse)
}
}

// If a read-write transaction returns Aborted, retry.
// On success or other failures, return here.
if t.readOnly || status.Code(err) != codes.Aborted {
Expand Down
8 changes: 7 additions & 1 deletion firestore/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func TestRunTransaction(t *testing.T) {
},
&pb.CommitResponse{CommitTime: aTimestamp3},
)
var commitTime time.Time
err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
docref := c.Collection("C").Doc("a")
doc, err := tx.Get(docref)
Expand All @@ -95,11 +96,16 @@ func TestRunTransaction(t *testing.T) {
return err
}
return tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}})
})
}, GetCommitTime(&commitTime))
if err != nil {
t.Fatal(err)
}

// validate commit time
if commitTime != aTimestamp3.AsTime() {
t.Fatalf("commit time %v should equal %v", commitTime, aTimestamp3)
}

// Query
srv.reset()
srv.addRPC(beginReq, beginRes)
Expand Down

0 comments on commit b850276

Please sign in to comment.