Skip to content

Commit

Permalink
refactor: make use of contexts in more places
Browse files Browse the repository at this point in the history
- `CacheControl.FlushCaches`
- `Querier.Query`
- `Resolver.Resolve`

Besides all the API churn, this leads to `ParallelBestResolver`,
`StrictResolver` and `UpstreamResolver` simplification: timeouts only
need to be setup in one place, `UpstreamResolver`.

We also benefit from using HTTP request contexts, so if the client
closes the connection we stop processing on our side.
  • Loading branch information
ThinkChaos committed Nov 21, 2023
1 parent e4ebc16 commit eae99ec
Show file tree
Hide file tree
Showing 52 changed files with 797 additions and 628 deletions.
12 changes: 6 additions & 6 deletions api/api_interface_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ type ListRefresher interface {
}

type Querier interface {
Query(question string, qType dns.Type) (*model.Response, error)
Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error)
}

type CacheControl interface {
FlushCaches()
FlushCaches(ctx context.Context)
}

func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
Expand Down Expand Up @@ -137,13 +137,13 @@ func (i *OpenAPIInterfaceImpl) ListRefresh(_ context.Context,
return ListRefresh200Response{}, nil
}

func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObject) (QueryResponseObject, error) {
func (i *OpenAPIInterfaceImpl) Query(ctx context.Context, request QueryRequestObject) (QueryResponseObject, error) {
qType := dns.Type(dns.StringToType[request.Body.Type])
if qType == dns.Type(dns.TypeNone) {
return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil
}

resp, err := i.querier.Query(dns.Fqdn(request.Body.Query), qType)
resp, err := i.querier.Query(ctx, dns.Fqdn(request.Body.Query), qType)
if err != nil {
return nil, err
}
Expand All @@ -156,10 +156,10 @@ func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObje
}), nil
}

func (i *OpenAPIInterfaceImpl) CacheFlush(_ context.Context,
func (i *OpenAPIInterfaceImpl) CacheFlush(ctx context.Context,
_ CacheFlushRequestObject,
) (CacheFlushResponseObject, error) {
i.cacheControl.FlushCaches()
i.cacheControl.FlushCaches(ctx)

return CacheFlush200Response{}, nil
}
39 changes: 22 additions & 17 deletions api/api_interface_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"time"

// . "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
Expand Down Expand Up @@ -54,14 +53,14 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus {
return args.Get(0).(BlockingStatus)
}

func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, error) {
args := m.Called(question, qType)
func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
args := m.Called(ctx, question, qType)

return args.Get(0).(*model.Response), args.Error(1)
}

