Skip to content

Commit

Permalink
fix: Ensure interceptors don't drain request body stream before netwo…
Browse files Browse the repository at this point in the history
…rk call
  • Loading branch information
Ndiritu committed Aug 22, 2024
1 parent e4dd84e commit 6c2d170
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 0 deletions.
1 change: 1 addition & 0 deletions components/http/okHttp/gradle/dependencies.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies {
// Use JUnit Jupiter Engine for testing.
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine'
testImplementation 'org.mockito:mockito-inline:5.2.0'
testImplementation 'com.squareup.okhttp3:logging-interceptor:4.12.0'


// This dependency is used internally, and not exposed to consumers on their own compile classpath.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import com.microsoft.kiota.ApiException;
import com.microsoft.kiota.HttpMethod;
import com.microsoft.kiota.NativeResponseHandler;
import com.microsoft.kiota.RequestInformation;
import com.microsoft.kiota.authentication.AuthenticationProvider;
import com.microsoft.kiota.serialization.Parsable;
Expand All @@ -19,13 +20,17 @@
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Dispatcher;
import okhttp3.Interceptor;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Protocol;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.logging.HttpLoggingInterceptor;
import okhttp3.logging.HttpLoggingInterceptor.Level;

import okio.Buffer;
import okio.Okio;

import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -401,6 +406,135 @@ void getRequestFromRequestInformationWithoutContentLengthOverrideWithEmptyPayloa
assertEquals(0, request.body().contentLength());
}

@Test
void buildsNativeRequestSupportingMultipleWrites() throws Exception {
final var authenticationProviderMock = mock(AuthenticationProvider.class);
final var requestInformation = new RequestInformation();
requestInformation.setUri(new URI("https://localhost"));
var requestBodyJson = "{\"name\":\"value\",\"array\":[\"1\",\"2\",\"3\"]}";
ByteArrayInputStream content =
new ByteArrayInputStream(requestBodyJson.getBytes(StandardCharsets.UTF_8));
requestInformation.setStreamContent(content, "application/json");
requestInformation.httpMethod = HttpMethod.PUT;

final var adapter = new OkHttpRequestAdapter(authenticationProviderMock);
final var request =
adapter.getRequestFromRequestInformation(
requestInformation, mock(Span.class), mock(Span.class));

final var requestBody = request.body();
assertNotNull(requestBody);
var buffer = new Buffer();
requestBody.writeTo(buffer);
assertEquals(requestBodyJson, buffer.readUtf8());

// Second write to the buffer to ensure the body is not consumed
buffer = new Buffer();
requestBody.writeTo(buffer);
assertEquals(requestBodyJson, buffer.readUtf8());
}

@Test
void buildsNativeRequestSupportingOneShotWrite() throws Exception {
final var authenticationProviderMock = mock(AuthenticationProvider.class);
final var testFile = new File("./src/test/resources/helloWorld.txt");
final var requestInformation = new RequestInformation();

requestInformation.setUri(new URI("https://localhost"));
requestInformation.httpMethod = HttpMethod.PUT;
final var contentLength = testFile.length();
requestInformation.headers.add("Content-Length", String.valueOf(contentLength));
try (FileInputStream content = new FileInputStream(testFile)) {
requestInformation.setStreamContent(content, "application/octet-stream");

final var adapter = new OkHttpRequestAdapter(authenticationProviderMock);
final var request =
adapter.getRequestFromRequestInformation(
requestInformation, mock(Span.class), mock(Span.class));

final var requestBody = request.body();
assertNotNull(requestBody);
var buffer = new Buffer();
requestBody.writeTo(buffer);
assertEquals(contentLength, buffer.size());

// Second write to the buffer to ensure the body is not consumed
buffer = new Buffer();
requestBody.writeTo(buffer);
assertEquals(0, buffer.size());
}
}

