Skip to content

Commit

Permalink
dial: add the ability to connect via socket fd
Browse files Browse the repository at this point in the history
This patch introduces `FdDialer`, which connects to Tarantool
using an existing socket file descriptor.

`FdDialer` is not authenticated when creating a connection.

Closes #321
  • Loading branch information
askalt committed Nov 17, 2023
1 parent ea5b53a commit a8c757e
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
work_dir*
.rocks
bench*
testdata/sidecar
56 changes: 56 additions & 0 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net"
"os"
"strings"
"time"

Expand Down Expand Up @@ -267,6 +268,61 @@ func (d OpenSslDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
return conn, nil
}

type FdDialer struct {
// Fd is a socket file descrpitor.
Fd uintptr
// RequiredProtocol contains minimal protocol version and
// list of protocol features that should be supported by
// Tarantool server. By default, there are no restrictions.
RequiredProtocolInfo ProtocolInfo
}

type fdAddr struct {
fd uintptr
}

func (a fdAddr) Network() string {
return "fd"
}

func (a fdAddr) String() string {
return fmt.Sprintf("fd://%d", a.fd)
}

type fdConn struct {
fd uintptr
net.Conn
}

func (c *fdConn) LocalAddr() net.Addr {
return fdAddr{fd: c.fd}
}

// Dial makes FdDialer satisfy the Dialer interface.
func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
file := os.NewFile(d.Fd, "")
c, err := net.FileConn(file)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}

conn := new(tntConn)
conn.isLocal = true
conn.net = &fdConn{fd: d.Fd, Conn: c}

dc := &deadlineIO{to: opts.IoTimeout, c: conn.net}
conn.reader = bufio.NewReaderSize(dc, bufSize)
conn.writer = bufio.NewWriterSize(dc, bufSize)

_, err = rawDial(conn, d.RequiredProtocolInfo)
if err != nil {
conn.net.Close()
return nil, err
}

return conn, nil
}

