Skip to content

Commit

Permalink
Fix TimingKey Memory Leak #587
Browse files Browse the repository at this point in the history
  • Loading branch information
Zizhong Zhang committed Apr 2, 2021
1 parent 64d8f76 commit 8a3bd30
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and what APIs have changed, if applicable.
- Update fluent client APIs to include projection mask as input parameter.
- Update projection mask builder APIs to support updating the mask objects.
- Added support for checking if a nested type supports new ProjectionMask API before generating new typesafe APIs for them.
- Fix TimingKey Memory Leak

## [29.17.0] - 2021-03-23
- Implement D2 cluster subsetting.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ public void shutdown(final Callback<None> callback)

callback.onSuccess(None.none());
});
TimingKey.unregisterKey(TIMING_KEY);
}

@Override
Expand Down
12 changes: 11 additions & 1 deletion r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamResponse;

import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -189,4 +189,14 @@ void onStreamResponse(StreamResponse res,
void onStreamError(Exception ex,
RequestContext requestContext,
Map<String, String> wireAttrs);

/**
* Returns a copy of a list of RestFilters
*/
List<RestFilter> getRestFilters();

/**
* Returns a copy of a list of StreamFilters
*/
List<StreamFilter> getStreamFilters();
}
10 changes: 10 additions & 0 deletions r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ public FilterChain addLast(StreamFilter filter)
return new FilterChainImpl(_restFilters, doAddLast(_streamFilters, decorateStreamFilter(filter)));
}

@Override
public List<RestFilter> getRestFilters() {
return new ArrayList<RestFilter>(_restFilters);
}

@Override
public List<StreamFilter> getStreamFilters() {
return new ArrayList<StreamFilter>(_streamFilters);
}

