diff --git a/topdown/http.go b/topdown/http.go index cad1bf8cba..4d8b224599 100644 --- a/topdown/http.go +++ b/topdown/http.go @@ -505,7 +505,7 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt if len(tlsCaCert) != 0 { tlsCaCert = bytes.Replace(tlsCaCert, []byte("\\n"), []byte("\n"), -1) - pool, err := addCACertsFromBytes(tlsConfig.RootCAs, []byte(tlsCaCert)) + pool, err := addCACertsFromBytes(tlsConfig.RootCAs, tlsCaCert) if err != nil { return nil, nil, err } @@ -692,21 +692,21 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) { return nil, nil } - headers, err := parseResponseHeaders(cachedRespData.Headers) - if err != nil { - return nil, err - } - - // check the freshness of the cached response - if isCachedResponseFresh(c.bctx, headers, c.forceCacheParams) { + if getCurrentTime(c.bctx).Before(cachedRespData.ExpiresAt) { return cachedRespData.formatToAST(c.forceJSONDecode, c.forceYAMLDecode) } + var err error c.httpReq, c.httpClient, err = createHTTPRequest(c.bctx, c.key) if err != nil { return nil, handleHTTPSendErr(c.bctx, err) } + headers, err := parseResponseHeaders(cachedRespData.Headers) + if err != nil { + return nil, err + } + // check with the server if the stale response is still up-to-date. // If server returns a new response (ie. status_code=200), update the cache with the new response // If server returns an unmodified response (ie. status_code=304), update the headers for the existing response @@ -727,6 +727,12 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) { } } + expiresAt, err := expiryFromHeaders(result.Header) + if err != nil { + return nil, err + } + cachedRespData.ExpiresAt = expiresAt + cachingMode, err := getCachingMode(c.key) if err != nil { return nil, err @@ -753,7 +759,7 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) { return nil, err } - if err := insertIntoHTTPSendInterQueryCache(c.bctx, c.key, result, respBody, c.forceCacheParams != nil); err != nil { + if err := insertIntoHTTPSendInterQueryCache(c.bctx, c.key, result, respBody, c.forceCacheParams); err != nil { return nil, err } @@ -761,8 +767,8 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) { } // insertIntoHTTPSendInterQueryCache inserts given key and value in the inter-query cache -func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp *http.Response, respBody []byte, force bool) error { - if resp == nil || (!force && !canStore(resp.Header)) { +func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) error { + if resp == nil || (!forceCaching(cacheParams) && !canStore(resp.Header)) { return nil } @@ -781,9 +787,9 @@ func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp var pcv cache.InterQueryCacheValue if cachingMode == defaultCachingMode { - pcv, err = newInterQueryCacheValue(resp, respBody, force) + pcv, err = newInterQueryCacheValue(bctx, resp, respBody, cacheParams) } else { - pcv, err = newInterQueryCacheData(resp, respBody, force) + pcv, err = newInterQueryCacheData(bctx, resp, respBody, cacheParams) } if err != nil { @@ -855,15 +861,15 @@ func getCachingMode(req ast.Object) (cachingMode, error) { return "", fmt.Errorf("invalid value specified for %v field: %v", key.String(), string(s)) } } - return cachingMode(defaultCachingMode), nil + return defaultCachingMode, nil } type interQueryCacheValue struct { Data []byte } -func newInterQueryCacheValue(resp *http.Response, respBody []byte, force bool) (*interQueryCacheValue, error) { - data, err := newInterQueryCacheData(resp, respBody, force) +func newInterQueryCacheValue(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheValue, error) { + data, err := newInterQueryCacheData(bctx, resp, respBody, cacheParams) if err != nil { return nil, err } @@ -893,20 +899,48 @@ type interQueryCacheData struct { Status string StatusCode int Headers http.Header + ExpiresAt time.Time +} + +func forceCaching(cacheParams *forceCacheParams) bool { + return cacheParams != nil && cacheParams.forceCacheDurationSeconds > 0 +} + +func expiryFromHeaders(headers http.Header) (time.Time, error) { + var expiresAt time.Time + maxAge, err := parseMaxAgeCacheDirective(parseCacheControlHeader(headers)) + if err != nil { + return time.Time{}, err + } + if maxAge != -1 { + createdAt, err := getResponseHeaderDate(headers) + if err != nil { + return time.Time{}, err + } + expiresAt = createdAt.Add(time.Second * time.Duration(maxAge)) + } else { + expiresAt = getResponseHeaderExpires(headers) + } + return expiresAt, nil } -func newInterQueryCacheData(resp *http.Response, respBody []byte, force bool) (*interQueryCacheData, error) { - if force { - now := time.Now().UTC() - resp.Header["Date"] = []string{now.Format(http.TimeFormat)} +func newInterQueryCacheData(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheData, error) { + var expiresAt time.Time + + if forceCaching(cacheParams) { + createdAt := getCurrentTime(bctx) + expiresAt = createdAt.Add(time.Second * time.Duration(cacheParams.forceCacheDurationSeconds)) } else { - _, err := parseResponseHeaders(resp.Header) + var err error + expiresAt, err = expiryFromHeaders(resp.Header) if err != nil { return nil, err } } - cv := interQueryCacheData{RespBody: respBody, + cv := interQueryCacheData{ + ExpiresAt: expiresAt, + RespBody: respBody, Status: resp.Status, StatusCode: resp.StatusCode, Headers: resp.Header} @@ -1013,58 +1047,6 @@ func canStore(headers http.Header) bool { return true } -func isCachedResponseFresh(bctx BuiltinContext, headers *responseHeaders, cacheParams *forceCacheParams) bool { - if headers.date.IsZero() { - return false - } - - currentTime := getCurrentTime(bctx) - if currentTime.IsZero() { - return false - } - - currentAge := currentTime.Sub(headers.date) - - // The time.Sub operation uses wall clock readings and - // not monotonic clock readings as the parsed version of the response time - // does not contain monotonic clock readings. This can result in negative durations. - // Another scenario where a negative duration can occur, is when a server sets the Date - // response header. As per https://tools.ietf.org/html/rfc7231#section-7.1.1.2, - // an origin server MUST NOT send a Date header field if it does not - // have a clock capable of providing a reasonable approximation of the - // current instance in Coordinated Universal Time. - // Hence, consider the cached response as stale if a negative duration is encountered. - if currentAge < 0 { - return false - } - - if cacheParams != nil { - // override the cache directives set by the server - maxAgeDur := time.Second * time.Duration(cacheParams.forceCacheDurationSeconds) - if maxAgeDur > currentAge { - return true - } - } else { - // Check "max-age" cache directive. - // The "max-age" response directive indicates that the response is to be - // considered stale after its age is greater than the specified number - // of seconds. - if headers.maxAge != -1 { - maxAgeDur := time.Second * time.Duration(headers.maxAge) - if maxAgeDur > currentAge { - return true - } - } else { - // Check "Expires" header. - // Note: "max-age" if set, takes precedence over "Expires" - if headers.expires.Sub(headers.date) > currentAge { - return true - } - } - } - return false -} - func getCurrentTime(bctx BuiltinContext) time.Time { var current time.Time @@ -1280,7 +1262,7 @@ func (c *interQueryCache) InsertIntoCache(value *http.Response) (ast.Value, erro } // fallback to the http send cache if error encountered while inserting response in inter-query cache - err = insertIntoHTTPSendInterQueryCache(c.bctx, c.key, value, respBody, c.forceCacheParams != nil) + err = insertIntoHTTPSendInterQueryCache(c.bctx, c.key, value, respBody, c.forceCacheParams) if err != nil { insertIntoHTTPSendCache(c.bctx, c.key, result) } diff --git a/topdown/http_test.go b/topdown/http_test.go index 399b12beb2..3a53b78e1c 100644 --- a/topdown/http_test.go +++ b/topdown/http_test.go @@ -13,7 +13,7 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" + "io" "math" "net" "net/http" @@ -405,7 +405,7 @@ func TestHTTPPostRequest(t *testing.T) { contentType := r.Header.Get("Content-Type") - bs, err := ioutil.ReadAll(r.Body) + bs, err := io.ReadAll(r.Body) if err != nil { t.Fatal(err) } @@ -1268,13 +1268,12 @@ func TestHTTPSendIntraQueryForceCaching(t *testing.T) { }, } - data := loadSmallTestData() - - t0 := time.Now().UTC() - opts := setTime(t0) + data := map[string]interface{}{} for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { + t0 := time.Now().UTC() + opts := setTime(t0) var requests []*http.Request ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1301,7 +1300,7 @@ func TestHTTPSendIntraQueryForceCaching(t *testing.T) { // eval first), so expect 2x the total request count the test case specified. actualCount := len(requests) / 2 if actualCount != tc.expectedReqCount { - t.Fatalf("Expected to get %d requests, got %d", tc.expectedReqCount, actualCount) + t.Errorf("Expected to get %d requests, got %d", tc.expectedReqCount, actualCount) } }) } @@ -1358,11 +1357,10 @@ func TestHTTPSendIntraQueryCachingModifiedResp(t *testing.T) { data := loadSmallTestData() - t0 := time.Now() - opts := setTime(t0) - for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { + t0 := time.Now().UTC() + opts := setTime(t0) var requests []*http.Request ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1373,7 +1371,7 @@ func TestHTTPSendIntraQueryCachingModifiedResp(t *testing.T) { headers[k] = v } - headers.Set("Date", t0.Format(time.RFC850)) + headers.Set("Date", t0.Format(http.TimeFormat)) etag := w.Header().Get("etag") @@ -1428,11 +1426,10 @@ func TestHTTPSendIntraQueryCachingNewResp(t *testing.T) { data := loadSmallTestData() - t0 := time.Now() - opts := setTime(t0) - for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { + t0 := time.Now().UTC() + opts := setTime(t0) var requests []*http.Request ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1443,7 +1440,7 @@ func TestHTTPSendIntraQueryCachingNewResp(t *testing.T) { headers[k] = v } - headers.Set("Date", t0.Format(time.RFC850)) + headers.Set("Date", t0.Format(http.TimeFormat)) etag := w.Header().Get("etag") @@ -1618,19 +1615,6 @@ func TestGetResponseHeaderDateEmpty(t *testing.T) { } } -func TestIsCachedResponseFreshZeroTime(t *testing.T) { - zeroTime := new(time.Time) - result := isCachedResponseFresh(BuiltinContext{}, &responseHeaders{date: *zeroTime}, nil) - if result { - t.Fatal("Expected stale cache response") - } - - result = isCachedResponseFresh(BuiltinContext{Time: ast.NullTerm()}, &responseHeaders{date: time.Now()}, nil) - if result { - t.Fatal("Expected stale cache response") - } -} - func TestParseMaxAgeCacheDirective(t *testing.T) { tests := []struct { note string @@ -1849,10 +1833,13 @@ func TestInterQueryCheckCacheError(t *testing.T) { } func TestNewInterQueryCacheValue(t *testing.T) { + date := "Wed, 31 Dec 2115 07:28:00 GMT" + maxAge := 290304000 + headers := make(http.Header) headers.Set("test-header", "test-value") - headers.Set("Cache-Control", "max-age=290304000, public") - headers.Set("Date", "Wed, 31 Dec 2115 07:28:00 GMT") + headers.Set("Cache-Control", fmt.Sprintf("max-age=%d, public", maxAge)) + headers.Set("Date", date) // test data var b = []byte(`[{"ID": "1", "Firstname": "John"}]`) @@ -1862,18 +1849,23 @@ func TestNewInterQueryCacheValue(t *testing.T) { StatusCode: http.StatusOK, Header: headers, Request: &http.Request{Method: "Get"}, - Body: ioutil.NopCloser(bytes.NewBuffer(b)), + Body: io.NopCloser(bytes.NewBuffer(b)), } - result, err := newInterQueryCacheValue(response, b, false) + result, err := newInterQueryCacheValue(BuiltinContext{}, response, b, &forceCacheParams{}) if err != nil { t.Fatalf("Unexpected error %v", err) } - cvd := interQueryCacheData{RespBody: b, + dateTime, _ := http.ParseTime(date) + + cvd := interQueryCacheData{ + RespBody: b, Status: "200 OK", StatusCode: http.StatusOK, - Headers: headers} + Headers: headers, + ExpiresAt: dateTime.Add(time.Duration(maxAge) * time.Second), + } cvdBytes, err := json.Marshal(cvd) if err != nil { @@ -1945,7 +1937,7 @@ func TestHTTPSClient(t *testing.T) { localServerKeyFile = "testdata/server-key.pem" ) - caCertPEM, err := ioutil.ReadFile(localCaFile) + caCertPEM, err := os.ReadFile(localCaFile) if err != nil { t.Fatal(err) } @@ -2036,17 +2028,17 @@ func TestHTTPSClient(t *testing.T) { t.Fatal(err) } - ca, err := ioutil.ReadFile(localCaFile) + ca, err := os.ReadFile(localCaFile) if err != nil { t.Fatal(err) } - cert, err := ioutil.ReadFile(localClientCertFile) + cert, err := os.ReadFile(localClientCertFile) if err != nil { t.Fatal(err) } - key, err := ioutil.ReadFile(localClientKeyFile) + key, err := os.ReadFile(localClientKeyFile) if err != nil { t.Fatal(err) } @@ -2247,7 +2239,7 @@ func TestHTTPSNoClientCerts(t *testing.T) { localServerKeyFile = "testdata/server-key.pem" ) - caCertPEM, err := ioutil.ReadFile(localCaFile) + caCertPEM, err := os.ReadFile(localCaFile) if err != nil { t.Fatal(err) } @@ -2315,7 +2307,7 @@ func TestHTTPSNoClientCerts(t *testing.T) { t.Fatal(err) } - ca, err := ioutil.ReadFile(localCaFile) + ca, err := os.ReadFile(localCaFile) if err != nil { t.Fatal(err) } @@ -2455,7 +2447,7 @@ func TestCertSelectionLogic(t *testing.T) { ) // Set up Environment - caCertPEM, err := ioutil.ReadFile(localCaFile) + caCertPEM, err := os.ReadFile(localCaFile) if err != nil { t.Fatal(err) } @@ -2465,7 +2457,7 @@ func TestCertSelectionLogic(t *testing.T) { t.Fatal("failed to parse CA cert") } - ca, err := ioutil.ReadFile(localCaFile) + ca, err := os.ReadFile(localCaFile) if err != nil { t.Fatal(err) } @@ -2659,14 +2651,14 @@ func TestSocketHTTPGetRequest(t *testing.T) { people = append(people, Person{ID: "1", Firstname: "John"}) // Create a local socket - tmpF, err := ioutil.TempFile("", "") + tmpF, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } socketPath := tmpF.Name() tmpF.Close() - os.Remove(socketPath) + _ = os.Remove(socketPath) socket, err := net.Listen("unix", socketPath) if err != nil {