Skip to content

Commit

Permalink
Fix TimingKey Memory Leak #587
Browse files Browse the repository at this point in the history
  • Loading branch information
Zizhong Zhang committed May 3, 2021
1 parent 292349f commit 83bb325
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and what APIs have changed, if applicable.
- Strictly enforce Gradle version compatibility in the `pegasus` Gradle plugin.
- Minimum required Gradle version is now `1.0` (effectively backward-compatible).
- Minimum suggested Gradle version is now `5.2.1`
- Fix TimingKey Memory Leak

## [29.18.2] - 2021-04-28
- Fix bug in generated fluent client APIs when typerefs are used as association key params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ public void shutdown(final Callback<None> callback)

callback.onSuccess(None.none());
});
TimingKey.unregisterKey(TIMING_KEY);
}

@Override
Expand Down
12 changes: 11 additions & 1 deletion r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamResponse;

import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -189,4 +189,14 @@ void onStreamResponse(StreamResponse res,
void onStreamError(Exception ex,
RequestContext requestContext,
Map<String, String> wireAttrs);

/**
* Returns a copy of a list of RestFilters
*/
List<RestFilter> getRestFilters();

/**
* Returns a copy of a list of StreamFilters
*/
List<StreamFilter> getStreamFilters();
}
10 changes: 10 additions & 0 deletions r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ public FilterChain addLast(StreamFilter filter)
return new FilterChainImpl(_restFilters, doAddLast(_streamFilters, decorateStreamFilter(filter)));
}

@Override
public List<RestFilter> getRestFilters() {
return new ArrayList<RestFilter>(_restFilters);
}

@Override
public List<StreamFilter> getStreamFilters() {
return new ArrayList<StreamFilter>(_streamFilters);
}

