diff --git a/spanner/client.go b/spanner/client.go index db9c8b1e669c..8ac786c4b769 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -334,18 +334,20 @@ type ClientConfig struct { } type openTelemetryConfig struct { - meterProvider metric.MeterProvider - attributeMap []attribute.KeyValue - otMetricRegistration metric.Registration - openSessionCount metric.Int64ObservableGauge - maxAllowedSessionsCount metric.Int64ObservableGauge - sessionsCount metric.Int64ObservableGauge - maxInUseSessionsCount metric.Int64ObservableGauge - getSessionTimeoutsCount metric.Int64Counter - acquiredSessionsCount metric.Int64Counter - releasedSessionsCount metric.Int64Counter - gfeLatency metric.Int64Histogram - gfeHeaderMissingCount metric.Int64Counter + meterProvider metric.MeterProvider + attributeMap []attribute.KeyValue + attributeMapWithMultiplexed []attribute.KeyValue + attributeMapWithoutMultiplexed []attribute.KeyValue + otMetricRegistration metric.Registration + openSessionCount metric.Int64ObservableGauge + maxAllowedSessionsCount metric.Int64ObservableGauge + sessionsCount metric.Int64ObservableGauge + maxInUseSessionsCount metric.Int64ObservableGauge + getSessionTimeoutsCount metric.Int64Counter + acquiredSessionsCount metric.Int64Counter + releasedSessionsCount metric.Int64Counter + gfeLatency metric.Int64Histogram + gfeHeaderMissingCount metric.Int64Counter } func contextWithOutgoingMetadata(ctx context.Context, md metadata.MD, disableRouteToLeader bool) context.Context { diff --git a/spanner/client_test.go b/spanner/client_test.go index 6d775b917966..818fcef89456 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -23,6 +23,7 @@ import ( "math/big" "net" "os" + "strconv" "strings" "sync" "sync/atomic" @@ -37,6 +38,7 @@ import ( "github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp/multiendpoint" "github.com/google/go-cmp/cmp/cmpopts" "github.com/googleapis/gax-go/v2" + "golang.org/x/sync/errgroup" "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/grpc/codes" @@ -63,6 +65,85 @@ func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config Client } func setupMockedTestServerWithConfigAndGCPMultiendpointPool(t *testing.T, config ClientConfig, clientOptions []option.ClientOption, poolCfg *grpc_gcp.ChannelPoolConfig) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { + grpcHeaderChecker := &itestutil.HeadersEnforcer{ + OnFailure: t.Fatalf, + Checkers: []*itestutil.HeaderChecker{ + { + Key: "x-goog-api-client", + ValuesValidator: func(token ...string) error { + if len(token) != 1 { + return status.Errorf(codes.Internal, "unexpected number of api client token headers: %v", len(token)) + } + if !strings.HasPrefix(token[0], "gl-go/") { + return status.Errorf(codes.Internal, "unexpected api client token: %v", token[0]) + } + if !strings.Contains(token[0], "gccl/") { + return status.Errorf(codes.Internal, "unexpected api client token: %v", token[0]) + } + return nil + }, + }, + }, + } + if config.Compression == gzip.Name { + grpcHeaderChecker.Checkers = append(grpcHeaderChecker.Checkers, &itestutil.HeaderChecker{ + Key: "x-response-encoding", + ValuesValidator: func(token ...string) error { + if len(token) != 1 { + return status.Errorf(codes.Internal, "unexpected number of compression headers: %v", len(token)) + } + if token[0] != gzip.Name { + return status.Errorf(codes.Internal, "unexpected compression: %v", token[0]) + } + return nil + }, + }) + } + clientOptions = append(clientOptions, grpcHeaderChecker.CallOptions()...) + server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) + opts = append(opts, clientOptions...) + ctx := context.Background() + formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var err error + if useGRPCgcp { + gmeCfg := &grpcgcp.GCPMultiEndpointOptions{ + GRPCgcpConfig: &grpc_gcp.ApiConfig{ + ChannelPool: poolCfg, + }, + MultiEndpoints: map[string]*multiendpoint.MultiEndpointOptions{ + "default": { + Endpoints: []string{server.ServerAddress}, + }, + }, + Default: "default", + } + client, _, err = NewMultiEndpointClientWithConfig(ctx, formattedDatabase, config, gmeCfg, opts...) + } else { + client, err = NewClientWithConfig(ctx, formattedDatabase, config, opts...) + } + if err != nil { + t.Fatal(err) + } + if isMultiplexEnabled { + waitFor(t, func() error { + client.idleSessions.mu.Lock() + defer client.idleSessions.mu.Unlock() + if client.idleSessions.multiplexedSession == nil { + return errInvalidSessionPool + } + return nil + }) + } + return server, client, func() { + client.Close() + serverTeardown() + } +} + +func setupMockedTestServerWithoutWaitingForMultiplexedSessionInit(t *testing.T) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { + config := ClientConfig{} + clientOptions := []option.ClientOption{} + var poolCfg *grpc_gcp.ChannelPoolConfig grpcHeaderChecker := &itestutil.HeadersEnforcer{ OnFailure: t.Fatalf, Checkers: []*itestutil.HeaderChecker{ @@ -392,6 +473,178 @@ func TestClient_MultiEndpoint(t *testing.T) { } } +func TestClient_MultiplexedSession(t *testing.T) { + var tests = []struct { + name string + test func(client *Client) error + validate func(server InMemSpannerServer) + wantErr error + }{ + { + name: "Given if multiplexed session is enabled, When executing single use R/O transactions, should use multiplexed session", + test: func(client *Client) error { + ctx := context.Background() + // Test the single use read-only transaction + _, err := client.Single().ReadRow(ctx, "Albums", Key{"foo"}, []string{"SingerId", "AlbumId", "AlbumTitle"}) + return err + }, + validate: func(server InMemSpannerServer) { + // Validate the multiplexed session is used + expectedSessionCount := uint(1) + if !isMultiplexEnabled { + expectedSessionCount = uint(25) // BatchCreateSession request from regular session pool + } + if !testEqual(expectedSessionCount, server.TotalSessionsCreated()) { + t.Errorf("TestClient_MultiplexedSession expected session creation with multiplexed=%s should be=%v, got: %v", strconv.FormatBool(isMultiplexEnabled), expectedSessionCount, server.TotalSessionsCreated()) + } + reqs := drainRequestsFromServer(server) + for _, s := range reqs { + switch s.(type) { + case *sppb.ReadRequest: + req, _ := s.(*sppb.ReadRequest) + // Validate the session is multiplexed + if !testEqual(isMultiplexEnabled, strings.Contains(req.Session, "multiplexed")) { + t.Errorf("TestClient_MultiplexedSession expected multiplexed session to be used, got: %v", req.Session) + } + + } + } + }, + }, + { + name: "Given if multiplexed session is enabled, When executing multi use R/O transactions, should use multiplexed session", + test: func(client *Client) error { + ctx := context.Background() + // Test the multi use read-only transaction + roTxn := client.ReadOnlyTransaction() + defer roTxn.Close() + iter := roTxn.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + if err := iter.Do(func(row *Row) error { + return nil + }); err != nil { + return err + } + iter = roTxn.Read(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}) + return iter.Do(func(row *Row) error { + return nil + }) + }, + validate: func(server InMemSpannerServer) { + // Validate the multiplexed session is used + expectedSessionCount := uint(1) + if !isMultiplexEnabled { + expectedSessionCount = uint(25) // BatchCreateSession request from regular session pool + } + if !testEqual(expectedSessionCount, server.TotalSessionsCreated()) { + t.Errorf("TestClient_MultiplexedSession expected session creation with multiplexed=%s should be=%v, got: %v", strconv.FormatBool(isMultiplexEnabled), expectedSessionCount, server.TotalSessionsCreated()) + } + reqs := drainRequestsFromServer(server) + for _, s := range reqs { + switch s.(type) { + case *sppb.ReadRequest: + req, _ := s.(*sppb.ReadRequest) + // Validate the session is multiplexed + if !testEqual(isMultiplexEnabled, strings.Contains(req.Session, "multiplexed")) { + t.Errorf("TestClient_MultiplexedSession expected multiplexed session to be used, got: %v", req.Session) + } + case *sppb.ExecuteSqlRequest: + req, _ := s.(*sppb.ExecuteSqlRequest) + // Validate the session is multiplexed + if !testEqual(isMultiplexEnabled, strings.Contains(req.Session, "multiplexed")) { + t.Errorf("TestClient_MultiplexedSession expected multiplexed session to be used, got: %v", req.Session) + } + } + + } + }, + }, + { + name: "Given if multiplexed session is enabled, When executing R/W transactions, should always use regular session", + test: func(client *Client) error { + ctx := context.Background() + // Test the read-write transaction + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *ReadWriteTransaction) error { + iter := txn.Read(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}) + return iter.Do(func(r *Row) error { + return nil + }) + }) + return err + }, + validate: func(server InMemSpannerServer) { + // Validate the regular session is used, toatl session created should be 25 + expectedSessionCount := uint(26) + if !isMultiplexEnabled { + expectedSessionCount = uint(25) // BatchCreateSession request from regular session pool + } + if !testEqual(expectedSessionCount, server.TotalSessionsCreated()) { + t.Errorf("TestClient_MultiplexedSession expected session creation with multiplexed=%s should be=%v, got: %v", strconv.FormatBool(isMultiplexEnabled), expectedSessionCount, server.TotalSessionsCreated()) + } + reqs := drainRequestsFromServer(server) + for _, s := range reqs { + switch s.(type) { + case *sppb.ReadRequest: + req, _ := s.(*sppb.ReadRequest) + // Validate the session is not multiplexed + if !testEqual(false, strings.Contains(req.Session, "multiplexed")) { + t.Errorf("TestClient_MultiplexedSession expected multiplexed session to be used, got: %v", req.Session) + } + } + } + }, + }, + { + name: "Given if multiplexed session is enabled, Only one multiplex session should be created for multiple read only transactions", + test: func(client *Client) error { + // Test the parallel single use read-only transaction + g := new(errgroup.Group) + for i := 0; i < 25; i++ { + g.Go(func() error { + ctx := context.Background() + // Test the single use read-only transaction + _, err := client.Single().ReadRow(ctx, "Albums", Key{"foo"}, []string{"SingerId", "AlbumId", "AlbumTitle"}) + return err + }) + } + return g.Wait() + }, + validate: func(server InMemSpannerServer) { + // Validate the multiplexed session is used + expectedSessionCount := uint(1) + if !isMultiplexEnabled { + expectedSessionCount = uint(25) // BatchCreateSession request from regular session pool + } + if !testEqual(expectedSessionCount, server.TotalSessionsCreated()) { + t.Errorf("TestClient_MultiplexedSession expected session creation with multiplexed=%s should be=%v, got: %v", strconv.FormatBool(isMultiplexEnabled), expectedSessionCount, server.TotalSessionsCreated()) + } + reqs := drainRequestsFromServer(server) + for _, s := range reqs { + switch s.(type) { + case *sppb.ReadRequest: + req, _ := s.(*sppb.ReadRequest) + // Verify that a multiplexed session is used when that is enabled. + if !testEqual(isMultiplexEnabled, strings.Contains(req.Session, "multiplexed")) { + t.Errorf("TestClient_MultiplexedSession expected multiplexed session to be used, got: %v", req.Session) + } + } + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server, client, teardown := setupMockedTestServer(t) + defer teardown() + gotErr := tt.test(client) + if !testEqual(gotErr, tt.wantErr) { + t.Errorf("TestClient_MultiplexedSession error=%v, wantErr: %v", gotErr, tt.wantErr) + } else { + tt.validate(server.TestSpanner) + } + }) + } +} + func TestClient_Single(t *testing.T) { t.Parallel() err := testSingleQuery(t, nil) @@ -532,7 +785,11 @@ func TestClient_Single_WhenInactiveTransactionsAndSessionIsNotFoundOnBackend_Rem if g, w := sh.eligibleForLongRunning, false; g != w { t.Fatalf("isLongRunningTransaction mismatch\nGot: %v\nWant: %v\n", g, w) } - if g, w := p.numOfLeakedSessionsRemoved, uint64(1); g != w { + expectedLeakedSessions := uint64(1) + if isMultiplexEnabled { + expectedLeakedSessions = 0 + } + if g, w := p.numOfLeakedSessionsRemoved, expectedLeakedSessions; g != w { t.Fatalf("Number of leaked sessions removed mismatch\nGot: %d\nWant: %d\n", g, w) } } @@ -914,6 +1171,7 @@ func testSingleQuery(t *testing.T, serverError error) error { ctx := context.Background() server, client, teardown := setupMockedTestServer(t) defer teardown() + if serverError != nil { server.TestSpanner.SetError(serverError) } @@ -2375,16 +2633,20 @@ func TestClient_ReadWriteTransactionWithOptimisticLockMode_ExecuteSqlRequest(t * &sppb.CommitRequest{}}, requests); err != nil { t.Fatal(err) } - if requests[1].(*sppb.ExecuteSqlRequest).GetTransaction().GetBegin().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if requests[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).GetTransaction().GetBegin().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { t.Fatal("Transaction is not set to optimistic") } - if requests[2].(*sppb.ExecuteSqlRequest).GetTransaction().GetBegin().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { + if requests[2+muxCreateBuffer].(*sppb.ExecuteSqlRequest).GetTransaction().GetBegin().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { t.Fatal("Transaction is not set to optimistic") } - if requests[3].(*sppb.BeginTransactionRequest).GetOptions().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { + if requests[3+muxCreateBuffer].(*sppb.BeginTransactionRequest).GetOptions().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { t.Fatal("Begin Transaction is not set to optimistic") } - if _, ok := requests[4].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { + if _, ok := requests[4+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { t.Fatal("expected streaming query to use transactionID from explicit begin transaction") } } @@ -2406,7 +2668,9 @@ func TestClient_ReadWriteTransactionWithOptimisticLockMode_ReadRequest(t *testin if err != nil { t.Fatalf("Failed to execute the transaction: %s", err) } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ &sppb.BatchCreateSessionsRequest{}, &sppb.ReadRequest{}, @@ -2416,16 +2680,20 @@ func TestClient_ReadWriteTransactionWithOptimisticLockMode_ReadRequest(t *testin &sppb.CommitRequest{}}, requests); err != nil { t.Fatal(err) } - if requests[1].(*sppb.ReadRequest).GetTransaction().GetBegin().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if requests[1+muxCreateBuffer].(*sppb.ReadRequest).GetTransaction().GetBegin().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { t.Fatal("Transaction is not set to optimistic") } - if requests[2].(*sppb.ReadRequest).GetTransaction().GetBegin().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { + if requests[2+muxCreateBuffer].(*sppb.ReadRequest).GetTransaction().GetBegin().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { t.Fatal("Transaction is not set to optimistic") } - if requests[3].(*sppb.BeginTransactionRequest).GetOptions().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { + if requests[3+muxCreateBuffer].(*sppb.BeginTransactionRequest).GetOptions().GetReadWrite().GetReadLockMode() != sppb.TransactionOptions_ReadWrite_OPTIMISTIC { t.Fatal("Begin Transaction is not set to optimistic") } - if _, ok := requests[4].(*sppb.ReadRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { + if _, ok := requests[4+muxCreateBuffer].(*sppb.ReadRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { t.Fatal("expected streaming read to use transactionID from explicit begin transaction") } } @@ -2927,6 +3195,7 @@ func TestClient_ReadWriteTransaction_FirstStatementAsQueryReturnsUnavailableRetr t.Parallel() server, client, teardown := setupMockedTestServer(t) defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, SimulatedExecutionTime{ Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable"), status.Error(codes.Aborted, "Transaction aborted")}, @@ -2959,13 +3228,17 @@ func TestClient_ReadWriteTransaction_FirstStatementAsQueryReturnsUnavailableRetr &sppb.CommitRequest{}}, requests); err != nil { t.Fatal(err) } - if _, ok := requests[1].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if _, ok := requests[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { t.Fatal("expected streaming query to use TransactionSelector::Begin") } - if _, ok := requests[2].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { + if _, ok := requests[2+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { t.Fatal("expected streaming query to use TransactionSelector::Begin") } - if _, ok := requests[4].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { + if _, ok := requests[4+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { t.Fatal("expected streaming query to use transactionID from explicit begin transaction") } } @@ -2976,6 +3249,7 @@ func TestClient_ReadWriteTransaction_FirstStatementAsReadFailsHalfway(t *testing t.Parallel() server, client, teardown := setupMockedTestServer(t) defer teardown() + server.TestSpanner.AddPartialResultSetError( SelectSingerIDAlbumIDAlbumTitleFromAlbums, PartialResultSetExecutionTime{ @@ -3008,13 +3282,17 @@ func TestClient_ReadWriteTransaction_FirstStatementAsReadFailsHalfway(t *testing &sppb.CommitRequest{}}, requests); err != nil { t.Fatal(err) } - if _, ok := requests[1].(*sppb.ReadRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if _, ok := requests[1+muxCreateBuffer].(*sppb.ReadRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { t.Fatal("expected streaming read to use TransactionSelector::Begin") } - if _, ok := requests[2].(*sppb.ReadRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { + if _, ok := requests[2+muxCreateBuffer].(*sppb.ReadRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { t.Fatal("expected streaming read to use transactionID from previous success request") } - if requests[2].(*sppb.ReadRequest).ResumeToken == nil { + if requests[2+muxCreateBuffer].(*sppb.ReadRequest).ResumeToken == nil { t.Fatal("expected streaming read to include resume token") } } @@ -3058,13 +3336,17 @@ func TestClient_ReadWriteTransaction_BatchDmlWithErrorOnFirstStatement(t *testin } // The first statement will fail and not return a transaction id. This will trigger a retry of // the entire transaction, and the retry will do an explicit BeginTransaction RPC. - if _, ok := requests[1].(*sppb.ExecuteBatchDmlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if _, ok := requests[1+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { t.Fatal("expected first BatchUpdate to use TransactionSelector::Begin") } - if _, ok := requests[3].(*sppb.ExecuteBatchDmlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { + if _, ok := requests[3+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { t.Fatal("expected second BatchUpdate to use transactionID from explicit begin") } - if _, ok := requests[4].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { + if _, ok := requests[4+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { t.Fatal("expected second ExecuteSqlRequest to use transactionID from explicit begin") } } @@ -3107,10 +3389,14 @@ func TestClient_ReadWriteTransaction_BatchDmlWithErrorOnSecondStatement(t *testi } // Although the batch DML returned an error, that error was for the second statement. That // means that the transaction was started by the first statement. - if _, ok := requests[1].(*sppb.ExecuteBatchDmlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if _, ok := requests[1+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Begin); !ok { t.Fatal("expected BatchUpdate to use TransactionSelector::Begin") } - if _, ok := requests[2].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { + if _, ok := requests[2+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Transaction.GetSelector().(*sppb.TransactionSelector_Id); !ok { t.Fatal("expected ExecuteSqlRequest use transactionID from BatchUpdate request") } } @@ -3126,6 +3412,7 @@ func TestClient_ReadWriteTransaction_MultipleReadsWithoutNext(t *testing.T) { Err: status.Errorf(codes.Internal, "stream terminated by RST_STREAM"), }, ) + _, err := client.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *ReadWriteTransaction) error { iter := tx.Read(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"}) iter.Stop() @@ -3272,7 +3559,11 @@ func TestClient_ApplyAtLeastOnceReuseSession(t *testing.T) { if g, w := uint64(sp.idleList.Len())+sp.createReqs, sp.incStep; g != w { t.Fatalf("idle session count mismatch:\nGot: %v\nWant: %v", g, w) } - if g, w := uint64(len(server.TestSpanner.DumpSessions())), sp.incStep; g != w { + expectedSessions := sp.incStep + if isMultiplexEnabled { + expectedSessions++ + } + if g, w := uint64(len(server.TestSpanner.DumpSessions())), expectedSessions; g != w { t.Fatalf("server session count mismatch:\nGot: %v\nWant: %v", g, w) } sp.mu.Unlock() @@ -3314,7 +3605,11 @@ func TestClient_ApplyAtLeastOnceInvalidArgument(t *testing.T) { if g, w := uint64(sp.idleList.Len())+sp.createReqs, sp.incStep; g != w { t.Fatalf("idle session count mismatch:\nGot: %v\nWant: %v", g, w) } - if g, w := uint64(len(server.TestSpanner.DumpSessions())), sp.incStep; g != w { + var countMuxSess uint64 + if isMultiplexEnabled { + countMuxSess = 1 + } + if g, w := uint64(len(server.TestSpanner.DumpSessions())), sp.incStep+countMuxSess; g != w { t.Fatalf("server session count mismatch:\nGot: %v\nWant: %v", g, w) } sp.mu.Unlock() @@ -3865,6 +4160,7 @@ func TestFailedCommit_NoRollback(t *testing.T) { }, }) defer teardown() + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid mutations")}, @@ -3895,6 +4191,7 @@ func TestFailedUpdate_ShouldRollback(t *testing.T) { }, }) defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid update"), status.Errorf(codes.InvalidArgument, "Invalid update")}, @@ -5208,7 +5505,7 @@ func TestClient_CloseWithUnresponsiveBackend(t *testing.T) { server.TestSpanner.Freeze() defer server.TestSpanner.Unfreeze() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() sp.close(ctx) @@ -5563,6 +5860,7 @@ func TestClient_ReadWriteTransactionWithTag_SessionNotFound(t *testing.T) { t.Parallel() server, client, teardown := setupMockedTestServer(t) defer teardown() + ctx := context.Background() server.TestSpanner.PutExecutionTime(MethodBeginTransaction, SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}) @@ -5602,10 +5900,14 @@ func TestClient_ReadWriteTransactionWithTag_SessionNotFound(t *testing.T) { }, requests); err != nil { t.Fatal(err) } - if g, w := requests[3].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if g, w := requests[3+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[5].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + if g, w := requests[5+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } } @@ -5649,19 +5951,23 @@ func TestClient_NestedReadWriteTransactionWithTag_AbortedOnce(t *testing.T) { }, requests); err != nil { t.Fatal(err) } - if g, w := requests[1].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if g, w := requests[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[2].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + if g, w := requests[2+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[3].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + if g, w := requests[3+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[4].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + if g, w := requests[4+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[6].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { + if g, w := requests[6+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } } @@ -5670,6 +5976,7 @@ func TestClient_NestedReadWriteTransactionWithTag_OuterAbortedOnce(t *testing.T) t.Parallel() server, client, teardown := setupMockedTestServer(t) defer teardown() + ctx := context.Background() server.TestSpanner.PutExecutionTime(MethodCommitTransaction, SimulatedExecutionTime{Errors: []error{nil, status.Error(codes.Aborted, "Transaction aborted")}}) @@ -5707,22 +6014,26 @@ func TestClient_NestedReadWriteTransactionWithTag_OuterAbortedOnce(t *testing.T) }, requests); err != nil { t.Fatal(err) } - if g, w := requests[1].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if g, w := requests[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[2].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + if g, w := requests[2+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[4].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { + if g, w := requests[4+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[5].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + if g, w := requests[5+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[6].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + if g, w := requests[6+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[8].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { + if g, w := requests[8+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } } @@ -5768,22 +6079,27 @@ func TestClient_NestedReadWriteTransactionWithTag_InnerBlindWrite(t *testing.T) }, requests); err != nil { t.Fatal(err) } - if g, w := requests[2].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + + if g, w := requests[2+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[3].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { + if g, w := requests[3+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[4].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { + if g, w := requests[4+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[6].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { + if g, w := requests[6+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag2"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[7].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { + if g, w := requests[7+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } - if g, w := requests[8].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { + if g, w := requests[8+muxCreateBuffer].(*sppb.CommitRequest).RequestOptions.TransactionTag, "test-tag1"; g != w { t.Fatalf("transaction tag mismatch\nGot: %s\nWant: %s", g, w) } } @@ -5810,7 +6126,11 @@ func TestClient_ReadWriteTransactionWithExcludeTxnFromChangeStreams_ExecuteSqlRe &sppb.CommitRequest{}}, requests); err != nil { t.Fatal(err) } - if !requests[1].(*sppb.ExecuteSqlRequest).Transaction.GetBegin().ExcludeTxnFromChangeStreams { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if !requests[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Transaction.GetBegin().ExcludeTxnFromChangeStreams { t.Fatal("Transaction is not set to be excluded from change streams") } } @@ -5838,7 +6158,11 @@ func TestClient_ReadWriteTransactionWithExcludeTxnFromChangeStreams_BufferWrite( &sppb.CommitRequest{}}, requests); err != nil { t.Fatal(err) } - if !requests[1].(*sppb.BeginTransactionRequest).Options.ExcludeTxnFromChangeStreams { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if !requests[1+muxCreateBuffer].(*sppb.BeginTransactionRequest).Options.ExcludeTxnFromChangeStreams { t.Fatal("Transaction is not set to be excluded from change streams") } } @@ -5865,7 +6189,11 @@ func TestClient_ReadWriteTransactionWithExcludeTxnFromChangeStreams_BatchUpdate( &sppb.CommitRequest{}}, requests); err != nil { t.Fatal(err) } - if !requests[1].(*sppb.ExecuteBatchDmlRequest).Transaction.GetBegin().ExcludeTxnFromChangeStreams { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if !requests[1+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Transaction.GetBegin().ExcludeTxnFromChangeStreams { t.Fatal("Transaction is not set to be excluded from change streams") } } @@ -5927,7 +6255,11 @@ func TestClient_ApplyExcludeTxnFromChangeStreams(t *testing.T) { &sppb.CommitRequest{}}, requests); err != nil { t.Fatal(err) } - if !requests[1].(*sppb.BeginTransactionRequest).Options.ExcludeTxnFromChangeStreams { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if !requests[1+muxCreateBuffer].(*sppb.BeginTransactionRequest).Options.ExcludeTxnFromChangeStreams { t.Fatal("Transaction is not set to be excluded from change streams") } } @@ -5951,7 +6283,11 @@ func TestClient_ApplyAtLeastOnceExcludeTxnFromChangeStreams(t *testing.T) { &sppb.CommitRequest{}}, requests); err != nil { t.Fatal(err) } - if !requests[1].(*sppb.CommitRequest).Transaction.(*sppb.CommitRequest_SingleUseTransaction).SingleUseTransaction.ExcludeTxnFromChangeStreams { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if !requests[1+muxCreateBuffer].(*sppb.CommitRequest).Transaction.(*sppb.CommitRequest_SingleUseTransaction).SingleUseTransaction.ExcludeTxnFromChangeStreams { t.Fatal("Transaction is not set to be excluded from change streams") } } @@ -5983,7 +6319,11 @@ func TestClient_BatchWriteExcludeTxnFromChangeStreams(t *testing.T) { &sppb.BatchWriteRequest{}}, requests); err != nil { t.Fatal(err) } - if !requests[1].(*sppb.BatchWriteRequest).ExcludeTxnFromChangeStreams { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if !requests[1+muxCreateBuffer].(*sppb.BatchWriteRequest).ExcludeTxnFromChangeStreams { t.Fatal("Transaction is not set to be excluded from change streams") } } diff --git a/spanner/go.mod b/spanner/go.mod index b25ba3ec2b7c..730cf53480fb 100644 --- a/spanner/go.mod +++ b/spanner/go.mod @@ -13,6 +13,7 @@ require ( go.opentelemetry.io/otel v1.24.0 go.opentelemetry.io/otel/metric v1.24.0 golang.org/x/oauth2 v0.21.0 + golang.org/x/sync v0.7.0 golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 google.golang.org/api v0.189.0 google.golang.org/genproto v0.0.0-20240722135656-d784300faade @@ -45,7 +46,6 @@ require ( go.opentelemetry.io/otel/trace v1.24.0 // indirect golang.org/x/crypto v0.25.0 // indirect golang.org/x/net v0.27.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 2e1d215a2c4d..3543104f44bb 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -91,6 +91,8 @@ var ( // GCLOUD_TESTS_GOLANG_SPANNER_INSTANCE_CONFIG. instanceConfig = getInstanceConfig() + isMultiplexEnabled = getMultiplexEnableFlag() + dbNameSpace = uid.NewSpace("gotest", &uid.Options{Sep: '_', Short: true}) instanceNameSpace = uid.NewSpace("gotest", &uid.Options{Sep: '-', Short: true}) backupIDSpace = uid.NewSpace("gotest", &uid.Options{Sep: '_', Short: true}) @@ -386,6 +388,10 @@ func getInstanceConfig() string { return os.Getenv("GCLOUD_TESTS_GOLANG_SPANNER_INSTANCE_CONFIG") } +func getMultiplexEnableFlag() bool { + return os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS") == "true" +} + const ( str1 = "alice" str2 = "a@example.com" @@ -1899,6 +1905,23 @@ func TestIntegration_DbRemovalRecovery(t *testing.T) { // from repeatedly trying to create sessions for the invalid database. client, dbPath, cleanup := prepareIntegrationTest(ctx, t, SessionPoolConfig{}, statements[testDialect][singerDDLStatements]) defer cleanup() + if isMultiplexEnabled { + // TODO: confirm that this is the valid scenario for multiplexed sessions, and what's expected behavior. + // wait for the multiplexed session to be created. + waitFor(t, func() error { + client.idleSessions.mu.Lock() + defer client.idleSessions.mu.Unlock() + if client.idleSessions.multiplexedSession == nil { + return errInvalidSessionPool + } + return nil + }) + // Close the multiplexed session to prevent the session pool maintainer + // from repeatedly trying to use sessions for the invalid database. + client.idleSessions.mu.Lock() + client.idleSessions.multiplexedSession = nil + client.idleSessions.mu.Unlock() + } // Drop the testing database. if err := databaseAdmin.DropDatabase(ctx, &adminpb.DropDatabaseRequest{Database: dbPath}); err != nil { diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 968f10980479..ba1d1f0b17bf 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -525,8 +525,11 @@ func (s *inMemSpannerServer) initDefaults() { s.transactionCounters = make(map[string]*uint64) } -func (s *inMemSpannerServer) generateSessionNameLocked(database string) string { +func (s *inMemSpannerServer) generateSessionNameLocked(database string, isMultiplexed bool) string { s.sessionCounter++ + if isMultiplexed { + return fmt.Sprintf("%s/sessions/multiplexed-%d", database, s.sessionCounter) + } return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter) } @@ -705,13 +708,21 @@ func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.C if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal { return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") } - sessionName := s.generateSessionNameLocked(req.Database) ts := getCurrentTimestamp() - var creatorRole string + var ( + creatorRole string + isMultiplexed bool + ) if req.Session != nil { creatorRole = req.Session.CreatorRole + isMultiplexed = req.Session.Multiplexed + } + sessionName := s.generateSessionNameLocked(req.Database, isMultiplexed) + header := metadata.New(map[string]string{"server-timing": "gfet4t7; dur=123"}) + if err := grpc.SendHeader(ctx, header); err != nil { + return nil, gstatus.Errorf(codes.Internal, "unable to send 'server-timing' header") } - session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts, CreatorRole: creatorRole} + session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts, CreatorRole: creatorRole, Multiplexed: isMultiplexed} s.totalSessionsCreated++ s.sessions[sessionName] = session return session, nil @@ -742,7 +753,7 @@ func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spann } sessions := make([]*spannerpb.Session, sessionsToCreate) for i := int32(0); i < sessionsToCreate; i++ { - sessionName := s.generateSessionNameLocked(req.Database) + sessionName := s.generateSessionNameLocked(req.Database, false) ts := getCurrentTimestamp() var creatorRole string if req.SessionTemplate != nil { diff --git a/spanner/kokoro/presubmit.sh b/spanner/kokoro/presubmit.sh new file mode 100755 index 000000000000..9d5ab3a7ebf6 --- /dev/null +++ b/spanner/kokoro/presubmit.sh @@ -0,0 +1,111 @@ +#!/bin/bash +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.. + +# Fail on any error +set -eo pipefail + +# Display commands being run +set -x + +# cd to project dir on Kokoro instance +cd github/google-cloud-go + +go version + +export GOCLOUD_HOME=$KOKORO_ARTIFACTS_DIR/google-cloud-go/ +export PATH="$GOPATH/bin:$PATH" +export GO111MODULE=on +export GOPROXY=https://proxy.golang.org +export GOOGLE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/${GOOGLE_APPLICATION_CREDENTIALS} +# Move code into artifacts dir +mkdir -p $GOCLOUD_HOME +git clone . $GOCLOUD_HOME +cd $GOCLOUD_HOME + +try3() { eval "$*" || eval "$*" || eval "$*"; } + +# All packages, including +build tools, are fetched. +try3 go mod download + +set +e # Run all tests, don't stop after the first failure. +exit_code=0 + +case $JOB_TYPE in +integration-with-multiplexed-session ) + GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS=true + echo "running presubmit with multiplexed sessions enabled: $GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" + ;; +esac + +# Run tests in the current directory and tee output to log file, +# to be pushed to GCS as artifact. +runPresubmitTests() { + if [[ $PWD == *"/internal/"* ]] || + [[ $PWD == *"/third_party/"* ]]; then + # internal tools only expected to work with latest go version + return + fi + + if [ -z ${RUN_INTEGRATION_TESTS} ]; then + GOWORK=off go test -race -v -timeout 15m -short ./... 2>&1 | + tee sponge_log.log + else + GOWORK=off go test -race -v -timeout 45m ./... 2>&1 | + tee sponge_log.log + fi + + # Skip running integration tests since Emulator does not support Multiplexed sessions + # Run integration tests against an emulator. +# if [ -f "emulator_test.sh" ]; then +# ./emulator_test.sh +# fi + # Takes the kokoro output log (raw stdout) and creates a machine-parseable + # xUnit XML file. + cat sponge_log.log | + go-junit-report -set-exit-code >sponge_log.xml + # Add the exit codes together so we exit non-zero if any module fails. + exit_code=$(($exit_code + $?)) + if [[ $PWD != *"/internal/"* ]]; then + GOWORK=off go build ./... + fi + exit_code=$(($exit_code + $?)) +} + +SIGNIFICANT_CHANGES=$(git --no-pager diff --name-only origin/main...$KOKORO_GIT_COMMIT_google_cloud_go | + grep -Ev '(\.md$|^\.github|\.json$|\.yaml$)' | xargs dirname | sort -u || true) + +if [ -z $SIGNIFICANT_CHANGES ]; then + echo "No changes detected, skipping tests" + exit 0 +fi + +# CHANGED_DIRS is the list of significant top-level directories that changed, +# but weren't deleted by the current PR. CHANGED_DIRS will be empty when run on main. +CHANGED_DIRS=$(echo "$SIGNIFICANT_CHANGES" | tr ' ' '\n' | cut -d/ -f1 | sort -u | + tr '\n' ' ' | xargs ls -d 2>/dev/null || true) + +echo "Running tests only in changed submodules: $CHANGED_DIRS" +for d in $CHANGED_DIRS; do + # run tests only if spanner module is part of $CHANGED_DIRS + if [[ $CHANGED_DIRS =~ spanner ]];then + for i in $(find "$d" -name go.mod); do + pushd $(dirname $i) + runPresubmitTests + popd + done + fi +done + +exit $exit_code \ No newline at end of file diff --git a/spanner/oc_test.go b/spanner/oc_test.go index 76a710e5ac5e..f49a7bf7b0c2 100644 --- a/spanner/oc_test.go +++ b/spanner/oc_test.go @@ -52,6 +52,19 @@ func TestOCStats(t *testing.T) { func TestOCStats_SessionPool(t *testing.T) { skipForPGTest(t) DisableGfeLatencyAndHeaderMissingCountViews() + // expectedValues is a map of expected values for different configurations of + // multiplexed session env="GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS". + expectedValues := map[string]map[bool]string{ + "open_session_count": { + false: "25", + // since we are doing only R/O operations and MinOpened=0, we should have only one session. + true: "1", + }, + "max_in_use_sessions": { + false: "1", + true: "0", + }, + } for _, test := range []struct { name string view *view.View @@ -62,7 +75,7 @@ func TestOCStats_SessionPool(t *testing.T) { "OpenSessionCount", OpenSessionCountView, "open_session_count", - "25", + expectedValues["open_session_count"][isMultiplexEnabled], }, { "MaxAllowedSessionsCount", @@ -74,7 +87,7 @@ func TestOCStats_SessionPool(t *testing.T) { "MaxInUseSessionsCount", MaxInUseSessionsCountView, "max_in_use_sessions", - "1", + expectedValues["max_in_use_sessions"][isMultiplexEnabled], }, { "AcquiredSessionsCount", @@ -167,11 +180,16 @@ func TestOCStats_SessionPool_SessionsCount(t *testing.T) { }) client.Single().ReadRow(context.Background(), "Users", Key{"alice"}, []string{"email"}) + expectedStats := 2 + if isMultiplexEnabled { + // num_in_use_sessions is not exported when multiplexed sessions are enabled and only ReadOnly transactions are performed. + expectedStats = 1 + } // Wait for a while to see all exported metrics. waitFor(t, func() error { select { case stat := <-te.Stats: - if len(stat.Rows) >= 2 { + if len(stat.Rows) >= expectedStats { return nil } } @@ -183,7 +201,7 @@ func TestOCStats_SessionPool_SessionsCount(t *testing.T) { case stat := <-te.Stats: // There are 4 types for this metric, so we should see at least four // rows. - if len(stat.Rows) < 2 { + if len(stat.Rows) < expectedStats { t.Fatal("No enough metrics are exported") } if got, want := stat.View.Measure.Name(), statsPrefix+"num_sessions_in_pool"; got != want { @@ -220,14 +238,17 @@ func TestOCStats_SessionPool_GetSessionTimeoutsCount(t *testing.T) { te := testutil.NewTestExporter(GetSessionTimeoutsCountView) defer te.Unregister() - server, client, teardown := setupMockedTestServer(t) + server, client, teardown := setupMockedTestServerWithoutWaitingForMultiplexedSessionInit(t) defer teardown() server.TestSpanner.PutExecutionTime(stestutil.MethodBatchCreateSession, stestutil.SimulatedExecutionTime{ MinimumExecutionTime: 2 * time.Millisecond, }) - + server.TestSpanner.PutExecutionTime(stestutil.MethodCreateSession, + stestutil.SimulatedExecutionTime{ + MinimumExecutionTime: 2 * time.Millisecond, + }) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) defer cancel() client.Single().ReadRow(ctx, "Users", Key{"alice"}, []string{"email"}) diff --git a/spanner/ot_metrics.go b/spanner/ot_metrics.go index 2f49cadec3e5..16190860ce77 100644 --- a/spanner/ot_metrics.go +++ b/spanner/ot_metrics.go @@ -33,12 +33,13 @@ const OtInstrumentationScope = "cloud.google.com/go" const metricsPrefix = "spanner/" var ( - attributeKeyClientID = attribute.Key("client_id") - attributeKeyDatabase = attribute.Key("database") - attributeKeyInstance = attribute.Key("instance_id") - attributeKeyLibVersion = attribute.Key("library_version") - attributeKeyType = attribute.Key("type") - attributeKeyMethod = attribute.Key("grpc_client_method") + attributeKeyClientID = attribute.Key("client_id") + attributeKeyDatabase = attribute.Key("database") + attributeKeyInstance = attribute.Key("instance_id") + attributeKeyLibVersion = attribute.Key("library_version") + attributeKeyType = attribute.Key("type") + attributeKeyMethod = attribute.Key("grpc_client_method") + attributeKeyIsMultiplexed = attribute.Key("is_multiplexed") attributeNumInUseSessions = attributeKeyType.String("num_in_use_sessions") attributeNumSessions = attributeKeyType.String("num_sessions") @@ -69,6 +70,12 @@ func createOpenTelemetryConfig(mp metric.MeterProvider, logger *log.Logger, sess } config.attributeMap = append(config.attributeMap, attributeMap...) + config.attributeMapWithMultiplexed = append(config.attributeMapWithMultiplexed, attributeMap...) + config.attributeMapWithMultiplexed = append(config.attributeMapWithMultiplexed, attributeKeyIsMultiplexed.String("true")) + + config.attributeMapWithoutMultiplexed = append(config.attributeMapWithoutMultiplexed, attributeMap...) + config.attributeMapWithoutMultiplexed = append(config.attributeMapWithoutMultiplexed, attributeKeyIsMultiplexed.String("false")) + setOpenTelemetryMetricProvider(config, mp, logger) return config, nil } @@ -197,13 +204,14 @@ func registerSessionPoolOTMetrics(pool *sessionPool) error { func(ctx context.Context, o metric.Observer) error { pool.mu.Lock() defer pool.mu.Unlock() - + if pool.multiplexedSession != nil { + o.ObserveInt64(otConfig.openSessionCount, int64(1), metric.WithAttributes(otConfig.attributeMapWithMultiplexed...)) + } o.ObserveInt64(otConfig.openSessionCount, int64(pool.numOpened), metric.WithAttributes(attributes...)) o.ObserveInt64(otConfig.maxAllowedSessionsCount, int64(pool.SessionPoolConfig.MaxOpened), metric.WithAttributes(attributes...)) - o.ObserveInt64(otConfig.sessionsCount, int64(pool.numInUse), metric.WithAttributes(attributesInUseSessions...)) + o.ObserveInt64(otConfig.sessionsCount, int64(pool.numInUse), metric.WithAttributes(append(attributesInUseSessions, attribute.Key("is_multiplexed").String("false"))...)) o.ObserveInt64(otConfig.sessionsCount, int64(pool.numSessions), metric.WithAttributes(attributesAvailableSessions...)) - o.ObserveInt64(otConfig.maxInUseSessionsCount, int64(pool.maxNumInUse), metric.WithAttributes(attributes...)) - + o.ObserveInt64(otConfig.maxInUseSessionsCount, int64(pool.maxNumInUse), metric.WithAttributes(append(attributes, attribute.Key("is_multiplexed").String("false"))...)) return nil }, otConfig.openSessionCount, diff --git a/spanner/pdml_test.go b/spanner/pdml_test.go index 1c5b57f0d15e..f905c68424ed 100644 --- a/spanner/pdml_test.go +++ b/spanner/pdml_test.go @@ -101,8 +101,12 @@ func TestPartitionedUpdate_Aborted(t *testing.T) { if err != nil { t.Fatal(err) } - id1 := gotReqs[2].(*sppb.ExecuteSqlRequest).Transaction.GetId() - id2 := gotReqs[4].(*sppb.ExecuteSqlRequest).Transaction.GetId() + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + id1 := gotReqs[2+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Transaction.GetId() + id2 := gotReqs[4+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Transaction.GetId() if bytes.Equal(id1, id2) { t.Errorf("same transaction used twice, expected two different transactions\ngot tx1: %q\ngot tx2: %q", id1, id2) } @@ -196,8 +200,12 @@ func TestPartitionedUpdate_ExcludeTxnFromChangeStreams(t *testing.T) { &sppb.ExecuteSqlRequest{}}, requests); err != nil { t.Fatal(err) } + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } - if !requests[1].(*sppb.BeginTransactionRequest).GetOptions().GetExcludeTxnFromChangeStreams() { + if !requests[1+muxCreateBuffer].(*sppb.BeginTransactionRequest).GetOptions().GetExcludeTxnFromChangeStreams() { t.Fatal("Transaction is not set to be excluded from change streams") } } diff --git a/spanner/session.go b/spanner/session.go index 2948790e663d..3e587095a007 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -24,6 +24,7 @@ import ( "log" "math" "math/rand" + "os" "runtime/debug" "strings" "sync" @@ -36,12 +37,16 @@ import ( "go.opencensus.io/stats" "go.opencensus.io/tag" octrace "go.opencensus.io/trace" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" ) -const healthCheckIntervalMins = 50 +const ( + healthCheckIntervalMins = 50 + multiplexSessionRefreshInterval = 7 * 24 * time.Hour +) // ActionOnInactiveTransactionKind describes the kind of action taken when there are inactive transactions. type ActionOnInactiveTransactionKind int @@ -85,6 +90,8 @@ type sessionHandle struct { // session is a pointer to a session object. Transactions never need to // access it directly. session *session + // client is the RPC channel to Cloud Spanner. It is set only once during session acquisition. + client *vkit.Client // checkoutTime is the time the session was checked out of the pool. checkoutTime time.Time // lastUseTime is the time the session was last used after checked out of the pool. @@ -115,6 +122,7 @@ func (sh *sessionHandle) recycle() { tracked := sh.trackedSessionHandle s := sh.session sh.session = nil + sh.client = nil sh.trackedSessionHandle = nil sh.checkoutTime = time.Time{} sh.lastUseTime = time.Time{} @@ -149,6 +157,10 @@ func (sh *sessionHandle) getClient() *vkit.Client { if sh.session == nil { return nil } + if sh.client != nil { + // Use the gRPC connection from the session handle + return sh.client + } return sh.session.client } @@ -185,6 +197,7 @@ func (sh *sessionHandle) destroy() { } tracked := sh.trackedSessionHandle sh.session = nil + sh.client = nil sh.trackedSessionHandle = nil sh.checkoutTime = time.Time{} sh.lastUseTime = time.Time{} @@ -253,6 +266,8 @@ type session struct { tx transactionID // firstHCDone indicates whether the first health check is done or not. firstHCDone bool + // isMultiplexed is true if the session is multiplexed. + isMultiplexed bool } // isValid returns true if the session is still valid for use. @@ -371,6 +386,11 @@ func (s *session) getNextCheck() time.Time { func (s *session) recycle() { s.setTransactionID(nil) s.pool.mu.Lock() + if s.isMultiplexed { + s.pool.decNumMultiplexedInUseLocked(context.Background()) + s.pool.mu.Unlock() + return + } if !s.pool.recycleLocked(s) { // s is rejected by its home session pool because it expired and the // session pool currently has enough open sessions. @@ -476,6 +496,11 @@ type SessionPoolConfig struct { // Defaults to 50m. HealthCheckInterval time.Duration + // MultiplexSessionCheckInterval is the interval at which the multiplexed session is checked whether it needs to be refreshed. + // + // Defaults to 10 mins. + MultiplexSessionCheckInterval time.Duration + // TrackSessionHandles determines whether the session pool will keep track // of the stacktrace of the goroutines that take sessions from the pool. // This setting can be used to track down session leak problems. @@ -555,6 +580,11 @@ func (spc *SessionPoolConfig) validate() error { return nil } +type muxSessionCreateRequest struct { + ctx context.Context + force bool +} + // sessionPool creates and caches Cloud Spanner sessions. type sessionPool struct { // mu protects sessionPool from concurrent access. @@ -570,12 +600,25 @@ type sessionPool struct { // idleList caches idle session IDs. Session IDs in this list can be // allocated for use. idleList list.List + // multiplexSessionClientCounter is the counter for the multiplexed session client. + multiplexSessionClientCounter int + // clientPool is a pool of Cloud Spanner grpc clients. + clientPool []*vkit.Client + // multiplexedSession contains the multiplexed session + multiplexedSession *session // mayGetSession is for broadcasting that session retrival/creation may // proceed. mayGetSession chan struct{} + // multiplexedSessionReq is the ongoing multiplexed session creation request (if any). + multiplexedSessionReq chan muxSessionCreateRequest + // mayGetMultiplexedSession is for broadcasting that multiplexed session retrieval is possible. + mayGetMultiplexedSession chan bool // sessionCreationError is the last error that occurred during session // creation and is propagated to any waiters waiting for a session. sessionCreationError error + // multiplexedSessionCreationError is the error that occurred during multiplexed session + // creation for the first time and is propagated to any waiters waiting for a session. + multiplexedSessionCreationError error // numOpened is the total number of open sessions from the session pool. numOpened uint64 // createReqs is the number of ongoing session creation requests. @@ -617,6 +660,9 @@ type sessionPool struct { numOfLeakedSessionsRemoved uint64 otConfig *openTelemetryConfig + + // enableMultiplexSession is a flag to enable multiplexed session. + enableMultiplexSession bool } // newSessionPool creates a new session pool. @@ -651,15 +697,24 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool, if config.usedSessionsRatioThreshold == 0 { config.usedSessionsRatioThreshold = DefaultSessionPoolConfig.usedSessionsRatioThreshold } - + if config.MultiplexSessionCheckInterval == 0 { + config.MultiplexSessionCheckInterval = 10 * time.Minute + } + isMultiplexed := strings.ToLower(os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS")) + if isMultiplexed != "" && isMultiplexed != "true" && isMultiplexed != "false" { + return nil, spannerErrorf(codes.InvalidArgument, "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS must be either true or false") + } pool := &sessionPool{ - sc: sc, - valid: true, - mayGetSession: make(chan struct{}), - SessionPoolConfig: config, - mw: newMaintenanceWindow(config.MaxOpened), - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - otConfig: sc.otConfig, + sc: sc, + valid: true, + mayGetSession: make(chan struct{}), + mayGetMultiplexedSession: make(chan bool), + multiplexedSessionReq: make(chan muxSessionCreateRequest), + SessionPoolConfig: config, + mw: newMaintenanceWindow(config.MaxOpened), + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + otConfig: sc.otConfig, + enableMultiplexSession: isMultiplexed == "true", } _, instance, database, err := parseDatabaseName(sc.database) @@ -682,7 +737,7 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool, // 10ms to finish, given a 5 minutes interval and 10 healthcheck workers, a // healthChecker can effectively mantain // 100 checks_per_worker/sec * 10 workers * 300 seconds = 300K sessions. - pool.hc = newHealthChecker(config.HealthCheckInterval, config.HealthCheckWorkers, config.healthCheckSampleInterval, pool) + pool.hc = newHealthChecker(config.HealthCheckInterval, config.MultiplexSessionCheckInterval, config.HealthCheckWorkers, config.healthCheckSampleInterval, pool) // First initialize the pool before we indicate that the healthchecker is // ready. This prevents the maintainer from starting before the pool has @@ -694,6 +749,22 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool, return nil, err } } + if pool.enableMultiplexSession { + go pool.createMultiplexedSession() + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + pool.multiplexedSessionReq <- muxSessionCreateRequest{force: true, ctx: ctx} + // listen for the session to be created + go func() { + select { + case <-ctx.Done(): + cancel() + return + // wait for the session to be created + case <-pool.mayGetMultiplexedSession: + } + return + }() + } pool.recordStat(context.Background(), MaxAllowedSessionsCount, int64(config.MaxOpened)) err = registerSessionPoolOTMetrics(pool) @@ -718,9 +789,17 @@ func (p *sessionPool) recordStat(ctx context.Context, m *stats.Int64Measure, n i recordStat(ctx, m, n) } -func (p *sessionPool) recordOTStat(ctx context.Context, m metric.Int64Counter, val int64) { +type recordOTStatOption struct { + attr []attribute.KeyValue +} + +func (p *sessionPool) recordOTStat(ctx context.Context, m metric.Int64Counter, val int64, option recordOTStatOption) { if m != nil { - m.Add(ctx, val, metric.WithAttributes(p.otConfig.attributeMap...)) + attrs := p.otConfig.attributeMap + if len(option.attr) > 0 { + attrs = option.attr + } + m.Add(ctx, val, metric.WithAttributes(attrs...)) } } @@ -813,13 +892,57 @@ func (p *sessionPool) growPoolLocked(numSessions uint64, distributeOverChannels return p.sc.batchCreateSessions(int32(numSessions), distributeOverChannels, p) } +func (p *sessionPool) createMultiplexedSession() { + for c := range p.multiplexedSessionReq { + p.mu.Lock() + sess := p.multiplexedSession + p.mu.Unlock() + if c.force || sess == nil { + p.mu.Lock() + p.sc.mu.Lock() + client, err := p.sc.nextClient() + p.sc.mu.Unlock() + p.mu.Unlock() + if err != nil { + // If we can't get a client, we can't create a session. + p.mu.Lock() + p.multiplexedSessionCreationError = err + p.mu.Unlock() + p.mayGetMultiplexedSession <- true + continue + } + p.sc.executeCreateMultiplexedSession(c.ctx, client, p.sc.md, p) + continue + } + select { + case p.mayGetMultiplexedSession <- true: + case <-c.ctx.Done(): + return + } + } +} + // sessionReady is executed by the SessionClient when a session has been // created and is ready to use. This method will add the new session to the // pool and decrease the number of sessions that is being created. -func (p *sessionPool) sessionReady(s *session) { +func (p *sessionPool) sessionReady(ctx context.Context, s *session) { p.mu.Lock() defer p.mu.Unlock() // Clear any session creation error. + if s.isMultiplexed { + s.pool = p + p.multiplexedSession = s + p.multiplexedSessionCreationError = nil + p.recordStat(context.Background(), OpenSessionCount, int64(1), tag.Tag{Key: tagKeyIsMultiplexed, Value: "true"}) + p.recordStat(context.Background(), SessionsCount, 1, tagNumSessions, tag.Tag{Key: tagKeyIsMultiplexed, Value: "true"}) + // either notify the waiting goroutine or skip if no one is waiting + select { + case p.mayGetMultiplexedSession <- true: + case <-ctx.Done(): + return + } + return + } p.sessionCreationError = nil // Set this pool as the home pool of the session and register it with the // health checker. @@ -849,12 +972,32 @@ func (p *sessionPool) sessionReady(s *session) { // or more requested sessions finished with an error. sessionCreationFailed will // decrease the number of sessions being created and notify any waiters that // the session creation failed. -func (p *sessionPool) sessionCreationFailed(err error, numSessions int32) { +func (p *sessionPool) sessionCreationFailed(ctx context.Context, err error, numSessions int32, isMultiplexed bool) { p.mu.Lock() defer p.mu.Unlock() + if isMultiplexed { + // Ignore the error if multiplexed session already present + if p.multiplexedSession != nil { + p.multiplexedSessionCreationError = nil + select { + case p.mayGetMultiplexedSession <- true: + case <-ctx.Done(): + return + } + return + } + p.recordStat(context.Background(), OpenSessionCount, int64(0), tag.Tag{Key: tagKeyIsMultiplexed, Value: "true"}) + p.multiplexedSessionCreationError = err + select { + case p.mayGetMultiplexedSession <- true: + case <-ctx.Done(): + return + } + return + } p.createReqs -= uint64(numSessions) p.numOpened -= uint64(numSessions) - p.recordStat(context.Background(), OpenSessionCount, int64(p.numOpened)) + p.recordStat(context.Background(), OpenSessionCount, int64(p.numOpened), tag.Tag{Key: tagKeyIsMultiplexed, Value: "false"}) // Notify other waiters blocking on session creation. p.sessionCreationError = err close(p.mayGetSession) @@ -924,6 +1067,12 @@ var errGetSessionTimeout = spannerErrorf(codes.Canceled, "timeout / context canc // sessions being checked out of the pool. func (p *sessionPool) newSessionHandle(s *session) (sh *sessionHandle) { sh = &sessionHandle{session: s, checkoutTime: time.Now(), lastUseTime: time.Now()} + if s.isMultiplexed { + p.mu.Lock() + sh.client = p.getRoundRobinClient() + p.mu.Unlock() + return sh + } if p.TrackSessionHandles || p.ActionOnInactiveTransaction == Warn || p.ActionOnInactiveTransaction == WarnAndClose || p.ActionOnInactiveTransaction == Close { p.mu.Lock() sh.trackedSessionHandle = p.trackedSessionHandles.PushBack(sh) @@ -935,8 +1084,29 @@ func (p *sessionPool) newSessionHandle(s *session) (sh *sessionHandle) { return sh } +func (p *sessionPool) getRoundRobinClient() *vkit.Client { + p.sc.mu.Lock() + defer func() { + p.multiplexSessionClientCounter++ + p.sc.mu.Unlock() + }() + if len(p.clientPool) == 0 { + p.clientPool = make([]*vkit.Client, p.sc.connPool.Num()) + for i := 0; i < p.sc.connPool.Num(); i++ { + c, err := p.sc.nextClient() + if err != nil { + // If we can't get a client, use the session's client. + return nil + } + p.clientPool[i] = c + } + } + p.multiplexSessionClientCounter = p.multiplexSessionClientCounter % len(p.clientPool) + return p.clientPool[p.multiplexSessionClientCounter] +} + // errGetSessionTimeout returns error for context timeout during -// sessionPool.take(). +// sessionPool.take() or sessionPool.takeMultiplexed(). func (p *sessionPool) errGetSessionTimeout(ctx context.Context) error { var code codes.Code if ctx.Err() == context.DeadlineExceeded { @@ -988,37 +1158,6 @@ func (p *sessionPool) getTrackedSessionHandleStacksLocked() string { return stackTraces } -func (p *sessionPool) createSession(ctx context.Context) (*session, error) { - trace.TracePrintf(ctx, nil, "Creating a new session") - doneCreate := func(done bool) { - p.mu.Lock() - if !done { - // Session creation failed, give budget back. - p.numOpened-- - p.recordStat(ctx, OpenSessionCount, int64(p.numOpened)) - } - p.createReqs-- - // Notify other waiters blocking on session creation. - close(p.mayGetSession) - p.mayGetSession = make(chan struct{}) - p.mu.Unlock() - } - s, err := p.sc.createSession(ctx) - if err != nil { - doneCreate(false) - // Should return error directly because of the previous retries on - // CreateSession RPC. - // If the error is a timeout, there is a chance that the session was - // created on the server but is not known to the session pool. This - // session will then be garbage collected by the server after 1 hour. - return nil, err - } - s.pool = p - p.hc.register(s) - doneCreate(true) - return s, nil -} - func (p *sessionPool) isHealthy(s *session) bool { if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) { if err := s.ping(); isSessionNotFoundError(err) { @@ -1085,9 +1224,9 @@ func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) { select { case <-ctx.Done(): trace.TracePrintf(ctx, nil, "Context done waiting for session") - p.recordStat(ctx, GetSessionTimeoutsCount, 1) + p.recordStat(ctx, GetSessionTimeoutsCount, 1, tag.Tag{Key: tagKeyIsMultiplexed, Value: "false"}) if p.otConfig != nil { - p.recordOTStat(ctx, p.otConfig.getSessionTimeoutsCount, 1) + p.recordOTStat(ctx, p.otConfig.getSessionTimeoutsCount, 1, recordOTStatOption{attr: p.otConfig.attributeMapWithoutMultiplexed}) } p.mu.Lock() p.numWaiters-- @@ -1107,6 +1246,63 @@ func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) { } } +// takeMultiplexed returns a cached session if there is available one; if there isn't +// any, it tries to allocate a new one. +func (p *sessionPool) takeMultiplexed(ctx context.Context) (*sessionHandle, error) { + trace.TracePrintf(ctx, nil, "Acquiring a multiplexed session") + for { + var s *session + p.mu.Lock() + if !p.valid { + p.mu.Unlock() + return nil, errInvalidSessionPool + } + if !p.enableMultiplexSession { + p.mu.Unlock() + return p.take(ctx) + } + // use the multiplex session if it is available + if p.multiplexedSession != nil { + // Multiplexed session is available, get it. + s = p.multiplexedSession + trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, + "Acquired multiplexed session") + p.mu.Unlock() + p.incNumMultiplexedInUse(ctx) + return p.newSessionHandle(s), nil + } + mayGetSession := p.mayGetMultiplexedSession + p.mu.Unlock() + p.multiplexedSessionReq <- muxSessionCreateRequest{force: false, ctx: ctx} + select { + case <-ctx.Done(): + trace.TracePrintf(ctx, nil, "Context done waiting for multiplexed session") + p.recordStat(ctx, GetSessionTimeoutsCount, 1, tag.Tag{Key: tagKeyIsMultiplexed, Value: "true"}) + if p.otConfig != nil { + p.recordOTStat(ctx, p.otConfig.getSessionTimeoutsCount, 1, recordOTStatOption{attr: p.otConfig.attributeMapWithMultiplexed}) + } + return nil, p.errGetSessionTimeout(ctx) + case <-mayGetSession: // Block until multiplexed session is created. + p.mu.Lock() + if p.multiplexedSessionCreationError != nil { + trace.TracePrintf(ctx, nil, "Error creating multiplexed session: %v", p.multiplexedSessionCreationError) + err := p.multiplexedSessionCreationError + if isUnimplementedError(err) { + logf(p.sc.logger, "Multiplexed session is not enabled on this project, continuing with regular sessions") + p.enableMultiplexSession = false + } else { + p.mu.Unlock() + // If the error is a timeout, there is a chance that the session was + // created on the server but is not known to the session pool. In this + // case, we should retry to get the session. + return nil, err + } + } + p.mu.Unlock() + } + } +} + // recycle puts session s back to the session pool's idle list, it returns true // if the session pool successfully recycles session s. func (p *sessionPool) recycle(s *session) bool { @@ -1135,6 +1331,9 @@ func (p *sessionPool) recycleLocked(s *session) bool { // If isExpire == true, the removal is triggered by session expiration and in // such cases, only idle sessions can be removed. func (p *sessionPool) remove(s *session, isExpire bool, wasInUse bool) bool { + if s.isMultiplexed { + return false + } p.mu.Lock() defer p.mu.Unlock() if isExpire && (p.numOpened <= p.MinOpened || s.getIdleList() == nil) { @@ -1142,6 +1341,7 @@ func (p *sessionPool) remove(s *session, isExpire bool, wasInUse bool) bool { // if number of open sessions is going below p.MinOpened. return false } + ol := s.setIdleList(nil) ctx := context.Background() // If the session is in the idlelist, remove it. @@ -1158,7 +1358,6 @@ func (p *sessionPool) remove(s *session, isExpire bool, wasInUse bool) bool { p.decNumInUseLocked(ctx) } p.recordStat(ctx, OpenSessionCount, int64(p.numOpened)) - // Broadcast that a session has been destroyed. close(p.mayGetSession) p.mayGetSession = make(chan struct{}) return true @@ -1178,14 +1377,21 @@ func (p *sessionPool) incNumInUse(ctx context.Context) { func (p *sessionPool) incNumInUseLocked(ctx context.Context) { p.numInUse++ - p.recordStat(ctx, SessionsCount, int64(p.numInUse), tagNumInUseSessions) - p.recordStat(ctx, AcquiredSessionsCount, 1) + p.recordStat(ctx, SessionsCount, int64(p.numInUse), tagNumInUseSessions, tag.Tag{Key: tagKeyIsMultiplexed, Value: "false"}) + p.recordStat(ctx, AcquiredSessionsCount, 1, tag.Tag{Key: tagKeyIsMultiplexed, Value: "false"}) if p.otConfig != nil { - p.recordOTStat(ctx, p.otConfig.acquiredSessionsCount, 1) + p.recordOTStat(ctx, p.otConfig.acquiredSessionsCount, 1, recordOTStatOption{attr: p.otConfig.attributeMapWithoutMultiplexed}) } if p.numInUse > p.maxNumInUse { p.maxNumInUse = p.numInUse - p.recordStat(ctx, MaxInUseSessionsCount, int64(p.maxNumInUse)) + p.recordStat(ctx, MaxInUseSessionsCount, int64(p.maxNumInUse), tag.Tag{Key: tagKeyIsMultiplexed, Value: "false"}) + } +} + +func (p *sessionPool) incNumMultiplexedInUse(ctx context.Context) { + p.recordStat(ctx, AcquiredSessionsCount, 1, tag.Tag{Key: tagKeyIsMultiplexed, Value: "true"}) + if p.otConfig != nil { + p.recordOTStat(ctx, p.otConfig.acquiredSessionsCount, 1, recordOTStatOption{attr: p.otConfig.attributeMapWithMultiplexed}) } } @@ -1196,10 +1402,17 @@ func (p *sessionPool) decNumInUseLocked(ctx context.Context) { logf(p.sc.logger, "Number of sessions in use is negative, resetting it to currSessionsCheckedOutLocked. Stack trace: %s", string(debug.Stack())) p.numInUse = p.currSessionsCheckedOutLocked() } - p.recordStat(ctx, SessionsCount, int64(p.numInUse), tagNumInUseSessions) - p.recordStat(ctx, ReleasedSessionsCount, 1) + p.recordStat(ctx, SessionsCount, int64(p.numInUse), tagNumInUseSessions, tag.Tag{Key: tagKeyIsMultiplexed, Value: "false"}) + p.recordStat(ctx, ReleasedSessionsCount, 1, tag.Tag{Key: tagKeyIsMultiplexed, Value: "false"}) if p.otConfig != nil { - p.recordOTStat(ctx, p.otConfig.releasedSessionsCount, 1) + p.recordOTStat(ctx, p.otConfig.releasedSessionsCount, 1, recordOTStatOption{attr: p.otConfig.attributeMapWithoutMultiplexed}) + } +} + +func (p *sessionPool) decNumMultiplexedInUseLocked(ctx context.Context) { + p.recordStat(ctx, ReleasedSessionsCount, 1, tag.Tag{Key: tagKeyIsMultiplexed, Value: "true"}) + if p.otConfig != nil { + p.recordOTStat(ctx, p.otConfig.releasedSessionsCount, 1, recordOTStatOption{attr: p.otConfig.attributeMapWithMultiplexed}) } } @@ -1342,6 +1555,8 @@ type healthChecker struct { pool *sessionPool // sampleInterval is the interval of sampling by the maintainer. sampleInterval time.Duration + // multiplexSessionRefreshInterval is the interval of refreshing multiplexed session. + multiplexSessionRefreshInterval time.Duration // ready is used to signal that maintainer can start running. ready chan struct{} // done is used to signal that health checker should be closed. @@ -1352,18 +1567,19 @@ type healthChecker struct { } // newHealthChecker initializes new instance of healthChecker. -func newHealthChecker(interval time.Duration, workers int, sampleInterval time.Duration, pool *sessionPool) *healthChecker { +func newHealthChecker(interval, multiplexSessionRefreshInterval time.Duration, workers int, sampleInterval time.Duration, pool *sessionPool) *healthChecker { if workers <= 0 { workers = 1 } hc := &healthChecker{ - interval: interval, - workers: workers, - pool: pool, - sampleInterval: sampleInterval, - ready: make(chan struct{}), - done: make(chan struct{}), - maintainerCancel: func() {}, + interval: interval, + multiplexSessionRefreshInterval: multiplexSessionRefreshInterval, + workers: workers, + pool: pool, + sampleInterval: sampleInterval, + ready: make(chan struct{}), + done: make(chan struct{}), + maintainerCancel: func() {}, } hc.waitWorkers.Add(1) go hc.maintainer() @@ -1371,6 +1587,9 @@ func newHealthChecker(interval time.Duration, workers int, sampleInterval time.D hc.waitWorkers.Add(1) go hc.worker(i) } + if hc.pool.enableMultiplexSession { + go hc.multiplexSessionWorker() + } return hc } @@ -1462,6 +1681,9 @@ func (hc *healthChecker) markDone(s *session) { // healthCheck checks the health of the session and pings it if needed. func (hc *healthChecker) healthCheck(s *session) { defer hc.markDone(s) + if s.isMultiplexed { + return + } if !s.pool.isValid() { // Session pool is closed, perform a garbage collection. s.destroy(false, false) @@ -1551,7 +1773,7 @@ func (hc *healthChecker) maintainer() { now := time.Now() if now.After(hc.pool.lastResetTime.Add(10 * time.Minute)) { hc.pool.maxNumInUse = hc.pool.numInUse - hc.pool.recordStat(context.Background(), MaxInUseSessionsCount, int64(hc.pool.maxNumInUse)) + hc.pool.recordStat(context.Background(), MaxInUseSessionsCount, int64(hc.pool.maxNumInUse), tag.Tag{Key: tagKeyIsMultiplexed, Value: "false"}) hc.pool.lastResetTime = now } hc.pool.mu.Unlock() @@ -1658,6 +1880,36 @@ func (hc *healthChecker) shrinkPool(ctx context.Context, shrinkToNumSessions uin } } +func (hc *healthChecker) multiplexSessionWorker() { + for { + if hc.isClosing() { + return + } + hc.pool.mu.Lock() + createTime := time.Now() + s := hc.pool.multiplexedSession + if s != nil { + createTime = hc.pool.multiplexedSession.createTime + } + hc.pool.mu.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if createTime.Add(multiplexSessionRefreshInterval).Before(time.Now()) { + // Multiplexed session is idle for more than 7 days, replace it. + hc.pool.multiplexedSessionReq <- muxSessionCreateRequest{force: true, ctx: ctx} + // wait for the new multiplexed session to be created. + <-hc.pool.mayGetMultiplexedSession + } + // Sleep for a while to avoid burning CPU. + select { + case <-time.After(hc.multiplexSessionRefreshInterval): + cancel() + case <-hc.done: + cancel() + return + } + } +} + // maxUint64 returns the maximum of two uint64. func maxUint64(a, b uint64) uint64 { if a > b { @@ -1691,6 +1943,17 @@ func isSessionNotFoundError(err error) bool { return strings.Contains(err.Error(), "Session not found") } +// isUnimplementedError returns true if the gRPC error code is Unimplemented. +func isUnimplementedError(err error) bool { + if err == nil { + return false + } + if ErrCode(err) == codes.Unimplemented { + return true + } + return false +} + func isFailedInlineBeginTransaction(err error) bool { if err == nil { return false diff --git a/spanner/session_test.go b/spanner/session_test.go index 0a9181b09258..2cf12995abce 100644 --- a/spanner/session_test.go +++ b/spanner/session_test.go @@ -196,7 +196,11 @@ func TestTakeFromIdleList(t *testing.T) { // Make sure maintainer keeps the idle sessions. server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ - SessionPoolConfig: SessionPoolConfig{MaxIdle: 10, MaxOpened: 10}, + SessionPoolConfig: SessionPoolConfig{ + MaxIdle: 10, + MaxOpened: 10, + healthCheckSampleInterval: 10 * time.Millisecond, + }, }) defer teardown() sp := client.idleSessions @@ -232,6 +236,9 @@ func TestTakeFromIdleList(t *testing.T) { if len(gotSessions) != 10 { t.Fatalf("got %v unique sessions, want 10", len(gotSessions)) } + if sp.multiplexedSession != nil { + gotSessions[sp.multiplexedSession.getID()] = true + } if !testEqual(gotSessions, wantSessions) { t.Fatalf("got sessions: %v, want %v", gotSessions, wantSessions) } @@ -361,12 +368,12 @@ func TestSessionLeak(t *testing.T) { t.Fatalf("Idle sessions count mismatch\nGot: %d\nWant: %d\n", g, w) } // The checked out session should contain a stack trace. - if single.sh.stack == nil { + if single.sh.stack == nil && !isMultiplexEnabled { t.Fatalf("Missing stacktrace from session handle") } stack := fmt.Sprintf("%s", single.sh.stack) testMethod := "TestSessionLeak" - if !strings.Contains(stack, testMethod) { + if !strings.Contains(stack, testMethod) && !isMultiplexEnabled { t.Fatalf("Stacktrace does not contain '%s'\nGot: %s", testMethod, stack) } // Return the session to the pool. @@ -395,13 +402,18 @@ func TestSessionLeak(t *testing.T) { iter2 := single2.Query(ctxWithTimeout, NewStatement(SelectFooFromBar)) _, gotErr := iter2.Next() wantErr := client.idleSessions.errGetSessionTimeoutWithTrackedSessionHandles(codes.DeadlineExceeded) + if isMultiplexEnabled { + wantErr = nil + } // The error should contain the stacktraces of all the checked out // sessions. if !testEqual(gotErr, wantErr) { t.Fatalf("Error mismatch on iterating result set.\nGot: %v\nWant: %v\n", gotErr, wantErr) } - if !strings.Contains(gotErr.Error(), testMethod) { - t.Fatalf("Error does not contain '%s'\nGot: %s", testMethod, gotErr.Error()) + if wantErr != nil { + if !strings.Contains(gotErr.Error(), testMethod) { + t.Fatalf("Error does not contain '%s'\nGot: %s", testMethod, gotErr.Error()) + } } // Close iterators to check sessions back into the pool before closing. iter2.Stop() @@ -447,8 +459,10 @@ func TestSessionLeak_WhenInactiveTransactions_RemoveSessionsFromPool(t *testing. // The checked out session should contain a stack trace as Logging is true. single.sh.mu.Lock() if single.sh.stack == nil { - single.sh.mu.Unlock() - t.Fatalf("Missing stacktrace from session handle") + if !isMultiplexEnabled { + single.sh.mu.Unlock() + t.Fatalf("Missing stacktrace from session handle") + } } if g, w := single.sh.eligibleForLongRunning, false; g != w { single.sh.mu.Unlock() @@ -464,7 +478,6 @@ func TestSessionLeak_WhenInactiveTransactions_RemoveSessionsFromPool(t *testing. // The session should have been removed from pool. p.mu.Lock() - defer p.mu.Unlock() if g, w := p.idleList.Len(), 0; g != w { t.Fatalf("Idle Sessions in pool, count mismatch\nGot: %d\nWant: %d\n", g, w) } @@ -474,14 +487,19 @@ func TestSessionLeak_WhenInactiveTransactions_RemoveSessionsFromPool(t *testing. if g, w := p.numOpened, uint64(0); g != w { t.Fatalf("Session pool size mismatch\nGot: %d\nWant: %d\n", g, w) } - if g, w := p.numOfLeakedSessionsRemoved, uint64(1); g != w { + expectedLeakedSession := uint64(1) + if isMultiplexEnabled { + expectedLeakedSession = 0 + } + if g, w := p.numOfLeakedSessionsRemoved, expectedLeakedSession; g != w { t.Fatalf("Number of leaked sessions removed mismatch\nGot: %d\nWant: %d\n", g, w) } + p.mu.Unlock() iter.Stop() } func TestMaintainer_LongRunningTransactionsCleanup_IfClose_VerifyInactiveSessionsClosed(t *testing.T) { - t.Parallel() + ctx := context.Background() _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ @@ -1079,6 +1097,7 @@ func TestMaxBurst(t *testing.T) { }, }) defer teardown() + sp := client.idleSessions // Will cause session creation RPC to be retried forever. @@ -1271,6 +1290,13 @@ func TestHealthCheckScheduler(t *testing.T) { gotPings[p]++ } for s := range liveSessions { + if strings.Contains(s, "multiplexed") { + // no pings for multiplexed sessions + if gotPings[s] > 0 { + return fmt.Errorf("got %v healthchecks on multiplexed session %v, want 0", gotPings[s], s) + } + continue + } want := int64(20) if got := gotPings[s]; got < want/2 || got > want+want/2 { // This is an unnacceptable amount of pings. @@ -1653,6 +1679,82 @@ func TestMaintainer(t *testing.T) { }) } +func TestMultiplexSessionWorker(t *testing.T) { + t.Parallel() + if !isMultiplexEnabled { + t.Skip("Multiplexing is not enabled") + } + ctx := context.Background() + + server, client, teardown := setupMockedTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MultiplexSessionCheckInterval: time.Millisecond, + }, + }) + defer teardown() + _, err := client.Single().ReadRow(ctx, "Albums", Key{"foo"}, []string{"SingerId", "AlbumId", "AlbumTitle"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + sp := client.idleSessions + waitFor(t, func() error { + sp.mu.Lock() + defer sp.mu.Unlock() + if sp.multiplexedSession == nil { + return errInvalidSessionPool + } + return nil + }) + if !testEqual(uint(1), server.TestSpanner.TotalSessionsCreated()) { + t.Fatalf("expected 1 session to be created, got %v", server.TestSpanner.TotalSessionsCreated()) + } + // Will cause session creation RPC to be fail. + server.TestSpanner.PutExecutionTime(MethodCreateSession, + SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.PermissionDenied, "try later")}, + KeepError: true, + }) + // To save test time, update the multiplex session creation time to trigger refresh. + sp.mu.Lock() + oldMultiplexedSession := sp.multiplexedSession.id + sp.multiplexedSession.createTime = sp.multiplexedSession.createTime.Add(-10 * 24 * time.Hour) + sp.mu.Unlock() + + // Subsequent read should use existing session. + _, err = client.Single().ReadRow(ctx, "Albums", Key{"foo"}, []string{"SingerId", "AlbumId", "AlbumTitle"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // To save test time, update the multiplex session creation time to trigger refresh. + sp.mu.Lock() + multiplexSessionID := sp.multiplexedSession.id + sp.mu.Unlock() + if !testEqual(oldMultiplexedSession, multiplexSessionID) { + t.Errorf("TestMultiplexSessionWorker expected multiplexed session id to be=%v, got: %v", oldMultiplexedSession, multiplexSessionID) + } + + // Let the first session request succeed. + server.TestSpanner.Freeze() + server.TestSpanner.PutExecutionTime(MethodCreateSession, SimulatedExecutionTime{}) + server.TestSpanner.Unfreeze() + + waitFor(t, func() error { + if server.TestSpanner.TotalSessionsCreated() != 2 { + return errInvalidSessionPool + } + return nil + }) + + sp.mu.Lock() + multiplexSessionID = sp.multiplexedSession.id + sp.mu.Unlock() + + if testEqual(oldMultiplexedSession, multiplexSessionID) { + t.Errorf("TestMultiplexSessionWorker expected multiplexed session id to be different, got: %v", multiplexSessionID) + } +} + // Tests that the session pool creates up to MinOpened connections. // // Historical context: This test also checks that a low @@ -1689,6 +1791,11 @@ loop: numOpened = sp.idleList.Len() sp.mu.Unlock() if numOpened == 10 { + if isMultiplexEnabled { + if sp.multiplexedSession == nil { + continue + } + } break loop } } diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go index ac5a37b34cd7..e0a56f9af7dc 100644 --- a/spanner/sessionclient.go +++ b/spanner/sessionclient.go @@ -71,13 +71,13 @@ func (cg *clientIDGenerator) nextID(database string) string { type sessionConsumer interface { // sessionReady is called when a session has been created and is ready for // use. - sessionReady(s *session) + sessionReady(ctx context.Context, s *session) // sessionCreationFailed is called when the creation of a sub-batch of // sessions failed. The numSessions argument specifies the number of // sessions that could not be created as a result of this error. A // consumer may receive multiple errors per batch. - sessionCreationFailed(err error, numSessions int32) + sessionCreationFailed(ctx context.Context, err error, numSessions int32, isMultiplexed bool) } // sessionClient creates sessions for a database, either in batches or one at a @@ -254,12 +254,12 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC if closed { err := spannerErrorf(codes.Canceled, "Session client closed") trace.TracePrintf(ctx, nil, "Session client closed while creating a batch of %d sessions: %v", createCount, err) - consumer.sessionCreationFailed(err, remainingCreateCount) + consumer.sessionCreationFailed(ctx, err, remainingCreateCount, false) break } if ctx.Err() != nil { trace.TracePrintf(ctx, nil, "Context error while creating a batch of %d sessions: %v", createCount, ctx.Err()) - consumer.sessionCreationFailed(ToSpannerError(ctx.Err()), remainingCreateCount) + consumer.sessionCreationFailed(ctx, ToSpannerError(ctx.Err()), remainingCreateCount, false) break } var mdForGFELatency metadata.MD @@ -294,13 +294,13 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC } if err != nil { trace.TracePrintf(ctx, nil, "Error creating a batch of %d sessions: %v", remainingCreateCount, err) - consumer.sessionCreationFailed(ToSpannerError(err), remainingCreateCount) + consumer.sessionCreationFailed(ctx, ToSpannerError(err), remainingCreateCount, false) break } actuallyCreated := int32(len(response.Session)) trace.TracePrintf(ctx, nil, "Received a batch of %d sessions", actuallyCreated) for _, s := range response.Session { - consumer.sessionReady(&session{valid: true, client: client, id: s.Name, createTime: time.Now(), md: md, logger: sc.logger}) + consumer.sessionReady(ctx, &session{valid: true, client: client, id: s.Name, createTime: time.Now(), md: md, logger: sc.logger}) } if actuallyCreated < remainingCreateCount { // Spanner could return less sessions than requested. In that case, we @@ -313,6 +313,62 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC } } +func (sc *sessionClient) executeCreateMultiplexedSession(ctx context.Context, client *vkit.Client, md metadata.MD, consumer sessionConsumer) { + ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.CreateSession") + defer func() { trace.EndSpan(ctx, nil) }() + trace.TracePrintf(ctx, nil, "Creating a multiplexed session") + sc.mu.Lock() + closed := sc.closed + sc.mu.Unlock() + if closed { + err := spannerErrorf(codes.Canceled, "Session client closed") + trace.TracePrintf(ctx, nil, "Session client closed while creating a multiplexed session: %v", err) + return + } + if ctx.Err() != nil { + trace.TracePrintf(ctx, nil, "Context error while creating a multiplexed session: %v", ctx.Err()) + consumer.sessionCreationFailed(ctx, ToSpannerError(ctx.Err()), 1, true) + return + } + var mdForGFELatency metadata.MD + response, err := client.CreateSession(contextWithOutgoingMetadata(ctx, sc.md, sc.disableRouteToLeader), &sppb.CreateSessionRequest{ + Database: sc.database, + // Multiplexed sessions do not support labels. + Session: &sppb.Session{CreatorRole: sc.databaseRole, Multiplexed: true}, + }, gax.WithGRPCOptions(grpc.Header(&mdForGFELatency))) + + if getGFELatencyMetricsFlag() && mdForGFELatency != nil { + _, instance, database, err := parseDatabaseName(sc.database) + if err != nil { + trace.TracePrintf(ctx, nil, "Error getting instance and database name: %v", err) + } + // Errors should not prevent initializing the session pool. + ctxGFE, err := tag.New(ctx, + tag.Upsert(tagKeyClientID, sc.id), + tag.Upsert(tagKeyDatabase, database), + tag.Upsert(tagKeyInstance, instance), + tag.Upsert(tagKeyLibVersion, internal.Version), + ) + if err != nil { + trace.TracePrintf(ctx, nil, "Error in adding tags in CreateSession for GFE Latency: %v", err) + } + err = captureGFELatencyStats(ctxGFE, mdForGFELatency, "executeCreateSession") + if err != nil { + trace.TracePrintf(ctx, nil, "Error in Capturing GFE Latency and Header Missing count. Try disabling and rerunning. Error: %v", err) + } + } + if metricErr := recordGFELatencyMetricsOT(ctx, mdForGFELatency, "executeCreateSession", sc.otConfig); metricErr != nil { + trace.TracePrintf(ctx, nil, "Error in recording GFE Latency through OpenTelemetry. Error: %v", metricErr) + } + if err != nil { + trace.TracePrintf(ctx, nil, "Error creating a multiplexed sessions: %v", err) + consumer.sessionCreationFailed(ctx, ToSpannerError(err), 1, true) + return + } + consumer.sessionReady(ctx, &session{valid: true, client: client, id: response.Name, createTime: time.Now(), md: md, logger: sc.logger, isMultiplexed: response.Multiplexed}) + trace.TracePrintf(ctx, nil, "Finished creating multiplexed sessions") +} + func (sc *sessionClient) sessionWithID(id string) (*session, error) { sc.mu.Lock() defer sc.mu.Unlock() diff --git a/spanner/sessionclient_test.go b/spanner/sessionclient_test.go index 02234eed0fbd..d1813754267b 100644 --- a/spanner/sessionclient_test.go +++ b/spanner/sessionclient_test.go @@ -52,14 +52,14 @@ type testConsumer struct { receivedAll chan struct{} } -func (tc *testConsumer) sessionReady(s *session) { +func (tc *testConsumer) sessionReady(_ context.Context, s *session) { tc.mu.Lock() defer tc.mu.Unlock() tc.sessions = append(tc.sessions, s) tc.checkReceivedAll() } -func (tc *testConsumer) sessionCreationFailed(err error, num int32) { +func (tc *testConsumer) sessionCreationFailed(_ context.Context, err error, num int32, _ bool) { tc.mu.Lock() defer tc.mu.Unlock() tc.errors = append(tc.errors, &testSessionCreateError{ @@ -148,7 +148,11 @@ func TestCreateAndCloseSession(t *testing.T) { if s == nil { t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s) } - if server.TestSpanner.TotalSessionsCreated() != 1 { + expectedCount := uint(1) + if isMultiplexEnabled { + expectedCount = 2 + } + if server.TestSpanner.TotalSessionsCreated() != expectedCount { t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsCreated(), 1) } s.delete(context.Background()) @@ -174,7 +178,11 @@ func TestCreateSessionWithDatabaseRole(t *testing.T) { if s == nil { t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s) } - if g, w := server.TestSpanner.TotalSessionsCreated(), uint(1); g != w { + expectedCount := uint(1) + if isMultiplexEnabled { + expectedCount = 2 + } + if g, w := server.TestSpanner.TotalSessionsCreated(), expectedCount; g != w { t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", g, w) } @@ -246,6 +254,18 @@ func TestBatchCreateAndCloseSession(t *testing.T) { if err != nil { t.Fatal(err) } + + if isMultiplexEnabled { + waitFor(t, func() error { + client.idleSessions.mu.Lock() + defer client.idleSessions.mu.Unlock() + if client.idleSessions.multiplexedSession == nil { + return fmt.Errorf("multiplexed session not created yet") + } + return nil + }) + } + consumer := newTestConsumer(numSessions) client.sc.batchCreateSessions(numSessions, true, consumer) <-consumer.receivedAll @@ -253,8 +273,12 @@ func TestBatchCreateAndCloseSession(t *testing.T) { t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), numSessions) } created := server.TestSpanner.TotalSessionsCreated() - prevCreated - if created != uint(numSessions) { - t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, numSessions) + expectedNumSessions := numSessions + if isMultiplexEnabled { + expectedNumSessions++ + } + if created != uint(expectedNumSessions) { + t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, expectedNumSessions) } // Check that all channels are used evenly. channelCounts := make(map[*vkit.Client]int32) @@ -282,7 +306,15 @@ func TestBatchCreateAndCloseSession(t *testing.T) { } for a, c := range connCounts { if c != 1 { - t.Fatalf("connection %q used an unexpected number of times\ngot: %v\nwant %v", a, c, 1) + if isMultiplexEnabled { + // The multiplexed session creation will use one of the connections + if c != 2 { + t.Fatalf("connection %q used an unexpected number of times\ngot: %v\nwant %v", a, c, 2) + } + } else { + t.Fatalf("connection %q used an unexpected number of times\ngot: %v\nwant %v", a, c, 1) + } + } } // Delete the sessions. @@ -311,6 +343,7 @@ func TestBatchCreateSessionsWithDatabaseRole(t *testing.T) { } server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc, DatabaseRole: "test"}) defer teardown() + ctx := context.Background() consumer := newTestConsumer(1) @@ -319,7 +352,11 @@ func TestBatchCreateSessionsWithDatabaseRole(t *testing.T) { if g, w := len(consumer.sessions), 1; g != w { t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", g, w) } - if g, w := server.TestSpanner.TotalSessionsCreated(), uint(1); g != w { + expectedCount := uint(1) + if isMultiplexEnabled { + expectedCount = 2 + } + if g, w := server.TestSpanner.TotalSessionsCreated(), expectedCount; g != w { t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", g, w) } s := consumer.sessions[0] @@ -435,10 +472,22 @@ func TestBatchCreateSessions_ServerExhausted(t *testing.T) { }, }) defer teardown() + if isMultiplexEnabled { + waitFor(t, func() error { + if client.idleSessions.multiplexedSession == nil { + return fmt.Errorf("multiplexed session not created yet") + } + return nil + }) + } numSessions := int32(100) maxSessions := int32(50) // Ensure that the server will never return more than 50 sessions in total. - server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions) + if isMultiplexEnabled { + server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions + 1) + } else { + server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions) + } consumer := newTestConsumer(numSessions) client.sc.batchCreateSessions(numSessions, true, consumer) <-consumer.receivedAll diff --git a/spanner/spannertest/inmem.go b/spanner/spannertest/inmem.go index e73f99ca9059..21b929ce0598 100644 --- a/spanner/spannertest/inmem.go +++ b/spanner/spannertest/inmem.go @@ -347,11 +347,20 @@ func (s *server) GetDatabaseDdl(ctx context.Context, req *adminpb.GetDatabaseDdl func (s *server) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { //s.logf("CreateSession(%q)", req.Database) - return s.newSession(), nil + isMultiplexed := false + if req.Session != nil && req.Session.Multiplexed { + isMultiplexed = true + } + sess := s.newSession(isMultiplexed) + + return sess, nil } -func (s *server) newSession() *spannerpb.Session { +func (s *server) newSession(isMultiplexed bool) *spannerpb.Session { id := genRandomSession() + if isMultiplexed { + id = "multiplexed-" + id + } now := time.Now() sess := &session{ name: id, @@ -359,13 +368,15 @@ func (s *server) newSession() *spannerpb.Session { lastUse: now, transactions: make(map[string]*transaction), } + sess.ctx, sess.cancel = context.WithCancel(context.Background()) s.mu.Lock() s.sessions[id] = sess s.mu.Unlock() - - return sess.Proto() + sesspb := sess.Proto() + sesspb.Multiplexed = isMultiplexed + return sesspb } func (s *server) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { @@ -373,7 +384,7 @@ func (s *server) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCr var sessions []*spannerpb.Session for i := int32(0); i < req.GetSessionCount(); i++ { - sessions = append(sessions, s.newSession()) + sessions = append(sessions, s.newSession(false)) } return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil diff --git a/spanner/stats.go b/spanner/stats.go index 8e77ecf3dbc5..bc8176b6d2a0 100644 --- a/spanner/stats.go +++ b/spanner/stats.go @@ -31,12 +31,14 @@ const statsPrefix = "cloud.google.com/go/spanner/" // Deprecated: OpenCensus project is deprecated. Use OpenTelemetry for capturing metrics. var ( - tagKeyClientID = tag.MustNewKey("client_id") - tagKeyDatabase = tag.MustNewKey("database") - tagKeyInstance = tag.MustNewKey("instance_id") - tagKeyLibVersion = tag.MustNewKey("library_version") - tagKeyType = tag.MustNewKey("type") - tagCommonKeys = []tag.Key{tagKeyClientID, tagKeyDatabase, tagKeyInstance, tagKeyLibVersion} + tagKeyClientID = tag.MustNewKey("client_id") + tagKeyDatabase = tag.MustNewKey("database") + tagKeyInstance = tag.MustNewKey("instance_id") + tagKeyLibVersion = tag.MustNewKey("library_version") + tagKeyType = tag.MustNewKey("type") + tagKeyIsMultiplexed = tag.MustNewKey("is_multiplexed") + + tagCommonKeys = []tag.Key{tagKeyClientID, tagKeyDatabase, tagKeyInstance, tagKeyLibVersion} tagNumInUseSessions = tag.Tag{Key: tagKeyType, Value: "num_in_use_sessions"} tagNumSessions = tag.Tag{Key: tagKeyType, Value: "num_sessions"} diff --git a/spanner/test/opentelemetry/test/ot_metrics_test.go b/spanner/test/opentelemetry/test/ot_metrics_test.go index 444d87e0e325..a29de04b21c1 100644 --- a/spanner/test/opentelemetry/test/ot_metrics_test.go +++ b/spanner/test/opentelemetry/test/ot_metrics_test.go @@ -21,6 +21,7 @@ package test import ( "context" "errors" + "strconv" "testing" "time" @@ -73,6 +74,12 @@ func TestOTMetrics_SessionPool(t *testing.T) { defer teardown() client.Single().ReadRow(context.Background(), "Users", spanner.Key{"alice"}, []string{"email"}) + expectedOpenSessionCount := int64(25) + expectedAcquiredSessionsCount := int64(1) + if isMultiplexEnabled { + expectedOpenSessionCount = 1 + expectedAcquiredSessionsCount = int64(2) + } for _, test := range []struct { name string expectedMetric metricdata.Metrics @@ -87,7 +94,7 @@ func TestOTMetrics_SessionPool(t *testing.T) { DataPoints: []metricdata.DataPoint[int64]{ { Attributes: attribute.NewSet(getAttributes(client.ClientID())...), - Value: 25, + Value: expectedOpenSessionCount, }, }, }, @@ -118,7 +125,7 @@ func TestOTMetrics_SessionPool(t *testing.T) { Data: metricdata.Gauge[int64]{ DataPoints: []metricdata.DataPoint[int64]{ { - Attributes: attribute.NewSet(getAttributes(client.ClientID())...), + Attributes: attribute.NewSet(append(getAttributes(client.ClientID()), attribute.Key("is_multiplexed").String(strconv.FormatBool(isMultiplexEnabled)))...), Value: 1, }, }, @@ -134,8 +141,8 @@ func TestOTMetrics_SessionPool(t *testing.T) { Data: metricdata.Sum[int64]{ DataPoints: []metricdata.DataPoint[int64]{ { - Attributes: attribute.NewSet(getAttributes(client.ClientID())...), - Value: 1, + Attributes: attribute.NewSet(append(getAttributes(client.ClientID()), attribute.Key("is_multiplexed").String(strconv.FormatBool(isMultiplexEnabled)))...), + Value: expectedAcquiredSessionsCount, // }, }, Temporality: metricdata.CumulativeTemporality, @@ -152,8 +159,8 @@ func TestOTMetrics_SessionPool(t *testing.T) { Data: metricdata.Sum[int64]{ DataPoints: []metricdata.DataPoint[int64]{ { - Attributes: attribute.NewSet(getAttributes(client.ClientID())...), - Value: 1, + Attributes: attribute.NewSet(append(getAttributes(client.ClientID()), attribute.Key("is_multiplexed").String(strconv.FormatBool(isMultiplexEnabled)))...), + Value: expectedAcquiredSessionsCount, // should be same as acquired sessions count }, }, Temporality: metricdata.CumulativeTemporality, @@ -165,6 +172,26 @@ func TestOTMetrics_SessionPool(t *testing.T) { t.Run(test.name, func(t *testing.T) { metricName := test.expectedMetric.Name expectedMetric := test.expectedMetric + if isMultiplexEnabled { + if metricName == "spanner/max_in_use_sessions" { + t.Skip("Skipping test for " + metricName + " as it is not applicable for multiplexed sessions") + } + if metricName == "spanner/open_session_count" { + // For multiplexed sessions, the open session count should be 1. + expectedMetric.Data = metricdata.Gauge[int64]{ + DataPoints: []metricdata.DataPoint[int64]{ + { + Attributes: attribute.NewSet(getAttributes(client.ClientID())...), + Value: 0, + }, + { + Attributes: attribute.NewSet(append(getAttributes(client.ClientID()), attribute.Key("is_multiplexed").String(strconv.FormatBool(isMultiplexEnabled)))...), + Value: 1, + }, + }, + } + } + } validateOTMetric(ctx, t, te, metricName, expectedMetric) }) } @@ -184,15 +211,20 @@ func TestOTMetrics_SessionPool_SessionsCount(t *testing.T) { // Wait for the session pool initialization to finish. expectedReads := spanner.DefaultSessionPoolConfig.MinOpened waitFor(t, func() error { - if uint64(server.TestSpanner.TotalSessionsCreated()) == expectedReads { - return nil + if isMultiplexEnabled { + if uint64(server.TestSpanner.TotalSessionsCreated()) == expectedReads+1 { + return nil + } + } else { + if uint64(server.TestSpanner.TotalSessionsCreated()) == expectedReads { + return nil + } } return errors.New("Not yet initialized") }) client.Single().ReadRow(context.Background(), "Users", spanner.Key{"alice"}, []string{"email"}) - attributesNumInUseSessions := append(getAttributes(client.ClientID()), attribute.Key("type").String("num_in_use_sessions")) attributesNumSessions := append(getAttributes(client.ClientID()), attribute.Key("type").String("num_sessions")) expectedMetricData := metricdata.Metrics{ @@ -202,7 +234,7 @@ func TestOTMetrics_SessionPool_SessionsCount(t *testing.T) { Data: metricdata.Gauge[int64]{ DataPoints: []metricdata.DataPoint[int64]{ { - Attributes: attribute.NewSet(attributesNumInUseSessions...), + Attributes: attribute.NewSet(append(getAttributes(client.ClientID()), attribute.Key("type").String("num_in_use_sessions"), attribute.Key("is_multiplexed").String("false"))...), Value: 0, }, { @@ -217,6 +249,10 @@ func TestOTMetrics_SessionPool_SessionsCount(t *testing.T) { } func TestOTMetrics_SessionPool_GetSessionTimeoutsCount(t *testing.T) { + if isMultiplexEnabled { + // multiplexed sessions will be always available in background, so this metric is not applicable. + t.Skip("Skipping test for GetSessionTimeoutsCount as it is not applicable for multiplexed sessions") + } ctx1 := context.Background() te := newOpenTelemetryTestExporter(false, false) t.Cleanup(func() { @@ -242,7 +278,7 @@ func TestOTMetrics_SessionPool_GetSessionTimeoutsCount(t *testing.T) { Data: metricdata.Sum[int64]{ DataPoints: []metricdata.DataPoint[int64]{ { - Attributes: attribute.NewSet(getAttributes(client.ClientID())...), + Attributes: attribute.NewSet(append(getAttributes(client.ClientID()), attribute.Key("is_multiplexed").String(strconv.FormatBool(isMultiplexEnabled)))...), Value: 1, }, }, @@ -296,7 +332,11 @@ func TestOTMetrics_GFELatency(t *testing.T) { } } - attributeGFELatency := append(getAttributes(client.ClientID()), attribute.Key("grpc_client_method").String("executeBatchCreateSessions")) + method := "executeBatchCreateSessions" + if isMultiplexEnabled { + method = "executeCreateSession" + } + attributeGFELatency := append(getAttributes(client.ClientID()), attribute.Key("grpc_client_method").String(method)) resourceMetrics, err := te.metrics(context.Background()) if err != nil { @@ -349,6 +389,10 @@ func TestOTMetrics_GFELatency(t *testing.T) { IsMonotonic: true, }, } + if isMultiplexEnabled { + // add datapoint from initial wait for multiplexed session to be available + expectedMetricData.Data.(metricdata.Sum[int64]).DataPoints[0].Value = 2 + } metricdatatest.AssertEqual(t, expectedMetricData, resourceMetrics.ScopeMetrics[0].Metrics[idx1], metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreExemplars()) } diff --git a/spanner/test/opentelemetry/test/ot_traces_test.go b/spanner/test/opentelemetry/test/ot_traces_test.go index c2a4ed942f98..5d9e61ef541b 100644 --- a/spanner/test/opentelemetry/test/ot_traces_test.go +++ b/spanner/test/opentelemetry/test/ot_traces_test.go @@ -57,6 +57,11 @@ func TestSpannerTracesWithOpenTelemetry(t *testing.T) { defer teardown() waitFor(t, func() error { + if isMultiplexEnabled { + if uint64(server.TestSpanner.TotalSessionsCreated()) == minOpened+1 { + return nil + } + } if uint64(server.TestSpanner.TotalSessionsCreated()) == minOpened { return nil } diff --git a/spanner/test/opentelemetry/test/test_util.go b/spanner/test/opentelemetry/test/test_util.go index 896551877b66..0b89aee19e15 100644 --- a/spanner/test/opentelemetry/test/test_util.go +++ b/spanner/test/opentelemetry/test/test_util.go @@ -22,6 +22,7 @@ package test import ( "context" "fmt" + "os" "testing" "time" @@ -29,6 +30,14 @@ import ( stestutil "cloud.google.com/go/spanner/internal/testutil" ) +var ( + isMultiplexEnabled = getMultiplexEnableFlag() +) + +func getMultiplexEnableFlag() bool { + return os.Getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS") == "true" +} + func setupMockedTestServerWithConfig(t *testing.T, config spanner.ClientConfig) (server *stestutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { server, opts, serverTeardown := stestutil.NewMockedSpannerInMemTestServer(t) ctx := context.Background() @@ -37,6 +46,13 @@ func setupMockedTestServerWithConfig(t *testing.T, config spanner.ClientConfig) if err != nil { t.Fatal(err) } + if isMultiplexEnabled { + // trigger R/O txn for multiplexed session creation to avoid flakiness + waitFor(t, func() error { + iter := client.Single().Query(ctx, spanner.NewStatement(stestutil.SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + return iter.Do(func(_ *spanner.Row) error { return nil }) + }) + } return server, client, func() { client.Close() serverTeardown() diff --git a/spanner/transaction.go b/spanner/transaction.go index 8bce7cc1e710..ab10ba433177 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -776,7 +776,7 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { }() // Retry the BeginTransaction call if a 'Session not found' is returned. for { - sh, err = t.sp.take(ctx) + sh, err = t.sp.takeMultiplexed(ctx) if err != nil { return err } @@ -866,7 +866,7 @@ func (t *ReadOnlyTransaction) acquireSingleUse(ctx context.Context) (*sessionHan }, }, } - sh, err := t.sp.take(ctx) + sh, err := t.sp.takeMultiplexed(ctx) if err != nil { return nil, nil, err } diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 9d23c763462d..90a3d7867b7d 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -56,14 +56,18 @@ func TestSingle(t *testing.T) { } // Only one BatchCreateSessionsRequest is sent. - if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{&sppb.BatchCreateSessionsRequest{}}); err != nil { + expectedReqs := []interface{}{&sppb.BatchCreateSessionsRequest{}} + if isMultiplexEnabled { + expectedReqs = []interface{}{&sppb.CreateSessionRequest{}} + } + if _, err := shouldHaveReceived(server.TestSpanner, expectedReqs); err != nil { t.Fatal(err) } } // Re-using ReadOnlyTransaction: can recover from acquire failure. func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { - t.Parallel() + ctx := context.Background() server, client, teardown := setupMockedTestServer(t) defer teardown() @@ -316,17 +320,20 @@ func TestBatchDML_WithMultipleDML(t *testing.T) { if err != nil { t.Fatal(err) } - - if got, want := gotReqs[1].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if got, want := gotReqs[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[2].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want { + if got, want := gotReqs[2+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[3].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want { + if got, want := gotReqs[3+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[4].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want { + if got, want := gotReqs[4+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want { t.Errorf("got %d, want %d", got, want) } } @@ -602,17 +609,20 @@ func TestBatchDML_StatementBased_WithMultipleDML(t *testing.T) { if err != nil { t.Fatal(err) } - - if got, want := gotReqs[2].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if got, want := gotReqs[2+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[3].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want { + if got, want := gotReqs[3+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[4].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want { + if got, want := gotReqs[4+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[5].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want { + if got, want := gotReqs[5+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want { t.Errorf("got %d, want %d", got, want) } } @@ -670,22 +680,26 @@ func TestPriorityInQueryOptions(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := gotReqs[2].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_LOW; got != want { + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if got, want := gotReqs[2+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_LOW; got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[3].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_MEDIUM; got != want { + if got, want := gotReqs[3+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_MEDIUM; got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[4].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_LOW; got != want { + if got, want := gotReqs[4+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_LOW; got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[5].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_LOW; got != want { + if got, want := gotReqs[5+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_LOW; got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[6].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_LOW; got != want { + if got, want := gotReqs[6+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_LOW; got != want { t.Errorf("got %d, want %d", got, want) } - if got, want := gotReqs[7].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_MEDIUM; got != want { + if got, want := gotReqs[7+muxCreateBuffer].(*sppb.ExecuteSqlRequest).RequestOptions.Priority, sppb.RequestOptions_PRIORITY_MEDIUM; got != want { t.Errorf("got %d, want %d", got, want) } } @@ -702,6 +716,31 @@ func shouldHaveReceived(server InMemSpannerServer, want []interface{}) ([]interf // Compares expected requests (want) with actual requests (got). func compareRequests(want []interface{}, got []interface{}) error { + if reflect.TypeOf(want[0]) != reflect.TypeOf(&sppb.BatchCreateSessionsRequest{}) { + sessReq := 0 + for i := 0; i < len(want); i++ { + if reflect.TypeOf(want[i]) == reflect.TypeOf(&sppb.BatchCreateSessionsRequest{}) { + sessReq = i + break + } + } + want[0], want[sessReq] = want[sessReq], want[0] + } + if isMultiplexEnabled { + if reflect.TypeOf(want[0]) != reflect.TypeOf(&sppb.CreateSessionRequest{}) { + want = append([]interface{}{&sppb.CreateSessionRequest{}}, want...) + } + if reflect.TypeOf(got[0]) == reflect.TypeOf(&sppb.BatchCreateSessionsRequest{}) { + muxSessionIndex := 0 + for i := 0; i < len(got); i++ { + if reflect.TypeOf(got[i]) == reflect.TypeOf(&sppb.CreateSessionRequest{}) { + muxSessionIndex = i + break + } + } + got[0], got[muxSessionIndex] = got[muxSessionIndex], got[0] + } + } if len(got) != len(want) { var gotMsg string for _, r := range got {