Skip to content

Commit

Permalink
fix(netxlite): allow overriding default cert pool (#1069)
Browse files Browse the repository at this point in the history
This diff tweaks #1068
to make sure overriding the default cert pool works.

In #1068 we introduced
code to add this functionality but we never tested it was working
as intended. It turns out it was not!

Because this diff amends the previous diff, we'll consider it
part of ooni/probe#2135.
  • Loading branch information
bassosimone authored Feb 1, 2023
1 parent 7485153 commit f78b7fb
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 23 deletions.
2 changes: 1 addition & 1 deletion internal/netxlite/dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestDialerSystem(t *testing.T) {
t.Fatal("unexpected conn")
}
if stop.Sub(start) > 100*time.Millisecond {
t.Fatal("undable to enforce timeout")
t.Fatal("unable to enforce timeout")
}
})
})
Expand Down
2 changes: 1 addition & 1 deletion internal/netxlite/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context,
func (d *quicDialerQUICGo) maybeApplyTLSDefaults(config *tls.Config, port int) *tls.Config {
config = config.Clone()
if config.RootCAs == nil {
config.RootCAs = defaultCertPool
config.RootCAs = NewDefaultCertPool()
}
if len(config.NextProtos) <= 0 {
switch port {
Expand Down
8 changes: 4 additions & 4 deletions internal/netxlite/quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ func TestQUICDialerQUICGo(t *testing.T) {
if tlsConfig.RootCAs != nil {
t.Fatal("tlsConfig.RootCAs should not have been changed")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("invalid gotTLSConfig.RootCAs")
if gotTLSConfig.RootCAs == nil {
t.Fatal("gotTLSConfig.RootCAs should have been set")
}
if tlsConfig.NextProtos != nil {
t.Fatal("tlsConfig.NextProtos should not have been changed")
Expand Down Expand Up @@ -289,8 +289,8 @@ func TestQUICDialerQUICGo(t *testing.T) {
if tlsConfig.RootCAs != nil {
t.Fatal("tlsConfig.RootCAs should not have been changed")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("invalid gotTLSConfig.RootCAs")
if gotTLSConfig.RootCAs == nil {
t.Fatal("gotTLSConfig.RootCAs should have been set")
}
if tlsConfig.NextProtos != nil {
t.Fatal("tlsConfig.NextProtos should not have been changed")
Expand Down
6 changes: 1 addition & 5 deletions internal/netxlite/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,6 @@ type tlsHandshakerConfigurable struct {

var _ model.TLSHandshaker = &tlsHandshakerConfigurable{}

// defaultCertPool is the cert pool we use by default. We store this
// value into a private variable to enable for unit testing.
var defaultCertPool = NewDefaultCertPool()

// tlsMaybeConnectionState returns the connection state if error is nil
// and otherwise just returns an empty state to the caller.
func tlsMaybeConnectionState(conn TLSConn, err error) tls.ConnectionState {
Expand All @@ -213,7 +209,7 @@ func (h *tlsHandshakerConfigurable) Handshake(
conn.SetDeadline(time.Now().Add(timeout))
if config.RootCAs == nil {
config = config.Clone()
config.RootCAs = defaultCertPool
config.RootCAs = NewDefaultCertPool()
}
tlsconn, err := h.newConn(conn, config)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/netxlite/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
if config.RootCAs != nil {
t.Fatal("config.RootCAs should still be nil")
}
if gotTLSConfig.RootCAs != defaultCertPool {
if gotTLSConfig.RootCAs == nil {
t.Fatal("gotTLSConfig.RootCAs has not been correctly set")
}
})
Expand Down
57 changes: 46 additions & 11 deletions internal/netxlite/tproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package netxlite
import (
"context"
"crypto/x509"
"net"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"testing"
"time"

"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)

func TestDefaultTProxy(t *testing.T) {
Expand All @@ -36,16 +39,48 @@ func TestDefaultTProxy(t *testing.T) {
}

func TestWithCustomTProxy(t *testing.T) {
expected := x509.NewCertPool()
tproxy := &mocks.UnderlyingNetwork{
MockMaybeModifyPool: func(pool *x509.CertPool) *x509.CertPool {
runtimex.Assert(expected != pool, "got unexpected pool")
return expected
},
}
WithCustomTProxy(tproxy, func() {
if NewDefaultCertPool() != expected {
t.Fatal("unexpected pool")

t.Run("we can override the default cert pool", func(t *testing.T) {
srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(444)
}))
defer srvr.Close()

// TODO(bassosimone): we need a more compact and ergonomic
// way of overriding the underlying network
tproxy := &mocks.UnderlyingNetwork{
MockDialContext: func(ctx context.Context, timeout time.Duration, network string, address string) (net.Conn, error) {
return (&DefaultTProxy{}).DialContext(ctx, timeout, network, address)
},
MockListenUDP: func(network string, addr *net.UDPAddr) (model.UDPLikeConn, error) {
return (&DefaultTProxy{}).ListenUDP(network, addr)
},
MockGetaddrinfoLookupANY: func(ctx context.Context, domain string) ([]string, string, error) {
return (&DefaultTProxy{}).GetaddrinfoLookupANY(ctx, domain)
},
MockGetaddrinfoResolverNetwork: func() string {
return (&DefaultTProxy{}).GetaddrinfoResolverNetwork()
},
MockMaybeModifyPool: func(*x509.CertPool) *x509.CertPool {
pool := x509.NewCertPool()
pool.AddCert(srvr.Certificate())
return pool
},
}

WithCustomTProxy(tproxy, func() {
clnt := NewHTTPClientStdlib(model.DiscardLogger)
req, err := http.NewRequestWithContext(context.Background(), "GET", srvr.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := clnt.Do(req)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 444 {
t.Fatal("unexpected status code")
}
})
})
}

0 comments on commit f78b7fb

Please sign in to comment.