diff --git a/beacon-chain/rpc/apimiddleware/custom_handlers.go b/beacon-chain/rpc/apimiddleware/custom_handlers.go index 24d5b6937328..201551b521f3 100644 --- a/beacon-chain/rpc/apimiddleware/custom_handlers.go +++ b/beacon-chain/rpc/apimiddleware/custom_handlers.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "regexp" "strconv" "strings" @@ -17,10 +18,15 @@ import ( ) const ( - versionHeader = "Eth-Consensus-Version" - grpcVersionHeader = "Grpc-metadata-Eth-Consensus-Version" + versionHeader = "Eth-Consensus-Version" + grpcVersionHeader = "Grpc-metadata-Eth-Consensus-Version" + jsonMediaType = "application/json" + octetStreamMediaType = "application/octet-stream" ) +// match a number with optional decimals +var priorityRegex = regexp.MustCompile(`q=(\d+(?:\.\d+)?)`) + type sszConfig struct { fileName string responseJson sszResponse @@ -99,7 +105,12 @@ func handleGetSSZ( req *http.Request, config sszConfig, ) (handled bool) { - if !sszRequested(req) { + ssz, err := sszRequested(req) + if err != nil { + apimiddleware.WriteError(w, apimiddleware.InternalServerError(err), nil) + return true + } + if !ssz { return false } @@ -193,17 +204,42 @@ func handlePostSSZ( return true } -func sszRequested(req *http.Request) bool { - accept, ok := req.Header["Accept"] - if !ok { - return false - } - for _, v := range accept { - if v == "application/octet-stream" { - return true +func sszRequested(req *http.Request) (bool, error) { + accept := req.Header.Values("Accept") + if len(accept) == 0 { + return false, nil + } + types := strings.Split(accept[0], ",") + currentType, currentPriority := "", 0.0 + for _, t := range types { + values := strings.Split(t, ";") + name := values[0] + if name != jsonMediaType && name != octetStreamMediaType { + continue + } + // no params specified + if len(values) == 1 { + priority := 1.0 + if priority > currentPriority { + currentType, currentPriority = name, priority + } + continue + } + params := values[1] + match := priorityRegex.FindAllStringSubmatch(params, 1) + if len(match) != 1 { + continue + } + priority, err := strconv.ParseFloat(match[0][1], 32) + if err != nil { + return false, err + } + if priority > currentPriority { + currentType, currentPriority = name, priority } } - return false + + return currentType == octetStreamMediaType, nil } func sszPosted(req *http.Request) bool { @@ -214,7 +250,7 @@ func sszPosted(req *http.Request) bool { if len(ct) != 1 { return false } - return ct[0] == "application/octet-stream" + return ct[0] == octetStreamMediaType } func prepareSSZRequestForProxying(m *apimiddleware.ApiProxyMiddleware, endpoint apimiddleware.Endpoint, req *http.Request) apimiddleware.ErrorJson { @@ -252,7 +288,7 @@ func preparePostedSSZData(req *http.Request) apimiddleware.ErrorJson { } req.Body = io.NopCloser(bytes.NewBuffer(data)) req.ContentLength = int64(len(data)) - req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Type", jsonMediaType) return nil } @@ -280,7 +316,7 @@ func writeSSZResponseHeaderAndBody(grpcResp *http.Response, w http.ResponseWrite } } w.Header().Set("Content-Length", strconv.Itoa(len(respSsz))) - w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Type", octetStreamMediaType) w.Header().Set("Content-Disposition", "attachment; filename="+fileName) w.Header().Set(versionHeader, respVersion) if statusCodeHeader != "" { diff --git a/beacon-chain/rpc/apimiddleware/custom_handlers_test.go b/beacon-chain/rpc/apimiddleware/custom_handlers_test.go index b8e6fa35b973..35b8e36f0a7b 100644 --- a/beacon-chain/rpc/apimiddleware/custom_handlers_test.go +++ b/beacon-chain/rpc/apimiddleware/custom_handlers_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -34,28 +35,88 @@ func (t testSSZResponseJson) SSZData() string { func TestSSZRequested(t *testing.T) { t.Run("ssz_requested", func(t *testing.T) { request := httptest.NewRequest("GET", "http://foo.example", nil) - request.Header["Accept"] = []string{"application/octet-stream"} - result := sszRequested(request) + request.Header["Accept"] = []string{octetStreamMediaType} + result, err := sszRequested(request) + require.NoError(t, err) + assert.Equal(t, true, result) + }) + + t.Run("ssz_content_type_first", func(t *testing.T) { + request := httptest.NewRequest("GET", "http://foo.example", nil) + request.Header["Accept"] = []string{fmt.Sprintf("%s,%s", octetStreamMediaType, jsonMediaType)} + result, err := sszRequested(request) + require.NoError(t, err) + assert.Equal(t, true, result) + }) + + t.Run("ssz_content_type_preferred_1", func(t *testing.T) { + request := httptest.NewRequest("GET", "http://foo.example", nil) + request.Header["Accept"] = []string{fmt.Sprintf("%s;q=0.9,%s", jsonMediaType, octetStreamMediaType)} + result, err := sszRequested(request) + require.NoError(t, err) assert.Equal(t, true, result) }) - t.Run("multiple_content_types", func(t *testing.T) { + t.Run("ssz_content_type_preferred_2", func(t *testing.T) { request := httptest.NewRequest("GET", "http://foo.example", nil) - request.Header["Accept"] = []string{"application/json", "application/octet-stream"} - result := sszRequested(request) + request.Header["Accept"] = []string{fmt.Sprintf("%s;q=0.95,%s;q=0.9", octetStreamMediaType, jsonMediaType)} + result, err := sszRequested(request) + require.NoError(t, err) assert.Equal(t, true, result) }) + t.Run("other_content_type_preferred", func(t *testing.T) { + request := httptest.NewRequest("GET", "http://foo.example", nil) + request.Header["Accept"] = []string{fmt.Sprintf("%s,%s;q=0.9", jsonMediaType, octetStreamMediaType)} + result, err := sszRequested(request) + require.NoError(t, err) + assert.Equal(t, false, result) + }) + + t.Run("other_params", func(t *testing.T) { + request := httptest.NewRequest("GET", "http://foo.example", nil) + request.Header["Accept"] = []string{fmt.Sprintf("%s,%s;q=0.9,otherparam=xyz", jsonMediaType, octetStreamMediaType)} + result, err := sszRequested(request) + require.NoError(t, err) + assert.Equal(t, false, result) + }) + t.Run("no_header", func(t *testing.T) { request := httptest.NewRequest("GET", "http://foo.example", nil) - result := sszRequested(request) + result, err := sszRequested(request) + require.NoError(t, err) + assert.Equal(t, false, result) + }) + + t.Run("empty_header", func(t *testing.T) { + request := httptest.NewRequest("GET", "http://foo.example", nil) + request.Header["Accept"] = []string{} + result, err := sszRequested(request) + require.NoError(t, err) + assert.Equal(t, false, result) + }) + + t.Run("empty_header_value", func(t *testing.T) { + request := httptest.NewRequest("GET", "http://foo.example", nil) + request.Header["Accept"] = []string{""} + result, err := sszRequested(request) + require.NoError(t, err) assert.Equal(t, false, result) }) t.Run("other_content_type", func(t *testing.T) { request := httptest.NewRequest("GET", "http://foo.example", nil) - request.Header["Accept"] = []string{"application/json"} - result := sszRequested(request) + request.Header["Accept"] = []string{"application/other"} + result, err := sszRequested(request) + require.NoError(t, err) + assert.Equal(t, false, result) + }) + + t.Run("garbage", func(t *testing.T) { + request := httptest.NewRequest("GET", "http://foo.example", nil) + request.Header["Accept"] = []string{"This is Sparta!!!"} + result, err := sszRequested(request) + require.NoError(t, err) assert.Equal(t, false, result) }) } @@ -82,7 +143,7 @@ func TestPreparePostedSszData(t *testing.T) { preparePostedSSZData(request) assert.Equal(t, int64(19), request.ContentLength) - assert.Equal(t, "application/json", request.Header.Get("Content-Type")) + assert.Equal(t, jsonMediaType, request.Header.Get("Content-Type")) } func TestSerializeMiddlewareResponseIntoSSZ(t *testing.T) { @@ -138,7 +199,7 @@ func TestWriteSSZResponseHeaderAndBody(t *testing.T) { v, ok = writer.Header()["Content-Type"] require.Equal(t, true, ok, "header not found") require.Equal(t, 1, len(v), "wrong number of header values") - assert.Equal(t, "application/octet-stream", v[0]) + assert.Equal(t, octetStreamMediaType, v[0]) v, ok = writer.Header()["Content-Disposition"] require.Equal(t, true, ok, "header not found") require.Equal(t, 1, len(v), "wrong number of header values")