private RestFilter decorateRestFilter(RestFilter filter)
{
return new TimedRestFilter(filter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.timing.TimingImportance;
import java.util.Arrays;
import java.util.List;
import java.util.Map;


Expand All @@ -31,7 +33,7 @@
*
* @author Xialin Zhu
*/
/* package private */ class TimedRestFilter implements RestFilter
public class TimedRestFilter implements RestFilter
{
protected static final String ON_REQUEST_SUFFIX = "onRequest";
protected static final String ON_RESPONSE_SUFFIX = "onResponse";
Expand All @@ -41,6 +43,7 @@
private final TimingKey _onRequestTimingKey;
private final TimingKey _onResponseTimingKey;
private final TimingKey _onErrorTimingKey;
private boolean _shared;

/**
* Registers {@link TimingKey}s for {@link com.linkedin.r2.message.timing.TimingNameConstants#TIMED_REST_FILTER}.
Expand All @@ -61,6 +64,7 @@ public TimedRestFilter(RestFilter restFilter)
_restFilter.getClass().getSimpleName(), TimingImportance.LOW);
_onErrorTimingKey = TimingKey.registerNewKey(timingKeyPrefix + ON_ERROR_SUFFIX + timingKeyPostfix,
_restFilter.getClass().getSimpleName(), TimingImportance.LOW);
_shared = false;
}

@Override
Expand Down Expand Up @@ -91,4 +95,16 @@ public void onRestError(Throwable ex,
TimingContextUtil.markTiming(requestContext, _onErrorTimingKey);
_restFilter.onRestError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter));
}

public void setShared() {
_shared = true;
}

public void onShutdown() {
if (!_shared) {
TimingKey.unregisterKey(_onErrorTimingKey);
TimingKey.unregisterKey(_onRequestTimingKey);
TimingKey.unregisterKey(_onResponseTimingKey);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
*
* @author Xialin Zhu
*/
/* package private */ class TimedStreamFilter implements StreamFilter
public class TimedStreamFilter implements StreamFilter
{
private final StreamFilter _streamFilter;
private final TimingKey _onRequestTimingKey;
private final TimingKey _onResponseTimingKey;
private final TimingKey _onErrorTimingKey;
private boolean _shared;

/**
* Registers {@link TimingKey}s for {@link com.linkedin.r2.message.timing.TimingNameConstants#TIMED_STREAM_FILTER}.
Expand All @@ -60,6 +61,7 @@ public TimedStreamFilter(StreamFilter streamFilter)
filterClassName, TimingImportance.LOW);
_onErrorTimingKey = TimingKey.registerNewKey(timingKeyPrefix + ON_ERROR_SUFFIX + timingKeyPostfix,
filterClassName, TimingImportance.LOW);
_shared = false;
}

@Override
Expand Down Expand Up @@ -91,4 +93,16 @@ public void onStreamError(Throwable ex,
TimingContextUtil.markTiming(requestContext, _onErrorTimingKey);
_streamFilter.onStreamError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter));
}

public void setShared() {
_shared = true;
}

public void onShutdown() {
if (!_shared) {
TimingKey.unregisterKey(_onErrorTimingKey);
TimingKey.unregisterKey(_onRequestTimingKey);
TimingKey.unregisterKey(_onResponseTimingKey);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
/* $Id$ */
package com.linkedin.r2.filter.transport;


import com.linkedin.common.callback.Callback;
import com.linkedin.common.util.None;
import com.linkedin.r2.filter.FilterChain;
import com.linkedin.r2.filter.TimedRestFilter;
import com.linkedin.r2.filter.TimedStreamFilter;
import com.linkedin.r2.filter.message.rest.RestFilter;
import com.linkedin.r2.filter.message.stream.StreamFilter;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.Response;
import com.linkedin.r2.message.rest.RestRequest;
Expand All @@ -29,12 +32,15 @@
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.timing.FrameworkTimingKeys;
import com.linkedin.r2.message.timing.TimingContextUtil;
import com.linkedin.r2.message.timing.TimingKey;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
import com.linkedin.r2.transport.common.bridge.common.TransportCallback;

import com.linkedin.r2.transport.common.bridge.common.TransportResponse;
import java.util.Collection;
import java.util.List;
import java.util.Map;


/**
* {@link TransportClient} adapter which composes a {@link TransportClient}
* and a {@link FilterChain}.
Expand Down Expand Up @@ -94,6 +100,12 @@ public void streamRequest(StreamRequest request,
public void shutdown(Callback<None> callback)
{
_client.shutdown(callback);

_filters.getStreamFilters().stream().filter(TimedStreamFilter.class::isInstance)
.map(TimedStreamFilter.class::cast).forEach(TimedStreamFilter::onShutdown);

_filters.getRestFilters().stream().filter(TimedRestFilter.class::isInstance)
.map(TimedRestFilter.class::cast).forEach(TimedRestFilter::onShutdown);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
package com.linkedin.r2.message.timing;

import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;

import com.linkedin.r2.message.RequestContext;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;


/**
Expand All @@ -32,6 +35,7 @@
public class TimingKey
{
private static final Map<String, TimingKey> _pool = new ConcurrentHashMap<>();
private static final ExecutorService _unregisterExecutor = Executors.newFixedThreadPool(1);

private final String _name;
private final String _type;
Expand Down Expand Up @@ -130,4 +134,26 @@ public static TimingKey registerNewKey(String uniqueName, String type, TimingImp
{
return registerNewKey(new TimingKey(uniqueName, type, timingImportance));
}

/**
* Unregister a TimingKey to reclaim the memory
*
*/
public static void unregisterKey(TimingKey key)
{
_unregisterExecutor.submit(new Callable<Void>() {
public Void call() throws Exception {
_pool.remove(key.getName());
return null;
}
});
}

/**
* Return how many registered keys, for testing purpose.
*/
public static int getCount() {
return _pool.size();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ public void onSuccess(None result)
{
callback.onError(new IllegalStateException("Shutdown has already been requested."));
}
TimingKey.unregisterKey(TIMING_KEY);
}

private void sendStreamRequestAsRestRequest(StreamRequest request, RequestContext requestContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.linkedin.r2.filter.CompressionConfig;
import com.linkedin.r2.filter.FilterChain;
import com.linkedin.r2.filter.FilterChains;
import com.linkedin.r2.filter.TimedRestFilter;
import com.linkedin.r2.filter.TimedStreamFilter;
import com.linkedin.r2.filter.compression.ClientCompressionFilter;
import com.linkedin.r2.filter.compression.ClientCompressionHelper;
import com.linkedin.r2.filter.compression.ClientStreamCompressionFilter;
Expand Down Expand Up @@ -685,6 +687,11 @@ private HttpClientFactory(FilterChain filters,
{
_channelPoolManagerFactory = new ConnectionSharingChannelPoolManagerFactory(_channelPoolManagerFactory);
}

_filters.getStreamFilters().stream().filter(TimedStreamFilter.class::isInstance)
.map(TimedStreamFilter.class::cast).forEach(TimedStreamFilter::setShared);
_filters.getRestFilters().stream().filter(TimedRestFilter.class::isInstance)
.map(TimedRestFilter.class::cast).forEach(TimedRestFilter::setShared);
}

public static class Builder
Expand Down Expand Up @@ -958,7 +965,6 @@ public TransportClient getClient(Map<String, ? extends Object> properties)
properties = new HashMap<String,Object>(properties);
sslContext = coerceAndRemoveFromMap(HTTP_SSL_CONTEXT, properties, SSLContext.class);
sslParameters = coerceAndRemoveFromMap(HTTP_SSL_PARAMS, properties, SSLParameters.class);

return getClient(properties, sslContext, sslParameters);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ public void onSuccess(None result)
_shutdownTimeout);
_jmxManager.onProviderShutdown(_channelPoolManager);
_jmxManager.onProviderShutdown(_sslChannelPoolManager);
TimingKey.unregisterKey(TIMING_KEY);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestRequestBuilder;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.timing.TimingKey;
import com.linkedin.r2.testutils.server.HttpServerBuilder;
import com.linkedin.r2.transport.common.Client;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
Expand Down Expand Up @@ -93,20 +94,25 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion
{
server.start();
List<Client> clients = new ArrayList<>();

int savedTimingKeyCount = TimingKey.getCount();
for (int i = 0; i < 100; i++)
{
HashMap<String, String> properties = new HashMap<>();
properties.put(HttpClientFactory.HTTP_PROTOCOL_VERSION, protocolVersion);
clients.add(new TransportClientAdapter(factory.getClient(properties), restOverStream));
}

int addedTimingKeyCount = TimingKey.getCount() - savedTimingKeyCount;
// In current implementation, one client can have around 30 TimingKeys by default.
Assert.assertTrue(addedTimingKeyCount >= 30 * clients.size());
for (Client c : clients)
{
RestRequest r = new RestRequestBuilder(new URI(URI)).build();
c.restRequest(r).get(30, TimeUnit.SECONDS);
}
Assert.assertEquals(httpServerStatsProvider.requestCount(), expectedRequests);

savedTimingKeyCount = TimingKey.getCount();
for (Client c : clients)
{
FutureCallback<None> callback = new FutureCallback<>();
Expand All @@ -117,6 +123,8 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion
FutureCallback<None> factoryShutdown = new FutureCallback<>();
factory.shutdown(factoryShutdown);
factoryShutdown.get(30, TimeUnit.SECONDS);
int removedTimingKeyCount = savedTimingKeyCount - TimingKey.getCount();
Assert.assertEquals(addedTimingKeyCount, removedTimingKeyCount);
}
finally
{
Expand Down

0 comments on commit 83bb325

Please sign in to comment.