diff --git a/.release-please-manifest-individual.json b/.release-please-manifest-individual.json index d526a3e881c9..1776a5c7838d 100644 --- a/.release-please-manifest-individual.json +++ b/.release-please-manifest-individual.json @@ -2,7 +2,7 @@ "auth": "0.0.0", "bigquery": "1.54.0", "bigtable": "1.19.0", - "datastore": "1.13.0", + "datastore": "1.14.0", "errorreporting": "0.3.0", "firestore": "1.12.0", "logging": "1.8.1", diff --git a/datastore/CHANGES.md b/datastore/CHANGES.md index 49784aec0982..2a105c236727 100644 --- a/datastore/CHANGES.md +++ b/datastore/CHANGES.md @@ -1,5 +1,24 @@ # Changes +## [1.14.0](https://github.com/googleapis/google-cloud-go/compare/datastore/v1.13.0...datastore/v1.14.0) (2023-08-22) + + +### Features + +* **datastore:** SUM and AVG aggregations ([#8307](https://github.com/googleapis/google-cloud-go/issues/8307)) ([a9fff18](https://github.com/googleapis/google-cloud-go/commit/a9fff181e4ea8281ad907e7b2e0d90e70013a4de)) +* **datastore:** Support aggregation query in transaction ([#8439](https://github.com/googleapis/google-cloud-go/issues/8439)) ([37681ff](https://github.com/googleapis/google-cloud-go/commit/37681ff291c0ccf4c908be55b97639c04b9dec48)) + + +### Bug Fixes + +* **datastore:** Correcting string representation of Key ([#8363](https://github.com/googleapis/google-cloud-go/issues/8363)) ([4cb1211](https://github.com/googleapis/google-cloud-go/commit/4cb12110ba229dfbe21568eb06c243bdffd1fee7)) +* **datastore:** Fix NoIndex for array property ([#7674](https://github.com/googleapis/google-cloud-go/issues/7674)) ([01951e6](https://github.com/googleapis/google-cloud-go/commit/01951e64f3955dc337172a30d78e2f92f65becb2)) + + +### Documentation + +* **datastore/admin:** Specify limit for `properties` in `Index` message in Datastore Admin API ([b890425](https://github.com/googleapis/google-cloud-go/commit/b8904253a0f8424ea4548469e5feef321bd7396a)) + ## [1.13.0](https://github.com/googleapis/google-cloud-go/compare/datastore/v1.12.1...datastore/v1.13.0) (2023-07-26) diff --git a/datastore/integration_test.go b/datastore/integration_test.go index 6a01926449cc..3dfd082c67f0 100644 --- a/datastore/integration_test.go +++ b/datastore/integration_test.go @@ -722,7 +722,15 @@ func TestIntegration_AggregationQueries(t *testing.T) { for i := range keys { keys[i] = IncompleteKey("SQChild", parent) } - keys, err := client.PutMulti(ctx, keys, children) + + // Create transaction with read before creating entities + readTime := time.Now() + txBeforeCreate, err := client.NewTransaction(ctx, []TransactionOption{ReadOnly, WithReadTime(readTime)}...) + if err != nil { + t.Fatalf("client.NewTransaction: %v", err) + } + + keys, err = client.PutMulti(ctx, keys, children) if err != nil { t.Fatalf("client.PutMulti: %v", err) } @@ -733,13 +741,22 @@ func TestIntegration_AggregationQueries(t *testing.T) { } }() + // Create transaction with read after creating entities + readTime = time.Now() + txAfterCreate, err := client.NewTransaction(ctx, []TransactionOption{ReadOnly, WithReadTime(readTime)}...) + if err != nil { + t.Fatalf("client.NewTransaction: %v", err) + } + testCases := []struct { - desc string - aggQuery *AggregationQuery - wantFailure bool - wantErrMsg string - wantAggResult AggregationResult + desc string + aggQuery *AggregationQuery + transactionOpts []TransactionOption + wantFailure bool + wantErrMsg string + wantAggResult AggregationResult }{ + { desc: "Count Failure - Missing index", aggQuery: NewQuery("SQChild").Ancestor(parent).Filter("T>=", now). @@ -757,6 +774,34 @@ func TestIntegration_AggregationQueries(t *testing.T) { "count": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: 5}}, }, }, + { + desc: "Aggregations in transaction before creating entities", + aggQuery: NewQuery("SQChild").Ancestor(parent).Filter("T=", now). + Transaction(txBeforeCreate). + NewAggregationQuery(). + WithCount("count"). + WithSum("I", "sum"). + WithAvg("I", "avg"), + wantAggResult: map[string]interface{}{ + "count": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: 0}}, + "sum": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: 0}}, + "avg": &pb.Value{ValueType: &pb.Value_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, + }, + }, + { + desc: "Aggregations in transaction after creating entities", + aggQuery: NewQuery("SQChild").Ancestor(parent).Filter("T=", now). + Transaction(txAfterCreate). + NewAggregationQuery(). + WithCount("count"). + WithSum("I", "sum"). + WithAvg("I", "avg"), + wantAggResult: map[string]interface{}{ + "count": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: 8}}, + "sum": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: 28}}, + "avg": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: 3.5}}, + }, + }, { desc: "Multiple aggregations", aggQuery: NewQuery("SQChild").Ancestor(parent).Filter("T=", now). diff --git a/datastore/internal/version.go b/datastore/internal/version.go index efedadbea253..c9ba91825c29 100644 --- a/datastore/internal/version.go +++ b/datastore/internal/version.go @@ -15,4 +15,4 @@ package internal // Version is the current tagged release of the library. -const Version = "1.13.0" +const Version = "1.14.0" diff --git a/datastore/query.go b/datastore/query.go index c9f9163c310d..e833ff8d8ff4 100644 --- a/datastore/query.go +++ b/datastore/query.go @@ -1026,7 +1026,6 @@ func DecodeCursor(s string) (Cursor, error) { // NewAggregationQuery returns an AggregationQuery with this query as its // base query. func (q *Query) NewAggregationQuery() *AggregationQuery { - q.eventual = true return &AggregationQuery{ query: q, aggregationQueries: make([]*pb.AggregationQuery_Aggregation, 0), diff --git a/datastore/query_test.go b/datastore/query_test.go index 22d13e1776ab..40fc8f327795 100644 --- a/datastore/query_test.go +++ b/datastore/query_test.go @@ -126,11 +126,6 @@ func fakeRunAggregationQuery(req *pb.RunAggregationQueryRequest) (*pb.RunAggrega }, }, }, - ReadOptions: &pb.ReadOptions{ - ConsistencyType: &pb.ReadOptions_ReadConsistency_{ - ReadConsistency: pb.ReadOptions_EVENTUAL, - }, - }, } if !proto.Equal(req, expectedIn) { return nil, fmt.Errorf("unsupported argument: got %v want %v", req, expectedIn) diff --git a/spanner/client.go b/spanner/client.go index a95b02431ad4..bfe00c2dcec0 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -563,6 +563,10 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea } } if t.shouldExplicitBegin(attempt) { + // Make sure we set the current session handle before calling BeginTransaction. + // Note that the t.begin(ctx) call could change the session that is being used by the transaction, as the + // BeginTransaction RPC invocation will be retried on a new session if it returns SessionNotFound. + t.txReadOnly.sh = sh if err = t.begin(ctx); err != nil { trace.TracePrintf(ctx, nil, "Error while BeginTransaction during retrying a ReadWrite transaction: %v", ToSpannerError(err)) return ToSpannerError(err) @@ -571,9 +575,9 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea t = &ReadWriteTransaction{ txReadyOrClosed: make(chan struct{}), } + t.txReadOnly.sh = sh } attempt++ - t.txReadOnly.sh = sh t.txReadOnly.sp = c.idleSessions t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo diff --git a/spanner/client_test.go b/spanner/client_test.go index 6f963102805e..1141a6ebb5bc 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -727,6 +727,202 @@ func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction_WithMaxOne } } +func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement_AndThenSessionNotFoundForBeginTransaction(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + server.TestSpanner.PutExecutionTime( + MethodBeginTransaction, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_AbortedForFirstStatement_AndThenSessionNotFoundForBeginTransaction(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}, + ) + server.TestSpanner.PutExecutionTime( + MethodBeginTransaction, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement_DoesNotLeakSession(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + MaxOpened: 1, + }, + }) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BatchCreateSessionsRequest{}, // We need to create more sessions, as the one used first was destroyed. + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + func TestClient_ReadOnlyTransaction_QueryOptions(t *testing.T) { for _, tt := range queryOptionsTestCases() { t.Run(tt.name, func(t *testing.T) { diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 7aa7d622df66..5d016ede945d 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -42,12 +42,14 @@ import ( adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" instance "cloud.google.com/go/spanner/admin/instance/apiv1" "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + v1 "cloud.google.com/go/spanner/apiv1" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "cloud.google.com/go/spanner/internal" "go.opencensus.io/stats/view" "go.opencensus.io/tag" "google.golang.org/api/iterator" "google.golang.org/api/option" + "google.golang.org/api/option/internaloption" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" @@ -846,6 +848,55 @@ func TestIntegration_SingleUse_WithQueryOptions(t *testing.T) { } } +func TestIntegration_TransactionWasStartedInDifferentSession(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + // Set up testing environment. + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) + defer cleanup() + + attempts := 0 + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, transaction *ReadWriteTransaction) error { + attempts++ + if attempts == 1 { + deleteTestSession(ctx, t, transaction.sh.getID()) + } + if _, err := readAll(transaction.Query(ctx, NewStatement("select * from singers"))); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if g, w := attempts, 2; g != w { + t.Fatalf("attempts mismatch\nGot: %v\nWant: %v", g, w) + } +} + +func deleteTestSession(ctx context.Context, t *testing.T, sessionName string) { + var opts []option.ClientOption + if emulatorAddr := os.Getenv("SPANNER_EMULATOR_HOST"); emulatorAddr != "" { + emulatorOpts := []option.ClientOption{ + option.WithEndpoint(emulatorAddr), + option.WithGRPCDialOption(grpc.WithInsecure()), + option.WithoutAuthentication(), + internaloption.SkipDialSettingsValidation(), + } + opts = append(emulatorOpts, opts...) + } + gapic, err := v1.NewClient(ctx, opts...) + if err != nil { + t.Fatalf("could not create gapic client: %v", err) + } + defer gapic.Close() + if err := gapic.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: sessionName}); err != nil { + t.Fatal(err) + } +} + func TestIntegration_SingleUse_ReadingWithLimit(t *testing.T) { t.Parallel() diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 922ae6ad1328..b1adf02f2182 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -581,13 +581,17 @@ func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, option return res } -func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { +func (s *inMemSpannerServer) getTransactionByID(session *spannerpb.Session, id []byte) (*spannerpb.Transaction, error) { s.mu.Lock() defer s.mu.Unlock() tx, ok := s.transactions[string(id)] if !ok { return nil, gstatus.Error(codes.NotFound, "Transaction not found") } + if !strings.HasPrefix(string(id), session.Name) { + return nil, gstatus.Error(codes.InvalidArgument, "Transaction was started in a different session.") + } + aborted, ok := s.abortedTransactions[string(id)] if ok && aborted { return nil, newAbortedErrorWithMinimalRetryDelay() @@ -813,7 +817,7 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec var id []byte s.updateSessionLastUseTime(session.Name) if id = s.getTransactionID(session, req.Transaction); id != nil { - _, err = s.getTransactionByID(id) + _, err = s.getTransactionByID(session, id) if err != nil { return nil, err } @@ -860,7 +864,7 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques s.updateSessionLastUseTime(session.Name) var id []byte if id = s.getTransactionID(session, req.Transaction); id != nil { - _, err = s.getTransactionByID(id) + _, err = s.getTransactionByID(session, id) if err != nil { return err } @@ -932,7 +936,7 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb s.updateSessionLastUseTime(session.Name) var id []byte if id = s.getTransactionID(session, req.Transaction); id != nil { - _, err = s.getTransactionByID(id) + _, err = s.getTransactionByID(session, id) if err != nil { return nil, err } @@ -1031,7 +1035,7 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe if req.GetSingleUseTransaction() != nil { tx = s.beginTransaction(session, req.GetSingleUseTransaction()) } else if req.GetTransactionId() != nil { - tx, err = s.getTransactionByID(req.GetTransactionId()) + tx, err = s.getTransactionByID(session, req.GetTransactionId()) if err != nil { return nil, err } @@ -1064,7 +1068,7 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba return nil, err } s.updateSessionLastUseTime(session.Name) - tx, err := s.getTransactionByID(req.TransactionId) + tx, err := s.getTransactionByID(session, req.TransactionId) if err != nil { return nil, err } @@ -1091,7 +1095,7 @@ func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb. var tx *spannerpb.Transaction s.updateSessionLastUseTime(session.Name) if id = s.getTransactionID(session, req.Transaction); id != nil { - tx, err = s.getTransactionByID(id) + tx, err = s.getTransactionByID(session, id) if err != nil { return nil, err } diff --git a/spanner/spannertest/README.md b/spanner/spannertest/README.md index e737bd6810f2..0714328af0a1 100644 --- a/spanner/spannertest/README.md +++ b/spanner/spannertest/README.md @@ -33,7 +33,6 @@ by ascending esotericism: - case insensitivity of table and column names and query aliases - transaction simulation - FOREIGN KEY and CHECK constraints -- INSERT DML statements - set operations (UNION, INTERSECT, EXCEPT) - STRUCT types - partition support diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go index 6babaf93e04a..31cebd9cebbd 100644 --- a/spanner/spannertest/db.go +++ b/spanner/spannertest/db.go @@ -1262,6 +1262,64 @@ func (d *database) Execute(stmt spansql.DMLStmt, params queryParams) (int, error } } return n, nil + case *spansql.Insert: + t, err := d.table(stmt.Table) + if err != nil { + return 0, err + } + + t.mu.Lock() + defer t.mu.Unlock() + + ec := evalContext{ + cols: t.cols, + params: params, + } + + values := make(row, len(t.cols)) + input := stmt.Input.(spansql.Values) + if len(input) > 0 { + for i := 0; i < len(input); i++ { + val := input[i] + for k, v := range val { + switch v := v.(type) { + // if spanner.Statement.Params is not empty, scratch row with ec.parameters + case spansql.Param: + values[k] = ec.params[t.cols[k].Name.SQL()].Value + // if nil is included in parameters, pass nil + case spansql.ID: + cutset := `""` + str := strings.Trim(v.SQL(), cutset) + if str == "nil" { + values[k] = nil + } else { + expr, err := ec.evalExpr(v) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid parameter format") + } + values[k] = expr + } + // if parameter is embedded in SQL as string, not in statement.Params, analyze parameters + default: + expr, err := ec.evalExpr(v) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid parameter format") + } + values[k] = expr + } + } + } + } + + // pk check if the primary key already exists + pk := values[:t.pkCols] + rowNum, found := t.rowForPK(pk) + if found { + return 0, status.Errorf(codes.AlreadyExists, "row already in table") + } + t.insertRow(rowNum, values) + + return 1, nil } } diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index cfe75c171e97..5dc3c0958aa7 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -708,15 +708,83 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{1, "abar"}), spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{2, nil}), spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{3, "bbar"}), - - spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{0, "joe", nil}), - spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{1, "doe", "joan"}), - spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{2, "wong", "wong"}), }) if err != nil { t.Fatalf("Inserting sample data: %v", err) } + // Perform INSERT DML; the results are checked later on. + n = 0 + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + for _, u := range []string{ + `INSERT INTO Updateable (id, first, last) VALUES (0, "joe", nil)`, + `INSERT INTO Updateable (id, first, last) VALUES (1, "doe", "joan")`, + `INSERT INTO Updateable (id, first, last) VALUES (2, "wong", "wong")`, + } { + nr, err := tx.Update(ctx, spanner.NewStatement(u)) + if err != nil { + return err + } + n += nr + } + return nil + }) + if err != nil { + t.Fatalf("Inserting with DML: %v", err) + } + if n != 3 { + t.Errorf("Inserting with DML affected %d rows, want 3", n) + } + + // Perform INSERT DML with statement.Params; the results are checked later on. + n = 0 + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + stmt := spanner.Statement{ + SQL: "INSERT INTO Updateable (id, first, last) VALUES (@id, @first, @last)", + Params: map[string]interface{}{ + "id": 3, + "first": "tom", + "last": "jerry", + }, + } + nr, err := tx.Update(ctx, stmt) + if err != nil { + return err + } + n += nr + return nil + }) + if err != nil { + t.Fatalf("Inserting with DML: %v", err) + } + if n != 1 { + t.Errorf("Inserting with DML affected %d rows, want 1", n) + } + + // Perform INSERT DML with statement.Params and inline parameter; the results are checked later on. + n = 0 + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + stmt := spanner.Statement{ + SQL: `INSERT INTO Updateable (id, first, last) VALUES (@id, "jim", @last)`, + Params: map[string]interface{}{ + "id": 4, + "last": nil, + }, + } + nr, err := tx.Update(ctx, stmt) + if err != nil { + return err + } + n += nr + return nil + }) + if err != nil { + t.Fatalf("Inserting with DML: %v", err) + } + if n != 1 { + t.Errorf("Inserting with DML affected %d rows, want 1", n) + } + // Perform UPDATE DML; the results are checked later on. n = 0 _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { @@ -724,7 +792,7 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { `UPDATE Updateable SET last = "bloggs" WHERE id = 0`, `UPDATE Updateable SET first = last, last = first WHERE id = 1`, `UPDATE Updateable SET last = DEFAULT WHERE id = 2`, - `UPDATE Updateable SET first = "noname" WHERE id = 3`, // no id=3 + `UPDATE Updateable SET first = "noname" WHERE id = 5`, // no id=5 } { nr, err := tx.Update(ctx, spanner.NewStatement(u)) if err != nil { @@ -1156,6 +1224,8 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { {int64(0), "joe", "bloggs"}, {int64(1), "joan", "doe"}, {int64(2), "wong", nil}, + {int64(3), "tom", "jerry"}, + {int64(4), "jim", nil}, }, }, // Regression test for aggregating no rows; it used to return an empty row. diff --git a/spanner/transaction.go b/spanner/transaction.go index 85de18327f89..81d7e036521c 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -1380,15 +1380,13 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { }() // Retry the BeginTransaction call if a 'Session not found' is returned. for { - if sh == nil || sh.getID() == "" || sh.getClient() == nil { + tx, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts) + if isSessionNotFoundError(err) { + sh.destroy() sh, err = t.sp.take(ctx) if err != nil { return err } - } - tx, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts) - if isSessionNotFoundError(err) { - sh.destroy() continue } else { err = ToSpannerError(err) @@ -1399,7 +1397,7 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { t.mu.Lock() t.tx = tx t.sh = sh - // State transite to txActive. + // Transition state to txActive. t.state = txActive t.mu.Unlock() }