Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API SSZ content negotiation #10760

Merged
merged 8 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 56 additions & 12 deletions beacon-chain/rpc/apimiddleware/custom_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"

Expand All @@ -17,8 +18,10 @@ import (
)

const (
versionHeader = "Eth-Consensus-Version"
grpcVersionHeader = "Grpc-metadata-Eth-Consensus-Version"
versionHeader = "Eth-Consensus-Version"
grpcVersionHeader = "Grpc-metadata-Eth-Consensus-Version"
jsonMediaType = "application/json"
octetStreamMediaType = "application/octet-stream"
)

type sszConfig struct {
Expand Down Expand Up @@ -99,7 +102,12 @@ func handleGetSSZ(
req *http.Request,
config sszConfig,
) (handled bool) {
if !sszRequested(req) {
ssz, err := sszRequested(req)
if err != nil {
apimiddleware.WriteError(w, apimiddleware.InternalServerError(err), nil)
return true
}
if !ssz {
return false
}

Expand Down Expand Up @@ -193,17 +201,53 @@ func handlePostSSZ(
return true
}

func sszRequested(req *http.Request) bool {
func sszRequested(req *http.Request) (bool, error) {
accept, ok := req.Header["Accept"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: would be slightly more concise if you used https://pkg.go.dev/net/http#Header.Values

accept := req.Header.Values("Accept")
if len(accept) == 0 {
    return false, nil
}

if !ok {
return false
return false, nil
}
if len(accept) == 0 {
return false, nil
}
for _, v := range accept {
if v == "application/octet-stream" {
return true
types := strings.Split(accept[0], ",")
// match a number with optional decimals
regex, err := regexp.Compile("q=\\d(\\.\\d)?")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only needs to be compiled once for the life of the program - you can avoid recompiling on every call by declaring a variable above the function in package scope, using https://pkg.go.dev/regexp#MustCompile

if err != nil {
return false, err
}
currentType, currentPriority := "", 1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this cause the initial empty value to take precedence over a string where all values include a priority? I think this causes a bug; if you had an accept string like application/octet-stream;q=0.95,application/json;q=0.9", the function would return false.

I think initializing currentPriority to 0 is going to simplify things. It would fix the bug just mentioned. You also won't need to check if currentType == "" { on line 227, because the priority comparison will implicitly handle that.

for _, t := range types {
values := strings.Split(t, ";")
name := values[0]
if name != jsonMediaType && name != octetStreamMediaType {
continue
}
// no params specified
if len(values) == 1 {
if currentType == "" {
currentType = name
continue
}
priority := 1.0
if priority > currentPriority {
currentType, currentPriority = name, priority
}
continue
}
params := values[1]
match := regex.Find([]byte(params))
if match != nil {
priority, err := strconv.ParseFloat(strings.Split(string(match), "=")[1], 32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to take on the overhead of using a regex I would suggest simplifying things by using the String regexp methods and using a capture group to extract the desired value, via https://pkg.go.dev/regexp#Regexp.FindAllStringSubmatch

prioRx := regexp.MustCompile(`q=(\d+(?:\.\d+)?)`) // note that I added a `+` to the `\d` shorthands, otherwise you'll only capture the first character
pg := prioRx.FindAllStringSubmatch(params, 1)
if len(pg) != 1 {
    continue
}
priority, err := strconv.ParseFloat(pg[0][1])
...

more self-contained example: https://go.dev/play/p/NRswRIwv2Tr

if err != nil {
return false, err
}
if priority > currentPriority {
currentType, currentPriority = name, priority
}
}
}
return false

return currentType == octetStreamMediaType, nil
}

func sszPosted(req *http.Request) bool {
Expand All @@ -214,7 +258,7 @@ func sszPosted(req *http.Request) bool {
if len(ct) != 1 {
return false
}
return ct[0] == "application/octet-stream"
return ct[0] == octetStreamMediaType
}

func prepareSSZRequestForProxying(m *apimiddleware.ApiProxyMiddleware, endpoint apimiddleware.Endpoint, req *http.Request) apimiddleware.ErrorJson {
Expand Down Expand Up @@ -252,7 +296,7 @@ func preparePostedSSZData(req *http.Request) apimiddleware.ErrorJson {
}
req.Body = io.NopCloser(bytes.NewBuffer(data))
req.ContentLength = int64(len(data))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Type", jsonMediaType)
return nil
}

Expand Down Expand Up @@ -280,7 +324,7 @@ func writeSSZResponseHeaderAndBody(grpcResp *http.Response, w http.ResponseWrite
}
}
w.Header().Set("Content-Length", strconv.Itoa(len(respSsz)))
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Type", octetStreamMediaType)
w.Header().Set("Content-Disposition", "attachment; filename="+fileName)
w.Header().Set(versionHeader, respVersion)
if statusCodeHeader != "" {
Expand Down
81 changes: 71 additions & 10 deletions beacon-chain/rpc/apimiddleware/custom_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -34,28 +35,88 @@ func (t testSSZResponseJson) SSZData() string {
func TestSSZRequested(t *testing.T) {
t.Run("ssz_requested", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{"application/octet-stream"}
result := sszRequested(request)
request.Header["Accept"] = []string{octetStreamMediaType}
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, true, result)
})

t.Run("ssz_content_type_first", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{fmt.Sprintf("%s,%s", octetStreamMediaType, jsonMediaType)}
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, true, result)
})

t.Run("ssz_content_type_preferred_1", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{fmt.Sprintf("%s;q=0.9,%s", jsonMediaType, octetStreamMediaType)}
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, true, result)
})

t.Run("multiple_content_types", func(t *testing.T) {
t.Run("ssz_content_type_preferred_2", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{"application/json", "application/octet-stream"}
result := sszRequested(request)
request.Header["Accept"] = []string{fmt.Sprintf("%s;q=0.9,%s", jsonMediaType, octetStreamMediaType)}
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, true, result)
})

t.Run("other_content_type_preferred", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{fmt.Sprintf("%s,%s;q=0.9", jsonMediaType, octetStreamMediaType)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test case like application/octet-stream;q=0.95,application/json;q=0.9

result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, false, result)
})

t.Run("other_params", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{fmt.Sprintf("%s,%s;q=0.9,otherparam=xyz", jsonMediaType, octetStreamMediaType)}
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, false, result)
})

t.Run("no_header", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
result := sszRequested(request)
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, false, result)
})

