diff --git a/experiment/sniblocking/sniblocking.go b/experiment/sniblocking/sniblocking.go index 9a1d3b5d..34390909 100644 --- a/experiment/sniblocking/sniblocking.go +++ b/experiment/sniblocking/sniblocking.go @@ -10,6 +10,7 @@ import ( "math/rand" "net" "net/url" + "sync" "time" "github.com/ooni/probe-engine/internal/netxlogger" @@ -21,7 +22,7 @@ import ( const ( testName = "sni_blocking" - testVersion = "0.0.2" + testVersion = "0.0.3" ) // Config contains the experiment config. @@ -38,6 +39,7 @@ type Config struct { type Subresult struct { BytesReceived int64 `json:"-"` BytesSent int64 `json:"-"` + Cached bool `json:"-"` Failure *string `json:"failure"` NetworkEvents oonidatamodel.NetworkEventsList `json:"network_events"` Queries oonidatamodel.DNSQueriesList `json:"queries"` @@ -55,7 +57,9 @@ type TestKeys struct { } type measurer struct { + cache map[string]Subresult config Config + mu sync.Mutex } func (m *measurer) ExperimentName() string { @@ -66,14 +70,13 @@ func (m *measurer) ExperimentVersion() string { return testVersion } -func measureone( +func (m *measurer) measureone( ctx context.Context, - output chan<- Subresult, handler modelx.Handler, beginning time.Time, sni string, thaddr string, -) { +) Subresult { // slightly delay the measurement gen := rand.New(rand.NewSource(time.Now().UnixNano())) sleeptime := time.Duration(gen.Intn(250)) * time.Millisecond @@ -81,11 +84,10 @@ func measureone( case <-time.After(sleeptime): case <-ctx.Done(): s := "generic_timeout_error" - output <- Subresult{ + return Subresult{ Failure: &s, SNI: sni, } - return } // perform the measurement result := oonitemplates.TLSConnect(ctx, oonitemplates.TLSConnectConfig{ @@ -110,7 +112,33 @@ func measureone( s := result.Error.Error() smk.Failure = &s } + return smk +} + +func (m *measurer) measureonewithcache( + ctx context.Context, + output chan<- Subresult, + handler modelx.Handler, + beginning time.Time, + sni string, + thaddr string, +) { + cachekey := sni + thaddr + m.mu.Lock() + smk, okay := m.cache[cachekey] + m.mu.Unlock() + if okay { + output <- smk + return + } + smk = m.measureone(ctx, handler, beginning, sni, thaddr) output <- smk + smk.BytesReceived = 0 // don't count them more than once + smk.BytesSent = 0 // ditto + smk.Cached = true + m.mu.Lock() + m.cache[cachekey] = smk + m.mu.Unlock() } func (m *measurer) startall( @@ -119,7 +147,7 @@ func (m *measurer) startall( ) <-chan Subresult { outputs := make(chan Subresult, len(inputs)) for _, input := range inputs { - go measureone( + go m.measureonewithcache( ctx, outputs, netxlogger.NewHandler(sess.Logger()), measurement.MeasurementStartTimeSaved, input, m.config.TestHelperAddress, @@ -153,7 +181,9 @@ func processall( sentBytes += smk.BytesSent receivedBytes += smk.BytesReceived current++ - sess.Logger().Infof("sni_blocking: %s: %s", smk.SNI, asString(smk.Failure)) + sess.Logger().Infof( + "sni_blocking: %s: %s [cached: %+v]", smk.SNI, + asString(smk.Failure), smk.Cached) if current >= len(inputs) { break } @@ -184,6 +214,11 @@ func (m *measurer) Run( measurement *model.Measurement, callbacks model.ExperimentCallbacks, ) error { + m.mu.Lock() + if m.cache == nil { + m.cache = make(map[string]Subresult) + } + m.mu.Unlock() if m.config.ControlSNI == "" { return errors.New("Experiment requires ControlSNI") } @@ -200,7 +235,10 @@ func (m *measurer) Run( return err } measurement.Input = maybeParsed - inputs := []string{m.config.ControlSNI, measurement.Input} + inputs := []string{m.config.ControlSNI} + if measurement.Input != m.config.ControlSNI { + inputs = append(inputs, measurement.Input) + } ctx, cancel := context.WithTimeout(ctx, 10*time.Second*time.Duration(len(inputs))) defer cancel() outputs := m.startall(ctx, sess, measurement, inputs) diff --git a/experiment/sniblocking/sniblocking_test.go b/experiment/sniblocking/sniblocking_test.go index 2c36049d..1ff74138 100644 --- a/experiment/sniblocking/sniblocking_test.go +++ b/experiment/sniblocking/sniblocking_test.go @@ -22,7 +22,7 @@ func TestUnitNewExperimentMeasurer(t *testing.T) { if measurer.ExperimentName() != "sni_blocking" { t.Fatal("unexpected name") } - if measurer.ExperimentVersion() != "0.0.2" { + if measurer.ExperimentVersion() != "0.0.3" { t.Fatal("unexpected version") } } @@ -42,7 +42,7 @@ func TestUnitMeasurerMeasureNoControlSNI(t *testing.T) { func TestUnitMeasurerMeasureNoMeasurementInput(t *testing.T) { measurer := NewExperimentMeasurer(Config{ - ControlSNI: "ps.ooni.io", + ControlSNI: "example.com", }) err := measurer.Run( context.Background(), @@ -59,7 +59,7 @@ func TestUnitMeasurerMeasureWithInvalidInput(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately cancel the context measurer := NewExperimentMeasurer(Config{ - ControlSNI: "ps.ooni.io", + ControlSNI: "example.com", }) measurement := &model.Measurement{ Input: "\t", @@ -79,7 +79,7 @@ func TestUnitMeasurerMeasureWithCancelledContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately cancel the context measurer := NewExperimentMeasurer(Config{ - ControlSNI: "ps.ooni.io", + ControlSNI: "example.com", }) measurement := &model.Measurement{ Input: "kernel.org", @@ -98,56 +98,85 @@ func TestUnitMeasurerMeasureWithCancelledContext(t *testing.T) { func TestUnitMeasureoneCancelledContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately cancel the context - outputs := make(chan Subresult, 1) - measureone( + result := new(measurer).measureone( ctx, - outputs, netxlogger.NewHandler(log.Log), time.Now(), "kernel.org", - "ps.ooni.io:443", + "example.com:443", ) - for result := range outputs { - if *result.Failure != "generic_timeout_error" { - t.Fatal("unexpected failure") - } - if result.SNI != "kernel.org" { - t.Fatal("unexpected SNI") - } - if result.BytesReceived != 0 { - t.Fatal("expected to receive bytes") - } - if result.BytesSent != 0 { - t.Fatal("expected to send bytes") - } - break + if *result.Failure != "generic_timeout_error" { + t.Fatal("unexpected failure") + } + if result.SNI != "kernel.org" { + t.Fatal("unexpected SNI") + } + if result.BytesReceived != 0 { + t.Fatal("expected to receive bytes") + } + if result.BytesSent != 0 { + t.Fatal("expected to send bytes") } } func TestUnitMeasureoneSuccess(t *testing.T) { - outputs := make(chan Subresult, 1) - measureone( + result := new(measurer).measureone( context.Background(), - outputs, netxlogger.NewHandler(log.Log), time.Now(), "kernel.org", - "ps.ooni.io:443", + "example.com:443", ) - for result := range outputs { + if *result.Failure != "ssl_invalid_hostname" { + t.Fatal("unexpected failure") + } + if result.SNI != "kernel.org" { + t.Fatal("unexpected SNI") + } + if result.BytesReceived <= 0 { + t.Fatal("expected to receive bytes") + } + if result.BytesSent <= 0 { + t.Fatal("expected to send bytes") + } +} + +func TestUnitMeasureonewithcacheWorks(t *testing.T) { + measurer := &measurer{cache: make(map[string]Subresult)} + output := make(chan Subresult, 2) + for i := 0; i < 2; i++ { + measurer.measureonewithcache( + context.Background(), + output, + netxlogger.NewHandler(log.Log), + time.Now(), + "kernel.org", + "example.com:443", + ) + } + for _, expected := range []bool{false, true} { + result := <-output + if result.Cached != expected { + t.Fatal("unexpected cached") + } if *result.Failure != "ssl_invalid_hostname" { t.Fatal("unexpected failure") } if result.SNI != "kernel.org" { t.Fatal("unexpected SNI") } - if result.BytesReceived <= 0 { + if result.BytesReceived <= 0 && !result.Cached { t.Fatal("expected to receive bytes") } - if result.BytesSent <= 0 { + if result.BytesSent <= 0 && !result.Cached { t.Fatal("expected to send bytes") } - break + if result.BytesReceived != 0 && result.Cached { + t.Fatal("expected to not receive bytes") + } + if result.BytesSent != 0 && result.Cached { + t.Fatal("expected to not send bytes") + } } } @@ -174,9 +203,9 @@ func TestUnitProcessallPanicsIfInvalidSNI(t *testing.T) { outputs, measurement, handler.NewPrinterCallbacks(log.Log), - []string{"kernel.org", "ps.ooni.io"}, + []string{"kernel.org", "example.com"}, newsession(), - "ps.ooni.io", + "example.com", ) }