func (m *CacheControlMock) FlushCaches() {
_ = m.Called()
func (m *CacheControlMock) FlushCaches(ctx context.Context) {
_ = m.Called(ctx)
}

var _ = Describe("API implementation tests", func() {
Expand All @@ -71,9 +70,15 @@ var _ = Describe("API implementation tests", func() {
listRefreshMock *ListRefreshMock
cacheControlMock *CacheControlMock
sut *OpenAPIInterfaceImpl

ctx context.Context
cancelFn context.CancelFunc
)

BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)

blockingControlMock = &BlockingControlMock{}
querierMock = &QuerierMock{}
listRefreshMock = &ListRefreshMock{}
Expand All @@ -95,12 +100,12 @@ var _ = Describe("API implementation tests", func() {
)
Expect(err).Should(Succeed())

querierMock.On("Query", "google.com.", A).Return(&model.Response{
querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{
Res: queryResponse,
Reason: "reason",
}, nil)

resp, err := sut.Query(context.Background(), QueryRequestObject{
resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Query: "google.com", Type: "A",
},
Expand All @@ -116,7 +121,7 @@ var _ = Describe("API implementation tests", func() {
})

It("should return 400 on wrong parameter", func() {
resp, err := sut.Query(context.Background(), QueryRequestObject{
resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Query: "google.com",
Type: "WRONGTYPE",
Expand All @@ -135,7 +140,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() {
listRefreshMock.On("RefreshLists").Return(nil)

resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed())
var resp200 ListRefresh200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand All @@ -144,7 +149,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 500 on failure", func() {
listRefreshMock.On("RefreshLists").Return(errors.New("failed"))

resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed())
var resp500 ListRefresh500TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp500))
Expand All @@ -160,7 +165,7 @@ var _ = Describe("API implementation tests", func() {
duration := "3s"
grroups := "gr1,gr2"

resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{
Duration: &duration,
Groups: &grroups,
Expand All @@ -173,7 +178,7 @@ var _ = Describe("API implementation tests", func() {

It("should return 400 on failure", func() {
blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed"))
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{})
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{})
Expect(err).Should(Succeed())
var resp400 DisableBlocking400TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp400))
Expand All @@ -182,7 +187,7 @@ var _ = Describe("API implementation tests", func() {

It("should return 400 on wrong duration parameter", func() {
wrongDuration := "4sds"
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{
Duration: &wrongDuration,
},
Expand All @@ -197,7 +202,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() {
blockingControlMock.On("EnableBlocking").Return()

resp, err := sut.EnableBlocking(context.Background(), EnableBlockingRequestObject{})
resp, err := sut.EnableBlocking(ctx, EnableBlockingRequestObject{})
Expect(err).Should(Succeed())
var resp200 EnableBlocking200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand All @@ -212,7 +217,7 @@ var _ = Describe("API implementation tests", func() {
AutoEnableInSec: 47,
})

resp, err := sut.BlockingStatus(context.Background(), BlockingStatusRequestObject{})
resp, err := sut.BlockingStatus(ctx, BlockingStatusRequestObject{})
Expect(err).Should(Succeed())
var resp200 BlockingStatus200JSONResponse
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand All @@ -227,8 +232,8 @@ var _ = Describe("API implementation tests", func() {
Describe("Cache API", func() {
When("Cache flush is called", func() {
It("should return 200 on success", func() {
cacheControlMock.On("FlushCaches").Return()
resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{})
cacheControlMock.On("FlushCaches", ctx).Return()
resp, err := sut.CacheFlush(ctx, CacheFlushRequestObject{})
Expect(err).Should(Succeed())
var resp200 CacheFlush200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand Down
6 changes: 3 additions & 3 deletions cache/expirationcache/expiration_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type Options struct {
// OnExpirationCallback will be called just before an element gets expired and will
// be removed from cache. This function can return new value and TTL to leave the
// element in the cache or nil to remove it
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
type OnExpirationCallback[T any] func(ctx context.Context, key string) (val *T, ttl time.Duration)

// OnCacheHitCallback will be called on cache get if entry was found
type OnCacheHitCallback func(key string)
Expand All @@ -58,7 +58,7 @@ func NewCacheWithOnExpired[T any](ctx context.Context, options Options,
l, _ := lru.New(defaultSize)
c := &ExpiringLRUCache[T]{
cleanUpInterval: defaultCleanUpInterval,
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
preExpirationFn: func(ctx context.Context, key string) (val *T, ttl time.Duration) {
return nil, 0
},
onCacheHit: func(key string) {},
Expand Down Expand Up @@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() {
var keysToDelete []string

for _, key := range expiredKeys {
newVal, newTTL := e.preExpirationFn(key)
newVal, newTTL := e.preExpirationFn(context.Background(), key)
if newVal != nil {
e.Put(key, newVal, newTTL)
} else {
Expand Down
6 changes: 3 additions & 3 deletions cache/expirationcache/expiration_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ var _ = Describe("Expiration cache", func() {
Describe("preExpiration function", func() {
When("function is defined", func() {
It("should update the value and TTL if function returns values", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "v2"

return &v2, time.Second
Expand All @@ -169,7 +169,7 @@ var _ = Describe("Expiration cache", func() {
})

It("should update the value and TTL if function returns values on cleanup if element is expired", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "val2"

return &v2, time.Second
Expand All @@ -192,7 +192,7 @@ var _ = Describe("Expiration cache", func() {
})

It("should delete the key if function returns nil", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
return nil, 0
}
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)
Expand Down
10 changes: 6 additions & 4 deletions cache/expirationcache/prefetching_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ type cacheValue[T any] struct {
type OnEntryReloadedCallback func(key string)

// ReloadEntryFn reloads a prefetched entry by key
type ReloadEntryFn[T any] func(key string) (*T, time.Duration)
type ReloadEntryFn[T any] func(ctx context.Context, key string) (*T, time.Duration)

type PrefetchingOptions[T any] struct {
Options
ReloadFn func(cacheKey string) (*T, time.Duration)
ReloadFn ReloadEntryFn[T]
PrefetchThreshold int
PrefetchExpires time.Duration
PrefetchMaxItemsCount int
Expand Down Expand Up @@ -70,9 +70,11 @@ func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
}

func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
func (e *PrefetchingExpiringLRUCache[T]) onExpired(
ctx context.Context, cacheKey string,
) (val *cacheValue[T], ttl time.Duration) {
if e.shouldPrefetch(cacheKey) {
loadedVal, ttl := e.reloadFn(cacheKey)
loadedVal, ttl := e.reloadFn(ctx, cacheKey)
if loadedVal != nil {
if e.onPrefetchEntryReloaded != nil {
e.onPrefetchEntryReloaded(cacheKey)
Expand Down
8 changes: 4 additions & 4 deletions cache/expirationcache/prefetching_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand Down Expand Up @@ -86,7 +86,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand All @@ -113,7 +113,7 @@ var _ = Describe("Prefetching expiration cache", func() {
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand Down Expand Up @@ -143,7 +143,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand Down
Loading

0 comments on commit eae99ec

Please sign in to comment.