Skip to content

Commit

Permalink
[linkedin#573] BackupRequest doesn't work with streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
Zizhong Zhang committed Mar 27, 2021
1 parent ce1716e commit 55df38c
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@
import com.linkedin.d2.balancer.properties.ServiceProperties;
import com.linkedin.d2.balancer.strategies.LoadBalancerStrategy.ExcludedHostHints;
import com.linkedin.d2.balancer.util.LoadBalancerUtil;
import com.linkedin.data.ByteString;
import com.linkedin.r2.filter.R2Constants;
import com.linkedin.r2.message.Request;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.stream.entitystream.ByteStringWriter;
import com.linkedin.r2.message.stream.entitystream.EntityStreams;
import com.linkedin.r2.message.stream.entitystream.FullEntityObserver;
import com.linkedin.r2.util.NamedThreadFactory;
import java.net.URI;
import java.util.List;
Expand Down Expand Up @@ -380,6 +384,29 @@ public void streamRequest(StreamRequest request, Callback<StreamResponse> callba
@Override
public void streamRequest(StreamRequest request, RequestContext requestContext, Callback<StreamResponse> callback)
{
// Buffering stream request raises concerns on memory usage and performance.
// Currently only support backup requests with IS_FULL_REQUEST.
if (!isFullRequest(requestContext)) {
_d2Client.streamRequest(request, requestContext, callback);
return;
}
if (!isBuffered(requestContext)) {
final FullEntityObserver observer = new FullEntityObserver(new Callback<ByteString>()
{
@Override
public void onError(Throwable e)
{
LOG.warn("Failed to record request's entity for retrying backup request.");
}

@Override
public void onSuccess(ByteString result)
{
requestContext.putLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY, result);
}
});
request.getEntityStream().addObserver(observer);
}
if (_isD2Async)
{
requestAsync(request, requestContext, _d2Client::streamRequest, callback);
Expand Down Expand Up @@ -524,6 +551,7 @@ public DecoratedCallback(R request, RequestContext requestContext, DecoratorClie
executorService.schedule(this::maybeSendBackupRequest, delayNano, TimeUnit.NANOSECONDS);
}

@SuppressWarnings("unchecked")
private void maybeSendBackupRequest()
{
Set<URI> exclusionSet = ExcludedHostHints.getRequestContextExcludedHosts(_requestContext);
Expand All @@ -532,9 +560,23 @@ private void maybeSendBackupRequest()
if (exclusionSet != null)
{
exclusionSet.forEach(uri -> ExcludedHostHints.addRequestContextExcludedHost(_backupRequestContext, uri));
if (_request instanceof StreamRequest && !isBuffered(_requestContext)) {
return;
}
if (!_done.get() && _strategy.isBackupRequestAllowed())
{
_client.doRequest(_request, _backupRequestContext, new Callback<T>()
boolean needCloneRC = false;
R request = _request;
if (_request instanceof StreamRequest) {
StreamRequest req = (StreamRequest)_request;
req = req.builder()
.build(EntityStreams.newEntityStream(new ByteStringWriter(
(ByteString) _requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY)
)));
request = (R)req;
needCloneRC = true;
}
_client.doRequest(request, needCloneRC ? _requestContext.clone(): _requestContext, new Callback<T>()
{
@Override
public void onSuccess(T result)
Expand Down Expand Up @@ -721,4 +763,15 @@ public boolean equals(Object obj)

}

private static boolean isFullRequest(RequestContext requestContext)
{
Object isFullRequest = requestContext.getLocalAttr(R2Constants.IS_FULL_REQUEST);
return isFullRequest != null && (Boolean)isFullRequest;
}

private static boolean isBuffered(RequestContext requestContext)
{
Object bufferedBody = requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY);
return bufferedBody != null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
import com.linkedin.d2.backuprequests.TestTrackingBackupRequestsStrategy;
import com.linkedin.d2.backuprequests.TrackingBackupRequestsStrategy;
import com.linkedin.d2.balancer.KeyMapper;
import com.linkedin.d2.balancer.LoadBalancer;
import com.linkedin.d2.balancer.ServiceUnavailableException;
import com.linkedin.d2.balancer.StaticLoadBalancerState;
import com.linkedin.d2.balancer.LoadBalancer;
import com.linkedin.d2.balancer.properties.PartitionData;
import com.linkedin.d2.balancer.properties.ServiceProperties;
import com.linkedin.d2.balancer.simple.SimpleLoadBalancer;
Expand All @@ -51,11 +51,17 @@
import com.linkedin.r2.message.rest.RestRequestBuilder;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.rest.RestResponseBuilder;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamRequestBuilder;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.stream.StreamResponseBuilder;
import com.linkedin.r2.message.stream.entitystream.ByteStringWriter;
import com.linkedin.r2.message.stream.entitystream.DrainReader;
import com.linkedin.r2.message.stream.entitystream.EntityStreams;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
import com.linkedin.r2.transport.common.bridge.common.TransportCallback;
import com.linkedin.r2.transport.common.bridge.common.TransportResponseImpl;
import com.linkedin.util.clock.SystemClock;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
Expand All @@ -69,6 +75,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
Expand All @@ -84,11 +91,7 @@
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNotSame;
import static org.testng.Assert.assertSame;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.*;


public class TestBackupRequestsClient
Expand All @@ -98,6 +101,7 @@ public class TestBackupRequestsClient
private static final String CLUSTER_NAME = "testCluster";
private static final String PATH = "";
private static final String STRATEGY_NAME = "degrader";
private static final String BUFFERED_HEADER = "buffered";
private static final ByteString CONTENT = ByteString.copy(new byte[8092]);

private ScheduledExecutorService _executor;
Expand Down Expand Up @@ -126,6 +130,105 @@ public void testRequest(boolean isD2Async) throws Exception
assertEquals(response.get().getStatus(), 200);
}

@Test(invocationCount = 3, dataProvider = "isD2Async")
public void testStreamRequestWithNoIsFullRequest(boolean isD2Async) throws Exception {
int responseDelayNano = 100000000; //1s till response comes back
int backupDelayNano = 50000000; // make backup request after 0.5 second
Deque<URI> hostsReceivingRequest = new ConcurrentLinkedDeque<>();
BackupRequestsClient client =
createAlwaysBackupClientWithHosts(Arrays.asList("http://test1.com:123", "http://test2.com:123"),
hostsReceivingRequest, responseDelayNano, backupDelayNano, isD2Async);

URI uri = URI.create("d2://testService");

// if there is no IS_FULL_REQUEST set, backup requests will not happen
StreamRequest streamRequest =
new StreamRequestBuilder(uri).build(EntityStreams.newEntityStream(new ByteStringWriter(CONTENT)));
RequestContext context = new RequestContext();
context.putLocalAttr(R2Constants.OPERATION, "get");
RequestContext context1 = context.clone();

CountDownLatch latch = new CountDownLatch(1);
AtomicReference<AssertionError> failure = new AtomicReference<>();

client.streamRequest(streamRequest, context1, new Callback<StreamResponse>() {
@Override
public void onError(Throwable e) {
failure.set(new AssertionError("Callback onError"));
latch.countDown();
}

@Override
public void onSuccess(StreamResponse result) {
try {
assertEquals(result.getStatus(), 200);
assertEquals(result.getHeader("buffered"), "false");
assertEquals(hostsReceivingRequest.size(), 1);
assertEquals(new HashSet<>(hostsReceivingRequest).size(), 1);
hostsReceivingRequest.clear();
} catch (AssertionError e) {
failure.set(e);
}
latch.countDown();
}
});

latch.await(2, TimeUnit.SECONDS);
if (failure.get() != null) {
throw failure.get();
}
}

@Test(invocationCount = 3, dataProvider = "isD2Async")
public void testStreamRequestWithIsFullRequest(boolean isD2Async) throws Exception {
int responseDelayNano = 500000000; //5s till response comes back
int backupDelayNano = 100000000; // make backup request after 1 second
Deque<URI> hostsReceivingRequest = new ConcurrentLinkedDeque<>();
BackupRequestsClient client =
createAlwaysBackupClientWithHosts(Arrays.asList("http://test1.com:123", "http://test2.com:123"),
hostsReceivingRequest, responseDelayNano, backupDelayNano, isD2Async);

URI uri = URI.create("d2://testService");

// if there is IS_FULL_REQUEST set, backup requests will happen
StreamRequest streamRequest =
new StreamRequestBuilder(uri).build(EntityStreams.newEntityStream(new ByteStringWriter(CONTENT)));
RequestContext context = new RequestContext();
context.putLocalAttr(R2Constants.OPERATION, "get");
context.putLocalAttr(R2Constants.IS_FULL_REQUEST, true);
RequestContext context1 = context.clone();

CountDownLatch latch = new CountDownLatch(1);
AtomicReference<AssertionError> failure = new AtomicReference<>();

client.streamRequest(streamRequest, context1, new Callback<StreamResponse>() {
@Override
public void onError(Throwable e) {
failure.set(new AssertionError("Callback onError"));
latch.countDown();
}

@Override
public void onSuccess(StreamResponse result) {
try {
assertEquals(result.getStatus(), 200);
assertEquals(result.getHeader("buffered"), "true");
assertEquals(hostsReceivingRequest.size(), 2);
assertEquals(new HashSet<>(hostsReceivingRequest).size(), 2);
hostsReceivingRequest.clear();
} catch (AssertionError e) {
failure.set(e);
}
latch.countDown();
}
});

latch.await(6, TimeUnit.SECONDS);
if (failure.get() != null) {
throw failure.get();
}
}

/**
* Backup Request should still work when a hint is given together with the flag indicating that the hint is only a preference, not requirement.
*/
Expand Down Expand Up @@ -629,6 +732,31 @@ public void restRequest(RestRequest request,
() -> callback.onResponse(TransportResponseImpl.success(new RestResponseBuilder().build())), responseDelayNano,
TimeUnit.NANOSECONDS);
}

@Override
public void streamRequest(StreamRequest request,
RequestContext requestContext,
Map<String, String> wireAttrs,
TransportCallback<StreamResponse> callback) {
// whenever a trackerClient is used to make request, record down it's hostname
hostsReceivingRequestList.add(uri);
if (null != requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY)) {
callback.onResponse(TransportResponseImpl.success(new StreamResponseBuilder().setHeader(
BUFFERED_HEADER, String.valueOf(requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY) != null)
).build(EntityStreams.emptyStream())));
return;
}
request.getEntityStream().setReader(new DrainReader(){
public void onDone() {
// delay response to allow backup request to happen
_executor.schedule(
() -> callback.onResponse(TransportResponseImpl.success(new StreamResponseBuilder().setHeader(
BUFFERED_HEADER, String.valueOf(requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY) != null)
).build(EntityStreams.emptyStream()))), responseDelayNano,
TimeUnit.NANOSECONDS);
}
});
}
};
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class R2Constants
public static final int DEFAULT_DATA_CHUNK_SIZE = 8192;
public static final boolean DEFAULT_REST_OVER_STREAM = false;
public static final String RETRY_MESSAGE_ATTRIBUTE_KEY = "RETRY";
public static final String BACKUP_REQUEST_BUFFERED_BODY = "BACKUP_REQUEST_BUFFERED_BODY";
@Deprecated
public static final String EXPECTED_SERVER_CERT_PRINCIPAL_NAME = "EXPECTED_SERVER_CERT_PRINCIPAL_NAME";
public static final String REQUESTED_SSL_SESSION_VALIDATOR = "REQUESTED_SSL_SESSION_VALIDATOR";
Expand Down

0 comments on commit 55df38c

Please sign in to comment.