diff --git a/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceHttpUrlConnection.java b/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceHttpUrlConnection.java index 0365b5b67b..63ed0a20c2 100644 --- a/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceHttpUrlConnection.java +++ b/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceHttpUrlConnection.java @@ -73,8 +73,6 @@ interface EmbraceHttpUrlConnection { @Nullable InputStream getErrorStream(); - boolean shouldInterceptHeaderRetrieval(@Nullable String key); - @Nullable String getHeaderField(int n); diff --git a/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceHttpUrlConnectionOverride.java b/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceHttpUrlConnectionOverride.java index 8832ba75d8..19ce346f1f 100644 --- a/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceHttpUrlConnectionOverride.java +++ b/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceHttpUrlConnectionOverride.java @@ -28,7 +28,7 @@ public String getHeaderByName(@NonNull String name) { public String getOverriddenURL(@NonNull String pathOverride) { try { return new URL(connection.getURL().getProtocol(), connection.getURL().getHost(), - connection.getURL().getPort(), pathOverride).toString(); + connection.getURL().getPort(), pathOverride + "?" + connection.getURL().getQuery()).toString(); } catch (MalformedURLException e) { InternalStaticEmbraceLogger.logError("Failed to override path of " + connection.getURL() + " with " + pathOverride); diff --git a/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceUrlConnectionDelegate.java b/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceUrlConnectionDelegate.java index 9628c39a78..48b19c398f 100644 --- a/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceUrlConnectionDelegate.java +++ b/embrace-android-sdk/src/main/java/io/embrace/android/embracesdk/internal/network/http/EmbraceUrlConnectionDelegate.java @@ -63,12 +63,12 @@ class EmbraceUrlConnectionDelegate implements Embra /** * The content encoding HTTP header. */ - private static final String CONTENT_ENCODING = "Content-Encoding"; + static final String CONTENT_ENCODING = "Content-Encoding"; /** * The content length HTTP header. */ - private static final String CONTENT_LENGTH = "Content-Length"; + static final String CONTENT_LENGTH = "Content-Length"; /** * Reference to the wrapped connection. @@ -140,8 +140,7 @@ class EmbraceUrlConnectionDelegate implements Embra @Nullable private volatile String traceparent = null; - @Nullable - private volatile byte[] responseBody = null; + private final boolean isSDKStarted; /** * Wraps an existing {@link HttpURLConnection} with the Embrace network logic. @@ -160,6 +159,7 @@ public EmbraceUrlConnectionDelegate(@NonNull T connection, boolean enableWrapIoS this.embrace = embrace; this.createdTime = embrace.getInternalInterface().getSdkCurrentTime(); this.callId = UUID.randomUUID().toString(); + this.isSDKStarted = embrace.isStarted(); } @Override @@ -169,13 +169,15 @@ public void addRequestProperty(@NonNull String key, @Nullable String value) { @Override public void connect() throws IOException { - identifyTraceId(); - try { - if (embrace.getInternalInterface().isNetworkSpanForwardingEnabled()) { - traceparent = connection.getRequestProperty(TRACEPARENT_HEADER_NAME); + if (isSDKStarted) { + identifyTraceId(); + try { + if (embrace.getInternalInterface().isNetworkSpanForwardingEnabled()) { + traceparent = connection.getRequestProperty(TRACEPARENT_HEADER_NAME); + } + } catch (Exception e) { + // Ignore traceparent if there was a problem obtaining it } - } catch (Exception e) { - // Ignore traceparent if there was a problem obtaining it } this.connection.connect(); } @@ -183,7 +185,7 @@ public void connect() throws IOException { @Override public void disconnect() { // The network call must be logged before we close the transport - internalLogNetworkCall(this.createdTime); + internalLogNetworkCall(createdTime); this.connection.disconnect(); } @@ -287,8 +289,7 @@ public InputStream getErrorStream() { return getWrappedInputStream(this.connection.getErrorStream()); } - @Override - public boolean shouldInterceptHeaderRetrieval(@Nullable String key) { + private boolean shouldInterceptHeaderRetrieval(@Nullable String key) { return shouldUncompressGzip() && key != null && (key.equalsIgnoreCase(CONTENT_ENCODING) || key.equalsIgnoreCase(CONTENT_LENGTH)); } @@ -358,7 +359,7 @@ public long getHeaderFieldLong(@NonNull String name, long defaultValue) { @Nullable public Map> getHeaderFields() { final long startTime = embrace.getInternalInterface().getSdkCurrentTime(); - cacheResponseData(); + cacheNetworkCallData(); internalLogNetworkCall(startTime); return headerFields.get(); } @@ -377,7 +378,7 @@ private R retrieveHeaderField(@Nullable String name, } R result = action.invoke(); - cacheResponseData(); + cacheNetworkCallData(); internalLogNetworkCall(startTime); return result; } @@ -473,7 +474,7 @@ public String getRequestProperty(@NonNull String key) { public int getResponseCode() { identifyTraceId(); long startTime = embrace.getInternalInterface().getSdkCurrentTime(); - cacheResponseData(); + cacheNetworkCallData(); internalLogNetworkCall(startTime); return responseCode.get(); } @@ -484,7 +485,7 @@ public String getResponseMessage() throws IOException { identifyTraceId(); long startTime = embrace.getInternalInterface().getSdkCurrentTime(); String responseMsg = this.connection.getResponseMessage(); - cacheResponseData(); + cacheNetworkCallData(); internalLogNetworkCall(startTime); return responseMsg; } @@ -547,7 +548,9 @@ public boolean usingProxy() { * ignored. */ synchronized void internalLogNetworkCall(long startTime) { - internalLogNetworkCall(startTime, embrace.getInternalInterface().getSdkCurrentTime(), false, null); + if (isSDKStarted) { + internalLogNetworkCall(startTime, embrace.getInternalInterface().getSdkCurrentTime(), false, null); + } } /** @@ -658,8 +661,7 @@ private CountingInputStreamWithCallback countingInputStream(InputStream inputStr hasNetworkCaptureRules(), (bytesCount, responseBody) -> { if (this.startTime != null && this.endTime != null) { - this.responseBody = responseBody; - cacheResponseData(); + cacheNetworkCallData(responseBody); internalLogNetworkCall( this.startTime, this.endTime, @@ -695,7 +697,7 @@ private boolean shouldUncompressGzip() { } private void identifyTraceId() { - if (traceId == null) { + if (isSDKStarted && traceId == null) { try { traceId = getRequestProperty(embrace.getTraceIdHeader()); } catch (Exception e) { @@ -808,13 +810,13 @@ private InputStream getWrappedInputStream(InputStream connectionInputStream) { countingInputStream(new BufferedInputStream(connectionInputStream)) : connectionInputStream; } - cacheResponseData(); + cacheNetworkCallData(); internalLogNetworkCall(startTime); return in; } private boolean hasNetworkCaptureRules() { - if (this.connection.getURL() == null) { + if (!isSDKStarted || this.connection.getURL() == null) { return false; } String url = this.connection.getURL().toString(); @@ -823,11 +825,17 @@ private boolean hasNetworkCaptureRules() { return embrace.getInternalInterface().shouldCaptureNetworkBody(url, method); } + private void cacheNetworkCallData() { + if (isSDKStarted) { + cacheNetworkCallData(null); + } + } + /** * Cache values from response at the first point of availability so that we won't try to retrieve these values when the response * is not available. */ - private void cacheResponseData() { + private void cacheNetworkCallData(@Nullable byte[] responseBody) { if (headerFields.get() == null) { synchronized (headerFields) { if (headerFields.get() == null) { @@ -874,26 +882,42 @@ private void cacheResponseData() { } } - if (shouldCaptureNetworkData() && networkCaptureData.get() == null) { + if (shouldCaptureNetworkData()) { // If we don't have network capture rules, it's unnecessary to save these values synchronized (networkCaptureData) { - if (shouldCaptureNetworkData() && networkCaptureData.get() == null) { + if (shouldCaptureNetworkData()) { try { - Map requestHeaders = this.requestHeaders; - String requestQueryParams = connection.getURL().getQuery(); - byte[] requestBody = this.outputStream != null ? this.outputStream.getRequestBody() : null; - Map responseHeaders = getProcessedHeaders(headerFields.get()); - - networkCaptureData.set( - new NetworkCaptureData( - requestHeaders, - requestQueryParams, - requestBody, - responseHeaders, - responseBody, - null - ) - ); + NetworkCaptureData existingData = networkCaptureData.get(); + if (existingData == null) { + Map requestHeaders = this.requestHeaders; + String requestQueryParams = connection.getURL().getQuery(); + byte[] requestBody = this.outputStream != null ? this.outputStream.getRequestBody() : null; + Map responseHeaders = getProcessedHeaders(headerFields.get()); + + networkCaptureData.set( + new NetworkCaptureData( + requestHeaders, + requestQueryParams, + requestBody, + responseHeaders, + responseBody, + null + ) + ); + } else if (responseBody != null) { + // Update the response body field in the cached networkCaptureData object if a subsequent call + // is update to update the network logging with this data. + networkCaptureData.set( + new NetworkCaptureData( + existingData.getRequestHeaders(), + existingData.getRequestQueryParams(), + existingData.getCapturedRequestBody(), + existingData.getResponseHeaders(), + responseBody, + null + ) + ); + } } catch (Exception e) { lastConnectionAccessException = e; } @@ -903,6 +927,7 @@ private void cacheResponseData() { } private boolean shouldCaptureNetworkData() { - return hasNetworkCaptureRules() && (enableWrapIoStreams || inputStreamAccessException != null); + return (hasNetworkCaptureRules() && (enableWrapIoStreams || inputStreamAccessException != null)) && + (networkCaptureData.get() == null || networkCaptureData.get().getCapturedResponseBody() == null); } } diff --git a/embrace-android-sdk/src/test/java/io/embrace/android/embracesdk/internal/network/http/EmbraceUrlConnectionDelegateTest.kt b/embrace-android-sdk/src/test/java/io/embrace/android/embracesdk/internal/network/http/EmbraceUrlConnectionDelegateTest.kt index fc36485e76..dab3be514e 100644 --- a/embrace-android-sdk/src/test/java/io/embrace/android/embracesdk/internal/network/http/EmbraceUrlConnectionDelegateTest.kt +++ b/embrace-android-sdk/src/test/java/io/embrace/android/embracesdk/internal/network/http/EmbraceUrlConnectionDelegateTest.kt @@ -1,8 +1,12 @@ package io.embrace.android.embracesdk.internal.network.http import io.embrace.android.embracesdk.Embrace +import io.embrace.android.embracesdk.config.behavior.NetworkBehavior.Companion.CONFIG_TRACE_ID_HEADER_DEFAULT_VALUE import io.embrace.android.embracesdk.config.behavior.NetworkSpanForwardingBehavior.Companion.TRACEPARENT_HEADER_NAME import io.embrace.android.embracesdk.internal.EmbraceInternalInterface +import io.embrace.android.embracesdk.internal.network.http.EmbraceHttpPathOverride.PATH_OVERRIDE +import io.embrace.android.embracesdk.internal.network.http.EmbraceUrlConnectionDelegate.CONTENT_ENCODING +import io.embrace.android.embracesdk.internal.network.http.EmbraceUrlConnectionDelegate.CONTENT_LENGTH import io.embrace.android.embracesdk.network.EmbraceNetworkRequest import io.embrace.android.embracesdk.network.http.HttpMethod import io.mockk.CapturingSlot @@ -11,38 +15,44 @@ import io.mockk.mockk import io.mockk.slot import io.mockk.verify import org.junit.Assert.assertEquals -import org.junit.Assert.assertNotNull import org.junit.Assert.assertNull import org.junit.Assert.assertThrows import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Test import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream import java.io.IOException import java.io.InputStream +import java.net.URL import java.util.concurrent.TimeoutException +import java.util.zip.GZIPInputStream +import java.util.zip.GZIPOutputStream import javax.net.ssl.HttpsURLConnection internal class EmbraceUrlConnectionDelegateTest { private lateinit var mockEmbrace: Embrace private lateinit var mockInternalInterface: EmbraceInternalInterface - private lateinit var mockConnection: HttpsURLConnection private lateinit var capturedCallId: MutableList private lateinit var capturedEmbraceNetworkRequest: CapturingSlot - private lateinit var embraceUrlConnectionDelegate: EmbraceUrlConnectionDelegate - private lateinit var embraceUrlConnectionDelegateUnwrapped: EmbraceUrlConnectionDelegate private var fakeTimeMs = REQUEST_TIME - private var shouldCaptureNetworkBody = false + private var isSDKStarted = false + private var shouldCaptureNetworkBody = true private var isNetworkSpanForwardingEnabled = false + private var traceIdHeaderName = CONFIG_TRACE_ID_HEADER_DEFAULT_VALUE @Before fun setup() { mockEmbrace = mockk(relaxed = true) every { mockEmbrace.internalInterface } answers { mockInternalInterface } + every { mockEmbrace.isStarted } answers { isSDKStarted } + every { mockEmbrace.traceIdHeader } answers { traceIdHeaderName } fakeTimeMs = REQUEST_TIME - shouldCaptureNetworkBody = false + isSDKStarted = true + shouldCaptureNetworkBody = true isNetworkSpanForwardingEnabled = false + traceIdHeaderName = CONFIG_TRACE_ID_HEADER_DEFAULT_VALUE capturedCallId = mutableListOf() capturedEmbraceNetworkRequest = slot() mockInternalInterface = mockk(relaxed = true) @@ -52,110 +62,313 @@ internal class EmbraceUrlConnectionDelegateTest { } answers { } every { mockInternalInterface.isNetworkSpanForwardingEnabled() } answers { isNetworkSpanForwardingEnabled } every { mockInternalInterface.getSdkCurrentTime() } answers { fakeTimeMs } - mockConnection = createMockConnection() - embraceUrlConnectionDelegate = - EmbraceUrlConnectionDelegate(mockConnection, true, mockEmbrace) - embraceUrlConnectionDelegateUnwrapped = - EmbraceUrlConnectionDelegate(mockConnection, false, mockEmbrace) } @Test - fun `completed network call logged exactly once if connection connected with wrapped output stream`() { - executeRequest() - verifyTwoCallsRecordedWithSameCallId() - with(capturedEmbraceNetworkRequest.captured) { - assertEquals(HttpMethod.POST.name, httpMethod) - assertEquals(REQUEST_TIME, startTime) - assertEquals(REQUEST_TIME, endTime) - assertEquals(HTTP_OK, responseCode) - assertEquals(requestBodySize.toLong(), bytesSent) - assertEquals(responseBodySize.toLong(), bytesReceived) - assertNull(errorType) - } + fun `completed successful requests with compressed responses from a wrapped stream are recorded properly`() { + executeRequest( + connection = createMockGzipConnection(), + wrappedIoStream = true + ) + validateWholeRequest( + url = url.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = HTTP_OK, + responseBodySize = responseBodySize, + requestSize = requestBodySize, + networkDataCaptured = true, + responseBody = responseBodyText + ) } @Test - fun `completed network call logged twice once if connection connected with wrapped output stream and network body captured`() { - shouldCaptureNetworkBody = true - executeRequest() + fun `completed successful requests with uncompressed responses from a wrapped stream are recorded properly`() { + executeRequest( + connection = createMockUncompressedConnection(), + wrappedIoStream = true + ) + validateWholeRequest( + url = url.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = HTTP_OK, + responseBodySize = responseBodySize, + requestSize = requestBodySize, + networkDataCaptured = true, + responseBody = responseBodyText + ) + } + + @Test + fun `completed successful requests with compressed responses from an unwrapped output streams are recorded properly`() { + executeRequest( + connection = createMockGzipConnection(), + wrappedIoStream = false + ) + validateWholeRequest( + url = url.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = HTTP_OK, + responseBodySize = gzippedResponseBodySize, + requestSize = 0 + ) + } + + @Test + fun `completed successful requests with uncompressed responses from an unwrapped output streams are recorded properly`() { + executeRequest( + connection = createMockUncompressedConnection(), + wrappedIoStream = false + ) + validateWholeRequest( + url = url.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = HTTP_OK, + responseBodySize = responseBodySize, + requestSize = 0 + ) + } + + @Test + fun `incomplete network request with uncompressed responses from a wrapped output stream are recorded properly`() { + executeRequest( + connection = createMockUncompressedConnection(), + wrappedIoStream = true, + exceptionOnInputStream = true + ) + validateWholeRequest( + url = url.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = null, + responseBodySize = 0, + requestSize = 0, + errorType = IO_ERROR, + errorMessage = "nope" + ) + } + + @Test + fun `incomplete network request with compressed responses from a wrapped output stream are recorded properly`() { + executeRequest( + connection = createMockGzipConnection(), + wrappedIoStream = true, + exceptionOnInputStream = true + ) + validateWholeRequest( + url = url.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = null, + responseBodySize = 0, + requestSize = 0, + errorType = IO_ERROR, + errorMessage = "nope" + ) + } + + @Test + fun `incomplete network request with uncompressed responses from an unwrapped output stream are recorded properly`() { + executeRequest( + connection = createMockUncompressedConnection(), + wrappedIoStream = false, + exceptionOnInputStream = true + ) + validateWholeRequest( + url = url.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = null, + responseBodySize = 0, + requestSize = 0, + errorType = IO_ERROR, + errorMessage = "nope" + ) + } + + @Test + fun `incomplete network request with compressed responses from an unwrapped output stream are recorded properly`() { + executeRequest( + connection = createMockGzipConnection(), + wrappedIoStream = false, + exceptionOnInputStream = true + ) + validateWholeRequest( + url = url.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = null, + responseBodySize = 0, + requestSize = 0, + errorType = IO_ERROR, + errorMessage = "nope" + ) + } + + @Test + fun `completed unsuccessful requests are recorded properly`() { + executeRequest( + connection = createMockGzipConnection(expectedResponseCode = 500), + wrappedIoStream = true + ) + validateWholeRequest( + url = url.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = 500, + responseBodySize = responseBodySize, + requestSize = requestBodySize, + networkDataCaptured = true, + responseBody = responseBodyText + ) + } + + @Test + fun `completed requests with custom paths are recorded properly`() { + executeRequest( + connection = createMockConnectionWithPathOverride(), + wrappedIoStream = true + ) + validateWholeRequest( + url = customUrl.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = HTTP_OK, + responseBodySize = responseBodySize, + requestSize = requestBodySize, + networkDataCaptured = true, + responseBody = responseBodyText + ) + } + + @Test + fun `incomplete requests with custom paths are recorded properly`() { + executeRequest( + connection = createMockConnectionWithPathOverride(), + wrappedIoStream = true, + exceptionOnInputStream = true + ) + validateWholeRequest( + url = customUrl.toString(), + startTime = REQUEST_TIME, + endTime = REQUEST_TIME, + httpMethod = HttpMethod.POST.name, + httpStatus = null, + responseBodySize = 0, + requestSize = 0, + errorType = IO_ERROR, + errorMessage = "nope" + ) + } + + @Test + fun `completed requests are not recorded if the SDK has not started`() { + isSDKStarted = false + executeRequest( + connection = createMockGzipConnection(), + wrappedIoStream = true + ) + verify(exactly = 0) { mockInternalInterface.recordAndDeduplicateNetworkRequest(any(), any()) } + } + + @Test + fun `incomplete requests are not recorded if the SDK has not started`() { + isSDKStarted = false + executeRequest( + connection = createMockGzipConnection(), + wrappedIoStream = true, + exceptionOnInputStream = true + ) + verify(exactly = 0) { mockInternalInterface.recordAndDeduplicateNetworkRequest(any(), any()) } + } + + @Test + fun `completed network call logged twice with same callId with a wrapped output stream`() { + executeRequest( + connection = createMockUncompressedConnection(), + wrappedIoStream = true + ) verifyTwoCallsRecordedWithSameCallId() - with(capturedEmbraceNetworkRequest.captured) { - assertEquals(HttpMethod.POST.name, httpMethod) - assertEquals(HTTP_OK, responseCode) - assertEquals(requestBodySize.toLong(), bytesSent) - assertEquals(responseBodySize.toLong(), bytesReceived) - assertNotNull(networkCaptureData) - assertNull(errorType) - } } @Test - fun `completed network call logged exactly once with no request size if connection connected with unwrapped output stream`() { - executeRequest(delegate = embraceUrlConnectionDelegateUnwrapped) + fun `completed network call logged exactly once with unwrapped output stream`() { + executeRequest( + connection = createMockUncompressedConnection(), + wrappedIoStream = false + ) verify(exactly = 1) { mockInternalInterface.recordAndDeduplicateNetworkRequest(any(), any()) } assertTrue(capturedCallId[0].isNotBlank()) - with(capturedEmbraceNetworkRequest.captured) { - assertEquals(HttpMethod.POST.name, httpMethod) - assertEquals(HTTP_OK, responseCode) - assertEquals(0L, bytesSent) - assertEquals(responseBodySize.toLong(), bytesReceived) - assertNull(errorType) - } } @Test - fun `incomplete network call logged exactly once and response data not accessed if connection connected`() { - executeRequest(exceptionOnInputStream = true) + fun `incomplete network call logged exactly once wrapped output stream`() { + executeRequest( + connection = createMockUncompressedConnection(), + wrappedIoStream = true, + exceptionOnInputStream = true + ) verify(exactly = 1) { mockInternalInterface.recordAndDeduplicateNetworkRequest(any(), any()) } - assertTrue(capturedCallId[0].isNotBlank()) - verify(exactly = 0) { mockConnection.responseCode } - verify(exactly = 0) { mockConnection.contentLength } - verify(exactly = 0) { mockConnection.headerFields } - with(capturedEmbraceNetworkRequest.captured) { - assertEquals(HttpMethod.POST.name, httpMethod) - assertNull(responseCode) - assertEquals(null, bytesSent) - assertEquals(null, bytesReceived) - assertEquals(IO_ERROR, errorType) - } } @Test - fun `disconnect called with uninitialized connection results in error request capture and no response access`() { - embraceUrlConnectionDelegate.disconnect() - verifyIncompleteRequestLogged() + fun `disconnect called with previously not connected connection results in error request capture and no response access`() { + val mockConnection = createMockUncompressedConnection() + EmbraceUrlConnectionDelegate(mockConnection, true, mockEmbrace).disconnect() + verifyIncompleteRequestLogged(mockConnection) verify(exactly = 1) { mockInternalInterface.recordAndDeduplicateNetworkRequest(any(), any()) } assertEquals(1, capturedCallId.size) } @Test fun `incomplete network request logged when there's a failure in accessing the response content length`() { + val mockConnection = createMockUncompressedConnection() every { mockConnection.contentLength } answers { throw TimeoutException() } - executeRequest() - verifyIncompleteRequestLogged(errorType = TIMEOUT_ERROR, noResponseAccess = false) + + executeRequest(connection = mockConnection, wrappedIoStream = true) + verifyIncompleteRequestLogged(mockConnection = mockConnection, errorType = TIMEOUT_ERROR, noResponseAccess = false) verifyTwoCallsRecordedWithSameCallId() } @Test fun `incomplete network request logged when there's a failure in accessing the response code`() { + val mockConnection = createMockUncompressedConnection() every { mockConnection.responseCode } answers { throw TimeoutException() } - executeRequest() - verifyIncompleteRequestLogged(errorType = TIMEOUT_ERROR, noResponseAccess = false) + + executeRequest(connection = mockConnection, wrappedIoStream = true) + verifyIncompleteRequestLogged(mockConnection = mockConnection, errorType = TIMEOUT_ERROR, noResponseAccess = false) verifyTwoCallsRecordedWithSameCallId() } @Test fun `incomplete network request logged when there's a failure in accessing the response headers`() { + val mockConnection = createMockUncompressedConnection() every { mockConnection.headerFields } answers { throw TimeoutException() } - executeRequest() - verifyIncompleteRequestLogged(errorType = TIMEOUT_ERROR, noResponseAccess = false) + + executeRequest(connection = mockConnection, wrappedIoStream = true) + verifyIncompleteRequestLogged(mockConnection = mockConnection, errorType = TIMEOUT_ERROR, noResponseAccess = false) verifyTwoCallsRecordedWithSameCallId() } @Test fun `complete network request logged when network data capture is off even if reading request body throws exception`() { + val mockConnection = createMockUncompressedConnection() every { (mockConnection.outputStream as CountingOutputStream).requestBody } answers { throw NullPointerException() } - executeRequest() + + executeRequest(connection = mockConnection, wrappedIoStream = true) with(capturedEmbraceNetworkRequest.captured) { assertEquals(HTTP_OK, responseCode) assertNull(errorType) @@ -163,67 +376,158 @@ internal class EmbraceUrlConnectionDelegateTest { } @Test - fun `check traceheaders are not forwarded by default`() { - executeRequest() + fun `check traceparents are not forwarded by default`() { + executeRequest( + connection = createMockConnectionWithTraceparent(), + wrappedIoStream = true + ) assertNull(capturedEmbraceNetworkRequest.captured.w3cTraceparent) assertEquals(HTTP_OK, capturedEmbraceNetworkRequest.captured.responseCode) } @Test - fun `check traceheaders are not forwarded on errors by default`() { - executeRequest(exceptionOnInputStream = true) + fun `check traceparents are not forwarded on errors by default`() { + executeRequest( + connection = createMockConnectionWithTraceparent(), + wrappedIoStream = true, + exceptionOnInputStream = true + ) assertNull(capturedEmbraceNetworkRequest.captured.responseCode) assertEquals(IO_ERROR, capturedEmbraceNetworkRequest.captured.errorType) assertNull(capturedEmbraceNetworkRequest.captured.w3cTraceparent) } @Test - fun `check traceheaders are forwarded if feature flag is on`() { + fun `check traceparents are forwarded if feature flag is on`() { isNetworkSpanForwardingEnabled = true - executeRequest() + executeRequest( + connection = createMockConnectionWithTraceparent(), + wrappedIoStream = true + ) assertEquals(HTTP_OK, capturedEmbraceNetworkRequest.captured.responseCode) assertEquals(TRACEPARENT, capturedEmbraceNetworkRequest.captured.w3cTraceparent) } @Test - fun `check traceheaders are forwarded on errors if feature flag is on`() { + fun `check traceparents are forwarded on errors if feature flag is on`() { isNetworkSpanForwardingEnabled = true - executeRequest(exceptionOnInputStream = true) + executeRequest( + connection = createMockConnectionWithTraceparent(), + wrappedIoStream = true, + exceptionOnInputStream = true + ) assertNull(capturedEmbraceNetworkRequest.captured.responseCode) assertEquals(TRACEPARENT, capturedEmbraceNetworkRequest.captured.w3cTraceparent) assertEquals(IO_ERROR, capturedEmbraceNetworkRequest.captured.errorType) } - private fun createMockConnection(): HttpsURLConnection { - val connection: HttpsURLConnection = mockk(relaxed = true) - val mockOutputStream: CountingOutputStream = mockk(relaxed = true) - val inputStream: InputStream = ByteArrayInputStream(responseBody) - every { mockOutputStream.requestBody } answers { requestBody } - every { connection.outputStream } answers { mockOutputStream } - every { connection.getRequestProperty(TRACEPARENT_HEADER_NAME) } answers { TRACEPARENT } - every { connection.requestMethod } answers { HttpMethod.POST.name } - every { connection.responseCode } answers { HTTP_OK } - every { connection.contentLength } answers { responseBodySize } - every { connection.headerFields } answers { - mapOf( - Pair("Content-Encoding", listOf("gzip")), - Pair("Content-Length", listOf(responseBodySize.toString())), - Pair("myHeader", listOf("myValue")) + @Test + fun `check traceIds are logged if a custom header name is specified`() { + traceIdHeaderName = "my-trace-id-header" + executeRequest( + connection = createMockGzipConnection( + extraRequestHeaders = mapOf(Pair("my-trace-id-header", listOf(customTraceId))) + ), + wrappedIoStream = true + ) + assertEquals(HTTP_OK, capturedEmbraceNetworkRequest.captured.responseCode) + assertEquals(customTraceId, capturedEmbraceNetworkRequest.captured.traceId) + } + + private fun createMockConnectionWithPathOverride() = createMockGzipConnection( + extraRequestHeaders = mapOf(Pair(PATH_OVERRIDE, listOf(customPath))) + ) + + private fun createMockConnectionWithTraceparent() = createMockGzipConnection( + extraRequestHeaders = mapOf(Pair(TRACEPARENT_HEADER_NAME, listOf(TRACEPARENT))) + ) + + private fun createMockUncompressedConnection(): HttpsURLConnection { + return createMockConnection( + inputStream = ByteArrayInputStream(responseBodyBytes), + expectedResponseSize = responseBodySize, + expectedResponseCode = HTTP_OK + ) + } + + private fun createMockGzipConnection( + expectedResponseCode: Int = HTTP_OK, + extraRequestHeaders: Map> = emptyMap() + ): HttpsURLConnection { + return createMockConnection( + inputStream = GZIPInputStream(ByteArrayInputStream(gzippedResponseBodyBytes)), + extraRequestHeaders = extraRequestHeaders, + expectedResponseSize = gzippedResponseBodySize, + expectedResponseCode = expectedResponseCode, + extraResponseHeaders = mapOf( + Pair(CONTENT_ENCODING, listOf("gzip")) ) + ) + } + + private fun createMockConnection( + inputStream: InputStream, + extraRequestHeaders: Map> = emptyMap(), + expectedResponseSize: Int, + expectedResponseCode: Int, + extraResponseHeaders: Map> = emptyMap() + ): HttpsURLConnection { + val mockConnection: HttpsURLConnection = mockk(relaxed = true) + every { mockConnection.inputStream } answers { inputStream } + every { mockConnection.contentLength } answers { expectedResponseSize } + + val outputStream = ByteArrayOutputStream(requestBodySize) + outputStream.write(requestBodyBytes) + every { mockConnection.outputStream } answers { outputStream } + + val requestHeaders = mutableMapOf( + Pair(requestHeaderName, listOf(requestHeaderValue)), + Pair(CONFIG_TRACE_ID_HEADER_DEFAULT_VALUE, listOf(defaultTraceId)) + ) + + if (extraRequestHeaders.isNotEmpty()) { + requestHeaders += extraRequestHeaders } - every { connection.inputStream } answers { inputStream } - return connection + + val responseHeaders = mutableMapOf( + Pair(CONTENT_LENGTH, listOf(expectedResponseSize.toString())), + Pair(responseHeaderName, listOf(responseHeaderValue)) + ) + + if (extraResponseHeaders.isNotEmpty()) { + responseHeaders += extraResponseHeaders + } + + every { mockConnection.requestProperties } answers { requestHeaders } + every { mockConnection.headerFields } answers { responseHeaders } + every { mockConnection.url } answers { url } + every { mockConnection.getRequestProperty(TRACEPARENT_HEADER_NAME) } answers { + requestHeaders[TRACEPARENT_HEADER_NAME]?.get(0) + } + every { mockConnection.getRequestProperty(traceIdHeaderName) } answers { + requestHeaders[traceIdHeaderName]?.get(0) + } + every { mockConnection.getRequestProperty(PATH_OVERRIDE) } answers { + requestHeaders[PATH_OVERRIDE]?.get(0) + } + every { mockConnection.requestMethod } answers { HttpMethod.POST.name } + every { mockConnection.responseCode } answers { expectedResponseCode } + + return mockConnection } private fun executeRequest( - delegate: EmbraceUrlConnectionDelegate = embraceUrlConnectionDelegate, + connection: HttpsURLConnection, + wrappedIoStream: Boolean = false, exceptionOnInputStream: Boolean = false ) { + val delegate = EmbraceUrlConnectionDelegate(connection, wrappedIoStream, mockEmbrace) with(delegate) { connect() - outputStream?.write(requestBody) + setRequestProperty(requestHeaderName, requestHeaderValue) + outputStream?.write(requestBodyBytes) if (exceptionOnInputStream) { - every { mockConnection.inputStream } answers { throw IOException() } + every { connection.inputStream } answers { throw IOException("nope") } assertThrows(IOException::class.java) { inputStream } } else { val input = inputStream @@ -237,7 +541,63 @@ internal class EmbraceUrlConnectionDelegateTest { } } - private fun verifyIncompleteRequestLogged(errorType: String = "UnknownState", noResponseAccess: Boolean = true) { + @Suppress("LongParameterList") + private fun validateWholeRequest( + url: String, + httpMethod: String, + startTime: Long, + endTime: Long, + httpStatus: Int?, + requestSize: Int?, + responseBodySize: Int?, + errorType: String? = null, + errorMessage: String? = null, + traceId: String = defaultTraceId, + w3cTraceparent: String? = null, + networkDataCaptured: Boolean = false, + responseBody: String? = null + ) { + with(capturedEmbraceNetworkRequest) { + assertEquals(url, captured.url) + assertEquals(httpMethod, captured.httpMethod) + assertEquals(startTime, captured.startTime) + assertEquals(endTime, captured.endTime) + assertEquals(httpStatus, captured.responseCode) + assertEquals(requestSize?.toLong(), captured.bytesOut) + assertEquals(responseBodySize?.toLong(), captured.bytesIn) + assertEquals(errorType, captured.errorType) + assertEquals(errorMessage, captured.errorMessage) + assertEquals(traceId, captured.traceId) + assertEquals(w3cTraceparent, captured.w3cTraceparent) + if (networkDataCaptured) { + validateNetworkCaptureData(responseBody) + } else { + assertNull(captured.networkCaptureData) + } + } + } + + private fun validateNetworkCaptureData(responseBody: String?) { + with(checkNotNull(capturedEmbraceNetworkRequest.captured.networkCaptureData)) { + assertEquals(requestHeaderValue, checkNotNull(requestHeaders)[requestHeaderName]) + assertEquals(responseHeaderValue, checkNotNull(responseHeaders)[responseHeaderName]) + assertEquals(defaultQueryString, requestQueryParams) + assertEquals(requestBodyText, capturedRequestBody?.toString(Charsets.UTF_8)) + if (responseBody == null) { + assertNull(capturedRequestBody) + } else { + assertEquals(responseBody, capturedResponseBody?.toString(Charsets.UTF_8)) + } + + assertNull(dataCaptureErrorMessage) + } + } + + private fun verifyIncompleteRequestLogged( + mockConnection: HttpsURLConnection, + errorType: String = "UnknownState", + noResponseAccess: Boolean = true + ) { if (noResponseAccess) { verify(exactly = 0) { mockConnection.responseCode } verify(exactly = 0) { mockConnection.contentLength } @@ -254,13 +614,39 @@ internal class EmbraceUrlConnectionDelegateTest { } companion object { + private fun ByteArray.toGzipByteArray(): ByteArray { + return ByteArrayOutputStream().use { byteArrayStream -> + GZIPOutputStream(byteArrayStream).use { gzipStream -> + gzipStream.write(this) + gzipStream.finish() + } + byteArrayStream.toByteArray() + } + } + private const val TRACEPARENT = "00-3c72a77a7b51af6fb3778c06d4c165ce-4c1d710fffc88e35-01" private const val HTTP_OK = 200 private const val REQUEST_TIME = 1692201601000L - private val requestBody = "test".toByteArray() - private val requestBodySize = requestBody.size - private val responseBody = "responseresponse".toByteArray() - private val responseBodySize = responseBody.size + private const val requestBodyText = "test" + private const val requestHeaderName = "requestHeader" + private const val requestHeaderValue = "requestHeaderVal" + private const val defaultQueryString = "param=yesPlease" + private const val defaultPath = "/test/default-path" + private const val customPath = "/test/custom-path" + private const val defaultHost = "embrace.io" + private const val responseBodyText = "derpderpderpderp" + private const val responseHeaderName = "responseHeader" + private const val responseHeaderValue = "responseHeaderVal" + private const val defaultTraceId = "default-trace-id" + private const val customTraceId = "custom-trace-id" + private val url = URL("https", defaultHost, 1881, "$defaultPath?$defaultQueryString") + private val customUrl = URL("https", defaultHost, 1881, "$customPath?$defaultQueryString") + private val requestBodyBytes = requestBodyText.toByteArray() + private val requestBodySize = requestBodyBytes.size + private val responseBodyBytes = responseBodyText.toByteArray() + private val responseBodySize = responseBodyBytes.size + private val gzippedResponseBodyBytes = responseBodyBytes.toGzipByteArray() + private val gzippedResponseBodySize = gzippedResponseBodyBytes.size private val IO_ERROR = checkNotNull(IOException::class.java.canonicalName) private val TIMEOUT_ERROR = checkNotNull(TimeoutException::class.java.canonicalName) }