diff --git a/netx/httptransport/httptransport.go b/netx/httptransport/httptransport.go index 82dab505..6544fb1d 100644 --- a/netx/httptransport/httptransport.go +++ b/netx/httptransport/httptransport.go @@ -40,8 +40,9 @@ type Resolver interface { // field of Config is nil/empty, we will use a suitable default. type Config struct { BogonIsError bool // default: bogon is not error - ByteCounter *bytecounter.Counter // default: no byte counting - ContextByteCounting bool // default: no context byte counting + ByteCounter *bytecounter.Counter // default: no explicit byte counting + CacheResolutions bool // default: no caching + ContextByteCounting bool // default: no implicit byte counting Dialer Dialer // default: dialer.DNSDialer Logger Logger // default: no logging ProxyURL *url.URL // default: no proxy @@ -71,6 +72,9 @@ func New(config Config) RoundTripper { if config.Saver != nil { r = resolver.SaverResolver{Resolver: r, Saver: config.Saver} } + if config.CacheResolutions { + r = &resolver.CacheResolver{Resolver: r} + } config.Resolver = r } if config.Dialer == nil { diff --git a/netx/httptransport/integration_test.go b/netx/httptransport/integration_test.go index c890e4dc..03f96561 100644 --- a/netx/httptransport/integration_test.go +++ b/netx/httptransport/integration_test.go @@ -21,6 +21,7 @@ func TestIntegrationSuccess(t *testing.T) { txp := httptransport.New(httptransport.Config{ BogonIsError: true, ByteCounter: counter, + CacheResolutions: true, ContextByteCounting: true, Logger: log.Log, Saver: saver, diff --git a/netx/resolver/cache.go b/netx/resolver/cache.go new file mode 100644 index 00000000..6c122f14 --- /dev/null +++ b/netx/resolver/cache.go @@ -0,0 +1,44 @@ +package resolver + +import ( + "context" + "sync" +) + +// CacheResolver is a resolver that caches successful replies. +type CacheResolver struct { + Resolver + mu sync.Mutex + cache map[string][]string +} + +// LookupHost implements Resolver.LookupHost +func (r *CacheResolver) LookupHost( + ctx context.Context, hostname string) ([]string, error) { + if entry := r.Get(hostname); entry != nil { + return entry, nil + } + entry, err := r.Resolver.LookupHost(ctx, hostname) + if err != nil { + return nil, err + } + r.Set(hostname, entry) + return entry, nil +} + +// Get gets the currently configured entry for domain, or nil +func (r *CacheResolver) Get(domain string) []string { + r.mu.Lock() + defer r.mu.Unlock() + return r.cache[domain] +} + +// Set allows to pre-populate the cache +func (r *CacheResolver) Set(domain string, addresses []string) { + r.mu.Lock() + if r.cache == nil { + r.cache = make(map[string][]string) + } + r.cache[domain] = addresses + r.mu.Unlock() +} diff --git a/netx/resolver/cache_test.go b/netx/resolver/cache_test.go new file mode 100644 index 00000000..87d88b0c --- /dev/null +++ b/netx/resolver/cache_test.go @@ -0,0 +1,53 @@ +package resolver_test + +import ( + "context" + "errors" + "testing" + + "github.com/ooni/probe-engine/netx/resolver" +) + +func TestUnitCacheFailure(t *testing.T) { + expected := errors.New("mocked error") + var r resolver.Resolver = resolver.FakeResolver{ + Err: expected, + } + r = &resolver.CacheResolver{Resolver: r} + addrs, err := r.LookupHost(context.Background(), "www.google.com") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil addrs here") + } +} + +func TestUnitCacheHitSuccess(t *testing.T) { + var r resolver.Resolver = resolver.FakeResolver{ + Err: errors.New("mocked error"), + } + cache := &resolver.CacheResolver{Resolver: r} + cache.Set("dns.google.com", []string{"8.8.8.8"}) + addrs, err := cache.LookupHost(context.Background(), "dns.google.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("not the result we expected") + } +} + +func TestUnitCacheMissSuccess(t *testing.T) { + var r resolver.Resolver = resolver.FakeResolver{ + Result: []string{"8.8.8.8"}, + } + r = &resolver.CacheResolver{Resolver: r} + addrs, err := r.LookupHost(context.Background(), "dns.google.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("not the result we expected") + } +}