From 8254a773b74ee9cf79752c9bb44d87b462a7de92 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 5 Sep 2024 12:57:34 +0200 Subject: [PATCH 1/7] Make hostInfoFromMap not tied to Session type --- conn.go | 2 +- host_source.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index cb363be71..d4f8e4d8c 100644 --- a/conn.go +++ b/conn.go @@ -1838,7 +1838,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { } for _, row := range rows { - host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}) + host, err := hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}, c.session.cfg.translateAddressPort) if err != nil { goto cont } diff --git a/host_source.go b/host_source.go index 23c3dfb2c..33fe7ea4e 100644 --- a/host_source.go +++ b/host_source.go @@ -658,7 +658,7 @@ func checkSystemSchema(control *controlConn) (bool, error) { // Given a map that represents a row from either system.local or system.peers // return as much information as we can in *HostInfo -func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) { +func hostInfoFromMap(row map[string]interface{}, host *HostInfo, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) { const assertErrorMsg = "Assertion failed for %s" var ok bool @@ -771,7 +771,7 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* } host.untranslatedConnectAddress = host.ConnectAddress() - ip, port := s.cfg.translateAddressPort(host.untranslatedConnectAddress, host.port) + ip, port := translateAddressPort(host.untranslatedConnectAddress, host.port) host.connectAddress = ip host.port = port @@ -789,7 +789,7 @@ func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPor return nil, errors.New("query returned 0 rows") } - host, err := s.hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}) + host, err := hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}, s.cfg.translateAddressPort) if err != nil { return nil, err } @@ -840,7 +840,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, er for _, row := range rows { // extract all available info about the peer - host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}) + host, err := hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}, r.session.cfg.translateAddressPort) if err != nil { return nil, err } else if !isValidPeer(host) { @@ -913,7 +913,7 @@ func (r *ringDescriber) getHostInfo(hostID UUID) (*HostInfo, error) { } for _, row := range rows { - h, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}) + h, err := hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}, r.session.cfg.translateAddressPort) if err != nil { return nil, err } From e7ff543ad176b3dec0e0ba691d60dac3a62bc659 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 5 Sep 2024 13:14:02 +0200 Subject: [PATCH 2/7] Fix error redeclaration in awaitSchemaAgreement --- conn.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index d4f8e4d8c..34752636a 100644 --- a/conn.go +++ b/conn.go @@ -1815,7 +1815,7 @@ func (c *Conn) querySystemLocal(ctx context.Context) *Iter { return c.query(ctx, "SELECT * FROM system.local WHERE key='local'"+usingClause) } -func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { +func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { usingClause := "" if c.session.control != nil { usingClause = c.session.usingTimeoutClause @@ -1827,18 +1827,25 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement) + var err error + for time.Now().Before(endDeadline) { iter := c.querySystemPeers(ctx, c.host.version) versions = make(map[string]struct{}) - rows, err := iter.SliceMap() + var rows []map[string]interface{} + rows, err = iter.SliceMap() if err != nil { + if err == ErrConnectionClosed { + break + } goto cont } for _, row := range rows { - host, err := hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}, c.session.cfg.translateAddressPort) + var host *HostInfo + host, err = hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}, c.session.cfg.translateAddressPort) if err != nil { goto cont } From 07c5a13e09aa87b022165479b980f127a471c0b4 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 5 Sep 2024 14:01:16 +0200 Subject: [PATCH 3/7] Replace time.After with time.NewTicker for periodic checks in awaitSchemaAgreement --- conn.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 34752636a..1c68fa6d0 100644 --- a/conn.go +++ b/conn.go @@ -1828,6 +1828,8 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement) var err error + ticker := time.NewTicker(200 * time.Millisecond) // Create a ticker that ticks every 200ms + defer ticker.Stop() for time.Now().Before(endDeadline) { iter := c.querySystemPeers(ctx, c.host.version) @@ -1879,7 +1881,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case <-time.After(200 * time.Millisecond): + case <-ticker.C: } } From 9294080994c921c77797b904695fe348877abf19 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 5 Sep 2024 14:55:18 +0200 Subject: [PATCH 4/7] Add new type of error for schema disagreement --- conn.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 1c68fa6d0..65382229e 100644 --- a/conn.go +++ b/conn.go @@ -1895,7 +1895,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { } // not exported - return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas) + return &ErrSchemaMismatch{schemas: schemas} } var ( @@ -1905,3 +1905,11 @@ var ( ErrConnectionClosed = errors.New("gocql: connection closed waiting for response") ErrNoStreams = errors.New("gocql: no streams available on connection") ) + +type ErrSchemaMismatch struct { + schemas []string +} + +func (e *ErrSchemaMismatch) Error() string { + return fmt.Sprintf("gocql: cluster schema versions not consistent: %+v", e.schemas) +} From cb664253936d4b452242f200e573e785f8ae026a Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 5 Sep 2024 15:21:41 +0200 Subject: [PATCH 5/7] Add extra tests for AwaitSchemaAgreement --- integration_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/integration_test.go b/integration_test.go index f548a829f..9e01f068e 100644 --- a/integration_test.go +++ b/integration_test.go @@ -6,6 +6,7 @@ package gocql // This file groups integration tests where Cassandra has to be set up with some special integration variables import ( "context" + "errors" "reflect" "testing" "time" @@ -233,3 +234,24 @@ func TestSessionAwaitSchemaAgreement(t *testing.T) { t.Fatalf("expected session.AwaitSchemaAgreement to not return an error but got '%v'", err) } } + +func TestSessionAwaitSchemaAgreementSessionClosed(t *testing.T) { + session := createSession(t) + session.Close() + + if err := session.AwaitSchemaAgreement(context.Background()); !errors.Is(err, ErrConnectionClosed) { + t.Fatalf("expected session.AwaitSchemaAgreement to return ErrConnectionClosed but got '%v'", err) + } + +} + +func TestSessionAwaitSchemaAgreementContextCanceled(t *testing.T) { + session := createSession(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + if err := session.AwaitSchemaAgreement(ctx); !errors.Is(err, context.Canceled) { + t.Fatalf("expected session.AwaitSchemaAgreement to return 'context canceled' but got '%v'", err) + } + +} From d0152a38e3073d4a648fffc6f6099af7d3d722ae Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 5 Sep 2024 16:21:08 +0200 Subject: [PATCH 6/7] Remove goto from the logic in awaitSchemaAgreement --- conn.go | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/conn.go b/conn.go index 65382229e..043276e4c 100644 --- a/conn.go +++ b/conn.go @@ -1831,6 +1831,15 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { ticker := time.NewTicker(200 * time.Millisecond) // Create a ticker that ticks every 200ms defer ticker.Stop() + waitForNextTick := func() error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + return nil + } + } + for time.Now().Before(endDeadline) { iter := c.querySystemPeers(ctx, c.host.version) @@ -1842,14 +1851,16 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { if err == ErrConnectionClosed { break } - goto cont + waitForNextTick() + continue } for _, row := range rows { var host *HostInfo host, err = hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}, c.session.cfg.translateAddressPort) if err != nil { - goto cont + waitForNextTick() + continue } if !isValidPeer(host) || host.schemaVersion == "" { c.logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host) @@ -1860,7 +1871,8 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { } if err = iter.Close(); err != nil { - goto cont + waitForNextTick() + continue } iter = c.query(ctx, localSchemas) @@ -1870,18 +1882,16 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { } if err = iter.Close(); err != nil { - goto cont + waitForNextTick() + continue } if len(versions) <= 1 { return nil } - cont: - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: + if err := waitForNextTick(); err != nil { + return err } } From 9ff526dfe7e6e2d2b88b136eca419a753cc87cbd Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Fri, 6 Sep 2024 09:11:10 +0200 Subject: [PATCH 7/7] Extract the code getting schema versions to separate function --- conn.go | 48 +++++++++++++++++++++++------------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/conn.go b/conn.go index 043276e4c..1aeb93d3e 100644 --- a/conn.go +++ b/conn.go @@ -1840,7 +1840,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { } } - for time.Now().Before(endDeadline) { + getSchemaAgreement := func() error { iter := c.querySystemPeers(ctx, c.host.version) versions = make(map[string]struct{}) @@ -1848,19 +1848,14 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { var rows []map[string]interface{} rows, err = iter.SliceMap() if err != nil { - if err == ErrConnectionClosed { - break - } - waitForNextTick() - continue + return err } for _, row := range rows { var host *HostInfo host, err = hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}, c.session.cfg.translateAddressPort) if err != nil { - waitForNextTick() - continue + return err } if !isValidPeer(host) || host.schemaVersion == "" { c.logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host) @@ -1871,8 +1866,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { } if err = iter.Close(); err != nil { - waitForNextTick() - continue + return err } iter = c.query(ctx, localSchemas) @@ -1882,30 +1876,34 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error { } if err = iter.Close(); err != nil { - waitForNextTick() - continue + return err } - if len(versions) <= 1 { - return nil - } + if len(versions) > 1 { + schemas := make([]string, 0, len(versions)) + for schema := range versions { + schemas = append(schemas, schema) + } - if err := waitForNextTick(); err != nil { - return err + return &ErrSchemaMismatch{schemas: schemas} } - } - if err != nil { - return err + return nil } - schemas := make([]string, 0, len(versions)) - for schema := range versions { - schemas = append(schemas, schema) + for time.Now().Before(endDeadline) { + err = getSchemaAgreement() + + if err == ErrConnectionClosed || err == nil { + return err + } + + if tickerErr := waitForNextTick(); tickerErr != nil { + return tickerErr + } } - // not exported - return &ErrSchemaMismatch{schemas: schemas} + return err } var (