diff --git a/frontend.go b/frontend.go index 62d15c7..fa71fb1 100644 --- a/frontend.go +++ b/frontend.go @@ -103,6 +103,15 @@ func (f *Frontend) load(ctx context.Context, key string, loader HTTPLoader) (Ent oldSpan := trace.FromContext(ctx) newContext := trace.NewContext(context.Background(), oldSpan) + deadline, hasDeadline := ctx.Deadline() + if hasDeadline { + var cancel context.CancelFunc + + newContext, cancel = context.WithDeadline(newContext, deadline) + + defer cancel() + } + newContextWithSpan, span := trace.StartSpan(newContext, "flamingo/httpcache/load") span.Annotate(nil, key) diff --git a/frontend_test.go b/frontend_test.go index 6b23ce7..e73fea8 100644 --- a/frontend_test.go +++ b/frontend_test.go @@ -3,14 +3,17 @@ package httpcache_test import ( "context" "net/http" + "net/http/httptest" "testing" "time" "flamingo.me/flamingo/v3/framework/flamingo" - "flamingo.me/httpcache/mocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "flamingo.me/httpcache/mocks" + "flamingo.me/httpcache" ) @@ -160,3 +163,84 @@ func TestFrontend_Get(t *testing.T) { }) } } + +func TestContextDeadlineExceeded(t *testing.T) { + t.Parallel() + + t.Run("exceeded, throw error", func(t *testing.T) { + t.Parallel() + + backend := new(mocks.Backend) + + backend.EXPECT().Get(testKey).Return(func() (httpcache.Entry, bool) { return httpcache.Entry{}, false }()) + + backend.EXPECT().Set(mock.Anything, mock.Anything).Return(nil) + + contextWithDeadline, cancel := context.WithDeadline(context.Background(), time.Now().Add(4*time.Second)) + t.Cleanup(cancel) + + f := new(httpcache.Frontend).Inject(new(flamingo.NullLogger)).SetBackend(backend) + got, err := f.Get(contextWithDeadline, testKey, loaderWithWaitingTime) + + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, httpcache.Entry{}, got) + }) + + t.Run("did not exceed, no error", func(t *testing.T) { + t.Parallel() + + backend := new(mocks.Backend) + + backend.EXPECT().Get(testKey).Return(func() (httpcache.Entry, bool) { return httpcache.Entry{}, false }()) + + backend.EXPECT().Set(mock.Anything, mock.Anything).Return(nil) + + contextWithDeadline, cancel := context.WithDeadline(context.Background(), time.Now().Add(6*time.Second)) + t.Cleanup(cancel) + + f := new(httpcache.Frontend).Inject(new(flamingo.NullLogger)).SetBackend(backend) + got, err := f.Get(contextWithDeadline, testKey, loaderWithWaitingTime) + + assert.NoError(t, err) + assert.Equal(t, []byte("body"), got.Body) + }) +} + +func loaderWithWaitingTime(ctx context.Context) (httpcache.Entry, error) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + + _, _ = w.Write([]byte("Test 123")) + })) + + defer server.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + return httpcache.Entry{}, err + } + + client := &http.Client{} + resp, err := client.Do(req) + + if resp != nil { + _ = resp.Body.Close() + } + + if err != nil { + return httpcache.Entry{}, err + } + + return httpcache.Entry{ + Meta: httpcache.Meta{ + LifeTime: time.Now().Add(10), + GraceTime: time.Now().Add(15), + Tags: nil, + }, + Header: nil, + Status: "200 OK", + StatusCode: http.StatusOK, + Body: []byte("body"), + }, nil +}