diff --git a/config/backend.go b/config/backend.go index dff676da9..e385e3f3a 100644 --- a/config/backend.go +++ b/config/backend.go @@ -40,10 +40,12 @@ func (b Backend) Schema(inline bool) *hcl.BodySchema { Origin string `hcl:"origin,optional"` Hostname string `hcl:"hostname,optional"` Path string `hcl:"path,optional"` - RequestHeaders map[string]string `hcl:"request_headers,optional"` - ResponseHeaders map[string]string `hcl:"response_headers,optional"` SetRequestHeaders map[string]string `hcl:"set_request_headers,optional"` + AddRequestHeaders map[string]string `hcl:"add_request_headers,optional"` + DelRequestHeaders []string `hcl:"remove_request_headers,optional"` SetResponseHeaders map[string]string `hcl:"set_response_headers,optional"` + AddResponseHeaders map[string]string `hcl:"add_response_headers,optional"` + DelResponseHeaders []string `hcl:"remove_response_headers,optional"` AddQueryParams map[string]cty.Value `hcl:"add_query_params,optional"` DelQueryParams []string `hcl:"remove_query_params,optional"` SetQueryParams map[string]cty.Value `hcl:"set_query_params,optional"` diff --git a/docs/README.md b/docs/README.md index 56374fb06..bbc8aa3fb 100644 --- a/docs/README.md +++ b/docs/README.md @@ -348,7 +348,11 @@ A `backend` defines the connection to a local/remote backend service. Backends c | `origin` | URL to connect to for backend requests
⚠ must start with the scheme `http://...` || | `path` | changeable part of upstream URL || | `request_body_limit` | Limit to configure the maximum buffer size while accessing `req.post` or `req.json_body` content. Valid units are: `KiB, MiB, GiB`. | `64MiB` | -| `set_request_headers` | header map to define additional or override header for the `origin` request || +| `add_request_headers` | header map to define additional header values for the `origin` request || +| `add_response_headers` | same as `add_request_headers` for the client response || +| `remove_request_headers` | header list to define header to be removed from the `origin` request || +| `remove_response_headers` | same as `remove_request_headers` for the client response || +| `set_request_headers` | header map to override header for the `origin` request || | `set_response_headers` | same as `set_request_headers` for the client response || | [`openapi`](#openapi_block) | Definition for validating outgoing requests to the `origin` and incoming responses from the `origin`. || | [`remove_query_params`](#query_params) | a list of query parameters to be removed from the upstream request URL || diff --git a/handler/context_options.go b/handler/context_options.go index b5ab112b5..a858bbeac 100644 --- a/handler/context_options.go +++ b/handler/context_options.go @@ -7,10 +7,12 @@ import ( ) const ( - attrReqHeaders = "request_headers" - attrResHeaders = "response_headers" attrSetReqHeaders = "set_request_headers" + attrAddReqHeaders = "add_request_headers" + attrDelReqHeaders = "remove_request_headers" attrSetResHeaders = "set_response_headers" + attrAddResHeaders = "add_response_headers" + attrDelResHeaders = "remove_response_headers" attrAddQueryParams = "add_query_params" attrDelQueryParams = "remove_query_params" attrSetQueryParams = "set_query_params" diff --git a/handler/proxy.go b/handler/proxy.go index a8273a47e..c0f57a9c7 100644 --- a/handler/proxy.go +++ b/handler/proxy.go @@ -181,7 +181,7 @@ func (p *Proxy) getTransport(scheme, origin, hostname string) *http.Transport { func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { startTime := time.Now() - if p.options.CORS != nil && isCorsPreflightRequest(req) { + if isCorsPreflightRequest(req) { p.setCorsRespHeaders(rw.Header(), req) rw.WriteHeader(http.StatusNoContent) return @@ -406,108 +406,146 @@ func (p *Proxy) Director(req *http.Request) error { func (p *Proxy) SetRoundtripContext(req *http.Request, beresp *http.Response) { var ( - attrCtx = []string{attrReqHeaders, attrSetReqHeaders} - bereq *http.Request - headerCtx http.Header + attrCtxAdd = attrAddReqHeaders + attrCtxDel = attrDelReqHeaders + attrCtxSet = attrSetReqHeaders + bereq *http.Request + headerCtx http.Header ) if beresp != nil { - attrCtx = []string{attrResHeaders, attrSetResHeaders} + attrCtxAdd = attrAddResHeaders + attrCtxDel = attrDelResHeaders + attrCtxSet = attrSetResHeaders bereq = beresp.Request headerCtx = beresp.Header + + defer p.setCorsRespHeaders(headerCtx, req) } else if req != nil { headerCtx = req.Header - } - evalCtx := eval.NewHTTPContext(p.evalContext, p.bufferOption, req, bereq, beresp) - - // Remove blacklisted headers after evaluation to be accessible within our context configuration. - if attrCtx[0] == attrReqHeaders { - for _, key := range headerBlacklist { - headerCtx.Del(key) + // Remove blacklisted headers after evaluation to + // be accessible within our context configuration. + if attrCtxSet == attrSetReqHeaders { + for _, key := range headerBlacklist { + headerCtx.Del(key) + } } } allAttributes, attrOk := p.options.Context.(body.Attributes) + if !attrOk { + return + } + + evalCtx := eval.NewHTTPContext(p.evalContext, p.bufferOption, req, bereq, beresp) + + var modifyQuery bool + + u := *req.URL + u.RawQuery = strings.ReplaceAll(u.RawQuery, "+", "%2B") + values := u.Query() + + for _, attrs := range allAttributes.JustAllAttributes() { + // apply header values in hierarchical and logical order: delete, set, add + attr, ok := attrs[attrCtxDel] + if ok { + val, diags := attr.Expr.Value(evalCtx) + if seetie.SetSeverityLevel(diags).HasErrors() { + p.log.WithField("parse config", p.String()).Error(diags) + } + + for _, key := range seetie.ValueToStringSlice(val) { + k := http.CanonicalHeaderKey(key) + if k == "User-Agent" { + headerCtx[k] = []string{} + continue + } - // apply header values - for _, ctxName := range attrCtx { // headers - if !attrOk { - break + headerCtx.Del(k) + } } - for _, attrs := range allAttributes.JustAllAttributesWithName(ctxName) { - attr, ok := attrs[ctxName] - if !ok { - continue + attr, ok = attrs[attrCtxSet] + if ok { + options, diags := NewOptionsMap(evalCtx, attr) + if diags != nil { + p.log.WithField("parse config", p.String()).Error(diags) } + + for key, values := range options { + k := http.CanonicalHeaderKey(key) + headerCtx[k] = values + } + } + + attr, ok = attrs[attrCtxAdd] + if ok { options, diags := NewOptionsMap(evalCtx, attr) - if diags.HasErrors() { + if diags != nil { p.log.WithField("parse config", p.String()).Error(diags) - continue } - setHeaderFields(headerCtx, options) + + for key, values := range options { + k := http.CanonicalHeaderKey(key) + headerCtx[k] = append(headerCtx[k], values...) + } } - } - // apply query params in hierarchical and logical order: delete, set, add - if attrOk && req != nil && beresp == nil { // just one way -> origin - var modify bool + if req == nil || beresp != nil { // just one way -> origin + continue + } - u := *req.URL - u.RawQuery = strings.ReplaceAll(u.RawQuery, "+", "%2B") - values := u.Query() + // apply query params in hierarchical and logical order: delete, set, add + attr, ok = attrs[attrDelQueryParams] + if ok { + val, diags := attr.Expr.Value(evalCtx) + if seetie.SetSeverityLevel(diags).HasErrors() { + p.log.WithField("parse config", p.String()).Error(diags) + } - // not by name to ensure the order for all params - for _, attrs := range allAttributes.JustAllAttributes() { - attr, ok := attrs[attrDelQueryParams] - if ok { - val, diags := attr.Expr.Value(evalCtx) - if seetie.SetSeverityLevel(diags).HasErrors() { - p.log.WithField("parse config", p.String()).Error(diags) - } - for _, key := range seetie.ValueToStringSlice(val) { - values.Del(key) - } - modify = true + for _, key := range seetie.ValueToStringSlice(val) { + values.Del(key) } - attr, ok = attrs[attrSetQueryParams] - if ok { - options, diags := NewOptionsMap(evalCtx, attr) - if diags != nil { - p.log.WithField("parse config", p.String()).Error(diags) - } - for k, v := range options { - values[k] = v - } - modify = true + modifyQuery = true + } + + attr, ok = attrs[attrSetQueryParams] + if ok { + options, diags := NewOptionsMap(evalCtx, attr) + if diags != nil { + p.log.WithField("parse config", p.String()).Error(diags) } - attr, ok = attrs[attrAddQueryParams] - if ok { - options, diags := NewOptionsMap(evalCtx, attr) - if diags != nil { - p.log.WithField("parse config", p.String()).Error(diags) - } - for k, v := range options { - if _, ok = values[k]; !ok { - values[k] = v - } else { - values[k] = append(values[k], v...) - } - } - modify = true + for k, v := range options { + values[k] = v } + + modifyQuery = true } - if modify { - req.URL.RawQuery = strings.ReplaceAll(values.Encode(), "+", "%20") + attr, ok = attrs[attrAddQueryParams] + if ok { + options, diags := NewOptionsMap(evalCtx, attr) + if diags != nil { + p.log.WithField("parse config", p.String()).Error(diags) + } + + for k, v := range options { + if _, ok = values[k]; !ok { + values[k] = v + } else { + values[k] = append(values[k], v...) + } + } + + modifyQuery = true } } - if beresp != nil && isCorsRequest(req) { - p.setCorsRespHeaders(headerCtx, req) + if modifyQuery { + req.URL.RawQuery = strings.ReplaceAll(values.Encode(), "+", "%20") } } @@ -549,7 +587,7 @@ func isCorsRequest(req *http.Request) bool { } func isCorsPreflightRequest(req *http.Request) bool { - return isCorsRequest(req) && req.Method == http.MethodOptions && (req.Header.Get("Access-Control-Request-Method") != "" || req.Header.Get("Access-Control-Request-Headers") != "") + return req.Method == http.MethodOptions && (req.Header.Get("Access-Control-Request-Method") != "" || req.Header.Get("Access-Control-Request-Headers") != "") } func IsCredentialed(headers http.Header) bool { @@ -557,7 +595,7 @@ func IsCredentialed(headers http.Header) bool { } func (p *Proxy) setCorsRespHeaders(headers http.Header, req *http.Request) { - if p.options.CORS == nil { + if p.options.CORS == nil || !isCorsRequest(req) { return } requestOrigin := req.Header.Get("Origin") diff --git a/handler/proxy_cors_test.go b/handler/proxy_cors_test.go index da0d17593..6574af78c 100644 --- a/handler/proxy_cors_test.go +++ b/handler/proxy_cors_test.go @@ -155,18 +155,6 @@ func TestCORSOptions_isCorsPreflightRequest(t *testing.T) { map[string]string{"Origin": "https://www.example.com"}, false, }, - { - "OPTIONS, without Origin, with ACRM", - http.MethodOptions, - map[string]string{"Access-Control-Request-Method": "POST"}, - false, - }, - { - "OPTIONS, without Origin, with ACRH", - http.MethodOptions, - map[string]string{"Access-Control-Request-Headers": "Content-Type"}, - false, - }, { "POST, with Origin, with ACRM", http.MethodPost, diff --git a/handler/proxy_test.go b/handler/proxy_test.go index fd6f20b15..f4f0db947 100644 --- a/handler/proxy_test.go +++ b/handler/proxy_test.go @@ -954,16 +954,16 @@ func TestProxy_SetRoundtripContext_Null_Eval(t *testing.T) { for i, tc := range []testCase{ {"no eval", `path = "/"`, test.Header{}}, - {"json_body client field", `response_headers = { "x-client" = "my-val-x-${req.json_body.client}" }`, + {"json_body client field", `set_response_headers = { "x-client" = "my-val-x-${req.json_body.client}" }`, test.Header{ "x-client": "my-val-x-true", }}, - {"json_body non existing field", `response_headers = { + {"json_body non existing field", `set_response_headers = { "${beresp.json_body.not-there}" = "my-val-0-${beresp.json_body.origin}" "${req.json_body.client}-my-val-a" = "my-val-b-${beresp.json_body.client}" }`, test.Header{"true-my-val-a": ""}}, // since one reference is failing ('not-there') the whole block does - {"json_body null value", `response_headers = { "x-null" = "${beresp.json_body.nil}" }`, test.Header{"x-null": ""}}, + {"json_body null value", `set_response_headers = { "x-null" = "${beresp.json_body.nil}" }`, test.Header{"x-null": ""}}, } { t.Run(tc.name, func(st *testing.T) { h := test.New(st) @@ -1055,10 +1055,10 @@ func TestProxy_BufferingOptions(t *testing.T) { {"beresp validation", newOptions(), `path = "/"`, eval.BufferResponse}, {"bereq validation", newOptions(), `path = "/"`, eval.BufferRequest}, {"no validation", newOptions(), `path = "/"`, eval.BufferNone}, - {"req buffer json.body & beresp validation", newOptions(), `response_headers = { x-test = "${req.json_body.client}" }`, eval.BufferRequest | eval.BufferResponse}, - {"beresp buffer json.body & bereq validation", newOptions(), `response_headers = { x-test = "${beresp.json_body.origin}" }`, eval.BufferRequest | eval.BufferResponse}, - {"req buffer json.body & bereq validation", newOptions(), `response_headers = { x-test = "${req.json_body.client}" }`, eval.BufferRequest}, - {"beresp buffer json.body & beresp validation", newOptions(), `response_headers = { x-test = "${beresp.json_body.origin}" }`, eval.BufferResponse}, + {"req buffer json.body & beresp validation", newOptions(), `set_response_headers = { x-test = "${req.json_body.client}" }`, eval.BufferRequest | eval.BufferResponse}, + {"beresp buffer json.body & bereq validation", newOptions(), `set_response_headers = { x-test = "${beresp.json_body.origin}" }`, eval.BufferRequest | eval.BufferResponse}, + {"req buffer json.body & bereq validation", newOptions(), `set_response_headers = { x-test = "${req.json_body.client}" }`, eval.BufferRequest}, + {"beresp buffer json.body & beresp validation", newOptions(), `set_response_headers = { x-test = "${beresp.json_body.origin}" }`, eval.BufferResponse}, } { t.Run(tc.name, func(st *testing.T) { h := test.New(st) diff --git a/internal/test/test_backend.go b/internal/test/test_backend.go index 6d296dc80..9d6242a09 100644 --- a/internal/test/test_backend.go +++ b/internal/test/test_backend.go @@ -75,6 +75,9 @@ func createAnythingHandler(status int) func(rw http.ResponseWriter, req *http.Re rw.Header().Set("Content-Length", strconv.Itoa(len(respContent))) rw.Header().Set("Content-Type", "application/json") + rw.Header().Set("Remove-Me-1", "r1") + rw.Header().Set("Remove-Me-2", "r2") + rw.WriteHeader(status) _, _ = rw.Write(respContent) } diff --git a/server/http_integration_test.go b/server/http_integration_test.go index 90f417c80..7ced2777b 100644 --- a/server/http_integration_test.go +++ b/server/http_integration_test.go @@ -665,6 +665,98 @@ func TestHTTPServer_QueryParams(t *testing.T) { } } +func TestHTTPServer_RequestHeaders(t *testing.T) { + client := newClient() + + const confPath = "testdata/integration/endpoint_eval/" + + type expectation struct { + Headers http.Header + } + + type testCase struct { + file string + query string + exp expectation + } + + for _, tc := range []testCase{ + {"12_couper.hcl", "ae=ae&aeb=aeb&def=def&xyz=xyz", expectation{ + Headers: http.Header{ + "Aeb": []string{"aeb", "aeb"}, + "Aeb_a_and_b": []string{"A&B", "A&B"}, + "Aeb_empty": []string{"", ""}, + "Aeb_multi": []string{"str1", "str2", "str3", "str4"}, + "Aeb_noop": []string{"", ""}, + "Aeb_null": []string{"", ""}, + "Aeb_string": []string{"str", "str"}, + "Def_a_and_b": []string{"A&B", "A&B"}, + "Def_empty": []string{"", ""}, + "Def_multi": []string{"str1", "str2", "str3", "str4"}, + "Def_noop": []string{"", ""}, + "Def_null": []string{"", ""}, + "Def_string": []string{"str", "str"}, + "Def": []string{"def", "def"}, + "Foo": []string{""}, + "Xxx": []string{"aaa", "bbb"}, + }, + }}, + } { + shutdown, _ := newCouper(path.Join(confPath, tc.file), test.New(t)) + + t.Run("_"+tc.query, func(subT *testing.T) { + helper := test.New(subT) + + req, err := http.NewRequest(http.MethodGet, "http://example.com:8080?"+tc.query, nil) + helper.Must(err) + + res, err := client.Do(req) + helper.Must(err) + + if r1 := res.Header.Get("Remove-Me-1"); r1 != "" { + t.Errorf("Unexpected header %s", r1) + } + if r2 := res.Header.Get("Remove-Me-2"); r2 != "" { + t.Errorf("Unexpected header %s", r2) + } + + if s1 := res.Header.Get("Set-Me-1"); s1 != "s1" { + t.Errorf("Missing or invalid header Set-Me-1: %s", s1) + } + if s2 := res.Header.Get("Set-Me-2"); s2 != "s2" { + t.Errorf("Missing or invalid header Set-Me-2: %s", s2) + } + + if a1 := res.Header.Get("Add-Me-1"); a1 != "a1" { + t.Errorf("Missing or invalid header Add-Me-1: %s", a1) + } + if a2 := res.Header.Get("Add-Me-2"); a2 != "a2" { + t.Errorf("Missing or invalid header Add-Me-2: %s", a2) + } + + resBytes, err := ioutil.ReadAll(res.Body) + helper.Must(err) + + _ = res.Body.Close() + + var jsonResult expectation + err = json.Unmarshal(resBytes, &jsonResult) + if err != nil { + t.Errorf("unmarshal json: %v: got:\n%s", err, string(resBytes)) + } + + jsonResult.Headers.Del("User-Agent") + jsonResult.Headers.Del("X-Forwarded-For") + + if !reflect.DeepEqual(jsonResult, tc.exp) { + t.Errorf("\nwant: \n%#v\ngot: \n%#v\npayload:\n%s", tc.exp, jsonResult, string(resBytes)) + } + }) + + shutdown() + } +} + func TestHTTPServer_QueryEncoding(t *testing.T) { client := newClient() diff --git a/server/testdata/integration/endpoint_eval/12_couper.hcl b/server/testdata/integration/endpoint_eval/12_couper.hcl new file mode 100644 index 000000000..6a313bd40 --- /dev/null +++ b/server/testdata/integration/endpoint_eval/12_couper.hcl @@ -0,0 +1,75 @@ +server "api" { + api { + endpoint "/" { + backend "anything" { + remove_request_headers = [ "aeb_del", "CaseIns", req.query.xyz[0] ] + set_request_headers = { + aeb_string = "str" + aeb_multi = ["str1", "str2"] + aeb_a_and_b = "A&B" + aeb_noop = req.headers.noop + aeb_null = null + aeb_empty = "" + xxx = ["yyy", "xxx"] + xxx = ["aaa", "bbb"] + "${req.query.aeb[0]}" = "aeb" + } + add_request_headers = { + aeb_string = "str" + aeb_multi = ["str3", "str4"] + aeb_a_and_b = "A&B" + aeb_noop = req.headers.noop + aeb_null = null + aeb_empty = "" + "${req.query.aeb[0]}" = "aeb" + } + + remove_response_headers = [ "Remove-Me-2" ] + set_response_headers = { + "Set-Me-2" = "s2" + } + add_response_headers = { + "Add-Me-2" = "a2" + } + } + } + } +} + +definitions { + # backend origin within a definition block gets replaced with the integration test "anything" server. + backend "anything" { + origin = env.COUPER_TEST_BACKEND_ADDR + + remove_request_headers = [ "def_del" ] + set_request_headers = { + def_string = "str" + def_multi = ["str1", "str2"] + def_a_and_b = "A&B" + def_noop = req.headers.noop + def_null = null + def_empty = "" + xxx = "ddd" + "${req.query.def[0]}" = "def" + foo = req.query.foo[0] + } + add_request_headers = { + def_string = "str" + def_multi = ["str3", "str4"] + def_a_and_b = "A&B" + def_noop = req.headers.noop + def_null = null + def_empty = "" + xxx = "eee" + "${req.query.def[0]}" = "def" + } + + remove_response_headers = [ "remove-me-1" ] + set_response_headers = { + "set-me-1" = "s1" + } + add_response_headers = { + "add-me-1" = "a1" + } + } +}