diff --git a/spanner/session.go b/spanner/session.go index 76d611fe262d..a6831b8c09b1 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -24,6 +24,7 @@ import ( "log" "math" "math/rand" + "os" "runtime/debug" "strings" "sync" @@ -37,6 +38,7 @@ import ( "go.opencensus.io/tag" octrace "go.opencensus.io/trace" "go.opentelemetry.io/otel/metric" + "google.golang.org/api/option" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" ) @@ -85,6 +87,8 @@ type sessionHandle struct { // session is a pointer to a session object. Transactions never need to // access it directly. session *session + // client is the RPC channel to Cloud Spanner. It is set only once during session acquisition. + client *vkit.Client // checkoutTime is the time the session was checked out of the pool. checkoutTime time.Time // lastUseTime is the time the session was last used after checked out of the pool. @@ -115,6 +119,9 @@ func (sh *sessionHandle) recycle() { tracked := sh.trackedSessionHandle s := sh.session sh.session = nil + if sh.client != nil { + sh.client = nil + } sh.trackedSessionHandle = nil sh.checkoutTime = time.Time{} sh.lastUseTime = time.Time{} @@ -149,6 +156,10 @@ func (sh *sessionHandle) getClient() *vkit.Client { if sh.session == nil { return nil } + if sh.session.isMultiplexed { + // Use the gRPC connection from the session handle + return sh.client + } return sh.session.client } @@ -185,6 +196,9 @@ func (sh *sessionHandle) destroy() { } tracked := sh.trackedSessionHandle sh.session = nil + if sh.client != nil { + sh.client = nil + } sh.trackedSessionHandle = nil sh.checkoutTime = time.Time{} sh.lastUseTime = time.Time{} @@ -252,6 +266,8 @@ type session struct { tx transactionID // firstHCDone indicates whether the first health check is done or not. firstHCDone bool + // isMultiplexed is true if the session is multiplexed. + isMultiplexed bool } // isValid returns true if the session is still valid for use. @@ -569,12 +585,20 @@ type sessionPool struct { // idleList caches idle session IDs. Session IDs in this list can be // allocated for use. idleList list.List + // multiplexedSessions contains the multiplexed sessions + multiplexedSession *session // mayGetSession is for broadcasting that session retrival/creation may // proceed. mayGetSession chan struct{} + // mayGetSession is for broadcasting that multiplexed session retrival/creation may + // proceed. + mayGetMultiplexedSession chan struct{} // sessionCreationError is the last error that occurred during session // creation and is propagated to any waiters waiting for a session. sessionCreationError error + // multiplexedSessionCreationError is the last error that occurred during multiplexed session + // creation and is propagated to any waiters waiting for a session. + multiplexedSessionCreationError error // numOpened is the total number of open sessions from the session pool. numOpened uint64 // createReqs is the number of ongoing session creation requests. @@ -616,6 +640,9 @@ type sessionPool struct { numOfLeakedSessionsRemoved uint64 otConfig *openTelemetryConfig + + // enableMultiplexSession is a flag to enable multiplexed session. + enableMultiplexSession bool } // newSessionPool creates a new session pool. @@ -652,13 +679,15 @@ func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool, } pool := &sessionPool{ - sc: sc, - valid: true, - mayGetSession: make(chan struct{}), - SessionPoolConfig: config, - mw: newMaintenanceWindow(config.MaxOpened), - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - otConfig: sc.otConfig, + sc: sc, + valid: true, + mayGetSession: make(chan struct{}), + mayGetMultiplexedSession: make(chan struct{}), + SessionPoolConfig: config, + mw: newMaintenanceWindow(config.MaxOpened), + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + otConfig: sc.otConfig, + enableMultiplexSession: os.Getenv("GOOGLE_CLOUD_SPANNER_ENABLE_MULTIPLEXED_SESSIONS") == "true" && os.Getenv("GOOGLE_CLOUD_SPANNER_FORCE_DISABLE_MULTIPLEXED_SESSIONS") != "true", } _, instance, database, err := parseDatabaseName(sc.database) @@ -812,6 +841,17 @@ func (p *sessionPool) growPoolLocked(numSessions uint64, distributeOverChannels return p.sc.batchCreateSessions(int32(numSessions), distributeOverChannels, p) } +func (p *sessionPool) getMultiplexedSession(ctx context.Context) error { + p.sc.mu.Lock() + defer p.sc.mu.Unlock() + client, err := p.sc.nextClient() + if err != nil { + return err + } + go p.sc.executeCreateMultiplexedSessions(ctx, client, p.sc.sessionLabels, p.sc.md, p) + return nil +} + // sessionReady is executed by the SessionClient when a session has been // created and is ready to use. This method will add the new session to the // pool and decrease the number of sessions that is being created. @@ -819,6 +859,16 @@ func (p *sessionPool) sessionReady(s *session) { p.mu.Lock() defer p.mu.Unlock() // Clear any session creation error. + if s.isMultiplexed { + s.pool = p + p.hc.register(s) + p.multiplexedSession = s + p.multiplexedSessionCreationError = nil + p.incNumSessionsLocked(context.Background()) + close(p.mayGetMultiplexedSession) + p.mayGetMultiplexedSession = make(chan struct{}) + return + } p.sessionCreationError = nil // Set this pool as the home pool of the session and register it with the // health checker. @@ -848,9 +898,15 @@ func (p *sessionPool) sessionReady(s *session) { // or more requested sessions finished with an error. sessionCreationFailed will // decrease the number of sessions being created and notify any waiters that // the session creation failed. -func (p *sessionPool) sessionCreationFailed(err error, numSessions int32) { +func (p *sessionPool) sessionCreationFailed(err error, numSessions int32, isMultiplexed bool) { p.mu.Lock() defer p.mu.Unlock() + if isMultiplexed { + p.multiplexedSessionCreationError = err + close(p.mayGetMultiplexedSession) + p.mayGetMultiplexedSession = make(chan struct{}) + return + } p.createReqs -= uint64(numSessions) p.numOpened -= uint64(numSessions) p.recordStat(context.Background(), OpenSessionCount, int64(p.numOpened)) @@ -923,6 +979,21 @@ var errGetSessionTimeout = spannerErrorf(codes.Canceled, "timeout / context canc // sessions being checked out of the pool. func (p *sessionPool) newSessionHandle(s *session) (sh *sessionHandle) { sh = &sessionHandle{session: s, checkoutTime: time.Now(), lastUseTime: time.Now()} + if s.isMultiplexed { + // TODO: handle 1-qps style traffic, we can return the same client which was used for session creation in that case. + + // allocate a new client for multiplexed session requests using round robin channel selection. + p.mu.Lock() + p.sc.mu.Lock() + clientOpt := option.WithGRPCConn(p.sc.connPool.Conn()) + p.sc.mu.Unlock() + p.mu.Unlock() + client, err := vkit.NewClient(context.Background(), clientOpt) + if err != nil { + return nil + } + sh.client = client + } if p.TrackSessionHandles || p.ActionOnInactiveTransaction == Warn || p.ActionOnInactiveTransaction == WarnAndClose || p.ActionOnInactiveTransaction == Close { p.mu.Lock() sh.trackedSessionHandle = p.trackedSessionHandles.PushBack(sh) @@ -935,7 +1006,7 @@ func (p *sessionPool) newSessionHandle(s *session) (sh *sessionHandle) { } // errGetSessionTimeout returns error for context timeout during -// sessionPool.take(). +// sessionPool.take() or sessionPool.takeMultiplexed() or . func (p *sessionPool) errGetSessionTimeout(ctx context.Context) error { var code codes.Code if ctx.Err() == context.DeadlineExceeded { @@ -1106,6 +1177,73 @@ func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) { } } +// takeMultiplexed returns a cached session if there is available one; if there isn't +// any, it tries to allocate a new one. +func (p *sessionPool) takeMultiplexed(ctx context.Context) (*sessionHandle, error) { + if p.enableMultiplexSession { + return p.take(ctx) + } + trace.TracePrintf(ctx, nil, "Acquiring a multiplexed session") + for { + var s *session + + p.mu.Lock() + if !p.valid { + p.mu.Unlock() + return nil, errInvalidSessionPool + } + if p.multiplexedSession != nil { + // Multiplexed session is available, get it. + s = p.multiplexedSession + trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, + "Acquired session") + p.decNumSessionsLocked(ctx) // TODO: add tag to differentiate from normal session. + } + if s != nil { + p.mu.Unlock() + // From here, If healthcheck workers failed to + // schedule healthcheck for the session timely, do the check here. + // Because session check is still much cheaper than session + // creation, they should be reused as much as possible. + if !p.isHealthy(s) { + continue + } + p.incNumInUse(ctx) // TODO: add tag to differentiate from normal session. + return p.newSessionHandle(s), nil + } + + // No session available. Start the creation of multiplexed session. + if err := p.getMultiplexedSession(ctx); err != nil { + p.mu.Unlock() + return nil, err + } + + mayGetSession := p.mayGetMultiplexedSession + p.mu.Unlock() + trace.TracePrintf(ctx, nil, "Waiting for multiplexed session to become available") + select { + case <-ctx.Done(): + trace.TracePrintf(ctx, nil, "Context done waiting for session") + p.recordStat(ctx, GetSessionTimeoutsCount, 1) // TODO: add tag to differentiate from normal session. + if p.otConfig != nil { + p.recordOTStat(ctx, p.otConfig.getSessionTimeoutsCount, 1) // TODO: add tag to differentiate from normal session. + } + p.mu.Lock() + p.mu.Unlock() + return nil, p.errGetSessionTimeout(ctx) + case <-mayGetSession: + p.mu.Lock() + if p.multiplexedSessionCreationError != nil { + trace.TracePrintf(ctx, nil, "Error creating multiplexed session: %v", p.multiplexedSessionCreationError) + err := p.multiplexedSessionCreationError + p.mu.Unlock() + return nil, err + } + p.mu.Unlock() + } + } +} + // recycle puts session s back to the session pool's idle list, it returns true // if the session pool successfully recycles session s. func (p *sessionPool) recycle(s *session) bool { @@ -1136,10 +1274,14 @@ func (p *sessionPool) recycleLocked(s *session) bool { func (p *sessionPool) remove(s *session, isExpire bool) bool { p.mu.Lock() defer p.mu.Unlock() - if isExpire && (p.numOpened <= p.MinOpened || s.getIdleList() == nil) { - // Don't expire session if the session is not in idle list (in use), or - // if number of open sessions is going below p.MinOpened. - return false + if s.isMultiplexed { + p.multiplexedSession = nil + } else { + if isExpire && (p.numOpened <= p.MinOpened || s.getIdleList() == nil) { + // Don't expire session if the session is not in idle list (in use), or + // if number of open sessions is going below p.MinOpened. + return false + } } ol := s.setIdleList(nil) ctx := context.Background() @@ -1156,6 +1298,11 @@ func (p *sessionPool) remove(s *session, isExpire bool) bool { p.decNumInUseLocked(ctx) p.recordStat(ctx, OpenSessionCount, int64(p.numOpened)) // Broadcast that a session has been destroyed. + if s.isMultiplexed { + close(p.mayGetMultiplexedSession) + p.mayGetMultiplexedSession = make(chan struct{}) + return true + } close(p.mayGetSession) p.mayGetSession = make(chan struct{}) return true diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go index ac5a37b34cd7..766d85e9c7be 100644 --- a/spanner/sessionclient.go +++ b/spanner/sessionclient.go @@ -77,7 +77,7 @@ type sessionConsumer interface { // sessions failed. The numSessions argument specifies the number of // sessions that could not be created as a result of this error. A // consumer may receive multiple errors per batch. - sessionCreationFailed(err error, numSessions int32) + sessionCreationFailed(err error, numSessions int32, isMultiplexed bool) } // sessionClient creates sessions for a database, either in batches or one at a @@ -254,12 +254,12 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC if closed { err := spannerErrorf(codes.Canceled, "Session client closed") trace.TracePrintf(ctx, nil, "Session client closed while creating a batch of %d sessions: %v", createCount, err) - consumer.sessionCreationFailed(err, remainingCreateCount) + consumer.sessionCreationFailed(err, remainingCreateCount, false) break } if ctx.Err() != nil { trace.TracePrintf(ctx, nil, "Context error while creating a batch of %d sessions: %v", createCount, ctx.Err()) - consumer.sessionCreationFailed(ToSpannerError(ctx.Err()), remainingCreateCount) + consumer.sessionCreationFailed(ToSpannerError(ctx.Err()), remainingCreateCount, false) break } var mdForGFELatency metadata.MD @@ -294,7 +294,7 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC } if err != nil { trace.TracePrintf(ctx, nil, "Error creating a batch of %d sessions: %v", remainingCreateCount, err) - consumer.sessionCreationFailed(ToSpannerError(err), remainingCreateCount) + consumer.sessionCreationFailed(ToSpannerError(err), remainingCreateCount, false) break } actuallyCreated := int32(len(response.Session)) @@ -313,6 +313,62 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC } } +func (sc *sessionClient) executeCreateMultiplexedSessions(ctx context.Context, client *vkit.Client, labels map[string]string, md metadata.MD, consumer sessionConsumer) { + ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.CreateSession") + defer func() { trace.EndSpan(ctx, nil) }() + trace.TracePrintf(ctx, nil, "Creating a multiplexes sessions") + sc.mu.Lock() + closed := sc.closed + sc.mu.Unlock() + if closed { + err := spannerErrorf(codes.Canceled, "Session client closed") + trace.TracePrintf(ctx, nil, "Session client closed while creating multiplexed sessions: %v", err) + return + } + if ctx.Err() != nil { + trace.TracePrintf(ctx, nil, "Context error while creating a multiplexed session: %v", ctx.Err()) + consumer.sessionCreationFailed(ToSpannerError(ctx.Err()), 1, true) + return + } + var mdForGFELatency metadata.MD + response, err := client.CreateSession(contextWithOutgoingMetadata(ctx, sc.md, sc.disableRouteToLeader), &sppb.CreateSessionRequest{ + Database: sc.database, + // Multiplexed sessions do not support labels. + Session: &sppb.Session{CreatorRole: sc.databaseRole, Multiplexed: true}, + }, gax.WithGRPCOptions(grpc.Header(&mdForGFELatency))) + + if getGFELatencyMetricsFlag() && mdForGFELatency != nil { + _, instance, database, err := parseDatabaseName(sc.database) + if err != nil { + trace.TracePrintf(ctx, nil, "Error getting instance and database name: %v", err) + } + // Errors should not prevent initializing the session pool. + ctxGFE, err := tag.New(ctx, + tag.Upsert(tagKeyClientID, sc.id), + tag.Upsert(tagKeyDatabase, database), + tag.Upsert(tagKeyInstance, instance), + tag.Upsert(tagKeyLibVersion, internal.Version), + ) + if err != nil { + trace.TracePrintf(ctx, nil, "Error in adding tags in CreateSession for GFE Latency: %v", err) + } + err = captureGFELatencyStats(ctxGFE, mdForGFELatency, "executeCreateSession") + if err != nil { + trace.TracePrintf(ctx, nil, "Error in Capturing GFE Latency and Header Missing count. Try disabling and rerunning. Error: %v", err) + } + } + if metricErr := recordGFELatencyMetricsOT(ctx, mdForGFELatency, "executeCreateSession", sc.otConfig); metricErr != nil { + trace.TracePrintf(ctx, nil, "Error in recording GFE Latency through OpenTelemetry. Error: %v", metricErr) + } + if err != nil { + trace.TracePrintf(ctx, nil, "Error creating a multiplexed sessions: %v", err) + consumer.sessionCreationFailed(ToSpannerError(err), 1, true) + return + } + consumer.sessionReady(&session{valid: true, client: client, id: response.Name, createTime: time.Now(), md: md, logger: sc.logger, isMultiplexed: response.Multiplexed}) + trace.TracePrintf(ctx, nil, "Finished creating multiplexed sessions") +} + func (sc *sessionClient) sessionWithID(id string) (*session, error) { sc.mu.Lock() defer sc.mu.Unlock() diff --git a/spanner/transaction.go b/spanner/transaction.go index f54c87dba16d..09324a07d833 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -728,7 +728,7 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { }() // Retry the BeginTransaction call if a 'Session not found' is returned. for { - sh, err = t.sp.take(ctx) + sh, err = t.sp.takeMultiplexed(ctx) if err != nil { return err } @@ -818,7 +818,7 @@ func (t *ReadOnlyTransaction) acquireSingleUse(ctx context.Context) (*sessionHan }, }, } - sh, err := t.sp.take(ctx) + sh, err := t.sp.takeMultiplexed(ctx) if err != nil { return nil, nil, err }