Skip to content

Commit

Permalink
dial: add DialContext function
Browse files Browse the repository at this point in the history
In order to replace timeouts with contexts in `Connect` instance
creation (go-tarantool), I need a `DialContext` function.
It accepts context, and cancels, if context is canceled by user.

Part of tarantool/go-tarantool#136
  • Loading branch information
DerekBum committed Oct 1, 2023
1 parent b452431 commit a46839f
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 27 deletions.
121 changes: 94 additions & 27 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package openssl

import (
"context"
"errors"
"net"
"time"
Expand Down Expand Up @@ -89,8 +90,53 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
// parameters.
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
flags DialFlags) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
return dialSession(d, network, addr, ctx, flags, nil)
host, err := parseHost(addr)
if err != nil {
return nil, err
}

conn, err := net.DialTimeout(network, addr, timeout)
if err != nil {
return nil, err
}
ctx, err = prepareCtx(ctx)
if err != nil {
return nil, err
}
client, err := createSession(conn, flags, host, ctx, nil)
if err != nil {
conn.Close()
}
return client, err
}

// DialContext acts like Dial but takes a context for network dial.
//
// The context includes only network dial. It does not include OpenSSL calls.
//
// See func Dial for a description of the network, addr, ctx and flags
// parameters.
func DialContext(context context.Context, network, addr string,
ctx *Ctx, flags DialFlags) (*Conn, error) {
host, err := parseHost(addr)
if err != nil {
return nil, err
}

dialer := net.Dialer{}
conn, err := dialer.DialContext(context, network, addr)
if err != nil {
return nil, err
}
ctx, err = prepareCtx(ctx)
if err != nil {
return nil, err
}
client, err := createSession(conn, flags, host, ctx, nil)
if err != nil {
conn.Close()
}
return client, err
}

// DialSession will connect to network/address and then wrap the corresponding
Expand All @@ -108,59 +154,80 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
// can be retrieved from the GetSession method on the Conn.
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
var d net.Dialer
return dialSession(d, network, addr, ctx, flags, session)
}

func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
host, _, err := net.SplitHostPort(addr)
host, err := parseHost(addr)
if err != nil {
return nil, err
}
if ctx == nil {
var err error
ctx, err = NewCtx()
if err != nil {
return nil, err
}
// TODO: use operating system default certificate chain?
}

c, err := d.Dial(network, addr)
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
conn, err := Client(c, ctx)
ctx, err = prepareCtx(ctx)
if err != nil {
c.Close()
return nil, err
}
if session != nil {
err := conn.setSession(session)
client, err := createSession(conn, flags, host, ctx, session)
if err != nil {
conn.Close()
}
return client, err
}

func prepareCtx(ctx *Ctx) (*Ctx, error) {
if ctx == nil {
var err error
ctx, err = NewCtx()
if err != nil {
c.Close()
return nil, err
}
// TODO: use operating system default certificate chain?
}
return ctx, nil
}

func parseHost(addr string) (string, error) {
host, _, err := net.SplitHostPort(addr)
return host, err
}

func handshake(conn *Conn, host string, flags DialFlags) error {
var err error
if flags&DisableSNI == 0 {
err = conn.SetTlsExtHostName(host)
if err != nil {
conn.Close()
return nil, err
return err
}
}
err = conn.Handshake()
if err != nil {
conn.Close()
return nil, err
return err
}
if flags&InsecureSkipHostVerification == 0 {
err = conn.VerifyHostname(host)
if err != nil {
return err
}
}
return nil
}

func createSession(c net.Conn, flags DialFlags, host string, ctx *Ctx,
session []byte) (*Conn, error) {
conn, err := Client(c, ctx)
if err != nil {
return nil, err
}
if session != nil {
err := conn.setSession(session)
if err != nil {
conn.Close()
return nil, err
}
}
if err := handshake(conn, host, flags); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
94 changes: 94 additions & 0 deletions net_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package openssl

import (
"context"
"crypto/rand"
"io"
"net"
"sync"
"testing"
"time"
)

var conn net.Conn

func sslConnect(t *testing.T, ssl_listener net.Listener) {
for {
var err error
conn, err = ssl_listener.Accept()
if err != nil {
t.Errorf("failed accept: %s", err)
continue
}
io.Copy(conn, io.LimitReader(rand.Reader, 1024))
break
}
}

func TestDial(t *testing.T) {
ctx := getCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
if err != nil {
t.Fatal(err)
}

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
sslConnect(t, ssl_listener)
wg.Done()
}()

client, err := Dial(ssl_listener.Addr().Network(),
ssl_listener.Addr().String(), ctx, InsecureSkipHostVerification)

wg.Wait()

if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if client.is_shutdown {
t.Fatal("client is closed after creation")
}
}

func TestDialTimeout(t *testing.T) {
ctx := getCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
if err != nil {
t.Fatal(err)
}

client, err := DialTimeout(ssl_listener.Addr().Network(),
ssl_listener.Addr().String(), time.Nanosecond, ctx, 0)

if client != nil || err == nil {
t.Fatalf("expected error")
}
}

func TestDialContext(t *testing.T) {
ctx := getCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
if err != nil {
t.Fatal(err)
}

cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
client, err := DialContext(cancelCtx, ssl_listener.Addr().Network(),
ssl_listener.Addr().String(), ctx, 0)

if client != nil || err == nil {
t.Fatalf("expected error")
}
}

0 comments on commit a46839f

Please sign in to comment.