Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure reading request body does not drain the input stream if stream is not resettable #1537

Merged
merged 5 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions components/http/okHttp/gradle/dependencies.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies {
// Use JUnit Jupiter Engine for testing.
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine'
testImplementation 'org.mockito:mockito-inline:5.2.0'
testImplementation 'com.squareup.okhttp3:logging-interceptor:4.12.0'


// This dependency is used internally, and not exposed to consumers on their own compile classpath.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,11 @@ public MediaType contentType() {
}
}

@Override
public boolean isOneShot() {
return !requestInfo.content.markSupported();
}

@Override
public long contentLength() throws IOException {
final Set<String> contentLength =
Expand All @@ -923,6 +928,9 @@ public long contentLength() throws IOException {
@Override
public void writeTo(@Nonnull BufferedSink sink) throws IOException {
sink.writeAll(Okio.source(requestInfo.content));
if (!isOneShot()) {
requestInfo.content.reset();
}
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import com.microsoft.kiota.ApiException;
import com.microsoft.kiota.HttpMethod;
import com.microsoft.kiota.NativeResponseHandler;
import com.microsoft.kiota.RequestInformation;
import com.microsoft.kiota.authentication.AuthenticationProvider;
import com.microsoft.kiota.serialization.Parsable;
Expand All @@ -19,13 +20,17 @@
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Dispatcher;
import okhttp3.Interceptor;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Protocol;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.logging.HttpLoggingInterceptor;
import okhttp3.logging.HttpLoggingInterceptor.Level;

import okio.Buffer;
import okio.Okio;

import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -401,6 +406,135 @@ void getRequestFromRequestInformationWithoutContentLengthOverrideWithEmptyPayloa
assertEquals(0, request.body().contentLength());
}

@Test
void buildsNativeRequestSupportingMultipleWrites() throws Exception {
final var authenticationProviderMock = mock(AuthenticationProvider.class);
final var requestInformation = new RequestInformation();
requestInformation.setUri(new URI("https://localhost"));
var requestBodyJson = "{\"name\":\"value\",\"array\":[\"1\",\"2\",\"3\"]}";
ByteArrayInputStream content =
new ByteArrayInputStream(requestBodyJson.getBytes(StandardCharsets.UTF_8));
requestInformation.setStreamContent(content, "application/json");
requestInformation.httpMethod = HttpMethod.PUT;

final var adapter = new OkHttpRequestAdapter(authenticationProviderMock);
final var request =
adapter.getRequestFromRequestInformation(
requestInformation, mock(Span.class), mock(Span.class));

final var requestBody = request.body();
assertNotNull(requestBody);
var buffer = new Buffer();
requestBody.writeTo(buffer);
assertEquals(requestBodyJson, buffer.readUtf8());

// Second write to the buffer to ensure the body is not consumed
buffer = new Buffer();
requestBody.writeTo(buffer);
assertEquals(requestBodyJson, buffer.readUtf8());
}

@Test
void buildsNativeRequestSupportingOneShotWrite() throws Exception {
final var authenticationProviderMock = mock(AuthenticationProvider.class);
final var testFile = new File("./src/test/resources/helloWorld.txt");
final var requestInformation = new RequestInformation();

requestInformation.setUri(new URI("https://localhost"));
requestInformation.httpMethod = HttpMethod.PUT;
final var contentLength = testFile.length();
requestInformation.headers.add("Content-Length", String.valueOf(contentLength));
try (FileInputStream content = new FileInputStream(testFile)) {
requestInformation.setStreamContent(content, "application/octet-stream");

final var adapter = new OkHttpRequestAdapter(authenticationProviderMock);
final var request =
adapter.getRequestFromRequestInformation(
requestInformation, mock(Span.class), mock(Span.class));

final var requestBody = request.body();
assertNotNull(requestBody);
var buffer = new Buffer();
requestBody.writeTo(buffer);
assertEquals(contentLength, buffer.size());

// Second write to the buffer to ensure the body is not consumed
buffer = new Buffer();
requestBody.writeTo(buffer);
assertEquals(0, buffer.size());
}
}

@Test
void loggingInterceptorDoesNotDrainRequestBodyForMarkableStreams() throws Exception {
var loggingInterceptor = new HttpLoggingInterceptor();
loggingInterceptor.setLevel(Level.BODY);

var okHttpClient =
KiotaClientFactory.create()
.addInterceptor(loggingInterceptor)
.addInterceptor(new MockResponseHandler())
.build();

final var authenticationProviderMock = mock(AuthenticationProvider.class);
authenticationProviderMock.authenticateRequest(
any(RequestInformation.class), any(Map.class));
var requestAdapter =
new OkHttpRequestAdapter(authenticationProviderMock, null, null, okHttpClient);

final var requestInformation = new RequestInformation();
requestInformation.setUri(new URI("https://localhost"));
var requestBodyJson = "{\"name\":\"value\",\"array\":[\"1\",\"2\",\"3\"]}";
ByteArrayInputStream content =
new ByteArrayInputStream(requestBodyJson.getBytes(StandardCharsets.UTF_8));
requestInformation.setStreamContent(content, "application/json");
requestInformation.httpMethod = HttpMethod.PUT;
var nativeResponseHandler = new NativeResponseHandler();
requestInformation.setResponseHandler(nativeResponseHandler);

var mockEntity = creatMockEntity();
requestAdapter.send(requestInformation, null, node -> mockEntity);
var nativeResponse = (Response) nativeResponseHandler.getValue();
assertNotNull(nativeResponse);
assertEquals(requestBodyJson, nativeResponse.body().source().readUtf8());
}

@Test
void loggingInterceptorDoesNotDrainRequestBodyForNonMarkableStreams() throws Exception {
var loggingInterceptor = new HttpLoggingInterceptor();
loggingInterceptor.setLevel(Level.BODY);

var okHttpClient =
KiotaClientFactory.create()
.addInterceptor(loggingInterceptor)
.addInterceptor(new MockResponseHandler())
.build();

final var authenticationProviderMock = mock(AuthenticationProvider.class);
authenticationProviderMock.authenticateRequest(
any(RequestInformation.class), any(Map.class));
var requestAdapter =
new OkHttpRequestAdapter(authenticationProviderMock, null, null, okHttpClient);

final var requestInformation = new RequestInformation();
requestInformation.setUri(new URI("https://localhost"));
requestInformation.httpMethod = HttpMethod.PUT;
var nativeResponseHandler = new NativeResponseHandler();
requestInformation.setResponseHandler(nativeResponseHandler);

final var testFile = new File("./src/test/resources/helloWorld.txt");
final var contentLength = testFile.length();

try (FileInputStream content = new FileInputStream(testFile)) {
requestInformation.setStreamContent(content, "application/octet-stream");
var mockEntity = creatMockEntity();
requestAdapter.send(requestInformation, null, node -> mockEntity);
var nativeResponse = (Response) nativeResponseHandler.getValue();
assertNotNull(nativeResponse);
assertEquals(contentLength, nativeResponse.body().source().readByteArray().length);
}
}

public static OkHttpClient getMockClient(final Response response) throws IOException {
final OkHttpClient mockClient = mock(OkHttpClient.class);
final Call remoteCall = mock(Call.class);
Expand Down Expand Up @@ -440,4 +574,33 @@ public ParseNodeFactory creatMockParseNodeFactory(
when(mockFactory.getValidContentType()).thenReturn(validContentType);
return mockFactory;
}

// Returns request body as response body
static class MockResponseHandler implements Interceptor {
@Override
public Response intercept(Chain chain) throws IOException {
final var request = chain.request();
final var requestBody = request.body();
if (request != null && requestBody != null) {
final var buffer = new Buffer();
requestBody.writeTo(buffer);
return new Response.Builder()
.code(200)
.message("OK")
.protocol(Protocol.HTTP_1_1)
.request(request)
.body(
ResponseBody.create(
buffer.readByteArray(),
MediaType.parse("application/json")))
.build();
}
return new Response.Builder()
.code(200)
.message("OK")
.protocol(Protocol.HTTP_1_1)
.request(request)
.build();
}
}
}
Loading