From 91705c432e89edbda120db76740925c54a8d6c5c Mon Sep 17 00:00:00 2001 From: Joel Unzain Date: Wed, 18 Oct 2017 19:08:33 -0700 Subject: [PATCH] Added logic and unit tests for a simple version of Header forwarding (from responses from api.xmidt to some TR1D1UM caller) --- src/tr1d1um/http.go | 18 ++++++++++++++++++ src/tr1d1um/http_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/src/tr1d1um/http.go b/src/tr1d1um/http.go index b932d7bf..f64c1d41 100644 --- a/src/tr1d1um/http.go +++ b/src/tr1d1um/http.go @@ -7,6 +7,7 @@ import ( "github.com/Comcast/webpa-common/logging" "github.com/go-kit/kit/log" "github.com/gorilla/mux" + "strings" ) //ConversionHandler wraps the main WDMP -> WRP conversion method @@ -67,5 +68,22 @@ func (ch *ConversionHandler) ServeHTTP(origin http.ResponseWriter, req *http.Req ctx, cancel := context.WithTimeout(req.Context(), ch.sender.GetRespTimeout()) response, err := ch.sender.Send(ch, origin, wdmpPayload, req.WithContext(ctx)) cancel() + + ForwardHeadersByPrefix("X", origin, response) + ch.sender.HandleResponse(ch, err, response, origin) } + +// Helper functions + +//ForwardHeadersByPrefix forwards header values whose keys start with the given prefix from some response +//into an responseWriter +func ForwardHeadersByPrefix(prefix string, origin http.ResponseWriter, resp *http.Response){ + for headerKey, headerValues := range resp.Header { + if strings.HasPrefix(headerKey, prefix) { + for _, headerValue := range headerValues { + origin.Header().Add(headerKey, headerValue) + } + } + } +} diff --git a/src/tr1d1um/http_test.go b/src/tr1d1um/http_test.go index 510e317e..b2debc30 100644 --- a/src/tr1d1um/http_test.go +++ b/src/tr1d1um/http_test.go @@ -107,6 +107,46 @@ func TestConversionHandler(t *testing.T) { }) } +func TestForwardHeadersByPrefix(t *testing.T) { + t.Run("NoHeaders", func(t *testing.T) { + assert := assert.New(t) + origin := httptest.NewRecorder() + + ForwardHeadersByPrefix("H", origin, resp) + assert.Empty(origin.Header()) + }) + + t.Run("MultipleHeadersFiltered", func(t *testing.T) { + assert := assert.New(t) + origin := httptest.NewRecorder() + resp := &http.Response{Header:http.Header{}} + + resp.Header.Add("Helium", "3") + resp.Header.Add("Hydrogen", "5") + resp.Header.Add("Hydrogen", "6") + + ForwardHeadersByPrefix("He", origin, resp) + assert.NotEmpty(origin.Header()) + assert.EqualValues(1, len(origin.Header())) + assert.EqualValues("3", origin.Header().Get("Helium")) + }) + + t.Run("MultipleHeadersFilteredFullArray", func(t *testing.T) { + assert := assert.New(t) + origin := httptest.NewRecorder() + resp := &http.Response{Header:http.Header{}} + + resp.Header.Add("Helium", "3") + resp.Header.Add("Hydrogen", "5") + resp.Header.Add("Hydrogen", "6") + + ForwardHeadersByPrefix("H", origin, resp) + assert.NotEmpty(origin.Header()) + assert.EqualValues(2, len(origin.Header())) + assert.EqualValues([]string{"5","6"}, origin.Header()["Hydrogen"]) + }) +} + func SetUpTest(encodeArg interface{}, req *http.Request) { recorder := httptest.NewRecorder() timeout := time.Nanosecond