From 8d3960213ce7a607361ada7b27a05f7ed0f53e65 Mon Sep 17 00:00:00 2001 From: Tung Wu Date: Fri, 3 Nov 2023 12:54:27 +0800 Subject: [PATCH] Rewrite http client instanceFollowRedirects doesn't works for post requests, so rewriting the client to follow redirects correctly --- .../com/oursky/authgear/data/HttpClient.kt | 65 ++++++--- .../data/assetlink/AssetLinkRepoHttp.kt | 13 +- .../authgear/data/oauth/OAuthRepoHttp.kt | 134 +++++------------- 3 files changed, 84 insertions(+), 128 deletions(-) diff --git a/sdk/src/main/java/com/oursky/authgear/data/HttpClient.kt b/sdk/src/main/java/com/oursky/authgear/data/HttpClient.kt index e00958ab..2696d424 100644 --- a/sdk/src/main/java/com/oursky/authgear/data/HttpClient.kt +++ b/sdk/src/main/java/com/oursky/authgear/data/HttpClient.kt @@ -8,6 +8,8 @@ import org.json.JSONException import org.json.JSONObject import java.net.HttpURLConnection import java.net.URL +import java.net.URLDecoder +import java.nio.charset.StandardCharsets internal class HttpClient { companion object { @@ -16,28 +18,59 @@ internal class HttpClient { url: URL, method: String, headers: Map, + body: ByteArray? = null, followRedirect: Boolean = true, - callback: (conn: HttpURLConnection) -> T + callback: (responseBody: ByteArray?) -> T ): T { - val conn = url.openConnection() as HttpURLConnection - try { - conn.requestMethod = method - conn.doInput = true - - if (!followRedirect) { + var responseBody: ByteArray? = null + var currentUrl = url + // Follow redirects by max. 5 times + val maxRedirects = 5 + for (i in 1..maxRedirects) { + val conn = currentUrl.openConnection() as HttpURLConnection + try { + conn.requestMethod = method + conn.doInput = true + // We handle redirects below conn.instanceFollowRedirects = false - } + headers.forEach { (key, value) -> + conn.setRequestProperty(key, value) + } + if (method != "GET" && method != "HEAD" && body != null) { + conn.doOutput = true + conn.outputStream.use { + it.write(body) + } + } + // Follow redirects + // We need this because instanceFollowRedirects do not follow redirects on POST requests + if (conn.responseCode in 300..399 && followRedirect) { + val location = conn.getHeaderField("Location") + val locationUtf8 = URLDecoder.decode(location, "UTF-8") + val next = URL(currentUrl, locationUtf8) + currentUrl = next + if (i >= maxRedirects) { + throw AuthgearException("maximum count of redirect reached") + } + continue + } - if (method != "GET" && method != "HEAD") { - conn.doOutput = true - } - headers.forEach { (key, value) -> - conn.setRequestProperty(key, value) + conn.errorStream?.use { + val responseString = String(it.readBytes(), StandardCharsets.UTF_8) + throwErrorIfNeeded(conn, responseString) + } + conn.inputStream.use { + val bytes = it.readBytes() + responseBody = bytes + val responseString = String(bytes, StandardCharsets.UTF_8) + throwErrorIfNeeded(conn, responseString) + } + break + } finally { + conn.disconnect() } - return callback(conn) - } finally { - conn.disconnect() } + return callback(responseBody) } fun throwErrorIfNeeded(conn: HttpURLConnection, responseString: String) { diff --git a/sdk/src/main/java/com/oursky/authgear/data/assetlink/AssetLinkRepoHttp.kt b/sdk/src/main/java/com/oursky/authgear/data/assetlink/AssetLinkRepoHttp.kt index 54d3c077..5ea811b9 100644 --- a/sdk/src/main/java/com/oursky/authgear/data/assetlink/AssetLinkRepoHttp.kt +++ b/sdk/src/main/java/com/oursky/authgear/data/assetlink/AssetLinkRepoHttp.kt @@ -21,16 +21,9 @@ internal class AssetLinkRepoHttp : AssetLinkRepo { url = URL(assetLinkUri.toString()), method = "GET", headers = hashMapOf() - ) { conn -> - conn.errorStream?.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - conn.inputStream.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - HttpClient.json.decodeFromString(responseString) - } + ) { respBody -> + val responseString = String(respBody!!, StandardCharsets.UTF_8) + HttpClient.json.decodeFromString(responseString) } return result } diff --git a/sdk/src/main/java/com/oursky/authgear/data/oauth/OAuthRepoHttp.kt b/sdk/src/main/java/com/oursky/authgear/data/oauth/OAuthRepoHttp.kt index 85131800..e0a5c8ba 100644 --- a/sdk/src/main/java/com/oursky/authgear/data/oauth/OAuthRepoHttp.kt +++ b/sdk/src/main/java/com/oursky/authgear/data/oauth/OAuthRepoHttp.kt @@ -33,16 +33,12 @@ internal class OAuthRepoHttp : OAuthRepo { val configAfterAcquire = this.config if (configAfterAcquire != null) return configAfterAcquire val url = URL(URL(endpoint), "/.well-known/openid-configuration") - val newConfig: OidcConfiguration = HttpClient.fetch(url = url, method = "GET", headers = emptyMap()) { conn -> - conn.errorStream?.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - conn.inputStream.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - HttpClient.json.decodeFromString(responseString) - } + val newConfig: OidcConfiguration = HttpClient.fetch( + url = url, + method = "GET", + headers = emptyMap()) { respBody -> + val responseString = String(respBody!!, StandardCharsets.UTF_8) + HttpClient.json.decodeFromString(responseString) } this.config = newConfig return newConfig @@ -72,20 +68,11 @@ internal class OAuthRepoHttp : OAuthRepo { return HttpClient.fetch( url = URL(config.tokenEndpoint), method = "POST", + body = body.toFormData().toByteArray(StandardCharsets.UTF_8), headers = headers - ) { conn -> - conn.outputStream.use { - it.write(body.toFormData().toByteArray(StandardCharsets.UTF_8)) - } - conn.errorStream?.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - conn.inputStream.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - HttpClient.json.decodeFromString(responseString) - } + ) { respBody -> + val responseString = String(respBody!!, StandardCharsets.UTF_8) + HttpClient.json.decodeFromString(responseString) } } @@ -101,20 +88,9 @@ internal class OAuthRepoHttp : OAuthRepo { headers = mutableMapOf( "authorization" to "Bearer $accessToken", "content-type" to "application/x-www-form-urlencoded" - ) - ) { conn -> - conn.outputStream.use { - it.write(body.toFormData().toByteArray(StandardCharsets.UTF_8)) - } - conn.errorStream?.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - conn.inputStream.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - } + ), + body = body.toFormData().toByteArray(StandardCharsets.UTF_8) + ) { } } override fun oidcRevocationRequest(refreshToken: String) { @@ -126,19 +102,9 @@ internal class OAuthRepoHttp : OAuthRepo { method = "POST", headers = mutableMapOf( "content-type" to "application/x-www-form-urlencoded" - ) - ) { conn -> - conn.outputStream.use { - it.write(body.toFormData().toByteArray(StandardCharsets.UTF_8)) - } - conn.errorStream?.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - conn.inputStream.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } + ), + body = body.toFormData().toByteArray(StandardCharsets.UTF_8) + ) { } } @@ -150,16 +116,9 @@ internal class OAuthRepoHttp : OAuthRepo { headers = mutableMapOf( "authorization" to "Bearer $accessToken" ) - ) { conn -> - conn.errorStream?.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - conn.inputStream.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - HttpClient.json.decodeFromString(responseString) - } + ) { respBody -> + val responseString = String(respBody!!, StandardCharsets.UTF_8) + HttpClient.json.decodeFromString(responseString) } } @@ -171,20 +130,11 @@ internal class OAuthRepoHttp : OAuthRepo { method = "POST", headers = mutableMapOf( "content-type" to "application/json" - ) - ) { conn -> - conn.outputStream.use { - it.write(HttpClient.json.encodeToString(body).toByteArray(StandardCharsets.UTF_8)) - } - conn.errorStream?.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - conn.inputStream.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - HttpClient.json.decodeFromString(responseString) - } + ), + body = HttpClient.json.encodeToString(body).toByteArray(StandardCharsets.UTF_8) + ) { respBody -> + val responseString = String(respBody!!, StandardCharsets.UTF_8) + HttpClient.json.decodeFromString(responseString) } return response.result } @@ -197,20 +147,11 @@ internal class OAuthRepoHttp : OAuthRepo { method = "POST", headers = mutableMapOf( "content-type" to "application/json" - ) - ) { conn -> - conn.outputStream.use { - it.write(HttpClient.json.encodeToString(body).toByteArray(StandardCharsets.UTF_8)) - } - conn.errorStream?.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - conn.inputStream.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - HttpClient.json.decodeFromString(responseString) - } + ), + body = HttpClient.json.encodeToString(body).toByteArray(StandardCharsets.UTF_8) + ) { respBody -> + val responseString = String(respBody!!, StandardCharsets.UTF_8) + HttpClient.json.decodeFromString(responseString) } return response.result } @@ -225,19 +166,8 @@ internal class OAuthRepoHttp : OAuthRepo { method = "POST", headers = mutableMapOf( "content-type" to "application/x-www-form-urlencoded" - ) - ) { conn -> - conn.outputStream.use { - it.write(body.toFormData().toByteArray(StandardCharsets.UTF_8)) - } - conn.errorStream?.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - conn.inputStream.use { - val responseString = String(it.readBytes(), StandardCharsets.UTF_8) - HttpClient.throwErrorIfNeeded(conn, responseString) - } - } + ), + body = body.toFormData().toByteArray(StandardCharsets.UTF_8) + ) { } } }