Skip to content

Commit

Permalink
Rewrite http client
Browse files Browse the repository at this point in the history
instanceFollowRedirects doesn't works for post requests, so rewriting the client to follow redirects correctly
  • Loading branch information
tung2744 committed Nov 3, 2023
1 parent de6ba44 commit 8d39602
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 128 deletions.
65 changes: 49 additions & 16 deletions sdk/src/main/java/com/oursky/authgear/data/HttpClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import org.json.JSONException
import org.json.JSONObject
import java.net.HttpURLConnection
import java.net.URL
import java.net.URLDecoder
import java.nio.charset.StandardCharsets

internal class HttpClient {
companion object {
Expand All @@ -16,28 +18,59 @@ internal class HttpClient {
url: URL,
method: String,
headers: Map<String, String>,
body: ByteArray? = null,
followRedirect: Boolean = true,
callback: (conn: HttpURLConnection) -> T
callback: (responseBody: ByteArray?) -> T
): T {
val conn = url.openConnection() as HttpURLConnection
try {
conn.requestMethod = method
conn.doInput = true

if (!followRedirect) {
var responseBody: ByteArray? = null
var currentUrl = url
// Follow redirects by max. 5 times
val maxRedirects = 5
for (i in 1..maxRedirects) {
val conn = currentUrl.openConnection() as HttpURLConnection
try {
conn.requestMethod = method
conn.doInput = true
// We handle redirects below
conn.instanceFollowRedirects = false
}
headers.forEach { (key, value) ->
conn.setRequestProperty(key, value)
}
if (method != "GET" && method != "HEAD" && body != null) {
conn.doOutput = true
conn.outputStream.use {
it.write(body)
}
}
// Follow redirects
// We need this because instanceFollowRedirects do not follow redirects on POST requests
if (conn.responseCode in 300..399 && followRedirect) {
val location = conn.getHeaderField("Location")
val locationUtf8 = URLDecoder.decode(location, "UTF-8")
val next = URL(currentUrl, locationUtf8)
currentUrl = next
if (i >= maxRedirects) {
throw AuthgearException("maximum count of redirect reached")
}
continue
}

if (method != "GET" && method != "HEAD") {
conn.doOutput = true
}
headers.forEach { (key, value) ->
conn.setRequestProperty(key, value)
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val bytes = it.readBytes()
responseBody = bytes
val responseString = String(bytes, StandardCharsets.UTF_8)
throwErrorIfNeeded(conn, responseString)
}
break
} finally {
conn.disconnect()
}
return callback(conn)
} finally {
conn.disconnect()
}
return callback(responseBody)
}

fun throwErrorIfNeeded(conn: HttpURLConnection, responseString: String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,9 @@ internal class AssetLinkRepoHttp : AssetLinkRepo {
url = URL(assetLinkUri.toString()),
method = "GET",
headers = hashMapOf()
) { conn ->
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
HttpClient.json.decodeFromString(responseString)
}
) { respBody ->
val responseString = String(respBody!!, StandardCharsets.UTF_8)
HttpClient.json.decodeFromString(responseString)
}
return result
}
Expand Down
134 changes: 32 additions & 102 deletions sdk/src/main/java/com/oursky/authgear/data/oauth/OAuthRepoHttp.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,12 @@ internal class OAuthRepoHttp : OAuthRepo {
val configAfterAcquire = this.config
if (configAfterAcquire != null) return configAfterAcquire
val url = URL(URL(endpoint), "/.well-known/openid-configuration")
val newConfig: OidcConfiguration = HttpClient.fetch(url = url, method = "GET", headers = emptyMap()) { conn ->
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
HttpClient.json.decodeFromString(responseString)
}
val newConfig: OidcConfiguration = HttpClient.fetch(
url = url,
method = "GET",
headers = emptyMap()) { respBody ->
val responseString = String(respBody!!, StandardCharsets.UTF_8)
HttpClient.json.decodeFromString(responseString)
}
this.config = newConfig
return newConfig
Expand Down Expand Up @@ -72,20 +68,11 @@ internal class OAuthRepoHttp : OAuthRepo {
return HttpClient.fetch(
url = URL(config.tokenEndpoint),
method = "POST",
body = body.toFormData().toByteArray(StandardCharsets.UTF_8),
headers = headers
) { conn ->
conn.outputStream.use {
it.write(body.toFormData().toByteArray(StandardCharsets.UTF_8))
}
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
HttpClient.json.decodeFromString(responseString)
}
) { respBody ->
val responseString = String(respBody!!, StandardCharsets.UTF_8)
HttpClient.json.decodeFromString(responseString)
}
}

Expand All @@ -101,20 +88,9 @@ internal class OAuthRepoHttp : OAuthRepo {
headers = mutableMapOf(
"authorization" to "Bearer $accessToken",
"content-type" to "application/x-www-form-urlencoded"
)
) { conn ->
conn.outputStream.use {
it.write(body.toFormData().toByteArray(StandardCharsets.UTF_8))
}
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
}
),
body = body.toFormData().toByteArray(StandardCharsets.UTF_8)
) { }
}

