Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor awaitSchemaAgreement logic #256

Merged
merged 7 commits into from
Sep 6, 2024
73 changes: 49 additions & 24 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,7 @@ func (c *Conn) querySystemLocal(ctx context.Context) *Iter {
return c.query(ctx, "SELECT * FROM system.local WHERE key='local'"+usingClause)
}

func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {
usingClause := ""
if c.session.control != nil {
usingClause = c.session.usingTimeoutClause
Expand All @@ -1827,20 +1827,35 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {

endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)

for time.Now().Before(endDeadline) {
var err error
ticker := time.NewTicker(200 * time.Millisecond) // Create a ticker that ticks every 200ms
defer ticker.Stop()

waitForNextTick := func() error {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
return nil
}
}

getSchemaAgreement := func() error {
iter := c.querySystemPeers(ctx, c.host.version)

versions = make(map[string]struct{})

rows, err := iter.SliceMap()
var rows []map[string]interface{}
rows, err = iter.SliceMap()
if err != nil {
goto cont
return err
}

for _, row := range rows {
host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port})
var host *HostInfo
host, err = hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}, c.session.cfg.translateAddressPort)
if err != nil {
goto cont
return err
}
if !isValidPeer(host) || host.schemaVersion == "" {
c.logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host)
Expand All @@ -1851,7 +1866,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
}

if err = iter.Close(); err != nil {
goto cont
return err
}

iter = c.query(ctx, localSchemas)
Expand All @@ -1861,32 +1876,34 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
}

if err = iter.Close(); err != nil {
goto cont
return err
}

if len(versions) <= 1 {
return nil
}
if len(versions) > 1 {
schemas := make([]string, 0, len(versions))
for schema := range versions {
schemas = append(schemas, schema)
}

cont:
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(200 * time.Millisecond):
return &ErrSchemaMismatch{schemas: schemas}
}
}

if err != nil {
return err
return nil
}

schemas := make([]string, 0, len(versions))
for schema := range versions {
schemas = append(schemas, schema)
for time.Now().Before(endDeadline) {
err = getSchemaAgreement()

if err == ErrConnectionClosed || err == nil {
return err
}

if tickerErr := waitForNextTick(); tickerErr != nil {
return tickerErr
}
}

// not exported
return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
return err
}

var (
Expand All @@ -1896,3 +1913,11 @@ var (
ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
ErrNoStreams = errors.New("gocql: no streams available on connection")
)

type ErrSchemaMismatch struct {
schemas []string
}

func (e *ErrSchemaMismatch) Error() string {
return fmt.Sprintf("gocql: cluster schema versions not consistent: %+v", e.schemas)
}
10 changes: 5 additions & 5 deletions host_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ func checkSystemSchema(control *controlConn) (bool, error) {

// Given a map that represents a row from either system.local or system.peers
// return as much information as we can in *HostInfo
func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) {
func hostInfoFromMap(row map[string]interface{}, host *HostInfo, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) {
const assertErrorMsg = "Assertion failed for %s"
var ok bool

Expand Down Expand Up @@ -771,7 +771,7 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*
}

host.untranslatedConnectAddress = host.ConnectAddress()
ip, port := s.cfg.translateAddressPort(host.untranslatedConnectAddress, host.port)
ip, port := translateAddressPort(host.untranslatedConnectAddress, host.port)
host.connectAddress = ip
host.port = port

Expand All @@ -789,7 +789,7 @@ func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPor
return nil, errors.New("query returned 0 rows")
}

host, err := s.hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort})
host, err := hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}, s.cfg.translateAddressPort)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -840,7 +840,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, er

for _, row := range rows {
// extract all available info about the peer
host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port})
host, err := hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}, r.session.cfg.translateAddressPort)
if err != nil {
return nil, err
} else if !isValidPeer(host) {
Expand Down Expand Up @@ -913,7 +913,7 @@ func (r *ringDescriber) getHostInfo(hostID UUID) (*HostInfo, error) {
}

for _, row := range rows {
h, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port})
h, err := hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}, r.session.cfg.translateAddressPort)
if err != nil {
return nil, err
}
Expand Down
22 changes: 22 additions & 0 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package gocql
// This file groups integration tests where Cassandra has to be set up with some special integration variables
import (
"context"
"errors"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -233,3 +234,24 @@ func TestSessionAwaitSchemaAgreement(t *testing.T) {
t.Fatalf("expected session.AwaitSchemaAgreement to not return an error but got '%v'", err)
}
}

func TestSessionAwaitSchemaAgreementSessionClosed(t *testing.T) {
sylwiaszunejko marked this conversation as resolved.
Show resolved Hide resolved
session := createSession(t)
session.Close()

if err := session.AwaitSchemaAgreement(context.Background()); !errors.Is(err, ErrConnectionClosed) {
t.Fatalf("expected session.AwaitSchemaAgreement to return ErrConnectionClosed but got '%v'", err)
}

}

func TestSessionAwaitSchemaAgreementContextCanceled(t *testing.T) {
session := createSession(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()

if err := session.AwaitSchemaAgreement(ctx); !errors.Is(err, context.Canceled) {
t.Fatalf("expected session.AwaitSchemaAgreement to return 'context canceled' but got '%v'", err)
}

}
Loading