Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: move tlsdialer to netxlite #404

Merged
merged 1 commit into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/engine/legacy/netx/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) {
// - SystemTLSHandshaker
//
// If you have others needs, manually build the chain you need.
func newTLSDialer(d dialer.Dialer, config *tls.Config) tlsdialer.TLSDialer {
return tlsdialer.TLSDialer{
func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer {
return &netxlite.TLSDialer{
Config: config,
Dialer: d,
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/netx/netx.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func NewTLSDialer(config Config) TLSDialer {
}
config.TLSConfig.RootCAs = config.CertPool
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
return tlsdialer.TLSDialer{
return &netxlite.TLSDialer{
Config: config.TLSConfig,
Dialer: config.Dialer,
TLSHandshaker: h,
Expand Down
14 changes: 7 additions & 7 deletions internal/engine/netx/netx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {

func TestNewTLSDialerVanilla(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{})
rtd, ok := td.(tlsdialer.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialer)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
Expand Down Expand Up @@ -243,7 +243,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{
TLSConfig: new(tls.Config),
})
rtd, ok := td.(tlsdialer.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialer)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
Expand Down Expand Up @@ -272,7 +272,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{
Logger: log.Log,
})
rtd, ok := td.(tlsdialer.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialer)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
Expand Down Expand Up @@ -312,7 +312,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{
TLSSaver: saver,
})
rtd, ok := td.(tlsdialer.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialer)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
Expand Down Expand Up @@ -352,7 +352,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
TLSConfig: new(tls.Config),
NoTLSVerify: true,
})
rtd, ok := td.(tlsdialer.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialer)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
Expand Down Expand Up @@ -384,7 +384,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{
NoTLSVerify: true,
})
rtd, ok := td.(tlsdialer.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialer)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
Expand Down Expand Up @@ -444,7 +444,7 @@ func TestNewWithDialer(t *testing.T) {

func TestNewWithTLSDialer(t *testing.T) {
expected := errors.New("mocked error")
tlsDialer := tlsdialer.TLSDialer{
tlsDialer := &netxlite.TLSDialer{
Config: new(tls.Config),
Dialer: netx.FakeDialer{Err: expected},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Expand Down
14 changes: 5 additions & 9 deletions internal/engine/netx/tlsdialer/integration_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package tlsdialer_test

import (
"context"
"net"
"net/http"
"testing"

"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)

Expand All @@ -16,18 +14,16 @@ func TestTLSDialerSuccess(t *testing.T) {
t.Skip("skip test in short mode")
}
log.SetLevel(log.DebugLevel)
dialer := tlsdialer.TLSDialer{Dialer: new(net.Dialer),
dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer),
TLSHandshaker: &netxlite.TLSHandshakerLogger{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Logger: log.Log,
},
}
txp := &http.Transport{DialTLS: func(network, address string) (net.Conn, error) {
// AlpineLinux edge is still using Go 1.13. We cannot switch to
// using DialTLSContext here as we'd like to until either Alpine
// switches to Go 1.14 or we drop the MK dependency.
return dialer.DialTLSContext(context.Background(), network, address)
}}
txp := &http.Transport{
DialTLSContext: dialer.DialTLSContext,
ForceAttemptHTTP2: true,
}
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if err != nil {
Expand Down
12 changes: 6 additions & 6 deletions internal/engine/netx/tlsdialer/saver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
}
nextprotos := []string{"h2"}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
tlsdlr := &netxlite.TLSDialer{
Config: &tls.Config{NextProtos: nextprotos},
Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
Expand Down Expand Up @@ -115,7 +115,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
}
nextprotos := []string{"h2"}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
tlsdlr := &netxlite.TLSDialer{
Config: &tls.Config{NextProtos: nextprotos},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
Expand Down Expand Up @@ -181,7 +181,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
tlsdlr := &netxlite.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
tlsdlr := &netxlite.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Expand Down Expand Up @@ -247,7 +247,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
tlsdlr := &netxlite.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Expand Down Expand Up @@ -280,7 +280,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
tlsdlr := &netxlite.TLSDialer{
Config: &tls.Config{InsecureSkipVerify: true},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
Expand Down
38 changes: 0 additions & 38 deletions internal/engine/netx/tlsdialer/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,41 +66,3 @@ func (h EmitterTLSHandshaker) Handshake(
})
return tlsconn, state, err
}

// TLSDialer is the TLS dialer
type TLSDialer struct {
Config *tls.Config
Dialer UnderlyingDialer
TLSHandshaker TLSHandshaker
}

// DialTLSContext is like tls.DialTLS but with the signature of net.Dialer.DialContext
func (d TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
// Implementation note: when DialTLS is not set, the code in
// net/http will perform the handshake. Otherwise, if DialTLS
// is set, we will end up here. This code is still used when
// performing non-HTTP TLS-enabled dial operations.
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
config := d.Config
if config == nil {
config = new(tls.Config)
} else {
config = config.Clone()
}
if config.ServerName == "" {
config.ServerName = host
}
tlsconn, _, err := d.TLSHandshaker.Handshake(ctx, conn, config)
if err != nil {
conn.Close()
return nil, err
}
return tlsconn, nil
}
113 changes: 0 additions & 113 deletions internal/engine/netx/tlsdialer/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/tls"
"errors"
"io"
"net"
"testing"
"time"

Expand Down Expand Up @@ -97,115 +96,3 @@ func TestEmitterTLSHandshakerFailure(t *testing.T) {
t.Fatal("expected nonzero DurationSinceBeginning")
}
}

func TestTLSDialerFailureSplitHostPort(t *testing.T) {
dialer := tlsdialer.TLSDialer{}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com") // missing port
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}

func TestTLSDialerFailureDialing(t *testing.T) {
dialer := tlsdialer.TLSDialer{Dialer: tlsdialer.EOFDialer{}}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}

func TestTLSDialerFailureHandshaking(t *testing.T) {
rec := &RecorderTLSHandshaker{TLSHandshaker: &netxlite.TLSHandshakerStdlib{}}
dialer := tlsdialer.TLSDialer{
Dialer: tlsdialer.EOFConnDialer{},
TLSHandshaker: rec,
}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
if rec.SNI != "www.google.com" {
t.Fatal("unexpected SNI value")
}
}

func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) {
rec := &RecorderTLSHandshaker{TLSHandshaker: &netxlite.TLSHandshakerStdlib{}}
dialer := tlsdialer.TLSDialer{
Config: &tls.Config{
ServerName: "x.org",
},
Dialer: tlsdialer.EOFConnDialer{},
TLSHandshaker: rec,
}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
if rec.SNI != "x.org" {
t.Fatal("unexpected SNI value")
}
}

type RecorderTLSHandshaker struct {
tlsdialer.TLSHandshaker
SNI string
}

func (h *RecorderTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
h.SNI = config.ServerName
return h.TLSHandshaker.Handshake(ctx, conn, config)
}

func TestDialTLSContextGood(t *testing.T) {
dialer := tlsdialer.TLSDialer{
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
}
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("connection is nil")
}
conn.Close()
}

func TestDialTLSContextTimeout(t *testing.T) {
dialer := tlsdialer.TLSDialer{
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{
Timeout: 10 * time.Microsecond,
},
},
}
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
if err.Error() != errorx.FailureGenericTimeoutError {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}
Loading