override fun oidcRevocationRequest(refreshToken: String) {
Expand All @@ -126,19 +102,9 @@ internal class OAuthRepoHttp : OAuthRepo {
method = "POST",
headers = mutableMapOf(
"content-type" to "application/x-www-form-urlencoded"
)
) { conn ->
conn.outputStream.use {
it.write(body.toFormData().toByteArray(StandardCharsets.UTF_8))
}
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
),
body = body.toFormData().toByteArray(StandardCharsets.UTF_8)
) {
}
}

Expand All @@ -150,16 +116,9 @@ internal class OAuthRepoHttp : OAuthRepo {
headers = mutableMapOf(
"authorization" to "Bearer $accessToken"
)
) { conn ->
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
HttpClient.json.decodeFromString(responseString)
}
) { respBody ->
val responseString = String(respBody!!, StandardCharsets.UTF_8)
HttpClient.json.decodeFromString(responseString)
}
}

Expand All @@ -171,20 +130,11 @@ internal class OAuthRepoHttp : OAuthRepo {
method = "POST",
headers = mutableMapOf(
"content-type" to "application/json"
)
) { conn ->
conn.outputStream.use {
it.write(HttpClient.json.encodeToString(body).toByteArray(StandardCharsets.UTF_8))
}
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
HttpClient.json.decodeFromString(responseString)
}
),
body = HttpClient.json.encodeToString(body).toByteArray(StandardCharsets.UTF_8)
) { respBody ->
val responseString = String(respBody!!, StandardCharsets.UTF_8)
HttpClient.json.decodeFromString(responseString)
}
return response.result
}
Expand All @@ -197,20 +147,11 @@ internal class OAuthRepoHttp : OAuthRepo {
method = "POST",
headers = mutableMapOf(
"content-type" to "application/json"
)
) { conn ->
conn.outputStream.use {
it.write(HttpClient.json.encodeToString(body).toByteArray(StandardCharsets.UTF_8))
}
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
HttpClient.json.decodeFromString(responseString)
}
),
body = HttpClient.json.encodeToString(body).toByteArray(StandardCharsets.UTF_8)
) { respBody ->
val responseString = String(respBody!!, StandardCharsets.UTF_8)
HttpClient.json.decodeFromString(responseString)
}
return response.result
}
Expand All @@ -225,19 +166,8 @@ internal class OAuthRepoHttp : OAuthRepo {
method = "POST",
headers = mutableMapOf(
"content-type" to "application/x-www-form-urlencoded"
)
) { conn ->
conn.outputStream.use {
it.write(body.toFormData().toByteArray(StandardCharsets.UTF_8))
}
conn.errorStream?.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
conn.inputStream.use {
val responseString = String(it.readBytes(), StandardCharsets.UTF_8)
HttpClient.throwErrorIfNeeded(conn, responseString)
}
}
),
body = body.toFormData().toByteArray(StandardCharsets.UTF_8)
) { }
}
}

0 comments on commit 8d39602

Please sign in to comment.