Skip to content

Commit

Permalink
Merge pull request #644 from ooni/issue/576
Browse files Browse the repository at this point in the history
tls/timeout: use deadline rather than context
  • Loading branch information
bassosimone authored May 29, 2020
2 parents a75f7aa + 367fcf7 commit fee7672
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 56 deletions.
12 changes: 12 additions & 0 deletions netx/dialer/eof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ func (EOFConn) RemoteAddr() net.Addr {
return EOFAddr{}
}

func (EOFConn) SetDeadline(t time.Time) error {
return nil
}

func (EOFConn) SetReadDeadline(t time.Time) error {
return nil
}

func (EOFConn) SetWriteDeadline(t time.Time) error {
return nil
}

type EOFAddr struct{}

func (EOFAddr) Network() string {
Expand Down
24 changes: 9 additions & 15 deletions netx/dialer/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,10 @@ func (h SystemTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
tlsconn := tls.Client(conn, config)
errch := make(chan error, 1)
go func() {
errch <- tlsconn.Handshake()
}()
select {
case err := <-errch:
if err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
case <-ctx.Done():
return nil, tls.ConnectionState{}, ctx.Err()
if err := tlsconn.Handshake(); err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
}

// TimeoutTLSHandshaker is a TLSHandshaker with timeout
Expand All @@ -54,9 +45,12 @@ func (h TimeoutTLSHandshaker) Handshake(
if h.HandshakeTimeout != 0 {
timeout = h.HandshakeTimeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return h.TLSHandshaker.Handshake(ctx, conn, config)
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, tls.ConnectionState{}, err
}
tlsconn, connstate, err := h.TLSHandshaker.Handshake(ctx, conn, config)
conn.SetDeadline(time.Time{})
return tlsconn, connstate, err
}

// ErrorWrapperTLSHandshaker wraps the returned error to be an OONI error
Expand Down
112 changes: 87 additions & 25 deletions netx/dialer/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,86 @@ import (
"github.com/ooni/probe-engine/netx/modelx"
)

func TestUnitSystemTLSHandshakerContextDone(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // immeditely cancel
func TestUnitSystemTLSHandshakerEOFError(t *testing.T) {
h := dialer.SystemTLSHandshaker{}
conn, _, err := h.Handshake(ctx, dialer.EOFConn{}, new(tls.Config))
if err != context.Canceled {
conn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{
ServerName: "x.org",
})
if err != io.EOF {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}

func TestUnitSystemTLSHandshakerEOFError(t *testing.T) {
h := dialer.SystemTLSHandshaker{}
conn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{
ServerName: "x.org",
})
if err != io.EOF {
func TestUnitTimeoutTLSHandshakerSetDeadlineError(t *testing.T) {
h := dialer.TimeoutTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
expected := errors.New("mocked error")
conn, _, err := h.Handshake(
context.Background(), &dialer.FakeConn{SetDeadlineError: expected},
new(tls.Config))
if !errors.Is(err, expected) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}

func TestUnitTimeoutTLSHandshaker(t *testing.T) {
func TestUnitTimeoutTLSHandshakerEOFError(t *testing.T) {
h := dialer.TimeoutTLSHandshaker{
TLSHandshaker: SlowTLSHandshaker{},
TLSHandshaker: dialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
conn, _, err := h.Handshake(
context.Background(), dialer.EOFConn{}, new(tls.Config))
if err != context.DeadlineExceeded {
context.Background(), dialer.EOFConn{}, &tls.Config{ServerName: "x.org"})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}

type SlowTLSHandshaker struct{}

func (SlowTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
select {
case <-ctx.Done():
return nil, tls.ConnectionState{}, ctx.Err()
case <-time.After(30 * time.Second):
return nil, tls.ConnectionState{}, io.EOF
func TestUnitTimeoutTLSHandshakerCallsSetDeadline(t *testing.T) {
h := dialer.TimeoutTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
underlying := &SetDeadlineConn{}
conn, _, err := h.Handshake(
context.Background(), underlying, &tls.Config{ServerName: "x.org"})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
if len(underlying.deadlines) != 2 {
t.Fatal("SetDeadline not called twice")
}
if underlying.deadlines[0].Before(time.Now()) {
t.Fatal("the first SetDeadline call was incorrect")
}
if !underlying.deadlines[1].IsZero() {
t.Fatal("the second SetDeadline call was incorrect")
}
}

type SetDeadlineConn struct {
dialer.EOFConn
deadlines []time.Time
}

func (c *SetDeadlineConn) SetDeadline(t time.Time) error {
c.deadlines = append(c.deadlines, t)
return nil
}

func TestUnitErrorWrapperTLSHandshakerFailure(t *testing.T) {
h := dialer.ErrorWrapperTLSHandshaker{TLSHandshaker: dialer.EOFTLSHandshaker{}}
conn, _, err := h.Handshake(
Expand Down Expand Up @@ -212,3 +238,39 @@ func (h *RecorderTLSHandshaker) Handshake(
h.SNI = config.ServerName
return h.TLSHandshaker.Handshake(ctx, conn, config)
}

func TestIntegrationDialTLSContextGood(t *testing.T) {
dialer := dialer.TLSDialer{
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: dialer.SystemTLSHandshaker{},
}
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("connection is nil")
}
conn.Close()
}

func TestIntegrationDialTLSContextTimeout(t *testing.T) {
dialer := dialer.TLSDialer{
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: dialer.ErrorWrapperTLSHandshaker{
TLSHandshaker: dialer.TimeoutTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
HandshakeTimeout: 10 * time.Microsecond,
},
},
}
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
if err.Error() != modelx.FailureGenericTimeoutError {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}
25 changes: 12 additions & 13 deletions netx/selfcensor/selfcensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,22 @@ type Spec struct {
// values are the IP addresses to return. If you set the values for
// a domain to `[]string{"NXDOMAIN"}`, the system resolver will return
// an NXDOMAIN response. If you set the values for a domain to
// `[]string{"TIMEOUT"}` the system resolver will block until the
// context used by the code is expired.
// `[]string{"TIMEOUT"}` the system resolver will return "i/o timeout".
PoisonSystemDNS map[string][]string

// BlockedEndpoints allows you to block specific IP endpoints. The key is
// `IP:port` to block. The format is the same of net.JoinHostPort. If
// the value is "REJECT", then the connection attempt will fail with
// ECONNREFUSED. If the value is "TIMEOUT", then the connector will block
// until the context is expired. If the value is anything else, we
// will perform a "REJECT".
// ECONNREFUSED. If the value is "TIMEOUT", then the connector will return
// claiming "i/o timeout". If the value is anything else, we will
// perform a "REJECT".
BlockedEndpoints map[string]string

// BlockedFingerprints allows you to block packets whose body contains
// specific fingerprints. Of course, the key is the fingerprint. If
// the value is "RST", then the connection will be reset. If the value
// is "TIMEOUT", then the code will block until the context is
// expired. If the value is anything else, we will perform a "RST".
// is "TIMEOUT", then the code will return claiming "i/o timeout". If
// the value is anything else, we will perform a "RST".
BlockedFingerprints map[string]string
}

Expand Down Expand Up @@ -109,6 +108,9 @@ func MaybeEnable(data string) (err error) {
// not censor anything unless you call selfcensor.Enable().
type SystemResolver struct{}

// errTimeout indicates that a timeout error has occurred.
var errTimeout = errors.New("i/o timeout")

// LookupHost implements Resolver.LookupHost
func (r SystemResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
if enabled.Load() != 0 { // jumps not taken by default
Expand All @@ -121,8 +123,7 @@ func (r SystemResolver) LookupHost(ctx context.Context, hostname string) ([]stri
return nil, errors.New("no such host")
}
if len(values) == 1 && values[0] == "TIMEOUT" {
<-ctx.Done()
return nil, ctx.Err()
return nil, errTimeout
}
if len(values) > 0 {
return values, nil
Expand Down Expand Up @@ -160,8 +161,7 @@ func (d SystemDialer) DialContext(
if spec.BlockedEndpoints != nil {
action, ok := spec.BlockedEndpoints[address]
if ok && action == "TIMEOUT" {
<-ctx.Done()
return nil, ctx.Err()
return nil, errTimeout
}
if ok {
switch network {
Expand Down Expand Up @@ -205,8 +205,7 @@ func (c connWrapper) match(p []byte, n int) (int, error) {
for key, value := range c.fingerprints {
if bytes.Index(p, []byte(key)) != -1 {
if value == "TIMEOUT" {
<-c.closed
return 0, errors.New("use of closed network connection")
return 0, errTimeout
}
return 0, errors.New("connection reset by peer")
}
Expand Down
5 changes: 2 additions & 3 deletions netx/selfcensor/selfcensor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package selfcensor_test

import (
"context"
"errors"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -108,7 +107,7 @@ func TestResolveCauseTimeout(t *testing.T) {
t.Fatal("we expected self censorship to be enabled now")
}
addrs, err := selfcensor.SystemResolver{}.LookupHost(ctx, "dns.google")
if !errors.Is(err, context.DeadlineExceeded) {
if err == nil || err.Error() != "i/o timeout" {
t.Fatal("not the error we expected")
}
if addrs != nil {
Expand Down Expand Up @@ -181,7 +180,7 @@ func TestDialCauseTimeout(t *testing.T) {
t.Fatal("we expected self censorship to be enabled now")
}
addrs, err := selfcensor.SystemDialer{}.DialContext(ctx, "tcp", "8.8.8.8:443")
if !errors.Is(err, context.DeadlineExceeded) {
if err == nil || err.Error() != "i/o timeout" {
t.Fatal("not the error we expected")
}
if addrs != nil {
Expand Down

0 comments on commit fee7672

Please sign in to comment.