// Addr makes tntConn satisfy the Conn interface.
func (c *tntConn) Addr() net.Addr {
if c.isLocal {
Expand Down
56 changes: 52 additions & 4 deletions dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,11 @@ func assertRequest(t *testing.T, r io.Reader, expected []byte) {
}

type testDialOpts struct {
errGreeting bool
errId bool
errAuth bool
idUnsupported bool
errGreeting bool
errId bool
errAuth bool
idUnsupported bool
authNoRequired bool
}

func testDialAccept(t *testing.T, ch chan struct{}, opts testDialOpts, l net.Listener) {
Expand Down Expand Up @@ -458,6 +459,9 @@ func testDialAccept(t *testing.T, ch chan struct{}, opts testDialOpts, l net.Lis
client.Write(idResponse)
}

if opts.authNoRequired {
return
}
// Check Auth request.
assertRequest(t, client, authRequestExpected)
if opts.errAuth {
Expand Down Expand Up @@ -566,3 +570,47 @@ func TestOpenSslDialer_Dial(t *testing.T) {
testDialer(t, l, dialer)
}
}

func TestFdDialer_Dial(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()

for _, cs := range testDialCases {
opts := cs.opts
if opts.errAuth {
// No need to test FdDialer for auth errors.
continue
}
// FdDialer doesn't make auth requests.
opts.authNoRequired = true

t.Run(cs.name, func(t *testing.T) {
ch := make(chan struct{})
go testDialAccept(t, ch, opts, l)

sock, err := net.Dial("tcp", addr)
require.NoError(t, err)
f, err := sock.(*net.TCPConn).File()
require.NoError(t, err)

dialer := tarantool.FdDialer{
Fd: f.Fd(),
}
ctx, cancel := test_helpers.GetConnectContext()
defer cancel()
conn, err := dialer.Dial(ctx, tarantool.DialOpts{
IoTimeout: time.Second * 2,
})
<-ch
if cs.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, cs.protocolInfo, conn.ProtocolInfo())
require.Equal(t, cs.version, []byte(conn.Greeting().Version))
conn.Close()
})
}
}
31 changes: 31 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tarantool_test
import (
"context"
"fmt"
"net"
"time"

"github.com/tarantool/go-iproto"
Expand Down Expand Up @@ -1337,3 +1338,33 @@ func ExampleWatchOnceRequest() {
fmt.Println(resp.Data)
}
}

// This example demonstrates how to use an existing socket file descriptor
// to establish a connection with Tarantool. This can be useful if the socket fd
// was inherited from the Tarantool process itself.
// For details, please see TestFdDialer.
func ExampleFdDialer() {
addr := dialer.Address
c, err := net.Dial("tcp", addr)
if err != nil {
fmt.Printf("can't establish connection: %v\n", err)
return
}
f, err := c.(*net.TCPConn).File()
if err != nil {
fmt.Printf("unexpected error: %v\n", err)
}
dialer := tarantool.FdDialer{
Fd: f.Fd(),
}
// Use an existing socket fd to create connection with Tarantool.
conn, err := tarantool.Connect(context.Background(), dialer, opts)
if err != nil {
fmt.Printf("connect error: %v\n", err)
return
}
resp, err := conn.Do(tarantool.NewPingRequest()).Get()
fmt.Println(resp.Code, err)
// Output:
// 0 <nil>
}
82 changes: 82 additions & 0 deletions tarantool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"log"
"math"
"os"
"os/exec"
"path/filepath"
"reflect"
"runtime"
"strings"
Expand Down Expand Up @@ -76,6 +78,7 @@ func (m *Member) DecodeMsgpack(d *msgpack.Decoder) error {
}

var server = "127.0.0.1:3013"
var fdDialerTestServer = "127.0.0.1:3014"
var spaceNo = uint32(617)
var spaceName = "test"
var indexNo = uint32(0)
Expand Down Expand Up @@ -3950,6 +3953,85 @@ func TestConnect_context_cancel(t *testing.T) {
}
}

func buildSidecar(dir string) error {
goPath, err := exec.LookPath("go")
if err != nil {
return err
}
cmd := exec.Command(goPath, "build", "sidecar.go")
cmd.Dir = filepath.Join(dir, "testdata")
return cmd.Run()
}

func TestFdDialer(t *testing.T) {
isLess, err := test_helpers.IsTarantoolVersionLess(3, 0, 0)
if err != nil || isLess {
t.Skip("box.session.new present in Tarantool since version 3.0")
}

wd, err := os.Getwd()
require.NoError(t, err)

err = buildSidecar(wd)
require.NoErrorf(t, err, "failed to build sidecar: %v", err)

instOpts := startOpts
instOpts.Listen = fdDialerTestServer
inst, err := test_helpers.StartTarantool(instOpts, TtDialer{
Address: fdDialerTestServer,
User: "test",
Password: "test",
})
require.NoError(t, err)
defer test_helpers.StopTarantoolWithCleanup(inst)

conn := test_helpers.ConnectWithValidation(t, dialer, opts)
defer conn.Close()

sidecarExe := filepath.Join(wd, "testdata", "sidecar")

evalBody := fmt.Sprintf(`
local socket = require('socket')
local popen = require('popen')
local os = require('os')
local s1, s2 = socket.socketpair('AF_UNIX', 'SOCK_STREAM', 0)
--[[ Tell sidecar which fd use to connect. --]]
os.setenv('SOCKET_FD', tostring(s2:fd()))
box.session.new({
type = 'binary',
fd = s1:fd(),
user = 'test',
})
s1:detach()
local ph, err = popen.new({'%s'}, {
stdout = popen.opts.PIPE,
stderr = popen.opts.PIPE,
inherit_fds = {s2:fd()},
})
if err ~= nil then
return 1, err
end
ph:wait()
local status_code = ph:info().status.exit_code
local stderr = ph:read({stderr=true}):rstrip()
local stdout = ph:read({stdout=true}):rstrip()
return status_code, stderr, stdout
`, sidecarExe)

var resp []interface{}
err = conn.EvalTyped(evalBody, []interface{}{}, &resp)
require.NoError(t, err)
require.Equal(t, "", resp[1], resp[1])
require.Equal(t, "", resp[2], resp[2])
require.Equal(t, int8(0), resp[0])
}

// runTestMain is a body of TestMain function
// (see https://pkg.go.dev/testing#hdr-Main).
// Using defer + os.Exit is not works so TestMain body
Expand Down
38 changes: 38 additions & 0 deletions testdata/sidecar.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package main

import (
"context"
"os"
"strconv"

"github.com/tarantool/go-tarantool/v2"
)

func main() {
fd, err := strconv.Atoi(os.Getenv("SOCKET_FD"))
if err != nil {
panic(err)
}
dialer := tarantool.FdDialer{
Fd: uintptr(fd),
}
conn, err := tarantool.Connect(context.Background(), dialer, tarantool.Opts{})
if err != nil {
panic(err)
}
if _, err := conn.Do(tarantool.NewPingRequest()).
Get(); err != nil {
panic(err)
}
// Insert new tuple.
if _, err := conn.Do(tarantool.NewInsertRequest("test").
Tuple([]interface{}{239})).Get(); err != nil {
panic(err)
}
// Delete inserted tuple.
if _, err := conn.Do(tarantool.NewDeleteRequest("test").
Index("primary").
Key([]interface{}{239})).Get(); err != nil {
panic(err)
}
}

0 comments on commit a8c757e

Please sign in to comment.