From 69519243fe2c03800931a2d1e8d094c39effc02c Mon Sep 17 00:00:00 2001 From: jsign Date: Sat, 16 Nov 2019 17:10:57 -0300 Subject: [PATCH] fix: services/httpd: parse correctly Accept header with extra test case Signed-off-by: jsign --- services/httpd/accept.go | 80 +++++++ services/httpd/response_writer.go | 54 ++++- services/httpd/response_writer_test.go | 299 +++++++++++++++---------- 3 files changed, 297 insertions(+), 136 deletions(-) create mode 100644 services/httpd/accept.go diff --git a/services/httpd/accept.go b/services/httpd/accept.go new file mode 100644 index 00000000000..8d6540fe40b --- /dev/null +++ b/services/httpd/accept.go @@ -0,0 +1,80 @@ +// This file is an adaptation of https://github.com/markusthoemmes/goautoneg. +// The copyright and license header are reproduced below. +// +// Copyright [yyyy] [name of copyright owner] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// http://www.apache.org/licenses/LICENSE-2.0 + +package httpd + +import ( + "mime" + "sort" + "strconv" + "strings" +) + +// accept is a structure to represent a clause in an HTTP Accept Header. +type accept struct { + Type, SubType string + Q float64 + Params map[string]string +} + +// parseAccept parses the given string as an Accept header as defined in +// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.1. +// Some rules are only loosely applied and might not be as strict as defined in the RFC. +func parseAccept(headers []string) []accept { + var res []accept + for _, header := range headers { + parts := strings.Split(header, ",") + for _, part := range parts { + mt, params, err := mime.ParseMediaType(part) + if err != nil { + continue + } + + accept := accept{ + Q: 1.0, // "[...] The default value is q=1" + Params: params, + } + + // A media-type is defined as + // "*/*" | ( type "/" "*" ) | ( type "/" subtype ) + types := strings.Split(mt, "/") + switch { + // This case is not defined in the spec keep it to mimic the original code. + case len(types) == 1 && types[0] == "*": + accept.Type = "*" + accept.SubType = "*" + case len(types) == 2: + accept.Type = types[0] + accept.SubType = types[1] + default: + continue + } + + if qVal, ok := params["q"]; ok { + // A parsing failure will set Q to 0. + accept.Q, _ = strconv.ParseFloat(qVal, 64) + delete(params, "q") + } + + res = append(res, accept) + } + } + sort.SliceStable(res, func(i, j int) bool { + return res[i].Q > res[j].Q + }) + return res +} diff --git a/services/httpd/response_writer.go b/services/httpd/response_writer.go index 78ad0cb3c1b..11cc5fa5d3f 100644 --- a/services/httpd/response_writer.go +++ b/services/httpd/response_writer.go @@ -20,27 +20,57 @@ type ResponseWriter interface { http.ResponseWriter } +type formatter interface { + WriteResponse(w io.Writer, resp Response) error +} + +type supportedContentType struct { + full string + acceptType string + acceptSubType string + formatter func(pretty bool) formatter +} + +var ( + csvFormatFactory = func(pretty bool) formatter { return &csvFormatter{statementID: -1} } + msgpackFormatFactory = func(pretty bool) formatter { return &msgpackFormatter{} } + jsonFormatFactory = func(pretty bool) formatter { return &jsonFormatter{Pretty: pretty} } + + contentTypes = []supportedContentType{ + {full: "application/json", acceptType: "application", acceptSubType: "json", formatter: jsonFormatFactory}, + {full: "application/csv", acceptType: "application", acceptSubType: "csv", formatter: csvFormatFactory}, + {full: "text/csv", acceptType: "text", acceptSubType: "csv", formatter: csvFormatFactory}, + {full: "application/x-msgpack", acceptType: "application", acceptSubType: "x-msgpack", formatter: msgpackFormatFactory}, + } + defaultContentType = contentTypes[0] +) + // NewResponseWriter creates a new ResponseWriter based on the Accept header // in the request that wraps the ResponseWriter. func NewResponseWriter(w http.ResponseWriter, r *http.Request) ResponseWriter { pretty := r.URL.Query().Get("pretty") == "true" rw := &responseWriter{ResponseWriter: w} - switch r.Header.Get("Accept") { - case "application/csv", "text/csv": - w.Header().Add("Content-Type", "text/csv") - rw.formatter = &csvFormatter{statementID: -1} - case "application/x-msgpack": - w.Header().Add("Content-Type", "application/x-msgpack") - rw.formatter = &msgpackFormatter{} - case "application/json": - fallthrough - default: - w.Header().Add("Content-Type", "application/json") - rw.formatter = &jsonFormatter{Pretty: pretty} + + acceptHeaders := parseAccept(r.Header["Accept"]) + for _, accept := range acceptHeaders { + for _, ct := range contentTypes { + if match(accept, ct) { + w.Header().Add("Content-Type", ct.full) + rw.formatter = ct.formatter(pretty) + return rw + } + } } + w.Header().Add("Content-Type", defaultContentType.full) + rw.formatter = defaultContentType.formatter(pretty) return rw } +func match(ah accept, sct supportedContentType) bool { + return (ah.Type == "*" || ah.Type == sct.acceptType) && + (ah.SubType == "*" || ah.SubType == sct.acceptSubType) +} + // WriteError is a convenience function for writing an error response to the ResponseWriter. func WriteError(w ResponseWriter, err error) (int, error) { return w.WriteResponse(Response{Err: err}) diff --git a/services/httpd/response_writer_test.go b/services/httpd/response_writer_test.go index 5aa62290e46..f1e7014673f 100644 --- a/services/httpd/response_writer_test.go +++ b/services/httpd/response_writer_test.go @@ -19,63 +19,79 @@ import ( ) func TestResponseWriter_CSV(t *testing.T) { - header := make(http.Header) - header.Set("Accept", "text/csv") - r := &http.Request{ - Header: header, - URL: &url.URL{}, + tableTest := []struct { + header string + }{ + {header: "*/csv"}, + {header: "text/*"}, + {header: "text/csv"}, + {header: "text/csv,application/json"}, + {header: "text/csv;q=1,application/json"}, + {header: "text/csv;q=0.9,application/json;q=0.8"}, + {header: "application/json;q=0.8,text/csv;q=0.9"}, } - w := httptest.NewRecorder() - writer := httpd.NewResponseWriter(w, r) - n, err := writer.WriteResponse(httpd.Response{ - Results: []*query.Result{ - { - StatementID: 0, - Series: []*models.Row{ - { - Name: "cpu", - Tags: map[string]string{ - "host": "server01", - "region": "uswest", - }, - Columns: []string{"time", "value"}, - Values: [][]interface{}{ - {time.Unix(0, 10), float64(2.5)}, - {time.Unix(0, 20), int64(5)}, - {time.Unix(0, 30), nil}, - {time.Unix(0, 40), "foobar"}, - {time.Unix(0, 50), true}, - {time.Unix(0, 60), false}, - {time.Unix(0, 70), uint64(math.MaxInt64 + 1)}, - }, - }, + for _, testCase := range tableTest { + testCase := testCase + t.Run(testCase.header, func(t *testing.T) { + t.Parallel() + header := make(http.Header) + header.Set("Accept", testCase.header) + r := &http.Request{ + Header: header, + URL: &url.URL{}, + } + w := httptest.NewRecorder() + + writer := httpd.NewResponseWriter(w, r) + n, err := writer.WriteResponse(httpd.Response{ + Results: []*query.Result{ { - Name: "cpu", - Tags: map[string]string{ - "host": "", - "region": "", - }, - Columns: []string{"time", "value"}, - Values: [][]interface{}{ - {time.Unix(0, 10), float64(2.5)}, - {time.Unix(0, 20), int64(5)}, - {time.Unix(0, 30), nil}, - {time.Unix(0, 40), "foobar"}, - {time.Unix(0, 50), true}, - {time.Unix(0, 60), false}, - {time.Unix(0, 70), uint64(math.MaxInt64 + 1)}, + StatementID: 0, + Series: []*models.Row{ + { + Name: "cpu", + Tags: map[string]string{ + "host": "server01", + "region": "uswest", + }, + Columns: []string{"time", "value"}, + Values: [][]interface{}{ + {time.Unix(0, 10), float64(2.5)}, + {time.Unix(0, 20), int64(5)}, + {time.Unix(0, 30), nil}, + {time.Unix(0, 40), "foobar"}, + {time.Unix(0, 50), true}, + {time.Unix(0, 60), false}, + {time.Unix(0, 70), uint64(math.MaxInt64 + 1)}, + }, + }, + { + Name: "cpu", + Tags: map[string]string{ + "host": "", + "region": "", + }, + Columns: []string{"time", "value"}, + Values: [][]interface{}{ + {time.Unix(0, 10), float64(2.5)}, + {time.Unix(0, 20), int64(5)}, + {time.Unix(0, 30), nil}, + {time.Unix(0, 40), "foobar"}, + {time.Unix(0, 50), true}, + {time.Unix(0, 60), false}, + {time.Unix(0, 70), uint64(math.MaxInt64 + 1)}, + }, + }, }, }, }, - }, - }, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - if got, want := w.Body.String(), `name,tags,time,value + if got, want := w.Body.String(), `name,tags,time,value cpu,"host=server01,region=uswest",10,2.5 cpu,"host=server01,region=uswest",20,5 cpu,"host=server01,region=uswest",30, @@ -91,99 +107,134 @@ cpu,,50,true cpu,,60,false cpu,,70,9223372036854775808 `; got != want { - t.Errorf("unexpected output:\n\ngot=%v\nwant=%s", got, want) - } else if got, want := n, len(want); got != want { - t.Errorf("unexpected output length: got=%d want=%d", got, want) + t.Errorf("unexpected output:\n\ngot=%v\nwant=%s", got, want) + } else if got, want := n, len(want); got != want { + t.Errorf("unexpected output length: got=%d want=%d", got, want) + } + }) } } func TestResponseWriter_MessagePack(t *testing.T) { - header := make(http.Header) - header.Set("Accept", "application/x-msgpack") - r := &http.Request{ - Header: header, - URL: &url.URL{}, + tableTest := []struct { + header string + }{ + {header: "*/x-msgpack"}, + {header: "application/x-msgpack"}, + {header: "application/x-msgpack,application/json"}, + {header: "application/x-msgpack;q=1,application/json"}, + {header: "application/x-msgpack;q=0.9,application/json;q=0.8"}, + {header: "application/json;q=0.8,application/x-msgpack;q=0.9"}, } - w := httptest.NewRecorder() - writer := httpd.NewResponseWriter(w, r) - _, err := writer.WriteResponse(httpd.Response{ - Results: []*query.Result{ - { - StatementID: 0, - Series: []*models.Row{ + for _, testCase := range tableTest { + testCase := testCase + t.Run(testCase.header, func(t *testing.T) { + t.Parallel() + header := make(http.Header) + header.Set("Accept", testCase.header) + r := &http.Request{ + Header: header, + URL: &url.URL{}, + } + w := httptest.NewRecorder() + + writer := httpd.NewResponseWriter(w, r) + _, err := writer.WriteResponse(httpd.Response{ + Results: []*query.Result{ { - Name: "cpu", - Tags: map[string]string{ - "host": "server01", - }, - Columns: []string{"time", "value"}, - Values: [][]interface{}{ - {time.Unix(0, 10), float64(2.5)}, - {time.Unix(0, 20), int64(5)}, - {time.Unix(0, 30), nil}, - {time.Unix(0, 40), "foobar"}, - {time.Unix(0, 50), true}, - {time.Unix(0, 60), false}, - {time.Unix(0, 70), uint64(math.MaxInt64 + 1)}, + StatementID: 0, + Series: []*models.Row{ + { + Name: "cpu", + Tags: map[string]string{ + "host": "server01", + }, + Columns: []string{"time", "value"}, + Values: [][]interface{}{ + {time.Unix(0, 10), float64(2.5)}, + {time.Unix(0, 20), int64(5)}, + {time.Unix(0, 30), nil}, + {time.Unix(0, 40), "foobar"}, + {time.Unix(0, 50), true}, + {time.Unix(0, 60), false}, + {time.Unix(0, 70), uint64(math.MaxInt64 + 1)}, + }, + }, }, }, }, - }, - }, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - // The reader always reads times as time.Local so encode the expected response - // as JSON and insert it into the expected values. - values, err := json.Marshal([][]interface{}{ - {time.Unix(0, 10).Local(), float64(2.5)}, - {time.Unix(0, 20).Local(), int64(5)}, - {time.Unix(0, 30).Local(), nil}, - {time.Unix(0, 40).Local(), "foobar"}, - {time.Unix(0, 50).Local(), true}, - {time.Unix(0, 60).Local(), false}, - {time.Unix(0, 70).Local(), uint64(math.MaxInt64 + 1)}, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + // The reader always reads times as time.Local so encode the expected response + // as JSON and insert it into the expected values. + values, err := json.Marshal([][]interface{}{ + {time.Unix(0, 10).Local(), float64(2.5)}, + {time.Unix(0, 20).Local(), int64(5)}, + {time.Unix(0, 30).Local(), nil}, + {time.Unix(0, 40).Local(), "foobar"}, + {time.Unix(0, 50).Local(), true}, + {time.Unix(0, 60).Local(), false}, + {time.Unix(0, 70).Local(), uint64(math.MaxInt64 + 1)}, + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - reader := msgp.NewReader(w.Body) - var buf bytes.Buffer - if _, err := reader.WriteToJSON(&buf); err != nil { - t.Fatalf("unexpected error: %s", err) - } - want := fmt.Sprintf(`{"results":[{"statement_id":0,"series":[{"name":"cpu","tags":{"host":"server01"},"columns":["time","value"],"values":%s}]}]}`, string(values)) - if got := strings.TrimSpace(buf.String()); got != want { - t.Fatalf("unexpected output:\n\ngot=%v\nwant=%v", got, want) + reader := msgp.NewReader(w.Body) + var buf bytes.Buffer + if _, err := reader.WriteToJSON(&buf); err != nil { + t.Fatalf("unexpected error: %s", err) + } + want := fmt.Sprintf(`{"results":[{"statement_id":0,"series":[{"name":"cpu","tags":{"host":"server01"},"columns":["time","value"],"values":%s}]}]}`, string(values)) + if got := strings.TrimSpace(buf.String()); got != want { + t.Fatalf("unexpected output:\n\ngot=%v\nwant=%v", got, want) + } + }) } } func TestResponseWriter_MessagePack_Error(t *testing.T) { - header := make(http.Header) - header.Set("Accept", "application/x-msgpack") - r := &http.Request{ - Header: header, - URL: &url.URL{}, + tableTest := []struct { + header string + }{ + {header: "application/x-msgpack"}, + {header: "application/x-msgpack,application/json"}, + {header: "application/x-msgpack;q=1,application/json"}, + {header: "application/x-msgpack;q=0.9,application/json;q=0.8"}, + {header: "application/json;q=0.8,application/x-msgpack;q=0.9"}, } - w := httptest.NewRecorder() - writer := httpd.NewResponseWriter(w, r) - writer.WriteResponse(httpd.Response{ - Err: fmt.Errorf("test error"), - }) + for _, testCase := range tableTest { + testCase := testCase + t.Run(testCase.header, func(t *testing.T) { + t.Parallel() + header := make(http.Header) + header.Set("Accept", testCase.header) + r := &http.Request{ + Header: header, + URL: &url.URL{}, + } + w := httptest.NewRecorder() - reader := msgp.NewReader(w.Body) - var buf bytes.Buffer - if _, err := reader.WriteToJSON(&buf); err != nil { - t.Fatalf("unexpected error: %s", err) - } - want := fmt.Sprintf(`{"error":"test error"}`) - if have := strings.TrimSpace(buf.String()); have != want { - t.Fatalf("unexpected output: %s != %s", have, want) + writer := httpd.NewResponseWriter(w, r) + writer.WriteResponse(httpd.Response{ + Err: fmt.Errorf("test error"), + }) + + reader := msgp.NewReader(w.Body) + var buf bytes.Buffer + if _, err := reader.WriteToJSON(&buf); err != nil { + t.Fatalf("unexpected error: %s", err) + } + want := fmt.Sprintf(`{"error":"test error"}`) + if have := strings.TrimSpace(buf.String()); have != want { + t.Fatalf("unexpected output: %s != %s", have, want) + } + }) } }