Skip to content

Commit

Permalink
Handshake returns cleaned peer certificates (#5311)
Browse files Browse the repository at this point in the history
* Pass through clean certificates in Handshake

* Actual do the work

* Add test that cleaner is called

* Defer work

* cleanup

* Clean certs in deprecated method also

* Revert more

* Inline

* Review comments
  • Loading branch information
yschimke authored and squarejesse committed Sep 7, 2019
1 parent 6f17886 commit ba2c676
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 37 deletions.
26 changes: 14 additions & 12 deletions okhttp/src/main/java/okhttp3/CertificatePinner.kt
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ import javax.net.ssl.SSLPeerUnverifiedException
@Suppress("NAME_SHADOWING")
class CertificatePinner internal constructor(
private val pins: Set<Pin>,
private val certificateChainCleaner: CertificateChainCleaner?
internal val certificateChainCleaner: CertificateChainCleaner?
) {
/**
* Confirms that at least one of the certificates pinned for `hostname` is in `peerCertificates`.
Expand All @@ -138,29 +138,31 @@ class CertificatePinner internal constructor(
*/
@Throws(SSLPeerUnverifiedException::class)
fun check(hostname: String, peerCertificates: List<Certificate>) {
var peerCertificates = peerCertificates
return check(hostname) {
(certificateChainCleaner?.clean(peerCertificates, hostname) ?: peerCertificates)
.map { it as X509Certificate }
}
}

internal fun check(hostname: String, cleanedPeerCertificatesFn: () -> List<X509Certificate>) {
val pins = findMatchingPins(hostname)
if (pins.isEmpty()) return

if (certificateChainCleaner != null) {
peerCertificates = certificateChainCleaner.clean(peerCertificates, hostname)
}
val peerCertificates = cleanedPeerCertificatesFn()

for (peerCertificate in peerCertificates) {
val x509Certificate = peerCertificate as X509Certificate

// Lazily compute the hashes for each certificate.
var sha1: ByteString? = null
var sha256: ByteString? = null

for (pin in pins) {
when (pin.hashAlgorithm) {
"sha256/" -> {
if (sha256 == null) sha256 = x509Certificate.toSha256ByteString()
if (sha256 == null) sha256 = peerCertificate.toSha256ByteString()
if (pin.hash == sha256) return // Success!
}
"sha1/" -> {
if (sha1 == null) sha1 = x509Certificate.toSha1ByteString()
if (sha1 == null) sha1 = peerCertificate.toSha1ByteString()
if (pin.hash == sha1) return // Success!
}
else -> throw AssertionError("unsupported hashAlgorithm: ${pin.hashAlgorithm}")
Expand All @@ -172,8 +174,8 @@ class CertificatePinner internal constructor(
val message = buildString {
append("Certificate pinning failure!")
append("\n Peer certificate chain:")
for (c in 0 until peerCertificates.size) {
val x509Certificate = peerCertificates[c] as X509Certificate
for (element in peerCertificates) {
val x509Certificate = element as X509Certificate
append("\n ")
append(pin(x509Certificate))
append(": ")
Expand All @@ -195,7 +197,7 @@ class CertificatePinner internal constructor(
ReplaceWith("check(hostname, peerCertificates.toList())")
)
@Throws(SSLPeerUnverifiedException::class)
inline fun check(hostname: String, vararg peerCertificates: Certificate) {
fun check(hostname: String, vararg peerCertificates: Certificate) {
check(hostname, peerCertificates.toList())
}

Expand Down
43 changes: 22 additions & 21 deletions okhttp/src/main/java/okhttp3/Handshake.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/
package okhttp3

import okhttp3.internal.toImmutableList
import okhttp3.internal.immutableListOf
import okhttp3.internal.toImmutableList
import java.io.IOException
import java.security.Principal
import java.security.cert.Certificate
Expand All @@ -31,7 +31,7 @@ import javax.net.ssl.SSLSession
* This value object describes a completed handshake. Use [ConnectionSpec] to set policy for new
* handshakes.
*/
class Handshake private constructor(
class Handshake internal constructor(
/**
* Returns the TLS version used for this connection. This value wasn't tracked prior to OkHttp
* 3.0. For responses cached by preceding versions this returns [TlsVersion.SSL_3_0].
Expand All @@ -41,12 +41,16 @@ class Handshake private constructor(
/** Returns the cipher suite used for the connection. */
@get:JvmName("cipherSuite") val cipherSuite: CipherSuite,

/** Returns a possibly-empty list of certificates that identify the remote peer. */
@get:JvmName("peerCertificates") val peerCertificates: List<Certificate>,

/** Returns a possibly-empty list of certificates that identify this peer. */
@get:JvmName("localCertificates") val localCertificates: List<Certificate>
@get:JvmName(
"localCertificates") val localCertificates: List<Certificate>,

// Delayed provider of peerCertificates, to allow lazy cleaning.
peerCertificatesFn: () -> List<Certificate>
) {
/** Returns a possibly-empty list of certificates that identify the remote peer. */
@get:JvmName("peerCertificates") val peerCertificates: List<Certificate> by lazy(
peerCertificatesFn)

@JvmName("-deprecated_tlsVersion")
@Deprecated(
Expand Down Expand Up @@ -146,26 +150,22 @@ class Handshake private constructor(
if ("NONE" == tlsVersionString) throw IOException("tlsVersion == NONE")
val tlsVersion = TlsVersion.forJavaName(tlsVersionString)

val peerCertificates: Array<Certificate>? = try {
peerCertificates
val peerCertificatesCopy = try {
peerCertificates.toImmutableList()
} catch (_: SSLPeerUnverifiedException) {
null
listOf<Certificate>()
}

val peerCertificatesList = if (peerCertificates != null) {
immutableListOf(*peerCertificates)
} else {
emptyList()
}
return Handshake(tlsVersion, cipherSuite,
localCertificates.toImmutableList()) { peerCertificatesCopy }
}

val localCertificates = localCertificates
val localCertificatesList = if (localCertificates != null) {
immutableListOf(*localCertificates)
private fun Array<out Certificate>?.toImmutableList(): List<Certificate> {
return if (this != null) {
immutableListOf(*this)
} else {
emptyList()
}

return Handshake(tlsVersion, cipherSuite, peerCertificatesList, localCertificatesList)
}

@Throws(IOException::class)
Expand All @@ -183,8 +183,9 @@ class Handshake private constructor(
peerCertificates: List<Certificate>,
localCertificates: List<Certificate>
): Handshake {
return Handshake(tlsVersion, cipherSuite, peerCertificates.toImmutableList(),
localCertificates.toImmutableList())
val peerCertificatesCopy = peerCertificates.toImmutableList()
return Handshake(tlsVersion, cipherSuite, localCertificates.toImmutableList()
) { peerCertificatesCopy }
}
}
}
16 changes: 12 additions & 4 deletions okhttp/src/main/java/okhttp3/internal/connection/RealConnection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import okhttp3.Response
import okhttp3.Route
import okhttp3.internal.EMPTY_RESPONSE
import okhttp3.internal.closeQuietly
import okhttp3.internal.toHostHeader
import okhttp3.internal.http.ExchangeCodec
import okhttp3.internal.http1.Http1ExchangeCodec
import okhttp3.internal.http2.ConnectionShutdownException
Expand All @@ -44,6 +43,7 @@ import okhttp3.internal.http2.Http2Stream
import okhttp3.internal.http2.StreamResetException
import okhttp3.internal.platform.Platform
import okhttp3.internal.tls.OkHostnameVerifier
import okhttp3.internal.toHostHeader
import okhttp3.internal.userAgent
import okhttp3.internal.ws.RealWebSocket
import okio.BufferedSink
Expand Down Expand Up @@ -370,9 +370,18 @@ class RealConnection(
}
}

val certificatePinner = address.certificatePinner!!

handshake = Handshake(unverifiedHandshake.tlsVersion, unverifiedHandshake.cipherSuite,
unverifiedHandshake.localCertificates) {
certificatePinner.certificateChainCleaner!!.clean(unverifiedHandshake.peerCertificates,
address.url.host)
}

// Check that the certificate pinner is satisfied by the certificates presented.
address.certificatePinner!!.check(address.url.host,
unverifiedHandshake.peerCertificates)
certificatePinner.check(address.url.host) {
handshake!!.peerCertificates.map { it as X509Certificate }
}

// Success! Save the handshake and the ALPN protocol.
val maybeProtocol = if (connectionSpec.supportsTlsExtensions) {
Expand All @@ -383,7 +392,6 @@ class RealConnection(
socket = sslSocket
source = sslSocket.source().buffer()
sink = sslSocket.sink().buffer()
handshake = unverifiedHandshake
protocol = if (maybeProtocol != null) Protocol.get(maybeProtocol) else Protocol.HTTP_1_1
success = true
} finally {
Expand Down

0 comments on commit ba2c676

Please sign in to comment.