diff --git a/CHANGELOG.md b/CHANGELOG.md index 2393216de..3019df4e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Added +- Support iproto feature discovery (#120). + ### Changed ### Fixed diff --git a/connection.go b/connection.go index 27fee0929..d4401aeef 100644 --- a/connection.go +++ b/connection.go @@ -14,6 +14,7 @@ import ( "math" "net" "runtime" + "strings" "sync" "sync/atomic" "time" @@ -146,6 +147,8 @@ type Connection struct { lenbuf [PacketLengthBytes]byte lastStreamId uint64 + + serverProtocolInfo ProtocolInfo } var _ = Connector(&Connection{}) // Check compatibility with connector interface. @@ -269,6 +272,10 @@ type Opts struct { Transport string // SslOpts is used only if the Transport == 'ssl' is set. Ssl SslOpts + // RequiredProtocolInfo contains minimal protocol version and + // list of protocol features that should be supported by + // Tarantool server. By default there are no restrictions + RequiredProtocolInfo ProtocolInfo } // SslOpts is a way to configure ssl transport. @@ -295,8 +302,11 @@ type SslOpts struct { // Copy returns the copy of an Opts object. // Beware that Notify channel, Logger and Handle are not copied. +// Any changes in copy RequiredProtocolInfo will not affect the original +// RequiredProtocolInfo value. func (opts Opts) Copy() Opts { optsCopy := opts + optsCopy.RequiredProtocolInfo = opts.RequiredProtocolInfo.Copy() return optsCopy } @@ -510,6 +520,18 @@ func (conn *Connection) dial() (err error) { conn.Greeting.Version = bytes.NewBuffer(greeting[:64]).String() conn.Greeting.auth = bytes.NewBuffer(greeting[64:108]).String() + // IPROTO_ID requests can be processed without authentication. + // https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/requests/#iproto-id + if err = conn.identify(w, r); err != nil { + connection.Close() + return err + } + + if err = checkProtocolInfo(opts.RequiredProtocolInfo, conn.serverProtocolInfo); err != nil { + connection.Close() + return fmt.Errorf("identify: %w", err) + } + // Auth if opts.User != "" { scr, err := scramble(conn.Greeting.auth, opts.Pass) @@ -616,6 +638,17 @@ func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) error return nil } +func (conn *Connection) writeIdRequest(w *bufio.Writer, protocolInfo ProtocolInfo) error { + req := NewIdRequest(protocolInfo) + + err := conn.writeRequest(w, req) + if err != nil { + return fmt.Errorf("identify: %w", err) + } + + return nil +} + func (conn *Connection) readResponse(r io.Reader) (Response, error) { respBytes, err := conn.read(r) if err != nil { @@ -648,6 +681,15 @@ func (conn *Connection) readAuthResponse(r io.Reader) error { return nil } +func (conn *Connection) readIdResponse(r io.Reader) (Response, error) { + resp, err := conn.readResponse(r) + if err != nil { + return resp, fmt.Errorf("identify: %w", err) + } + + return resp, nil +} + func (conn *Connection) createConnection(reconnect bool) (err error) { var reconnects uint for conn.c == nil && conn.state == connDisconnected { @@ -1191,3 +1233,89 @@ func (conn *Connection) NewStream() (*Stream, error) { Conn: conn, }, nil } + +// checkProtocolInfo checks that expected protocol version is +// and protocol features are supported. +func checkProtocolInfo(expected ProtocolInfo, actual ProtocolInfo) error { + var found bool + var missingFeatures []ProtocolFeature + + if expected.Version > actual.Version { + return fmt.Errorf("protocol version %d is not supported", expected.Version) + } + + // It seems that iterating over a small list is way faster + // than building a map: https://stackoverflow.com/a/52710077/11646599 + for _, expectedFeature := range expected.Features { + found = false + for _, actualFeature := range actual.Features { + if expectedFeature == actualFeature { + found = true + } + } + if !found { + missingFeatures = append(missingFeatures, expectedFeature) + } + } + + if len(missingFeatures) == 1 { + return fmt.Errorf("protocol feature %s is not supported", missingFeatures[0]) + } + + if len(missingFeatures) > 1 { + var sarr []string + for _, missingFeature := range missingFeatures { + sarr = append(sarr, missingFeature.String()) + } + return fmt.Errorf("protocol features %s are not supported", strings.Join(sarr, ", ")) + } + + return nil +} + +// identify sends info about client protocol, receives info +// about server protocol in response and stores it in the connection. +func (conn *Connection) identify(w *bufio.Writer, r *bufio.Reader) error { + var ok bool + + werr := conn.writeIdRequest(w, clientProtocolInfo) + if werr != nil { + return werr + } + + resp, rerr := conn.readIdResponse(r) + if rerr != nil { + if resp.Code == ErrUnknownRequestType { + // IPROTO_ID requests are not supported by server. + return nil + } + + return rerr + } + + if len(resp.Data) == 0 { + return fmt.Errorf("identify: unexpected response: no data") + } + + conn.serverProtocolInfo, ok = resp.Data[0].(ProtocolInfo) + if !ok { + return fmt.Errorf("identify: unexpected response: wrong data") + } + + return nil +} + +// ServerProtocolVersion returns protocol version and protocol features +// supported by connected Tarantool server. Beware that values might be +// outdated if connection is in a disconnected state. +// Since 1.10.0 +func (conn *Connection) ServerProtocolInfo() ProtocolInfo { + return conn.serverProtocolInfo.Copy() +} + +// ClientProtocolVersion returns protocol version and protocol features +// supported by Go connection client. +// Since 1.10.0 +func (conn *Connection) ClientProtocolInfo() ProtocolInfo { + return clientProtocolInfo.Copy() +} diff --git a/connection_pool/example_test.go b/connection_pool/example_test.go index 9a486924a..9f7512538 100644 --- a/connection_pool/example_test.go +++ b/connection_pool/example_test.go @@ -19,7 +19,7 @@ type Tuple struct { var testRoles = []bool{true, true, false, true, true} -func examplePool(roles []bool) (*connection_pool.ConnectionPool, error) { +func examplePool(roles []bool, connOpts tarantool.Opts) (*connection_pool.ConnectionPool, error) { err := test_helpers.SetClusterRO(servers, connOpts, roles) if err != nil { return nil, fmt.Errorf("ConnectionPool is not established") @@ -33,7 +33,7 @@ func examplePool(roles []bool) (*connection_pool.ConnectionPool, error) { } func ExampleConnectionPool_Select() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -94,7 +94,7 @@ func ExampleConnectionPool_Select() { } func ExampleConnectionPool_SelectTyped() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -156,7 +156,7 @@ func ExampleConnectionPool_SelectTyped() { } func ExampleConnectionPool_SelectAsync() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -239,7 +239,7 @@ func ExampleConnectionPool_SelectAsync() { func ExampleConnectionPool_SelectAsync_err() { roles := []bool{true, true, true, true, true} - pool, err := examplePool(roles) + pool, err := examplePool(roles, connOpts) if err != nil { fmt.Println(err) } @@ -258,7 +258,7 @@ func ExampleConnectionPool_SelectAsync_err() { } func ExampleConnectionPool_Ping() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -276,7 +276,7 @@ func ExampleConnectionPool_Ping() { } func ExampleConnectionPool_Insert() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -325,7 +325,7 @@ func ExampleConnectionPool_Insert() { } func ExampleConnectionPool_Delete() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -377,7 +377,7 @@ func ExampleConnectionPool_Delete() { } func ExampleConnectionPool_Replace() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -448,7 +448,7 @@ func ExampleConnectionPool_Replace() { } func ExampleConnectionPool_Update() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -492,7 +492,7 @@ func ExampleConnectionPool_Update() { } func ExampleConnectionPool_Call() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -512,7 +512,7 @@ func ExampleConnectionPool_Call() { } func ExampleConnectionPool_Eval() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -532,7 +532,7 @@ func ExampleConnectionPool_Eval() { } func ExampleConnectionPool_Do() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -551,7 +551,7 @@ func ExampleConnectionPool_Do() { } func ExampleConnectionPool_NewPrepared() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } @@ -586,7 +586,17 @@ func ExampleCommitRequest() { return } - pool, err := examplePool(testRoles) + // Assert that server supports expected features + txnOpts := connOpts.Copy() + txnOpts.RequiredProtocolInfo = tarantool.ProtocolInfo{ + Version: tarantool.ProtocolVersion(1), + Features: []tarantool.ProtocolFeature{ + tarantool.StreamsFeature, + tarantool.TransactionsFeature, + }, + } + + pool, err := examplePool(testRoles, txnOpts) if err != nil { fmt.Println(err) return @@ -672,8 +682,18 @@ func ExampleRollbackRequest() { return } + // Assert that server supports expected features + txnOpts := connOpts.Copy() + txnOpts.RequiredProtocolInfo = tarantool.ProtocolInfo{ + Version: tarantool.ProtocolVersion(1), + Features: []tarantool.ProtocolFeature{ + tarantool.StreamsFeature, + tarantool.TransactionsFeature, + }, + } + // example pool has only one rw instance - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, txnOpts) if err != nil { fmt.Println(err) return @@ -758,8 +778,18 @@ func ExampleBeginRequest_TxnIsolation() { return } + // Assert that server supports expected features + txnOpts := connOpts.Copy() + txnOpts.RequiredProtocolInfo = tarantool.ProtocolInfo{ + Version: tarantool.ProtocolVersion(1), + Features: []tarantool.ProtocolFeature{ + tarantool.StreamsFeature, + tarantool.TransactionsFeature, + }, + } + // example pool has only one rw instance - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, txnOpts) if err != nil { fmt.Println(err) return @@ -836,7 +866,7 @@ func ExampleBeginRequest_TxnIsolation() { } func ExampleConnectorAdapter() { - pool, err := examplePool(testRoles) + pool, err := examplePool(testRoles, connOpts) if err != nil { fmt.Println(err) } diff --git a/connection_test.go b/connection_test.go new file mode 100644 index 000000000..6da4465a3 --- /dev/null +++ b/connection_test.go @@ -0,0 +1,31 @@ +package tarantool_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + . "github.com/tarantool/go-tarantool" +) + +func TestOptsCopyPreservesRequiredProtocolFeatures(t *testing.T) { + original := Opts{ + RequiredProtocolInfo: ProtocolInfo{ + Version: ProtocolVersion(100), + Features: []ProtocolFeature{ProtocolFeature(99), ProtocolFeature(100)}, + }, + } + + origCopy := original.Copy() + + original.RequiredProtocolInfo.Features[1] = ProtocolFeature(98) + + require.Equal(t, + origCopy, + Opts{ + RequiredProtocolInfo: ProtocolInfo{ + Version: ProtocolVersion(100), + Features: []ProtocolFeature{ProtocolFeature(99), ProtocolFeature(100)}, + }, + }) +} diff --git a/const.go b/const.go index 4a3cb6833..35ec83380 100644 --- a/const.go +++ b/const.go @@ -18,6 +18,7 @@ const ( RollbackRequestCode = 16 PingRequestCode = 64 SubscribeRequestCode = 66 + IdRequestCode = 73 KeyCode = 0x00 KeySync = 0x01 @@ -41,6 +42,8 @@ const ( KeySQLBind = 0x41 KeySQLInfo = 0x42 KeyStmtID = 0x43 + KeyVersion = 0x54 + KeyFeatures = 0x55 KeyTimeout = 0x56 KeyTxnIsolation = 0x59 diff --git a/example_test.go b/example_test.go index 37939a268..f3fad05dd 100644 --- a/example_test.go +++ b/example_test.go @@ -18,7 +18,7 @@ type Tuple struct { Name string } -func example_connect() *tarantool.Connection { +func example_connect(opts tarantool.Opts) *tarantool.Connection { conn, err := tarantool.Connect(server, opts) if err != nil { panic("Connection is not established: " + err.Error()) @@ -45,7 +45,7 @@ func ExampleSslOpts() { } func ExampleConnection_Select() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) @@ -71,7 +71,7 @@ func ExampleConnection_Select() { } func ExampleConnection_SelectTyped() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() var res []Tuple @@ -94,7 +94,7 @@ func ExampleConnection_SelectTyped() { } func ExampleConnection_SelectAsync() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() spaceNo := uint32(517) @@ -128,7 +128,7 @@ func ExampleConnection_SelectAsync() { } func ExampleConnection_GetTyped() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() const space = "test" @@ -145,7 +145,7 @@ func ExampleConnection_GetTyped() { } func ExampleIntKey() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() const space = "test" @@ -162,7 +162,7 @@ func ExampleIntKey() { } func ExampleUintKey() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() const space = "test" @@ -179,7 +179,7 @@ func ExampleUintKey() { } func ExampleStringKey() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() const space = "teststring" @@ -199,7 +199,7 @@ func ExampleStringKey() { } func ExampleIntIntKey() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() const space = "testintint" @@ -220,7 +220,7 @@ func ExampleIntIntKey() { } func ExampleSelectRequest() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() req := tarantool.NewSelectRequest(517). @@ -250,7 +250,7 @@ func ExampleSelectRequest() { } func ExampleUpdateRequest() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() req := tarantool.NewUpdateRequest(517). @@ -280,7 +280,7 @@ func ExampleUpdateRequest() { } func ExampleUpsertRequest() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() var req tarantool.Request @@ -320,6 +320,33 @@ func ExampleUpsertRequest() { // response is []interface {}{[]interface {}{0x459, "first", "updated"}} } +func ExampleProtocolVersion() { + conn := example_connect(opts) + defer conn.Close() + + clientProtocolInfo := conn.ClientProtocolInfo() + fmt.Println("Connector client protocol version:", clientProtocolInfo.Version) + fmt.Println("Connector client protocol features:", clientProtocolInfo.Features) + // Output: + // Connector client protocol version: 4 + // Connector client protocol features: [StreamsFeature TransactionsFeature] +} + +func getTestTxnOpts() tarantool.Opts { + txnOpts := opts.Copy() + + // Assert that server supports expected protocol features + txnOpts.RequiredProtocolInfo = tarantool.ProtocolInfo{ + Version: tarantool.ProtocolVersion(1), + Features: []tarantool.ProtocolFeature{ + tarantool.StreamsFeature, + tarantool.TransactionsFeature, + }, + } + + return txnOpts +} + func ExampleCommitRequest() { var req tarantool.Request var resp *tarantool.Response @@ -331,7 +358,8 @@ func ExampleCommitRequest() { return } - conn := example_connect() + txnOpts := getTestTxnOpts() + conn := example_connect(txnOpts) defer conn.Close() stream, _ := conn.NewStream() @@ -407,7 +435,8 @@ func ExampleRollbackRequest() { return } - conn := example_connect() + txnOpts := getTestTxnOpts() + conn := example_connect(txnOpts) defer conn.Close() stream, _ := conn.NewStream() @@ -483,7 +512,8 @@ func ExampleBeginRequest_TxnIsolation() { return } - conn := example_connect() + txnOpts := getTestTxnOpts() + conn := example_connect(txnOpts) defer conn.Close() stream, _ := conn.NewStream() @@ -551,7 +581,7 @@ func ExampleBeginRequest_TxnIsolation() { } func ExampleFuture_GetIterator() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() const timeout = 3 * time.Second @@ -584,7 +614,7 @@ func ExampleFuture_GetIterator() { } func ExampleConnection_Ping() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() // Ping a Tarantool instance to check connection. @@ -599,7 +629,7 @@ func ExampleConnection_Ping() { } func ExampleConnection_Insert() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() // Insert a new tuple { 31, 1 }. @@ -632,7 +662,7 @@ func ExampleConnection_Insert() { } func ExampleConnection_Delete() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() // Insert a new tuple { 35, 1 }. @@ -665,7 +695,7 @@ func ExampleConnection_Delete() { } func ExampleConnection_Replace() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() // Insert a new tuple { 13, 1 }. @@ -714,7 +744,7 @@ func ExampleConnection_Replace() { } func ExampleConnection_Update() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() // Insert a new tuple { 14, 1 }. @@ -734,7 +764,7 @@ func ExampleConnection_Update() { } func ExampleConnection_Call() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() // Call a function 'simple_concat' with arguments. @@ -751,7 +781,7 @@ func ExampleConnection_Call() { } func ExampleConnection_Eval() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() // Run raw Lua code. @@ -788,7 +818,7 @@ func ExampleConnect() { // Example demonstrates how to retrieve information with space schema. func ExampleSchema() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() schema := conn.Schema @@ -810,7 +840,7 @@ func ExampleSchema() { // Example demonstrates how to retrieve information with space schema. func ExampleSpace() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() // Save Schema to a local variable to avoid races @@ -1021,7 +1051,7 @@ func ExampleConnection_NewPrepared() { // of the request. For those purposes use context.WithTimeout() as // the root context. func ExamplePingRequest_Context() { - conn := example_connect() + conn := example_connect(opts) defer conn.Close() timeout := time.Nanosecond diff --git a/export_test.go b/export_test.go index cc9a2a594..464a85844 100644 --- a/export_test.go +++ b/export_test.go @@ -111,6 +111,12 @@ func RefImplRollbackBody(enc *encoder) error { return fillRollback(enc) } +// RefImplIdBody is reference implementation for filling of an id +// request's body. +func RefImplIdBody(enc *encoder, protocolInfo ProtocolInfo) error { + return fillId(enc, protocolInfo) +} + func NewEncoder(w io.Writer) *encoder { return newEncoder(w) } diff --git a/protocol.go b/protocol.go new file mode 100644 index 000000000..7d7e451df --- /dev/null +++ b/protocol.go @@ -0,0 +1,139 @@ +package tarantool + +import ( + "context" + "fmt" +) + +// ProtocolVersion type stores Tarantool protocol version. +type ProtocolVersion uint64 + +// ProtocolVersion type stores a Tarantool protocol feature. +type ProtocolFeature uint64 + +// ProtocolInfo type aggregates Tarantool protocol version and features info. +type ProtocolInfo struct { + // Version is the supported protocol version. + Version ProtocolVersion + // Features are supported protocol features. + Features []ProtocolFeature +} + +// Copy returns the exact copy of a ProtocolInfo object. +// Any changes in copy will not affect the original values. +func (info ProtocolInfo) Copy() ProtocolInfo { + infoCopy := info + + if info.Features != nil { + infoCopy.Features = make([]ProtocolFeature, len(info.Features)) + copy(infoCopy.Features, info.Features) + } + + return infoCopy +} + +const ( + // StreamsFeature represents streams support (supported by connector). + StreamsFeature ProtocolFeature = 0 + // TransactionsFeature represents interactive transactions support. + // (supported by connector). + TransactionsFeature ProtocolFeature = 1 + // ErrorExtensionFeature represents support of MP_ERROR objects over MessagePack + // (unsupported by connector). + ErrorExtensionFeature ProtocolFeature = 2 + // WatchersFeature represents support of watchers + // (unsupported by connector). + WatchersFeature ProtocolFeature = 3 + // PaginationFeature represents support of pagination + // (unsupported by connector). + PaginationFeature ProtocolFeature = 4 +) + +// String returns the name of a Tarantool feature. +// If value X is not a known feature, returns "Unknown feature (code X)" string. +func (ftr ProtocolFeature) String() string { + switch ftr { + case StreamsFeature: + return "StreamsFeature" + case TransactionsFeature: + return "TransactionsFeature" + case ErrorExtensionFeature: + return "ErrorExtensionFeature" + case WatchersFeature: + return "WatchersFeature" + case PaginationFeature: + return "PaginationFeature" + default: + return fmt.Sprintf("Unknown feature (code %d)", ftr) + } +} + +var clientProtocolInfo ProtocolInfo = ProtocolInfo{ + // Protocol version supported by connector. Version 3 + // was introduced in Tarantool 2.10.0, version 4 was + // introduced in master 948e5cd (possible 2.10.5 or 2.11.0). + // Support of protocol version on connector side was introduced in + // 1.10.0. + Version: ProtocolVersion(4), + // Streams and transactions were introduced in protocol version 1 + // (Tarantool 2.10.0), in connector since 1.7.0. + Features: []ProtocolFeature{ + StreamsFeature, + TransactionsFeature, + }, +} + +// IdRequest informs the server about supported protocol +// version and protocol features. +type IdRequest struct { + baseRequest + protocolInfo ProtocolInfo +} + +func fillId(enc *encoder, protocolInfo ProtocolInfo) error { + enc.EncodeMapLen(2) + + encodeUint(enc, KeyVersion) + if err := enc.Encode(protocolInfo.Version); err != nil { + return err + } + + encodeUint(enc, KeyFeatures) + + t := len(protocolInfo.Features) + if err := enc.EncodeArrayLen(t); err != nil { + return err + } + + for _, feature := range protocolInfo.Features { + if err := enc.Encode(feature); err != nil { + return err + } + } + + return nil +} + +// NewIdRequest returns a new IdRequest. +func NewIdRequest(protocolInfo ProtocolInfo) *IdRequest { + req := new(IdRequest) + req.requestCode = IdRequestCode + req.protocolInfo = protocolInfo.Copy() + return req +} + +// Body fills an encoder with the id request body. +func (req *IdRequest) Body(res SchemaResolver, enc *encoder) error { + return fillId(enc, req.protocolInfo) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *IdRequest) Context(ctx context.Context) *IdRequest { + req.ctx = ctx + return req +} diff --git a/protocol_test.go b/protocol_test.go new file mode 100644 index 000000000..401b90a42 --- /dev/null +++ b/protocol_test.go @@ -0,0 +1,37 @@ +package tarantool_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + . "github.com/tarantool/go-tarantool" +) + +func TestProtocolInfoCopyPreservesFeatures(t *testing.T) { + original := ProtocolInfo{ + Version: ProtocolVersion(100), + Features: []ProtocolFeature{ProtocolFeature(99), ProtocolFeature(100)}, + } + + origCopy := original.Copy() + + original.Features[1] = ProtocolFeature(98) + + require.Equal(t, + origCopy, + ProtocolInfo{ + Version: ProtocolVersion(100), + Features: []ProtocolFeature{ProtocolFeature(99), ProtocolFeature(100)}, + }) +} + +func TestFeatureStringRepresentation(t *testing.T) { + require.Equal(t, StreamsFeature.String(), "StreamsFeature") + require.Equal(t, TransactionsFeature.String(), "TransactionsFeature") + require.Equal(t, ErrorExtensionFeature.String(), "ErrorExtensionFeature") + require.Equal(t, WatchersFeature.String(), "WatchersFeature") + require.Equal(t, PaginationFeature.String(), "PaginationFeature") + + require.Equal(t, ProtocolFeature(15532).String(), "Unknown feature (code 15532)") +} diff --git a/response.go b/response.go index 7b203bc54..9e38e970d 100644 --- a/response.go +++ b/response.go @@ -147,8 +147,10 @@ func (resp *Response) decodeBody() (err error) { offset := resp.buf.Offset() defer resp.buf.Seek(offset) - var l int + var l, larr int var stmtID, bindCount uint64 + var serverProtocolInfo ProtocolInfo + var feature ProtocolFeature d := newDecoder(&resp.buf) @@ -190,6 +192,22 @@ func (resp *Response) decodeBody() (err error) { if bindCount, err = d.DecodeUint64(); err != nil { return err } + case KeyVersion: + if err = d.Decode(&serverProtocolInfo.Version); err != nil { + return err + } + case KeyFeatures: + if larr, err = d.DecodeArrayLen(); err != nil { + return err + } + + serverProtocolInfo.Features = make([]ProtocolFeature, larr) + for i := 0; i < larr; i++ { + if err = d.Decode(&feature); err != nil { + return err + } + serverProtocolInfo.Features[i] = feature + } default: if err = d.Skip(); err != nil { return err @@ -204,6 +222,18 @@ func (resp *Response) decodeBody() (err error) { } resp.Data = []interface{}{stmt} } + + // Tarantool may send only version >= 1 + if (serverProtocolInfo.Version != ProtocolVersion(0)) || (serverProtocolInfo.Features != nil) { + if serverProtocolInfo.Version == ProtocolVersion(0) { + return fmt.Errorf("No protocol version provided in Id response") + } + if serverProtocolInfo.Features == nil { + return fmt.Errorf("No features provided in Id response") + } + resp.Data = []interface{}{serverProtocolInfo} + } + if resp.Code != OkCode && resp.Code != PushCode { resp.Code &^= ErrorCodeBit err = Error{resp.Code, resp.Error} diff --git a/tarantool_test.go b/tarantool_test.go index 1350390f9..5d97c6108 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -20,6 +20,17 @@ import ( "github.com/tarantool/go-tarantool/test_helpers" ) +var startOpts test_helpers.StartOpts = test_helpers.StartOpts{ + InitScript: "config.lua", + Listen: server, + WorkDir: "work_dir", + User: opts.User, + Pass: opts.Pass, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 3, + RetryTimeout: 500 * time.Millisecond, +} + type Member struct { Name string Nonce string @@ -2830,6 +2841,313 @@ func TestStream_DoWithClosedConn(t *testing.T) { } } +func TestConnectionProtocolInfoSupported(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + // First Tarantool protocol version (1, StreamsFeature and TransactionsFeature) + // was introduced between 2.10.0-beta1 and 2.10.0-beta2. + // Versions 2 (ErrorExtensionFeature) and 3 (WatchersFeature) were also + // introduced between 2.10.0-beta1 and 2.10.0-beta2. Version 4 + // (PaginationFeature) was introduced in master 948e5cd (possible 2.10.5 or + // 2.11.0). So each release Tarantool >= 2.10 (same as each Tarantool with + // id support) has protocol version >= 3 and first four features. + tarantool210ProtocolInfo := ProtocolInfo{ + Version: ProtocolVersion(3), + Features: []ProtocolFeature{ + StreamsFeature, + TransactionsFeature, + ErrorExtensionFeature, + WatchersFeature, + }, + } + + clientProtocolInfo := conn.ClientProtocolInfo() + require.Equal(t, + clientProtocolInfo, + ProtocolInfo{ + Version: ProtocolVersion(4), + Features: []ProtocolFeature{StreamsFeature, TransactionsFeature}, + }) + + serverProtocolInfo := conn.ServerProtocolInfo() + require.GreaterOrEqual(t, + serverProtocolInfo.Version, + tarantool210ProtocolInfo.Version) + require.Subset(t, + serverProtocolInfo.Features, + tarantool210ProtocolInfo.Features) +} + +func TestClientIdRequestObject(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + tarantool210ProtocolInfo := ProtocolInfo{ + Version: ProtocolVersion(3), + Features: []ProtocolFeature{ + StreamsFeature, + TransactionsFeature, + ErrorExtensionFeature, + WatchersFeature, + }, + } + + req := NewIdRequest(ProtocolInfo{ + Version: ProtocolVersion(1), + Features: []ProtocolFeature{StreamsFeature}, + }) + resp, err := conn.Do(req).Get() + require.Nilf(t, err, "No errors on Id request execution") + require.NotNilf(t, resp, "Response not empty") + require.NotNilf(t, resp.Data, "Response data not empty") + require.Equal(t, len(resp.Data), 1, "Response data contains exactly one object") + + serverProtocolInfo, ok := resp.Data[0].(ProtocolInfo) + require.Truef(t, ok, "Response Data object is an ProtocolInfo object") + require.GreaterOrEqual(t, + serverProtocolInfo.Version, + tarantool210ProtocolInfo.Version) + require.Subset(t, + serverProtocolInfo.Features, + tarantool210ProtocolInfo.Features) +} + +func TestClientIdRequestObjectWithNilContext(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + tarantool210ProtocolInfo := ProtocolInfo{ + Version: ProtocolVersion(3), + Features: []ProtocolFeature{ + StreamsFeature, + TransactionsFeature, + ErrorExtensionFeature, + WatchersFeature, + }, + } + + req := NewIdRequest(ProtocolInfo{ + Version: ProtocolVersion(1), + Features: []ProtocolFeature{StreamsFeature}, + }).Context(nil) //nolint + resp, err := conn.Do(req).Get() + require.Nilf(t, err, "No errors on Id request execution") + require.NotNilf(t, resp, "Response not empty") + require.NotNilf(t, resp.Data, "Response data not empty") + require.Equal(t, len(resp.Data), 1, "Response data contains exactly one object") + + serverProtocolInfo, ok := resp.Data[0].(ProtocolInfo) + require.Truef(t, ok, "Response Data object is an ProtocolInfo object") + require.GreaterOrEqual(t, + serverProtocolInfo.Version, + tarantool210ProtocolInfo.Version) + require.Subset(t, + serverProtocolInfo.Features, + tarantool210ProtocolInfo.Features) +} + +func TestClientIdRequestObjectWithPassedCanceledContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewIdRequest(ProtocolInfo{ + Version: ProtocolVersion(1), + Features: []ProtocolFeature{StreamsFeature}, + }).Context(ctx) //nolint + cancel() + resp, err := conn.Do(req).Get() + require.Nilf(t, resp, "Response is empty") + require.NotNilf(t, err, "Error is not empty") + require.Equal(t, err.Error(), "context is done") +} + +func TestClientIdRequestObjectWithContext(t *testing.T) { + var err error + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewIdRequest(ProtocolInfo{ + Version: ProtocolVersion(1), + Features: []ProtocolFeature{StreamsFeature}, + }).Context(ctx) //nolint + fut := conn.Do(req) + cancel() + resp, err := fut.Get() + require.Nilf(t, resp, "Response is empty") + require.NotNilf(t, err, "Error is not empty") + require.Equal(t, err.Error(), "context is done") +} + +func TestConnectionProtocolInfoUnsupported(t *testing.T) { + test_helpers.SkipIfIdSupported(t) + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + clientProtocolInfo := conn.ClientProtocolInfo() + require.Equal(t, + clientProtocolInfo, + ProtocolInfo{ + Version: ProtocolVersion(4), + Features: []ProtocolFeature{StreamsFeature, TransactionsFeature}, + }) + + serverProtocolInfo := conn.ServerProtocolInfo() + require.Equal(t, serverProtocolInfo, ProtocolInfo{}) +} + +func TestConnectionClientFeaturesUmmutable(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + info := conn.ClientProtocolInfo() + infoOrig := info.Copy() + info.Features[0] = ProtocolFeature(15532) + + require.Equal(t, conn.ClientProtocolInfo(), infoOrig) + require.NotEqual(t, conn.ClientProtocolInfo(), info) +} + +func TestConnectionServerFeaturesUmmutable(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + info := conn.ServerProtocolInfo() + infoOrig := info.Copy() + info.Features[0] = ProtocolFeature(15532) + + require.Equal(t, conn.ServerProtocolInfo(), infoOrig) + require.NotEqual(t, conn.ServerProtocolInfo(), info) +} + +func TestConnectionProtocolVersionRequirementSuccess(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + connOpts := opts.Copy() + connOpts.RequiredProtocolInfo = ProtocolInfo{ + Version: ProtocolVersion(3), + } + + conn, err := Connect(server, connOpts) + + require.Nilf(t, err, "No errors on connect") + require.NotNilf(t, conn, "Connect success") + + conn.Close() +} + +func TestConnectionProtocolVersionRequirementFail(t *testing.T) { + test_helpers.SkipIfIdSupported(t) + + connOpts := opts.Copy() + connOpts.RequiredProtocolInfo = ProtocolInfo{ + Version: ProtocolVersion(3), + } + + conn, err := Connect(server, connOpts) + + require.Nilf(t, conn, "Connect fail") + require.NotNilf(t, err, "Got error on connect") + require.Contains(t, err.Error(), "identify: protocol version 3 is not supported") +} + +func TestConnectionProtocolFeatureRequirementSuccess(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + connOpts := opts.Copy() + connOpts.RequiredProtocolInfo = ProtocolInfo{ + Features: []ProtocolFeature{TransactionsFeature}, + } + + conn, err := Connect(server, connOpts) + + require.NotNilf(t, conn, "Connect success") + require.Nilf(t, err, "No errors on connect") + + conn.Close() +} + +func TestConnectionProtocolFeatureRequirementFail(t *testing.T) { + test_helpers.SkipIfIdSupported(t) + + connOpts := opts.Copy() + connOpts.RequiredProtocolInfo = ProtocolInfo{ + Features: []ProtocolFeature{TransactionsFeature}, + } + + conn, err := Connect(server, connOpts) + + require.Nilf(t, conn, "Connect fail") + require.NotNilf(t, err, "Got error on connect") + require.Contains(t, err.Error(), "identify: protocol feature TransactionsFeature is not supported") +} + +func TestConnectionProtocolFeatureRequirementManyFail(t *testing.T) { + test_helpers.SkipIfIdSupported(t) + + connOpts := opts.Copy() + connOpts.RequiredProtocolInfo = ProtocolInfo{ + Features: []ProtocolFeature{TransactionsFeature, ProtocolFeature(15532)}, + } + + conn, err := Connect(server, connOpts) + + require.Nilf(t, conn, "Connect fail") + require.NotNilf(t, err, "Got error on connect") + require.Contains(t, + err.Error(), + "identify: protocol features TransactionsFeature, Unknown feature (code 15532) are not supported") +} + +func TestConnectionFeatureOptsImmutable(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + restartOpts := startOpts + restartOpts.Listen = "127.0.0.1:3014" + inst, err := test_helpers.StartTarantool(restartOpts) + defer test_helpers.StopTarantoolWithCleanup(inst) + + if err != nil { + log.Fatalf("Failed to prepare test tarantool: %s", err) + } + + retries := uint(10) + timeout := 100 * time.Millisecond + + connOpts := opts.Copy() + connOpts.Reconnect = timeout + connOpts.MaxReconnects = retries + connOpts.RequiredProtocolInfo = ProtocolInfo{ + Features: []ProtocolFeature{TransactionsFeature}, + } + + // Connect with valid opts + conn := test_helpers.ConnectWithValidation(t, server, connOpts) + defer conn.Close() + + // Change opts outside + connOpts.RequiredProtocolInfo.Features[0] = ProtocolFeature(15532) + + // Trigger reconnect with opts re-check + test_helpers.StopTarantool(inst) + err = test_helpers.RestartTarantool(&inst) + require.Nilf(t, err, "Failed to restart tarantool") + + connected := test_helpers.WaitUntilReconnected(conn, retries, timeout) + require.True(t, connected, "Reconnect success") +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body @@ -2842,17 +3160,9 @@ func runTestMain(m *testing.M) int { log.Fatalf("Could not check the Tarantool version") } - inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ - InitScript: "config.lua", - Listen: server, - WorkDir: "work_dir", - User: opts.User, - Pass: opts.Pass, - WaitStart: 100 * time.Millisecond, - ConnectRetry: 3, - RetryTimeout: 500 * time.Millisecond, - MemtxUseMvccEngine: !isStreamUnsupported, - }) + startOpts.MemtxUseMvccEngine = !isStreamUnsupported + + inst, err := test_helpers.StartTarantool(startOpts) defer test_helpers.StopTarantoolWithCleanup(inst) if err != nil { diff --git a/test_helpers/utils.go b/test_helpers/utils.go index c936e90b3..dff0bb357 100644 --- a/test_helpers/utils.go +++ b/test_helpers/utils.go @@ -2,6 +2,7 @@ package test_helpers import ( "testing" + "time" "github.com/tarantool/go-tarantool" ) @@ -40,6 +41,26 @@ func DeleteRecordByKey(t *testing.T, conn tarantool.Connector, } } +// WaitUntilReconnected waits until connection is reestablished. +// Returns false in case of connection is not in the connected state +// after specified retries count, true otherwise. +func WaitUntilReconnected(conn *tarantool.Connection, retries uint, timeout time.Duration) bool { + for i := uint(0); ; i++ { + connected := conn.ConnectedNow() + if connected { + return true + } + + if i == retries { + break + } + + time.Sleep(timeout) + } + + return false +} + func SkipIfSQLUnsupported(t testing.TB) { t.Helper() @@ -66,3 +87,36 @@ func SkipIfStreamsUnsupported(t *testing.T) { t.Skip("Skipping test for Tarantool without streams support") } } + +// SkipIfIdUnsupported skips test run if Tarantool without +// IPROTO_ID support is used. +func SkipIfIdUnsupported(t *testing.T) { + t.Helper() + + // Tarantool supports Id requests since version 2.10.0 + isLess, err := IsTarantoolVersionLess(2, 10, 0) + if err != nil { + t.Fatalf("Could not check the Tarantool version") + } + + if isLess { + t.Skip("Skipping test for Tarantool without id requests support") + } +} + +// SkipIfIdSupported skips test run if Tarantool with +// IPROTO_ID support is used. Skip is useful for tests validating +// that protocol info is processed as expected even for pre-IPROTO_ID instances. +func SkipIfIdSupported(t *testing.T) { + t.Helper() + + // Tarantool supports Id requests since version 2.10.0 + isLess, err := IsTarantoolVersionLess(2, 10, 0) + if err != nil { + t.Fatalf("Could not check the Tarantool version") + } + + if !isLess { + t.Skip("Skipping test for Tarantool with non-zero protocol version and features") + } +}