diff --git a/kong/dao/cassandra/apis.lua b/kong/dao/cassandra/apis.lua index 527ecb3cf3a..02fed92c073 100644 --- a/kong/dao/cassandra/apis.lua +++ b/kong/dao/cassandra/apis.lua @@ -1,13 +1,32 @@ local BaseDao = require "kong.dao.cassandra.base_dao" local constants = require "kong.constants" local PluginsConfigurations = require "kong.dao.cassandra.plugins_configurations" +local url = require "socket.url" + +local function validate_target_url(value) + local parsed_url = url.parse(value) + if parsed_url.scheme and parsed_url.host then + parsed_url.scheme = parsed_url.scheme:lower() + if parsed_url.scheme == "http" or parsed_url.scheme == "https" then + parsed_url.path = parsed_url.path or "/" + + print(url.build(parsed_url)) + + return true, nil, { target_url = url.build(parsed_url)} + else + return false, "Supported protocols are HTTP and HTTPS" + end + end + + return false, "Invalid target URL" +end local SCHEMA = { id = { type = constants.DATABASE_TYPES.ID }, name = { type = "string", unique = true, queryable = true, default = function(api_t) return api_t.public_dns end }, public_dns = { type = "string", required = true, unique = true, queryable = true, regex = "([a-zA-Z0-9-]+(\\.[a-zA-Z0-9-]+)*)" }, - target_url = { type = "string", required = true }, + target_url = { type = "string", required = true, func = validate_target_url }, created_at = { type = constants.DATABASE_TYPES.TIMESTAMP } } diff --git a/kong/dao/schemas.lua b/kong/dao/schemas.lua index 3a36db9ceb9..a24b2f28980 100644 --- a/kong/dao/schemas.lua +++ b/kong/dao/schemas.lua @@ -126,13 +126,15 @@ function _M.validate(t, schema, is_update) end end + local require_satisfied = true -- Check required fields are set if v.required and (t[column] == nil or t[column] == "") then errors = utils.add_error(errors, column, column.." is required") + require_satisfied = false end - -- Check field against a custom function - if v.func and type(v.func) == "function" then + -- Check field against a custom function only if the value requirement has been satisfied + if require_satisfied and v.func and type(v.func) == "function" then local ok, err, new_fields = v.func(t[column], t) if not ok or err then errors = utils.add_error(errors, column, err) diff --git a/kong/tools/utils.lua b/kong/tools/utils.lua index 05e2d3a1dfe..1cfc5b8c73b 100644 --- a/kong/tools/utils.lua +++ b/kong/tools/utils.lua @@ -67,7 +67,7 @@ end function _M.add_error(errors, k, v) if not errors then errors = {} end - if errors and errors[k] then + if errors and errors[k] and v then local list = {} table.insert(list, errors[k]) table.insert(list, v) diff --git a/spec/integration/admin_api/admin_api_spec.lua b/spec/integration/admin_api/admin_api_spec.lua index 6be06c86fe9..3817adb2d2b 100644 --- a/spec/integration/admin_api/admin_api_spec.lua +++ b/spec/integration/admin_api/admin_api_spec.lua @@ -1,6 +1,7 @@ local json = require "cjson" local http_client = require "kong.tools.http_client" local spec_helper = require "spec.spec_helpers" +local cjson = require "cjson" local CREATED_IDS = {} local ENDPOINTS = { @@ -89,6 +90,58 @@ describe("Admin API", function() end) + describe("APIs Schema", function() + + it("should return error with wrong target_url", function() + local response, status = http_client.post(spec_helper.API_URL.."/apis", { + public_dns = "hello.com", + target_url = "asdasd" + }) + + assert.are.equal(400, status) + assert.are.equal("Invalid target URL", cjson.decode(response).target_url) + end) + + it("should return error with wrong target_url protocol", function() + local response, status = http_client.post(spec_helper.API_URL.."/apis", { + public_dns = "hello.com", + target_url = "wot://hello.com/" + }) + + assert.are.equal(400, status) + assert.are.equal("Supported protocols are HTTP and HTTPS", cjson.decode(response).target_url) + end) + + it("should work without a path", function() + local response, status = http_client.post(spec_helper.API_URL.."/apis", { + public_dns = "hello.com", + target_url = "http://hello.com" + }) + + local body = cjson.decode(response) + assert.are.equal(201, status) + assert.are.equal("http://hello.com/", body.target_url) + + -- Clean up + http_client.delete(spec_helper.API_URL.."/apis/"..body.id) + end) + + it("should work without upper case protocol", function() + local response, status = http_client.post(spec_helper.API_URL.."/apis", { + public_dns = "hello2.com", + target_url = "HTTP://hello.com/world" + }) + + local body = cjson.decode(response) + assert.are.equal(201, status) + assert.are.equal("http://hello.com/world", cjson.decode(response).target_url) + + -- Clean up + http_client.delete(spec_helper.API_URL.."/apis/"..body.id) + end) + + end) + describe("POST", function() describe("application/x-www-form-urlencoded", function() test_for_each_endpoint(function(endpoint, base_url) @@ -419,4 +472,5 @@ describe("Admin API", function() end) end) end) + end)