t.Run("empty_header", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{}
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, false, result)
})

t.Run("empty_header_value", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{""}
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, false, result)
})

t.Run("other_content_type", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{"application/json"}
result := sszRequested(request)
request.Header["Accept"] = []string{"application/other"}
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, false, result)
})

t.Run("garbage", func(t *testing.T) {
request := httptest.NewRequest("GET", "http://foo.example", nil)
request.Header["Accept"] = []string{"This is Sparta!!!"}
result, err := sszRequested(request)
require.NoError(t, err)
assert.Equal(t, false, result)
})
}
Expand All @@ -82,7 +143,7 @@ func TestPreparePostedSszData(t *testing.T) {

preparePostedSSZData(request)
assert.Equal(t, int64(19), request.ContentLength)
assert.Equal(t, "application/json", request.Header.Get("Content-Type"))
assert.Equal(t, jsonMediaType, request.Header.Get("Content-Type"))
}

func TestSerializeMiddlewareResponseIntoSSZ(t *testing.T) {
Expand Down Expand Up @@ -138,7 +199,7 @@ func TestWriteSSZResponseHeaderAndBody(t *testing.T) {
v, ok = writer.Header()["Content-Type"]
require.Equal(t, true, ok, "header not found")
require.Equal(t, 1, len(v), "wrong number of header values")
assert.Equal(t, "application/octet-stream", v[0])
assert.Equal(t, octetStreamMediaType, v[0])
v, ok = writer.Header()["Content-Disposition"]
require.Equal(t, true, ok, "header not found")
require.Equal(t, 1, len(v), "wrong number of header values")
Expand Down