@Test
void loggingInterceptorDoesNotDrainRequestBodyForMarkableStreams() throws Exception {
var loggingInterceptor = new HttpLoggingInterceptor();
loggingInterceptor.setLevel(Level.BODY);

var okHttpClient =
KiotaClientFactory.create()
.addInterceptor(loggingInterceptor)
.addInterceptor(new MockResponseHandler())
.build();

final var authenticationProviderMock = mock(AuthenticationProvider.class);
authenticationProviderMock.authenticateRequest(
any(RequestInformation.class), any(Map.class));
var requestAdapter =
new OkHttpRequestAdapter(authenticationProviderMock, null, null, okHttpClient);

final var requestInformation = new RequestInformation();
requestInformation.setUri(new URI("https://localhost"));
var requestBodyJson = "{\"name\":\"value\",\"array\":[\"1\",\"2\",\"3\"]}";
ByteArrayInputStream content =
new ByteArrayInputStream(requestBodyJson.getBytes(StandardCharsets.UTF_8));
requestInformation.setStreamContent(content, "application/json");
requestInformation.httpMethod = HttpMethod.PUT;
var nativeResponseHandler = new NativeResponseHandler();
requestInformation.setResponseHandler(nativeResponseHandler);

var mockEntity = creatMockEntity();
requestAdapter.send(requestInformation, null, (node) -> mockEntity);
var nativeResponse = (Response) nativeResponseHandler.getValue();
assertNotNull(nativeResponse);
assertEquals(requestBodyJson, nativeResponse.body().source().readUtf8());
}

@Test
void loggingInterceptorDoesNotDrainRequestBodyForNonMarkableStreams() throws Exception {
var loggingInterceptor = new HttpLoggingInterceptor();
loggingInterceptor.setLevel(Level.BODY);

var okHttpClient =
KiotaClientFactory.create()
.addInterceptor(loggingInterceptor)
.addInterceptor(new MockResponseHandler())
.build();

final var authenticationProviderMock = mock(AuthenticationProvider.class);
authenticationProviderMock.authenticateRequest(
any(RequestInformation.class), any(Map.class));
var requestAdapter =
new OkHttpRequestAdapter(authenticationProviderMock, null, null, okHttpClient);

final var requestInformation = new RequestInformation();
requestInformation.setUri(new URI("https://localhost"));
requestInformation.httpMethod = HttpMethod.PUT;
var nativeResponseHandler = new NativeResponseHandler();
requestInformation.setResponseHandler(nativeResponseHandler);

final var testFile = new File("./src/test/resources/helloWorld.txt");
final var contentLength = testFile.length();

try (FileInputStream content = new FileInputStream(testFile)) {
requestInformation.setStreamContent(content, "application/octet-stream");
var mockEntity = creatMockEntity();
requestAdapter.send(requestInformation, null, (node) -> mockEntity);
var nativeResponse = (Response) nativeResponseHandler.getValue();
assertNotNull(nativeResponse);
assertEquals(contentLength, nativeResponse.body().source().readByteArray().length);
}
}

public static OkHttpClient getMockClient(final Response response) throws IOException {
final OkHttpClient mockClient = mock(OkHttpClient.class);
final Call remoteCall = mock(Call.class);
Expand Down Expand Up @@ -440,4 +574,33 @@ public ParseNodeFactory creatMockParseNodeFactory(
when(mockFactory.getValidContentType()).thenReturn(validContentType);
return mockFactory;
}

// Returns request body as response body
static class MockResponseHandler implements Interceptor {
@Override
public Response intercept(Chain chain) throws IOException {
final var request = chain.request();
final var requestBody = request.body();
if (request != null && requestBody != null) {
final var buffer = new Buffer();
requestBody.writeTo(buffer);
return new Response.Builder()
.code(200)
.message("OK")
.protocol(Protocol.HTTP_1_1)
.request(request)
.body(
ResponseBody.create(
buffer.readByteArray(),
MediaType.parse("application/json")))
.build();
}
return new Response.Builder()
.code(200)
.message("OK")
.protocol(Protocol.HTTP_1_1)
.request(request)
.build();
}
}
}

0 comments on commit 6c2d170

Please sign in to comment.