diff --git a/rest/api_test.go b/rest/api_test.go index 169df9fbbf..6458df6b27 100644 --- a/rest/api_test.go +++ b/rest/api_test.go @@ -256,26 +256,50 @@ func TestCORSOrigin(t *testing.T) { reqHeaders := map[string]string{ "Origin": tc.origin, } - response := rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders) - assert.Equal(t, tc.headerOutput, response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusBadRequest) - require.Contains(t, response.Body.String(), invalidDatabaseName) - - response = rt.SendRequestWithHeaders("GET", "/{{.db}}/", "", reqHeaders) - assert.Equal(t, tc.headerOutput, response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusUnauthorized) - require.Contains(t, response.Body.String(), ErrLoginRequired.Message) - - response = rt.SendRequestWithHeaders("GET", "/notadb/", "", reqHeaders) - assert.Equal(t, tc.headerOutput, response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusUnauthorized) - require.Contains(t, response.Body.String(), ErrLoginRequired.Message) - - // admin port doesn't have CORS - response = rt.SendAdminRequestWithHeaders("GET", "/{{.keyspace}}/_all_docs", "", reqHeaders) - assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusOK) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/{{.keyspace}}/", "", reqHeaders) + assert.Equal(t, tc.headerOutput, response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusBadRequest) + require.Contains(t, response.Body.String(), invalidDatabaseName) + } else { + RequireStatus(t, response, http.StatusNoContent) + + } + } + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders) + assert.Equal(t, tc.headerOutput, response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusUnauthorized) + require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/notadb/", "", reqHeaders) + assert.Equal(t, tc.headerOutput, response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusUnauthorized) + require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + } else { + RequireStatus(t, response, http.StatusNoContent) + + } + } + + for _, method := range []string{http.MethodGet, http.MethodOptions} { + // admin port doesn't have CORS + response := rt.SendAdminRequestWithHeaders(method, "/{{.keyspace}}/_all_docs", "", reqHeaders) + assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusOK) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } // test with a config without * should reject non-matches sc := rt.ServerContext() defer func() { @@ -284,8 +308,10 @@ func TestCORSOrigin(t *testing.T) { sc.Config.API.CORS.Origin = []string{"http://example.com", "http://staging.example.com"} if !base.StringSliceContains(sc.Config.API.CORS.Origin, tc.origin) { - response = rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders) - assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/{{.keyspace}}/", "", reqHeaders) + assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + } } }) } diff --git a/rest/cors_test.go b/rest/cors_test.go index e777f7ece8..5c46d660eb 100644 --- a/rest/cors_test.go +++ b/rest/cors_test.go @@ -39,25 +39,46 @@ func TestCORSDynamicSet(t *testing.T) { reqHeaders := map[string]string{ "Origin": "http://example.com", } - response := rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusBadRequest) - require.Contains(t, response.Body.String(), invalidDatabaseName) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/{{.keyspace}}/", "", reqHeaders) + require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusBadRequest) + require.Contains(t, response.Body.String(), invalidDatabaseName) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } // successful request - response = rt.SendUserRequestWithHeaders("GET", "/{{.keyspace}}/_all_docs", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusOK) - - response = rt.SendRequestWithHeaders("GET", "/{{.db}}/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusUnauthorized) - require.Contains(t, response.Body.String(), ErrLoginRequired.Message) - - response = rt.SendUserRequestWithHeaders("GET", "/{{.db}}/", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusOK) - + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendUserRequestWithHeaders(method, "/{{.keyspace}}/_all_docs", "", reqHeaders, username, RestTesterDefaultUserPassword) + require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusOK) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders) + require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusUnauthorized) + require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendUserRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders, username, RestTesterDefaultUserPassword) + require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusOK) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } dbConfig = rt.NewDbConfig() dbConfig.CORS = &auth.CORSConfig{ Origin: []string{"http://example.org"}, @@ -67,52 +88,109 @@ func TestCORSDynamicSet(t *testing.T) { RequireStatus(t, resp, http.StatusCreated) // this falls back to the server config CORS without the user being authenticated - response = rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusBadRequest) - require.Contains(t, response.Body.String(), invalidDatabaseName) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/{{.keyspace}}/", "", reqHeaders) + if method == http.MethodGet { + require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + RequireStatus(t, response, http.StatusBadRequest) + require.Contains(t, response.Body.String(), invalidDatabaseName) + } else { + // information leak: the options request knows about the database and knows it doesn't match + require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + RequireStatus(t, response, http.StatusNoContent) + } + } // successful request - mismatched headers - response = rt.SendUserRequestWithHeaders("GET", "/{{.keyspace}}/_all_docs", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusOK) - - response = rt.SendRequestWithHeaders("GET", "/{{.db}}/", "", reqHeaders) - RequireStatus(t, response, http.StatusUnauthorized) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - require.Contains(t, response.Body.String(), ErrLoginRequired.Message) - - response = rt.SendRequestWithHeaders("GET", "/notadb/", "", reqHeaders) - RequireStatus(t, response, http.StatusUnauthorized) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendUserRequestWithHeaders(method, "/{{.keyspace}}/_all_docs", "", reqHeaders, username, RestTesterDefaultUserPassword) + require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusOK) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } - response = rt.SendUserRequestWithHeaders("GET", "/{{.db}}/", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusOK) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusUnauthorized) + require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + } else { + RequireStatus(t, response, http.StatusNoContent) + // information leak: the options request knows about the database and knows it doesn't match + require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + } + } + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/notadb/", "", reqHeaders) + require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusUnauthorized) + require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendUserRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders, username, RestTesterDefaultUserPassword) + require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusOK) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } // successful request - matched headers reqHeaders = map[string]string{ "Origin": "http://example.org", } - response = rt.SendUserRequestWithHeaders("GET", "/{{.keyspace}}/_all_docs", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "http://example.org", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusOK) - - response = rt.SendRequestWithHeaders("GET", "/{{.db}}/", "", reqHeaders) - require.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusUnauthorized) - require.Contains(t, response.Body.String(), ErrLoginRequired.Message) - response = rt.SendRequestWithHeaders("GET", "/notadb/", "", reqHeaders) - require.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusUnauthorized) - require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendUserRequestWithHeaders(method, "/{{.keyspace}}/_all_docs", "", reqHeaders, username, RestTesterDefaultUserPassword) + require.Equal(t, "http://example.org", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusOK) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } - response = rt.SendUserRequestWithHeaders("GET", "/{{.db}}/", "", reqHeaders, username, RestTesterDefaultUserPassword) - require.Equal(t, "http://example.org", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusOK) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders) + if method == http.MethodGet { + require.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) + RequireStatus(t, response, http.StatusUnauthorized) + require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + } else { + // information leak: the options request knows about the database and knows it doesn't match + require.Equal(t, "http://example.org", response.Header().Get("Access-Control-Allow-Origin")) + RequireStatus(t, response, http.StatusNoContent) + } + } + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/notadb/", "", reqHeaders) + require.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusUnauthorized) + require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendUserRequestWithHeaders(method, "/{{.db}}/", "", reqHeaders, username, RestTesterDefaultUserPassword) + require.Equal(t, "http://example.org", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusOK) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } } func TestCORSNoMux(t *testing.T) { @@ -123,30 +201,43 @@ func TestCORSNoMux(t *testing.T) { "Origin": "http://example.com", } // this method doesn't exist - response := rt.SendRequestWithHeaders("GET", "/_notanendpoint/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusNotFound) - require.Contains(t, response.Body.String(), "unknown URL") + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/_notanendpoint/", "", reqHeaders) + require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + RequireStatus(t, response, http.StatusNotFound) + require.Contains(t, response.Body.String(), "unknown URL") + } // admin port shouldn't populate CORS - response = rt.SendAdminRequestWithHeaders("GET", "/_notanendpoint/", "", reqHeaders) - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusNotFound) - require.Contains(t, response.Body.String(), "unknown URL") - + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendAdminRequestWithHeaders(method, "/_notanendpoint/", "", reqHeaders) + require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + RequireStatus(t, response, http.StatusNotFound) + require.Contains(t, response.Body.String(), "unknown URL") + } // this method doesn't exist - response = rt.SendRequestWithHeaders(http.MethodDelete, "/notadb/", "", reqHeaders) - require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) - RequireStatus(t, response, http.StatusMethodNotAllowed) - require.Equal(t, strconv.Itoa(rt.ServerContext().Config.API.CORS.MaxAge), response.Header().Get("Access-Control-Max-Age")) - require.Equal(t, "GET, HEAD, POST, PUT", response.Header().Get("Access-Control-Allow-Methods")) - - response = rt.SendAdminRequestWithHeaders(http.MethodDelete, "/_stats/", "", reqHeaders) - RequireStatus(t, response, http.StatusMethodNotAllowed) - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) - require.Equal(t, "", response.Header().Get("Access-Control-Max-Age")) - require.Equal(t, "", response.Header().Get("Access-Control-Allow-Methods")) + for _, method := range []string{http.MethodDelete, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, "/notadb/", "", reqHeaders) + require.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin")) + if method == http.MethodDelete { + RequireStatus(t, response, http.StatusMethodNotAllowed) + } else { + RequireStatus(t, response, http.StatusNoContent) + } + require.Equal(t, strconv.Itoa(rt.ServerContext().Config.API.CORS.MaxAge), response.Header().Get("Access-Control-Max-Age")) + require.Equal(t, "GET, HEAD, POST, PUT", response.Header().Get("Access-Control-Allow-Methods")) + } + for _, method := range []string{http.MethodDelete, http.MethodOptions} { + response := rt.SendAdminRequestWithHeaders(method, "/_stats/", "", reqHeaders) + if method == http.MethodGet { + RequireStatus(t, response, http.StatusMethodNotAllowed) + } + + require.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "", response.Header().Get("Access-Control-Max-Age")) + require.Equal(t, "", response.Header().Get("Access-Control-Allow-Methods")) + } } func TestCORSUserNoAccess(t *testing.T) { @@ -165,16 +256,30 @@ func TestCORSUserNoAccess(t *testing.T) { response := rt.SendAdminRequest(http.MethodPut, "/"+rt.GetDatabase().Name+"/_user/"+alice, `{"name": "`+alice+`", "password": "`+RestTesterDefaultUserPassword+`"}`) + RequireStatus(t, response, http.StatusCreated) for _, endpoint := range []string{"/{{.db}}/", "/notadb/"} { t.Run(endpoint, func(t *testing.T) { reqHeaders := map[string]string{ "Origin": "http://couchbase.com", } - response = rt.SendRequestWithHeaders(http.MethodGet, endpoint, "", reqHeaders) - RequireStatus(t, response, http.StatusUnauthorized) - require.Contains(t, response.Body.String(), ErrLoginRequired.Message) - assert.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, endpoint, "", reqHeaders) + if method == http.MethodOptions && endpoint == "/{{.db}}/" { + // information leak: the options request knows about the database and knows it doesn't match + assert.Equal(t, "http://couchbase.com", response.Header().Get("Access-Control-Allow-Origin")) + } else { + assert.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) + } + + if method == http.MethodGet { + RequireStatus(t, response, http.StatusUnauthorized) + require.Contains(t, response.Body.String(), ErrLoginRequired.Message) + + } else { + RequireStatus(t, response, http.StatusNoContent) + } + } }) } } @@ -197,18 +302,20 @@ func TestCORSOriginPerDatabase(t *testing.T) { defer rt.Close() testCases := []struct { - name string - endpoint string - origin string - headerResponse string - responseCode int + name string + endpoint string + origin string + headerResponse string + headerResponseOptions string + responseCode int }{ { - name: "CORS origin allowed couchbase", - endpoint: "/{{.db}}/", - origin: "http://couchbase.com", - headerResponse: "http://couchbase.com", - responseCode: http.StatusOK, + name: "CORS origin allowed couchbase", + endpoint: "/{{.db}}/", + origin: "http://couchbase.com", + headerResponse: "http://couchbase.com", + headerResponseOptions: "http://couchbase.com", + responseCode: http.StatusOK, }, { name: "CORS origin allowed example.com", @@ -232,18 +339,20 @@ func TestCORSOriginPerDatabase(t *testing.T) { responseCode: http.StatusOK, }, { - name: "root url allow couchbase", - endpoint: "/", - origin: "http://couchbase.com", - headerResponse: "*", - responseCode: http.StatusOK, + name: "root url allow couchbase", + endpoint: "/", + origin: "http://couchbase.com", + headerResponse: "*", + headerResponseOptions: "*", + responseCode: http.StatusOK, }, { - name: "root url allow example.com", - endpoint: "/", - origin: "http://example.com", - headerResponse: "http://example.com", - responseCode: http.StatusOK, + name: "root url allow example.com", + endpoint: "/", + origin: "http://example.com", + headerResponse: "http://example.com", + headerResponseOptions: "http://example.com", + responseCode: http.StatusOK, }, } for _, test := range testCases { @@ -251,9 +360,15 @@ func TestCORSOriginPerDatabase(t *testing.T) { reqHeaders := map[string]string{ "Origin": test.origin, } - response := rt.SendRequestWithHeaders(http.MethodGet, test.endpoint, "", reqHeaders) - require.Equal(t, test.responseCode, response.Code) - require.Equal(t, test.headerResponse, response.Header().Get("Access-Control-Allow-Origin")) + for _, method := range []string{http.MethodGet, http.MethodOptions} { + response := rt.SendRequestWithHeaders(method, test.endpoint, "", reqHeaders) + if method == http.MethodGet { + require.Equal(t, test.responseCode, response.Code) + } else { + require.Equal(t, http.StatusNoContent, response.Code) + } + require.Equal(t, test.headerResponse, response.Header().Get("Access-Control-Allow-Origin")) + } }) } diff --git a/rest/routing.go b/rest/routing.go index 19739f16c3..79a511b219 100644 --- a/rest/routing.go +++ b/rest/routing.go @@ -368,18 +368,30 @@ func wrapRouter(sc *ServerContext, privs handlerPrivs, router *mux.Router) http. h.logRequestLine() // Inject CORS if enabled and requested and not admin port - cors := sc.Config.API.CORS - if privs != adminPrivs && cors != nil { - cors.AddResponseHeaders(rq, response) - } - // What methods would have matched? var options []string + var keyspace string for _, method := range []string{"GET", "HEAD", "POST", "PUT", "DELETE"} { - if wouldMatch(router, rq, method) { + found, matchedKeyspace := wouldMatch(router, rq, method) + if found { options = append(options, method) + if keyspace == "" && matchedKeyspace != "" { + keyspace = matchedKeyspace + } + } + } + + cors := sc.Config.API.CORS + dbName, _, _, _ := ParseKeyspace(keyspace) + if dbName != "" { + db, err := h.server.GetActiveDatabase(dbName) + if err == nil { + cors = db.CORS } } + if cors != nil && privs != adminPrivs && privs != metricsPrivs { + cors.AddResponseHeaders(rq, response) + } if len(options) == 0 { h.writeStatus(http.StatusNotFound, "unknown URL") } else { @@ -410,10 +422,25 @@ func FixQuotedSlashes(rq *http.Request) { } } -func wouldMatch(router *mux.Router, rq *http.Request, method string) bool { +func wouldMatch(router *mux.Router, rq *http.Request, method string) (found bool, keyspace string) { savedMethod := rq.Method rq.Method = method defer func() { rq.Method = savedMethod }() var matchInfo mux.RouteMatch - return router.Match(rq, &matchInfo) + found = router.Match(rq, &matchInfo) + // If a match is found, check for any db/keyspace path variable in the resolved match. Some paths may + // match routes with different path variables depending on the method. + if found { + matchVars := matchInfo.Vars + if dbName, ok := matchVars["db"]; ok { + keyspace = dbName + } else if keyspaceName, ok := matchVars["keyspace"]; ok { + keyspace = keyspaceName + } else if targetDbName, ok := matchVars["targetdb"]; ok { + keyspace = targetDbName + } else if newDbName, ok := matchVars["newdb"]; ok { + keyspace = newDbName + } + } + return found, keyspace }