Skip to content

Commit

Permalink
Refresh gocql session on no connection error (#4058)
Browse files Browse the repository at this point in the history
  • Loading branch information
yycptt authored Mar 17, 2021
1 parent 86e6c7c commit ed82bb7
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 32 deletions.
12 changes: 9 additions & 3 deletions common/persistence/nosql/nosqlplugin/cassandra/gocql/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,20 @@ const (
CounterBatch
)

func newBatch(
gocqlBatch *gocql.Batch,
) *batch {
return &batch{
Batch: gocqlBatch,
}
}

func (b *batch) WithContext(ctx context.Context) Batch {
b2 := b.Batch.WithContext(ctx)
if b2 == nil {
return nil
}
return &batch{
Batch: b2,
}
return newBatch(b2)
}

func (b *batch) WithTimestamp(timestamp int64) Batch {
Expand Down
13 changes: 1 addition & 12 deletions common/persistence/nosql/nosqlplugin/cassandra/gocql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,7 @@ func NewClient() Client {
func (c client) CreateSession(
config ClusterConfig,
) (Session, error) {
cluster := newCassandraCluster(config)
cluster.ProtoVersion = config.ProtoVersion
cluster.Consistency = mustConvertConsistency(config.Consistency)
cluster.SerialConsistency = mustConvertSerialConsistency(config.SerialConsistency)
cluster.Timeout = config.Timeout
gocqlSession, err := cluster.CreateSession()
if err != nil {
return nil, err
}
return &session{
Session: gocqlSession,
}, nil
return newSession(config)
}

func (c client) IsTimeoutError(err error) bool {
Expand Down
54 changes: 51 additions & 3 deletions common/persistence/nosql/nosqlplugin/cassandra/gocql/query.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) 2017-2020 Uber Technologies, Inc.
// Portions of the Software are attributed to Copyright (c) 2020 Temporal Technologies Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -31,9 +32,54 @@ var _ Query = (*query)(nil)
type (
query struct {
*gocql.Query

session *session
}
)

func newQuery(
session *session,
gocqlQuery *gocql.Query,
) *query {
return &query{
Query: gocqlQuery,
session: session,
}
}

func (q *query) Exec() error {
err := q.Query.Exec()
return q.handleError(err)
}

func (q *query) Scan(
dest ...interface{},
) error {
err := q.Query.Scan(dest...)
return q.handleError(err)
}

func (q *query) ScanCAS(
dest ...interface{},
) (bool, error) {
applied, err := q.Query.ScanCAS(dest...)
return applied, q.handleError(err)
}

func (q *query) MapScan(
m map[string]interface{},
) error {
err := q.Query.MapScan(m)
return q.handleError(err)
}

func (q *query) MapScanCAS(
dest map[string]interface{},
) (bool, error) {
applied, err := q.Query.MapScanCAS(dest)
return applied, q.handleError(err)
}

func (q *query) Iter() Iter {
iter := q.Query.Iter()
if iter == nil {
Expand Down Expand Up @@ -67,12 +113,14 @@ func (q *query) WithContext(ctx context.Context) Query {
if q2 == nil {
return nil
}
return &query{
Query: q2,
}
return newQuery(q.session, q2)
}

func (q *query) Bind(v ...interface{}) Query {
q.Query.Bind(v...)
return q
}

func (q *query) handleError(err error) error {
return q.session.handleError(err)
}
109 changes: 95 additions & 14 deletions common/persistence/nosql/nosqlplugin/cassandra/gocql/session.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) 2017-2020 Uber Technologies, Inc.
// Portions of the Software are attributed to Copyright (c) 2020 Temporal Technologies Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand All @@ -20,54 +21,134 @@

package gocql

import "github.com/gocql/gocql"
import (
"sync"
"sync/atomic"
"time"

"github.com/gocql/gocql"

"github.com/uber/cadence/common"
)

var _ Session = (*session)(nil)

const (
sessionRefreshMinInternal = 5 * time.Second
)

type (
session struct {
*gocql.Session
atomic.Value // *gocql.Session
sync.Mutex

status int32
config ClusterConfig
sessionInitTime time.Time
}
)

func newSession(
config ClusterConfig,
) (*session, error) {
gocqlSession, err := initSession(config)
if err != nil {
return nil, err
}
session := &session{
status: common.DaemonStatusStarted,
config: config,
sessionInitTime: time.Now().UTC(),
}
session.Value.Store(gocqlSession)
return session, nil
}

func initSession(
config ClusterConfig,
) (*gocql.Session, error) {
cluster := newCassandraCluster(config)
cluster.ProtoVersion = config.ProtoVersion
cluster.Consistency = mustConvertConsistency(config.Consistency)
cluster.SerialConsistency = mustConvertSerialConsistency(config.SerialConsistency)
cluster.Timeout = config.Timeout
return cluster.CreateSession()
}

func (s *session) refresh() error {
if atomic.LoadInt32(&s.status) != common.DaemonStatusStarted {
return nil
}

s.Lock()
defer s.Unlock()

if time.Now().UTC().Sub(s.sessionInitTime) < sessionRefreshMinInternal {
return nil
}

newSession, err := initSession(s.config)
if err != nil {
return err
}

s.sessionInitTime = time.Now().UTC()
oldSession := s.Value.Load().(*gocql.Session)
s.Value.Store(newSession)
oldSession.Close()
return nil
}

func (s *session) Query(
stmt string,
values ...interface{},
) Query {
q := s.Session.Query(stmt, values...)
q := s.Value.Load().(*gocql.Session).Query(stmt, values...)
if q == nil {
return nil
}
return &query{
Query: q,
}
return newQuery(s, q)
}

func (s *session) NewBatch(
batchType BatchType,
) Batch {
b := s.Session.NewBatch(mustConvertBatchType(batchType))
b := s.Value.Load().(*gocql.Session).NewBatch(mustConvertBatchType(batchType))
if b == nil {
return nil
}
return &batch{
Batch: b,
}
return newBatch(b)
}

func (s *session) ExecuteBatch(
b Batch,
) error {
return s.Session.ExecuteBatch(b.(*batch).Batch)
err := s.Value.Load().(*gocql.Session).ExecuteBatch(b.(*batch).Batch)
return s.handleError(err)
}

func (s *session) MapExecuteBatchCAS(
b Batch,
previous map[string]interface{},
) (bool, Iter, error) {
applied, iter, err := s.Session.MapExecuteBatchCAS(b.(*batch).Batch, previous)
applied, iter, err := s.Value.Load().(*gocql.Session).MapExecuteBatchCAS(b.(*batch).Batch, previous)
if iter == nil {
return applied, nil, err
return applied, nil, s.handleError(err)
}
return applied, iter, s.handleError(err)
}

func (s *session) Close() {
if !atomic.CompareAndSwapInt32(&s.status, common.DaemonStatusStarted, common.DaemonStatusStopped) {
return
}

s.Value.Load().(*gocql.Session).Close()
}

func (s *session) handleError(err error) error {
if err == gocql.ErrNoConnections {
_ = s.refresh()
}
return applied, iter, err
return err
}

0 comments on commit ed82bb7

Please sign in to comment.