private RestFilter decorateRestFilter(RestFilter filter)
{
return new TimedRestFilter(filter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.timing.TimingImportance;
import java.util.Arrays;
import java.util.List;
import java.util.Map;


Expand All @@ -31,7 +33,7 @@
*
* @author Xialin Zhu
*/
/* package private */ class TimedRestFilter implements RestFilter
public class TimedRestFilter implements RestFilter
{
protected static final String ON_REQUEST_SUFFIX = "onRequest";
protected static final String ON_RESPONSE_SUFFIX = "onResponse";
Expand Down Expand Up @@ -91,4 +93,8 @@ public void onRestError(Throwable ex,
TimingContextUtil.markTiming(requestContext, _onErrorTimingKey);
_restFilter.onRestError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter));
}

public List<TimingKey> getTimingKeyList() {
return Arrays.asList(_onErrorTimingKey, _onRequestTimingKey, _onResponseTimingKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.timing.TimingImportance;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static com.linkedin.r2.filter.TimedRestFilter.ON_ERROR_SUFFIX;
Expand All @@ -34,7 +36,7 @@
*
* @author Xialin Zhu
*/
/* package private */ class TimedStreamFilter implements StreamFilter
public class TimedStreamFilter implements StreamFilter
{
private final StreamFilter _streamFilter;
private final TimingKey _onRequestTimingKey;
Expand Down Expand Up @@ -91,4 +93,8 @@ public void onStreamError(Throwable ex,
TimingContextUtil.markTiming(requestContext, _onErrorTimingKey);
_streamFilter.onStreamError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter));
}

public List<TimingKey> getTimingKeyList() {
return Arrays.asList(_onErrorTimingKey, _onRequestTimingKey, _onResponseTimingKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
/* $Id$ */
package com.linkedin.r2.filter.transport;


import com.linkedin.common.callback.Callback;
import com.linkedin.common.util.None;
import com.linkedin.r2.filter.FilterChain;
import com.linkedin.r2.filter.TimedRestFilter;
import com.linkedin.r2.filter.TimedStreamFilter;
import com.linkedin.r2.filter.message.rest.RestFilter;
import com.linkedin.r2.filter.message.stream.StreamFilter;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.Response;
import com.linkedin.r2.message.rest.RestRequest;
Expand All @@ -29,12 +32,15 @@
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.timing.FrameworkTimingKeys;
import com.linkedin.r2.message.timing.TimingContextUtil;
import com.linkedin.r2.message.timing.TimingKey;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
import com.linkedin.r2.transport.common.bridge.common.TransportCallback;

import com.linkedin.r2.transport.common.bridge.common.TransportResponse;
import java.util.Collection;
import java.util.List;
import java.util.Map;


/**
* {@link TransportClient} adapter which composes a {@link TransportClient}
* and a {@link FilterChain}.
Expand All @@ -46,15 +52,17 @@ public class FilterChainClient implements TransportClient
{
private final TransportClient _client;
private final FilterChain _filters;
private final FilterChain _sharedFilters;

/**
* Construct a new instance by composing the specified {@link TransportClient}
* and {@link FilterChain}.
*
* @param client the {@link TransportClient} to be composed.
* @param filters the {@link FilterChain} to be composed.
* @param sharedFilters the {@link FilterChain} can be used by other clients.
*/
public FilterChainClient(TransportClient client, FilterChain filters)
public FilterChainClient(TransportClient client, FilterChain filters, FilterChain sharedFilters)
{
_client = client;

Expand All @@ -66,6 +74,7 @@ public FilterChainClient(TransportClient client, FilterChain filters)
.addLastRest(requestFilter)
.addFirst(responseFilter)
.addLast(requestFilter);
_sharedFilters = sharedFilters;
}

@Override
Expand Down Expand Up @@ -94,6 +103,27 @@ public void streamRequest(StreamRequest request,
public void shutdown(Callback<None> callback)
{
_client.shutdown(callback);

List<StreamFilter> streamFilters = _filters.getStreamFilters();
List<RestFilter> restFilters = _filters.getRestFilters();
List<StreamFilter> sharedStreamFilters = _sharedFilters.getStreamFilters();
List<RestFilter> sharedRestFilters = _sharedFilters.getRestFilters();

streamFilters.stream()
.filter(filter -> !sharedStreamFilters.contains(filter))
.filter(TimedStreamFilter.class::isInstance)
.map(TimedStreamFilter.class::cast)
.map(TimedStreamFilter::getTimingKeyList)
.flatMap(Collection::stream)
.forEach(TimingKey::unregisterKey);

restFilters.stream()
.filter(filter -> !sharedRestFilters.contains(filter))
.filter(TimedRestFilter.class::isInstance)
.map(TimedRestFilter.class::cast)
.map(TimedRestFilter::getTimingKeyList)
.flatMap(Collection::stream)
.forEach(TimingKey::unregisterKey);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
package com.linkedin.r2.message.timing;

import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;

import com.linkedin.r2.message.RequestContext;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;


/**
Expand All @@ -32,6 +35,7 @@
public class TimingKey
{
private static final Map<String, TimingKey> _pool = new ConcurrentHashMap<>();
private static final ExecutorService _unregisterExecutor = Executors.newFixedThreadPool(1);

private final String _name;
private final String _type;
Expand Down Expand Up @@ -130,4 +134,26 @@ public static TimingKey registerNewKey(String uniqueName, String type, TimingImp
{
return registerNewKey(new TimingKey(uniqueName, type, timingImportance));
}

/**
* Unregister a TimingKey to reclaim the memory
*
*/
public static void unregisterKey(TimingKey key)
{
_unregisterExecutor.submit(new Callable<Void>() {
public Void call() throws Exception {
_pool.remove(key.getName());
return null;
}
});
}

/**
* Return how many registered keys, for testing purpose.
*/
public static int getCount() {
return _pool.size();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ public void onSuccess(None result)
{
callback.onError(new IllegalStateException("Shutdown has already been requested."));
}
TimingKey.unregisterKey(TIMING_KEY);
}

private void sendStreamRequestAsRestRequest(StreamRequest request, RequestContext requestContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,6 @@ public TransportClient getClient(Map<String, ? extends Object> properties)
properties = new HashMap<String,Object>(properties);
sslContext = coerceAndRemoveFromMap(HTTP_SSL_CONTEXT, properties, SSLContext.class);
sslParameters = coerceAndRemoveFromMap(HTTP_SSL_PARAMS, properties, SSLParameters.class);

return getClient(properties, sslContext, sslParameters);
}

Expand Down Expand Up @@ -1121,7 +1120,7 @@ private TransportClient getClient(Map<String, ? extends Object> properties,
filters = filters.addLastRest(disruptFilter);
filters = filters.addLast(disruptFilter);

client = new FilterChainClient(client, filters);
client = new FilterChainClient(client, filters, _filters);
client = new FactoryClient(client);
synchronized (_mutex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ public void onSuccess(None result)
_shutdownTimeout);
_jmxManager.onProviderShutdown(_channelPoolManager);
_jmxManager.onProviderShutdown(_sslChannelPoolManager);
TimingKey.unregisterKey(TIMING_KEY);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestRequestBuilder;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.timing.TimingKey;
import com.linkedin.r2.testutils.server.HttpServerBuilder;
import com.linkedin.r2.transport.common.Client;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
Expand Down Expand Up @@ -93,20 +94,25 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion
{
server.start();
List<Client> clients = new ArrayList<>();

int savedTimingKeyCount = TimingKey.getCount();
for (int i = 0; i < 100; i++)
{
HashMap<String, String> properties = new HashMap<>();
properties.put(HttpClientFactory.HTTP_PROTOCOL_VERSION, protocolVersion);
clients.add(new TransportClientAdapter(factory.getClient(properties), restOverStream));
}

int addedTimingKeyCount = TimingKey.getCount() - savedTimingKeyCount;
// In current implementation, one client can have around 30 TimingKeys by default.
Assert.assertTrue(addedTimingKeyCount >= 30 * clients.size());
for (Client c : clients)
{
RestRequest r = new RestRequestBuilder(new URI(URI)).build();
c.restRequest(r).get(30, TimeUnit.SECONDS);
}
Assert.assertEquals(httpServerStatsProvider.requestCount(), expectedRequests);

savedTimingKeyCount = TimingKey.getCount();
for (Client c : clients)
{
FutureCallback<None> callback = new FutureCallback<>();
Expand All @@ -117,6 +123,8 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion
FutureCallback<None> factoryShutdown = new FutureCallback<>();
factory.shutdown(factoryShutdown);
factoryShutdown.get(30, TimeUnit.SECONDS);
int removedTimingKeyCount = savedTimingKeyCount - TimingKey.getCount();
Assert.assertEquals(addedTimingKeyCount, removedTimingKeyCount);
}
finally
{
Expand Down

0 comments on commit 8a3bd30

Please sign in to comment.