diff --git a/internal/enginenetx/httpsdialer_internal_test.go b/internal/enginenetx/httpsdialer_internal_test.go index 4077f9c0f1..e38742fc4e 100644 --- a/internal/enginenetx/httpsdialer_internal_test.go +++ b/internal/enginenetx/httpsdialer_internal_test.go @@ -1,13 +1,8 @@ package enginenetx import ( - "context" "crypto/tls" "errors" - "fmt" - "net" - "sync" - "sync/atomic" "testing" "github.com/ooni/probe-cli/v3/internal/mocks" @@ -15,45 +10,6 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -func TestHTTPSDialerTacticsEmitter(t *testing.T) { - t.Run("we correctly handle the case of a canceled context", func(t *testing.T) { - hd := &HTTPSDialer{ - idGenerator: &atomic.Int64{}, - logger: model.DiscardLogger, - netx: &netxlite.Netx{Underlying: nil}, // nil means: use netxlite's singleton - policy: &HTTPSDialerNullPolicy{}, - resolver: netxlite.NewStdlibResolver(model.DiscardLogger), - rootCAs: netxlite.NewMozillaCertPool(), - wg: &sync.WaitGroup{}, - } - - var tactics []*HTTPSDialerTactic - for idx := 0; idx < 255; idx++ { - tactics = append(tactics, &HTTPSDialerTactic{ - Endpoint: net.JoinHostPort(fmt.Sprintf("10.0.0.%d", idx), "443"), - InitialDelay: 0, - SNI: "www.example.com", - VerifyHostname: "www.example.com", - }) - } - - ctx, cancel := context.WithCancel(context.Background()) - cancel() // we want the tested function to run with a canceled context - - out := hd.tacticsEmitter(ctx, tactics...) - - for range out { - // Here we do nothing! - // - // Ideally, we would like to count and assert that we have - // got no tactic from the channel but the selection of ready - // channels is nondeterministic, so we cannot really be - // asserting that. This leaves us with asking the question - // of what we should be asserting here? - } - }) -} - func TestHTTPSDialerVerifyCertificateChain(t *testing.T) { t.Run("without any peer certificate", func(t *testing.T) { tlsConn := &mocks.TLSConn{ @@ -81,3 +37,81 @@ func TestHTTPSDialerVerifyCertificateChain(t *testing.T) { } }) } + +func TestHTTPSDialerReduceResult(t *testing.T) { + t.Run("we return the first conn in a list of conns and close the other conns", func(t *testing.T) { + var closed int + expect := &mocks.TLSConn{} // empty + connv := []model.TLSConn{ + expect, + &mocks.TLSConn{ + Conn: mocks.Conn{ + MockClose: func() error { + closed++ + return nil + }, + }, + }, + &mocks.TLSConn{ + Conn: mocks.Conn{ + MockClose: func() error { + closed++ + return nil + }, + }, + }, + } + + conn, err := httpsDialerReduceResult(connv, nil) + if err != nil { + t.Fatal(err) + } + + if conn != expect { + t.Fatal("unexpected conn") + } + + if closed != 2 { + t.Fatal("did not call close") + } + }) + + t.Run("we join together a list of errors", func(t *testing.T) { + expectErr := "connection_refused\ninterrupted" + errorv := []error{errors.New("connection_refused"), errors.New("interrupted")} + + conn, err := httpsDialerReduceResult(nil, errorv) + if err == nil || err.Error() != expectErr { + t.Fatal("unexpected err", err) + } + + if conn != nil { + t.Fatal("expected nil conn") + } + }) + + t.Run("with a single error we return such an error", func(t *testing.T) { + expected := errors.New("connection_refused") + errorv := []error{expected} + + conn, err := httpsDialerReduceResult(nil, errorv) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + + if conn != nil { + t.Fatal("expected nil conn") + } + }) + + t.Run("we return errDNSNoAnswer if we don't have any conns or errors to return", func(t *testing.T) { + conn, err := httpsDialerReduceResult(nil, nil) + if !errors.Is(err, errDNSNoAnswer) { + t.Fatal("unexpected error", err) + } + + if conn != nil { + t.Fatal("expected nil conn") + } + }) +} diff --git a/internal/enginenetx/httpsdialer_test.go b/internal/enginenetx/httpsdialer_test.go index 56f0fb3225..cd367c8d96 100644 --- a/internal/enginenetx/httpsdialer_test.go +++ b/internal/enginenetx/httpsdialer_test.go @@ -74,9 +74,6 @@ func TestHTTPSDialerNetemQA(t *testing.T) { // short indicates whether this is a short test short bool - // policy is the dialer policy - policy enginenetx.HTTPSDialerPolicy - // stats is the stats tracker to use. stats enginenetx.HTTPSDialerStatsTracker @@ -101,7 +98,6 @@ func TestHTTPSDialerNetemQA(t *testing.T) { { name: "net.SplitHostPort failure", short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, stats: &enginenetx.HTTPSDialerNullStatsTracker{}, endpoint: "www.example.com", // note: here the port is missing scenario: netemx.InternetScenario, @@ -112,17 +108,19 @@ func TestHTTPSDialerNetemQA(t *testing.T) { }, // This test case ensures that we handle the case of a nonexistent domain + // where we get a dns_no_answer error. The original DNS error is lost in + // background goroutines and what we report to the caller is just that there + // is no available IP address and tactic to attempt using. { name: "hd.policy.LookupTactics failure", short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, stats: &enginenetx.HTTPSDialerNullStatsTracker{}, endpoint: "www.example.nonexistent:443", // note: the domain does not exist scenario: netemx.InternetScenario, configureDPI: func(dpi *netem.DPIEngine) { // nothing }, - expectErr: "dns_nxdomain_error", + expectErr: "dns_no_answer", }, // This test case is the common case: all is good with multiple addresses to dial (I am @@ -130,7 +128,6 @@ func TestHTTPSDialerNetemQA(t *testing.T) { { name: "successful dial with multiple addresses", short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, stats: &enginenetx.HTTPSDialerNullStatsTracker{}, endpoint: "www.example.com:443", scenario: []*netemx.ScenarioDomainAddresses{{ @@ -157,7 +154,6 @@ func TestHTTPSDialerNetemQA(t *testing.T) { { name: "with TCP connect errors", short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, stats: &enginenetx.HTTPSDialerNullStatsTracker{}, endpoint: "www.example.com:443", scenario: []*netemx.ScenarioDomainAddresses{{ @@ -192,7 +188,6 @@ func TestHTTPSDialerNetemQA(t *testing.T) { { name: "with TLS handshake errors", short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, stats: &enginenetx.HTTPSDialerNullStatsTracker{}, endpoint: "www.example.com:443", scenario: []*netemx.ScenarioDomainAddresses{{ @@ -223,7 +218,6 @@ func TestHTTPSDialerNetemQA(t *testing.T) { { name: "with a TLS certificate valid for ANOTHER domain", short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, stats: &enginenetx.HTTPSDialerNullStatsTracker{}, endpoint: "wrong.host.badssl.com:443", scenario: []*netemx.ScenarioDomainAddresses{{ @@ -249,7 +243,6 @@ func TestHTTPSDialerNetemQA(t *testing.T) { { name: "with TLS certificate signed by an unknown authority", short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, stats: &enginenetx.HTTPSDialerNullStatsTracker{}, endpoint: "untrusted-root.badssl.com:443", scenario: []*netemx.ScenarioDomainAddresses{{ @@ -275,7 +268,6 @@ func TestHTTPSDialerNetemQA(t *testing.T) { { name: "with expired TLS certificate", short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, stats: &enginenetx.HTTPSDialerNullStatsTracker{}, endpoint: "expired.badssl.com:443", scenario: []*netemx.ScenarioDomainAddresses{{ @@ -299,9 +291,8 @@ func TestHTTPSDialerNetemQA(t *testing.T) { // This is a corner case: what if the context is canceled after the DNS lookup // but before we start dialing? Are we closing all goroutines and returning correctly? { - name: "with context being canceled in OnStarting", - short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, + name: "with context being canceled in OnStarting", + short: true, stats: &httpsDialerCancelingContextStatsTracker{ cancel: nil, flags: httpsDialerCancelingContextStatsTrackerOnStarting, @@ -322,15 +313,15 @@ func TestHTTPSDialerNetemQA(t *testing.T) { configureDPI: func(dpi *netem.DPIEngine) { // nothing }, - expectErr: "context canceled", + expectErr: "interrupted\ninterrupted", }, - // This is another corner case: what happens if the context is canceled after we - // have a good connection but before we're able to report it to the caller? + // This is another corner case: what happens if the context is canceled + // right after we eastablish a connection? Because of how the current code + // is written, the easiest thing to do is to just return the conn. { - name: "with context being canceled in OnSuccess for the first success", - short: true, - policy: &enginenetx.HTTPSDialerNullPolicy{}, + name: "with context being canceled in OnSuccess for the first success", + short: true, stats: &httpsDialerCancelingContextStatsTracker{ cancel: nil, flags: httpsDialerCancelingContextStatsTrackerOnSuccess, @@ -351,7 +342,7 @@ func TestHTTPSDialerNetemQA(t *testing.T) { configureDPI: func(dpi *netem.DPIEngine) { // nothing }, - expectErr: "context canceled", + expectErr: "", }} for _, tc := range allTestCases { @@ -382,12 +373,16 @@ func TestHTTPSDialerNetemQA(t *testing.T) { // create the getaddrinfo resolver resolver := netx.NewStdlibResolver(log.Log) + policy := &enginenetx.HTTPSDialerNullPolicy{ + Logger: log.Log, + Resolver: resolver, + } + // create the TLS dialer dialer := enginenetx.NewHTTPSDialer( log.Log, netx, - tc.policy, - resolver, + policy, tc.stats, ) defer dialer.CloseIdleConnections() @@ -428,9 +423,6 @@ func TestHTTPSDialerNetemQA(t *testing.T) { if tlsConn != nil { defer tlsConn.Close() } - - // wait for background connections to join - dialer.WaitGroup().Wait() }() // now verify that we have closed all the connections @@ -510,8 +502,10 @@ func TestHTTPSDialerHostNetworkQA(t *testing.T) { MockGetaddrinfoLookupANY: tproxy.GetaddrinfoLookupANY, MockGetaddrinfoResolverNetwork: tproxy.GetaddrinfoResolverNetwork, }}, - &enginenetx.HTTPSDialerNullPolicy{}, - resolver, + &enginenetx.HTTPSDialerNullPolicy{ + Logger: log.Log, + Resolver: resolver, + }, &enginenetx.HTTPSDialerNullStatsTracker{}, ) diff --git a/internal/enginenetx/httpsdialercore.go b/internal/enginenetx/httpsdialercore.go index 84a9a0bd64..c8f92dc627 100644 --- a/internal/enginenetx/httpsdialercore.go +++ b/internal/enginenetx/httpsdialercore.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "net" - "sync" "sync/atomic" "time" @@ -70,16 +69,8 @@ func (dt *HTTPSDialerTactic) Summary() string { // HTTPSDialerPolicy describes the policy used by the [*HTTPSDialer]. type HTTPSDialerPolicy interface { - // LookupTactics performs a DNS lookup for the given domain using the given resolver and - // returns either a list of tactics for dialing or an error. - // - // This function MUST NOT return an empty list and a nil error. If this happens the - // code inside [HTTPSDialer] will PANIC. - LookupTactics(ctx context.Context, domain, port string, reso model.Resolver) ([]*HTTPSDialerTactic, error) - - // Parallelism returns the number of goroutines to create when TLS dialing. The - // [HTTPSDialer] will PANIC if the returned number is less than 1. - Parallelism() int + // LookupTactics returns zero or more tactics for the given host and port. + LookupTactics(ctx context.Context, domain, port string) <-chan *HTTPSDialerTactic } // HTTPSDialerStatsTracker tracks what happens while dialing TLS connections. @@ -119,18 +110,11 @@ type HTTPSDialer struct { // policy defines the dialing policy to use. policy HTTPSDialerPolicy - // resolver is the DNS resolver to use. - resolver model.Resolver - // rootCAs contains the root certificate pool we should use. rootCAs *x509.CertPool // stats tracks what happens while dialing. stats HTTPSDialerStatsTracker - - // wg is the wait group for knowing when all goroutines - // started in the background joined (for testing). - wg *sync.WaitGroup } // NewHTTPSDialer constructs a new [*HTTPSDialer] instance. @@ -143,8 +127,6 @@ type HTTPSDialer struct { // // - policy defines the dialer policy; // -// - resolver is the resolver to use; -// // - stats tracks what happens while we're dialing. // // The returned [*HTTPSDialer] would use the underlying network's @@ -153,7 +135,6 @@ func NewHTTPSDialer( logger model.Logger, netx *netxlite.Netx, policy HTTPSDialerPolicy, - resolver model.Resolver, stats HTTPSDialerStatsTracker, ) *HTTPSDialer { return &HTTPSDialer{ @@ -162,26 +143,18 @@ func NewHTTPSDialer( Prefix: "HTTPSDialer: ", Logger: logger, }, - netx: netx, - policy: policy, - resolver: resolver, - rootCAs: netx.MaybeCustomUnderlyingNetwork().Get().DefaultCertPool(), - stats: stats, - wg: &sync.WaitGroup{}, + netx: netx, + policy: policy, + rootCAs: netx.MaybeCustomUnderlyingNetwork().Get().DefaultCertPool(), + stats: stats, } } var _ model.TLSDialer = &HTTPSDialer{} -// WaitGroup returns the [*sync.WaitGroup] tracking the number of background goroutines, -// which is definitely useful in testing to make sure we join all the goroutines. -func (hd *HTTPSDialer) WaitGroup() *sync.WaitGroup { - return hd.wg -} - // CloseIdleConnections implements model.TLSDialer. func (hd *HTTPSDialer) CloseIdleConnections() { - hd.resolver.CloseIdleConnections() + // nothing } // httpsDialerErrorOrConn contains either an error or a valid conn. @@ -193,6 +166,13 @@ type httpsDialerErrorOrConn struct { Err error } +// errDNSNoAnswer is the error returned when we have no tactic to try +var errDNSNoAnswer = netxlite.NewErrWrapper( + netxlite.ClassifyResolverError, + netxlite.DNSRoundTripOperation, + netxlite.ErrOODNSNoAnswer, +) + // DialTLSContext implements model.TLSDialer. func (hd *HTTPSDialer) DialTLSContext(ctx context.Context, network string, endpoint string) (net.Conn, error) { hostname, port, err := net.SplitHostPort(endpoint) @@ -205,123 +185,77 @@ func (hd *HTTPSDialer) DialTLSContext(ctx context.Context, network string, endpo ctx, cancel := context.WithCancel(ctx) defer cancel() - // See https://github.com/ooni/probe-cli/pull/1295#issuecomment-1731243994 for context - // on why here we MUST make sure we short-circuit IP addresses. - resoWithShortCircuit := &netxlite.ResolverShortCircuitIPAddr{Resolver: hd.resolver} - - logger := &logx.PrefixLogger{ - Prefix: fmt.Sprintf("[#%d] ", hd.idGenerator.Add(1)), - Logger: hd.logger, - } - ol := logx.NewOperationLogger(logger, "LookupTactics: %s", net.JoinHostPort(hostname, port)) - tactics, err := hd.policy.LookupTactics(ctx, hostname, port, resoWithShortCircuit) - if err != nil { - ol.Stop(err) - return nil, err - } - ol.Stop(tactics) - runtimex.Assert(len(tactics) >= 1, "expected at least one tactic here") - - emitter := hd.tacticsEmitter(ctx, tactics...) + // The emitter will emit tactics and then close the channel when done. We spawn 1+ workers + // that handle tactics in paralellel and posts on the collector channel. + emitter := hd.policy.LookupTactics(ctx, hostname, port) collector := make(chan *httpsDialerErrorOrConn) - - parallelism := hd.policy.Parallelism() - runtimex.Assert(parallelism >= 1, "expected parallelism to be >= 1") + joiner := make(chan any) + const parallelism = 16 for idx := 0; idx < parallelism; idx++ { - hd.wg.Add(1) - go func() { - defer hd.wg.Done() - hd.worker(ctx, hostname, emitter, collector) - }() + go hd.worker(ctx, joiner, emitter, collector) } + // wait until all goroutines have joined var ( - numDials = len(tactics) - errorv = []error{} + connv = []model.TLSConn{} + errorv = []error{} + numJoined = 0 ) - for idx := 0; idx < numDials; idx++ { + for numJoined < parallelism { select { - case <-ctx.Done(): - return nil, ctx.Err() + case <-joiner: + numJoined++ case result := <-collector: + // If the goroutine failed, record the error and continue processing results if result.Err != nil { errorv = append(errorv, result.Err) continue } - // Returning early cancels the context and this cancellation - // causes other background goroutines to interrupt their long - // running network operations or unblocks them while sending - return result.Conn, nil + // Save the conn and tell goroutines to stop ASAP + connv = append(connv, result.Conn) + cancel() } } - return nil, errors.Join(errorv...) + return httpsDialerReduceResult(connv, errorv) } -// tacticsEmitter returns a channel closed once we have emitted all the tactics or the context is done. -func (hd *HTTPSDialer) tacticsEmitter(ctx context.Context, tactics ...*HTTPSDialerTactic) <-chan *HTTPSDialerTactic { - out := make(chan *HTTPSDialerTactic) - - hd.wg.Add(1) - go func() { - defer hd.wg.Done() - defer close(out) - - for _, tactic := range tactics { - select { - case out <- tactic: - continue - - case <-ctx.Done(): - return - } +// httpsDialerReduceResult returns either an established conn or an error, using [errDNSNoAnswer] in +// case the list of connections and the list of errors are empty. +func httpsDialerReduceResult(connv []model.TLSConn, errorv []error) (model.TLSConn, error) { + switch { + case len(connv) >= 1: + for _, c := range connv[1:] { + c.Close() } - }() + return connv[0], nil + + case len(errorv) >= 1: + return nil, errors.Join(errorv...) - return out + default: + return nil, errDNSNoAnswer + } } // worker attempts to establish a TLS connection using and emits a single // [*httpsDialerErrorOrConn] for each tactic. -func (hd *HTTPSDialer) worker( - ctx context.Context, - hostname string, - reader <-chan *HTTPSDialerTactic, - writer chan<- *httpsDialerErrorOrConn, -) { - // Note: no need to be concerned with the wait group here because - // we're managing it inside DialTLSContext so Add and Done live together - - for { - select { - case tactic, good := <-reader: - if !good { - // This happens when the emitter goroutine has closed the channel - return - } - - logger := &logx.PrefixLogger{ - Prefix: fmt.Sprintf("[#%d] ", hd.idGenerator.Add(1)), - Logger: hd.logger, - } - conn, err := hd.dialTLS(ctx, logger, tactic) - - select { - case <-ctx.Done(): - if conn != nil { - conn.Close() // we own the connection - } - return +func (hd *HTTPSDialer) worker(ctx context.Context, joiner chan<- any, + reader <-chan *HTTPSDialerTactic, writer chan<- *httpsDialerErrorOrConn) { + // let the parent know that we terminated + defer func() { joiner <- true }() + + for tactic := range reader { + prefixLogger := &logx.PrefixLogger{ + Prefix: fmt.Sprintf("[#%d] ", hd.idGenerator.Add(1)), + Logger: hd.logger, + } - case writer <- &httpsDialerErrorOrConn{Conn: conn, Err: err}: - continue - } + conn, err := hd.dialTLS(ctx, prefixLogger, tactic) - case <-ctx.Done(): - return - } + writer <- &httpsDialerErrorOrConn{Conn: conn, Err: err} } } @@ -410,7 +344,7 @@ func httpsDialerTacticWaitReady(ctx context.Context, tactic *HTTPSDialerTactic) return nil case <-ctx.Done(): - return ctx.Err() + return netxlite.NewTopLevelGenericErrWrapper(ctx.Err()) } } diff --git a/internal/enginenetx/httpsdialernull.go b/internal/enginenetx/httpsdialernull.go index 93022d83f1..fdaed6dffa 100644 --- a/internal/enginenetx/httpsdialernull.go +++ b/internal/enginenetx/httpsdialernull.go @@ -6,19 +6,26 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) -// HTTPSDialerNullPolicy is the default "null" policy where we use the default -// resolver provided to LookupTactics and we use the correct SNI. +// HTTPSDialerNullPolicy is the default "null" policy where we use the +// given resolver and the domain as the SNI. +// +// The zero value is invalid; please, init all MANDATORY fields. // // We say that this is the "null" policy because this is what you would get // by default if you were not using any policy. // // This policy uses an Happy-Eyeballs-like algorithm. Dial attempts are -// staggered by 300 milliseconds and up to sixteen dial attempts could be -// active at the same time. Further dials will run once one of the -// sixteen active concurrent dials have failed to connect. -type HTTPSDialerNullPolicy struct{} +// staggered by httpsDialerHappyEyeballsDelay. +type HTTPSDialerNullPolicy struct { + // Logger is the MANDATORY logger. + Logger model.Logger + + // Resolver is the MANDATORY resolver. + Resolver model.Resolver +} var _ HTTPSDialerPolicy = &HTTPSDialerNullPolicy{} @@ -29,29 +36,36 @@ var _ HTTPSDialerPolicy = &HTTPSDialerNullPolicy{} const httpsDialerHappyEyeballsDelay = 900 * time.Millisecond // LookupTactics implements HTTPSDialerPolicy. -func (*HTTPSDialerNullPolicy) LookupTactics( - ctx context.Context, domain, port string, reso model.Resolver) ([]*HTTPSDialerTactic, error) { - addrs, err := reso.LookupHost(ctx, domain) - if err != nil { - return nil, err - } - - var tactics []*HTTPSDialerTactic - for idx, addr := range addrs { - tactics = append(tactics, &HTTPSDialerTactic{ - Endpoint: net.JoinHostPort(addr, port), - InitialDelay: happyEyeballsDelay(httpsDialerHappyEyeballsDelay, idx), - SNI: domain, - VerifyHostname: domain, - }) - } - - return tactics, nil -} - -// Parallelism implements HTTPSDialerPolicy. -func (*HTTPSDialerNullPolicy) Parallelism() int { - return 16 +func (p *HTTPSDialerNullPolicy) LookupTactics( + ctx context.Context, domain, port string) <-chan *HTTPSDialerTactic { + out := make(chan *HTTPSDialerTactic) + + go func() { + // make sure we close the output channel when done + defer close(out) + + // See https://github.com/ooni/probe-cli/pull/1295#issuecomment-1731243994 for context + // on why here we MUST make sure we short-circuit IP addresses. + resoWithShortCircuit := &netxlite.ResolverShortCircuitIPAddr{Resolver: p.Resolver} + + addrs, err := resoWithShortCircuit.LookupHost(ctx, domain) + if err != nil { + p.Logger.Warnf("resoWithShortCircuit.LookupHost: %s", err.Error()) + return + } + + for idx, addr := range addrs { + tactic := &HTTPSDialerTactic{ + Endpoint: net.JoinHostPort(addr, port), + InitialDelay: happyEyeballsDelay(httpsDialerHappyEyeballsDelay, idx), + SNI: domain, + VerifyHostname: domain, + } + out <- tactic + } + }() + + return out } // HTTPSDialerNullStatsTracker is the "null" [HTTPSDialerStatsTracker]. diff --git a/internal/enginenetx/httpsdialerstatic.go b/internal/enginenetx/httpsdialerstatic.go index 4b634c4a1e..644c4f5ebe 100644 --- a/internal/enginenetx/httpsdialerstatic.go +++ b/internal/enginenetx/httpsdialerstatic.go @@ -80,15 +80,18 @@ var _ HTTPSDialerPolicy = &HTTPSDialerStaticPolicy{} // LookupTactics implements HTTPSDialerPolicy. func (ldp *HTTPSDialerStaticPolicy) LookupTactics( - ctx context.Context, domain string, port string, reso model.Resolver) ([]*HTTPSDialerTactic, error) { + ctx context.Context, domain string, port string) <-chan *HTTPSDialerTactic { tactics, found := ldp.Root.Domains[domain] if !found { - return ldp.Fallback.LookupTactics(ctx, domain, port, reso) + return ldp.Fallback.LookupTactics(ctx, domain, port) } - return tactics, nil -} -// Parallelism implements HTTPSDialerPolicy. -func (ldp *HTTPSDialerStaticPolicy) Parallelism() int { - return 16 + out := make(chan *HTTPSDialerTactic) + go func() { + defer close(out) + for _, tactic := range tactics { + out <- tactic + } + }() + return out } diff --git a/internal/enginenetx/httpsdialerstatic_test.go b/internal/enginenetx/httpsdialerstatic_test.go index db742584bf..2fc14ef38e 100644 --- a/internal/enginenetx/httpsdialerstatic_test.go +++ b/internal/enginenetx/httpsdialerstatic_test.go @@ -3,10 +3,10 @@ package enginenetx import ( "context" "encoding/json" - "errors" "testing" "time" + "github.com/apex/log" "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/kvstore" "github.com/ooni/probe-cli/v3/internal/mocks" @@ -163,77 +163,80 @@ func TestHTTPSDialerStaticPolicy(t *testing.T) { }) t.Run("LookupTactics", func(t *testing.T) { - t.Run("we can lookup a static tactic", func(t *testing.T) { - expect := []*HTTPSDialerTactic{ - { - Endpoint: "162.55.247.208:443", - InitialDelay: 0, - SNI: "www.example.com", - VerifyHostname: "api.ooni.io", - }, - { - Endpoint: "162.55.247.208:443", - InitialDelay: 0, - SNI: "www.example.org", - VerifyHostname: "api.ooni.io", - }, - } - - p := &HTTPSDialerStaticPolicy{ - Fallback: nil, // explicitly nil such that there is a panic if we access it - Root: &HTTPSDialerStaticPolicyRoot{ - Domains: map[string][]*HTTPSDialerTactic{ - "api.ooni.io": expect, - }, - Version: HTTPSDialerStaticPolicyVersion, - }, - } + expectedTactic := &HTTPSDialerTactic{ + Endpoint: "162.55.247.208:443", + InitialDelay: 0, + SNI: "www.example.com", + VerifyHostname: "api.ooni.io", + } + staticPolicyRoot := &HTTPSDialerStaticPolicyRoot{ + Domains: map[string][]*HTTPSDialerTactic{ + "api.ooni.io": {expectedTactic}, + }, + Version: HTTPSDialerStaticPolicyVersion, + } + kvStore := &kvstore.Memory{} + rawStaticPolicyRoot := runtimex.Try1(json.Marshal(staticPolicyRoot)) + if err := kvStore.Set(HTTPSDialerStaticPolicyKey, rawStaticPolicyRoot); err != nil { + t.Fatal(err) + } + t.Run("with static policy", func(t *testing.T) { ctx := context.Background() - resolver := &mocks.Resolver{} // empty to cause panic if any method is invoked - got, err := p.LookupTactics(ctx, "api.ooni.io", "443", resolver) + + policy, err := NewHTTPSDialerStaticPolicy(kvStore, nil /* explictly to crash if used */) if err != nil { t.Fatal(err) } + tactics := policy.LookupTactics(ctx, "api.ooni.io", "443") + got := []*HTTPSDialerTactic{} + for tactic := range tactics { + t.Logf("%+v", tactic) + got = append(got, tactic) + } + + expect := []*HTTPSDialerTactic{expectedTactic} + if diff := cmp.Diff(expect, got); diff != "" { t.Fatal(diff) } }) t.Run("we fallback if needed", func(t *testing.T) { - expect := errors.New("mocked error") + ctx := context.Background() - resolver := &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, expect + fallback := &HTTPSDialerNullPolicy{ + Logger: log.Log, + Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"93.184.216.34"}, nil + }, }, } - p := &HTTPSDialerStaticPolicy{ - Fallback: &HTTPSDialerNullPolicy{}, - Root: &HTTPSDialerStaticPolicyRoot{ - Domains: nil, // empty so we fallback for all domains - Version: HTTPSDialerStaticPolicyVersion, - }, + policy, err := NewHTTPSDialerStaticPolicy(kvStore, fallback) + if err != nil { + t.Fatal(err) } - ctx := context.Background() - tactics, err := p.LookupTactics(ctx, "api.ooni.io", "443", resolver) - if !errors.Is(err, expect) { - t.Fatal("unexpected error", err) + tactics := policy.LookupTactics(ctx, "www.example.com", "443") + got := []*HTTPSDialerTactic{} + for tactic := range tactics { + t.Logf("%+v", tactic) + got = append(got, tactic) } - if len(tactics) != 0 { - t.Fatal("expected no tactics here") + expect := []*HTTPSDialerTactic{{ + Endpoint: "93.184.216.34:443", + InitialDelay: 0, + SNI: "www.example.com", + VerifyHostname: "www.example.com", + }} + + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) } }) }) - - t.Run("Parallelism", func(t *testing.T) { - p := &HTTPSDialerStaticPolicy{ /* empty */ } - if p.Parallelism() != 16 { - t.Fatal("unexpected parallelism") - } - }) } diff --git a/internal/enginenetx/network.go b/internal/enginenetx/network.go index e6b23a4290..4da83c6a20 100644 --- a/internal/enginenetx/network.go +++ b/internal/enginenetx/network.go @@ -14,6 +14,7 @@ import ( // Network is the network abstraction used by the OONI engine. type Network struct { + reso model.Resolver stats *HTTPSDialerStatsManager txp model.HTTPTransport } @@ -42,6 +43,9 @@ func (n *Network) Close() error { // make sure we close the transport's idle connections n.txp.CloseIdleConnections() + // same as above but for the resolver's connections + n.reso.CloseIdleConnections() + // make sure we sync stats to disk return n.stats.Close() } @@ -87,8 +91,7 @@ func NewNetwork( httpsDialer := NewHTTPSDialer( logger, &netxlite.Netx{Underlying: nil}, // nil means using netxlite's singleton - newHTTPSDialerPolicy(kvStore), - resolver, + newHTTPSDialerPolicy(kvStore, logger, resolver), stats, ) @@ -123,6 +126,7 @@ func NewNetwork( txp = bytecounter.WrapHTTPTransport(txp, counter) netx := &Network{ + reso: resolver, stats: stats, txp: txp, } @@ -130,9 +134,9 @@ func NewNetwork( } // newHTTPSDialerPolicy contains the logic to select the [HTTPSDialerPolicy] to use. -func newHTTPSDialerPolicy(kvStore model.KeyValueStore) HTTPSDialerPolicy { +func newHTTPSDialerPolicy(kvStore model.KeyValueStore, logger model.Logger, resolver model.Resolver) HTTPSDialerPolicy { // the fallback policy we're using is the "null" policy - fallback := &HTTPSDialerNullPolicy{} + fallback := &HTTPSDialerNullPolicy{logger, resolver} // make sure we honor a user-provided policy policy, err := NewHTTPSDialerStaticPolicy(kvStore, fallback) diff --git a/internal/enginenetx/network_internal_test.go b/internal/enginenetx/network_internal_test.go index b21d8e75be..3e667ff42a 100644 --- a/internal/enginenetx/network_internal_test.go +++ b/internal/enginenetx/network_internal_test.go @@ -27,6 +27,11 @@ func TestNetworkUnit(t *testing.T) { }, } netx := &Network{ + reso: &mocks.Resolver{ + MockCloseIdleConnections: func() { + // nothing + }, + }, stats: &HTTPSDialerStatsManager{ TimeNow: time.Now, kvStore: &kvstore.Memory{}, @@ -43,4 +48,34 @@ func TestNetworkUnit(t *testing.T) { t.Fatal("did not call the transport's CloseIdleConnections") } }) + + t.Run("Close calls the resolvers's CloseIdleConnections method", func(t *testing.T) { + var called bool + expected := &mocks.Resolver{ + MockCloseIdleConnections: func() { + called = true + }, + } + netx := &Network{ + reso: expected, + stats: &HTTPSDialerStatsManager{ + TimeNow: time.Now, + kvStore: &kvstore.Memory{}, + logger: model.DiscardLogger, + mu: sync.Mutex{}, + root: &HTTPSDialerStatsRootContainer{}, + }, + txp: &mocks.HTTPTransport{ + MockCloseIdleConnections: func() { + // nothing + }, + }, + } + if err := netx.Close(); err != nil { + t.Fatal(err) + } + if !called { + t.Fatal("did not call the transport's CloseIdleConnections") + } + }) }