Skip to content

Commit

Permalink
Fixes #439
Browse files Browse the repository at this point in the history
  • Loading branch information
whiskeysierra committed Feb 24, 2019
1 parent 470d5d5 commit 4c2ca26
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import lombok.AllArgsConstructor;
import lombok.Getter;

import java.io.IOException;
import java.util.function.Predicate;

@AllArgsConstructor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ final class LocalResponse extends HttpServletResponseWrapper implements HttpResp

private final String protocolVersion;

private Tee tee;
private Tee body;
private Tee buffer;
private boolean used; // point of no return, once we exposed our stream, we need to buffer

LocalResponse(final HttpServletResponse response, final String protocolVersion) {
super(response);
Expand Down Expand Up @@ -60,39 +62,51 @@ public Charset getCharset() {

@Override
public HttpResponse withBody() throws IOException {
if (tee == null) {
this.tee = new Tee(super.getOutputStream());
if (body == null) {
bufferIfNecessary();
this.body = buffer;
}
return this;
}

private void bufferIfNecessary() throws IOException {
if (buffer == null) {
this.buffer = new Tee(super.getOutputStream());
}
}

@Override
public HttpResponse withoutBody() {
this.tee = null;
this.body = null;
if (!used) {
this.buffer = null;
}
return this;
}

@Override
public ServletOutputStream getOutputStream() throws IOException {
if (tee == null) {
if (buffer == null) {
return super.getOutputStream();
} else {
return tee.getOutputStream();
this.used = true;
return buffer.getOutputStream();
}
}

@Override
public PrintWriter getWriter() throws IOException {
if (tee == null) {
if (buffer == null) {
return super.getWriter();
} else {
return tee.getWriter(this::getCharset);
this.used = true;
return buffer.getWriter(this::getCharset);
}
}

@Override
public byte[] getBody() {
return tee == null ? new byte[0] : tee.getBytes();
return body == null ? new byte[0] : body.getBytes();
}

private static class Tee {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ final class RemoteRequest extends HttpServletRequestWrapper implements HttpReque
private final FormRequestMode formRequestMode = FormRequestMode.fromProperties();

private byte[] body;
private byte[] buffered;

RemoteRequest(final HttpServletRequest request) {
super(request);
Expand Down Expand Up @@ -93,28 +94,35 @@ public Charset getCharset() {
@Override
public HttpRequest withBody() throws IOException {
if (body == null) {
bufferIfNecessary();
this.body = buffered;
}

return this;
}

private void bufferIfNecessary() throws IOException {
if (buffered == null) {
if (isFormRequest()) {
switch (formRequestMode) {
case PARAMETER:
this.body = reconstructBodyFromParameters();
return this;
this.buffered = reconstructBodyFromParameters();
return;
case OFF:
this.body = new byte[0];
return this;
this.buffered = new byte[0];
return;
default:
break;
}
}

this.body = ByteStreams.toByteArray(super.getInputStream());
this.buffered = ByteStreams.toByteArray(super.getInputStream());
}

return this;
}

@Override
public HttpRequest withoutBody() {
this.body = new byte[0];
this.body = null;
return this;
}

Expand Down Expand Up @@ -148,9 +156,10 @@ static String encode(final String s, final String charset) {
}

@Override
public ServletInputStream getInputStream() {
// TODO we need the ability to not buffer but still allow downstream filters/servlets to read the original request
return new ServletInputStreamAdapter(new ByteArrayInputStream(body));
public ServletInputStream getInputStream() throws IOException {
return buffered == null ?
super.getInputStream() :
new ServletInputStreamAdapter(new ByteArrayInputStream(buffered));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.mockito.ArgumentCaptor;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.zalando.logbook.Correlation;
import org.zalando.logbook.DefaultHttpLogFormatter;
import org.zalando.logbook.DefaultSink;
import org.zalando.logbook.HttpLogFormatter;
Expand All @@ -13,6 +14,8 @@
import org.zalando.logbook.HttpRequest;
import org.zalando.logbook.HttpResponse;
import org.zalando.logbook.Logbook;
import org.zalando.logbook.Precorrelation;
import org.zalando.logbook.Sink;
import org.zalando.logbook.Strategy;

import javax.servlet.DispatcherType;
Expand Down Expand Up @@ -52,14 +55,28 @@ final class AsyncDispatchTest {
.strategy(new Strategy() {
@Override
public HttpRequest process(final HttpRequest request) throws IOException {
request.getBody();
return request.withBody().withBody();
return request.withBody().withBody().withoutBody().withBody();
}

@Override
public void write(final Precorrelation precorrelation, final HttpRequest request,
final Sink sink) throws IOException {

request.withoutBody().withBody();
sink.write(precorrelation, request);
}

@Override
public HttpResponse process(final HttpRequest request, final HttpResponse response) throws IOException {
response.getBody();
return response.withBody().withBody();
return response.withBody().withBody().withoutBody().withBody();
}

@Override
public void write(final Correlation correlation, final HttpRequest request,
final HttpResponse response, final Sink sink) throws IOException {

response.withoutBody().withBody();
sink.write(correlation, request, response);
}
})
.sink(new DefaultSink(formatter, writer))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

Expand All @@ -17,8 +18,10 @@
import java.util.Objects;
import java.util.concurrent.Callable;

import static org.springframework.http.MediaType.TEXT_PLAIN_VALUE;

@RestController
@RequestMapping(value = "/api", produces = MediaType.APPLICATION_JSON_VALUE)
@RequestMapping(path = "/api", produces = MediaType.APPLICATION_JSON_VALUE)
public class ExampleController {

@RequestMapping("/sync")
Expand All @@ -28,6 +31,11 @@ public ResponseEntity<Message> message() {
return ResponseEntity.ok(message);
}

@RequestMapping(path = "/echo", consumes = TEXT_PLAIN_VALUE, produces = TEXT_PLAIN_VALUE)
public ResponseEntity<String> echo(@RequestBody final String message) {
return ResponseEntity.ok(message);
}

@RequestMapping("/async")
public Callable<ResponseEntity<Message>> returnMessage() {
return () -> {
Expand All @@ -48,7 +56,7 @@ public void error() {
throw new UnsupportedOperationException();
}

@RequestMapping(value = "/read-byte", produces = MediaType.TEXT_PLAIN_VALUE)
@RequestMapping(path = "/read-byte", produces = TEXT_PLAIN_VALUE)
public void readByte(final HttpServletRequest request, final HttpServletResponse response) throws IOException {

final ServletInputStream input = request.getInputStream();
Expand All @@ -63,7 +71,7 @@ public void readByte(final HttpServletRequest request, final HttpServletResponse
}
}

@RequestMapping(value = "/read-bytes", produces = MediaType.TEXT_PLAIN_VALUE)
@RequestMapping(path = "/read-bytes", produces = TEXT_PLAIN_VALUE)
public void readBytes(final HttpServletRequest request, final HttpServletResponse response) throws IOException {

final ServletInputStream input = request.getInputStream();
Expand All @@ -80,12 +88,12 @@ public void readBytes(final HttpServletRequest request, final HttpServletRespons
}
}

@RequestMapping(value = "/stream", produces = MediaType.TEXT_PLAIN_VALUE)
@RequestMapping(path = "/stream", produces = TEXT_PLAIN_VALUE)
public void stream(final HttpServletRequest request, final HttpServletResponse response) throws IOException {
ByteStreams.copy(request.getInputStream(), response.getOutputStream());
}

@RequestMapping(value = "/reader", produces = MediaType.TEXT_PLAIN_VALUE)
@RequestMapping(path = "/reader", produces = TEXT_PLAIN_VALUE)
public void reader(final HttpServletRequest request, final HttpServletResponse response) throws IOException {
try (final PrintWriter writer = response.getWriter()) {
copy(request.getReader(), writer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.hamcrest.MockitoHamcrest.argThat;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request;
Expand Down Expand Up @@ -90,7 +91,6 @@ void shouldLogAuthorizedResponseOnce() throws Exception {

@ParameterizedTest
@ValueSource(ints = {401, 403})
@SuppressWarnings("unchecked")
void shouldFormatUnauthorizedRequestOnce(final int status) throws Exception {
securityFilter.setStatus(status);

Expand All @@ -101,7 +101,6 @@ void shouldFormatUnauthorizedRequestOnce(final int status) throws Exception {

@ParameterizedTest
@ValueSource(ints = {401, 403})
@SuppressWarnings("unchecked")
void shouldFormatUnauthorizedResponseOnce(final int status) throws Exception {
securityFilter.setStatus(status);

Expand Down Expand Up @@ -164,4 +163,11 @@ void shouldHandleUnauthorizedAsyncDispatchRequest() throws Exception {
.andReturn()));
}

@Test
void shouldEcho() throws Exception {
mvc.perform(get("/api/echo").content("Hello, world!"));

verify(writer).write(any(Precorrelation.class), argThat(containsString("Hello, world!")));
}

}

0 comments on commit 4c2ca26

Please sign in to comment.