Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TimingKey Memory Leak #587 #588

Merged
merged 1 commit into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and what APIs have changed, if applicable.
- Strictly enforce Gradle version compatibility in the `pegasus` Gradle plugin.
- Minimum required Gradle version is now `1.0` (effectively backward-compatible).
- Minimum suggested Gradle version is now `5.2.1`
- Fix TimingKey Memory Leak

## [29.18.2] - 2021-04-28
- Fix bug in generated fluent client APIs when typerefs are used as association key params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ public void shutdown(final Callback<None> callback)

callback.onSuccess(None.none());
});
TimingKey.unregisterKey(TIMING_KEY);
}

@Override
Expand Down
12 changes: 11 additions & 1 deletion r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamResponse;

import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -189,4 +189,14 @@ void onStreamResponse(StreamResponse res,
void onStreamError(Exception ex,
RequestContext requestContext,
Map<String, String> wireAttrs);

/**
* Returns a copy of a list of RestFilters
*/
List<RestFilter> getRestFilters();

/**
* Returns a copy of a list of StreamFilters
*/
List<StreamFilter> getStreamFilters();
}
10 changes: 10 additions & 0 deletions r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ public FilterChain addLast(StreamFilter filter)
return new FilterChainImpl(_restFilters, doAddLast(_streamFilters, decorateStreamFilter(filter)));
}

@Override
public List<RestFilter> getRestFilters() {
return new ArrayList<RestFilter>(_restFilters);
}

@Override
public List<StreamFilter> getStreamFilters() {
return new ArrayList<StreamFilter>(_streamFilters);
}

private RestFilter decorateRestFilter(RestFilter filter)
{
return new TimedRestFilter(filter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.timing.TimingImportance;
import java.util.Arrays;
import java.util.List;
import java.util.Map;


Expand All @@ -31,7 +33,7 @@
*
* @author Xialin Zhu
*/
/* package private */ class TimedRestFilter implements RestFilter
public class TimedRestFilter implements RestFilter
{
protected static final String ON_REQUEST_SUFFIX = "onRequest";
protected static final String ON_RESPONSE_SUFFIX = "onResponse";
Expand All @@ -41,6 +43,7 @@
private final TimingKey _onRequestTimingKey;
private final TimingKey _onResponseTimingKey;
private final TimingKey _onErrorTimingKey;
private boolean _shared;

/**
* Registers {@link TimingKey}s for {@link com.linkedin.r2.message.timing.TimingNameConstants#TIMED_REST_FILTER}.
Expand All @@ -61,6 +64,7 @@ public TimedRestFilter(RestFilter restFilter)
_restFilter.getClass().getSimpleName(), TimingImportance.LOW);
_onErrorTimingKey = TimingKey.registerNewKey(timingKeyPrefix + ON_ERROR_SUFFIX + timingKeyPostfix,
_restFilter.getClass().getSimpleName(), TimingImportance.LOW);
_shared = false;
}

@Override
Expand Down Expand Up @@ -91,4 +95,16 @@ public void onRestError(Throwable ex,
TimingContextUtil.markTiming(requestContext, _onErrorTimingKey);
_restFilter.onRestError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter));
}

public void setShared() {
_shared = true;
}

public void onShutdown() {
if (!_shared) {
TimingKey.unregisterKey(_onErrorTimingKey);
TimingKey.unregisterKey(_onRequestTimingKey);
TimingKey.unregisterKey(_onResponseTimingKey);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
*
* @author Xialin Zhu
*/
/* package private */ class TimedStreamFilter implements StreamFilter
public class TimedStreamFilter implements StreamFilter
zizhong marked this conversation as resolved.
Show resolved Hide resolved
{
private final StreamFilter _streamFilter;
private final TimingKey _onRequestTimingKey;
private final TimingKey _onResponseTimingKey;
private final TimingKey _onErrorTimingKey;
private boolean _shared;

/**
* Registers {@link TimingKey}s for {@link com.linkedin.r2.message.timing.TimingNameConstants#TIMED_STREAM_FILTER}.
Expand All @@ -60,6 +61,7 @@ public TimedStreamFilter(StreamFilter streamFilter)
filterClassName, TimingImportance.LOW);
_onErrorTimingKey = TimingKey.registerNewKey(timingKeyPrefix + ON_ERROR_SUFFIX + timingKeyPostfix,
filterClassName, TimingImportance.LOW);
_shared = false;
}

@Override
Expand Down Expand Up @@ -91,4 +93,16 @@ public void onStreamError(Throwable ex,
TimingContextUtil.markTiming(requestContext, _onErrorTimingKey);
_streamFilter.onStreamError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter));
}

public void setShared() {
_shared = true;
}

public void onShutdown() {
if (!_shared) {
TimingKey.unregisterKey(_onErrorTimingKey);
TimingKey.unregisterKey(_onRequestTimingKey);
TimingKey.unregisterKey(_onResponseTimingKey);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
/* $Id$ */
package com.linkedin.r2.filter.transport;


import com.linkedin.common.callback.Callback;
import com.linkedin.common.util.None;
import com.linkedin.r2.filter.FilterChain;
import com.linkedin.r2.filter.TimedRestFilter;
import com.linkedin.r2.filter.TimedStreamFilter;
import com.linkedin.r2.filter.message.rest.RestFilter;
import com.linkedin.r2.filter.message.stream.StreamFilter;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.Response;
import com.linkedin.r2.message.rest.RestRequest;
Expand All @@ -29,12 +32,15 @@
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.timing.FrameworkTimingKeys;
import com.linkedin.r2.message.timing.TimingContextUtil;
import com.linkedin.r2.message.timing.TimingKey;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
import com.linkedin.r2.transport.common.bridge.common.TransportCallback;

import com.linkedin.r2.transport.common.bridge.common.TransportResponse;
import java.util.Collection;
import java.util.List;
import java.util.Map;


/**
* {@link TransportClient} adapter which composes a {@link TransportClient}
* and a {@link FilterChain}.
Expand Down Expand Up @@ -94,6 +100,12 @@ public void streamRequest(StreamRequest request,
public void shutdown(Callback<None> callback)
{
_client.shutdown(callback);

_filters.getStreamFilters().stream().filter(TimedStreamFilter.class::isInstance)
.map(TimedStreamFilter.class::cast).forEach(TimedStreamFilter::onShutdown);

_filters.getRestFilters().stream().filter(TimedRestFilter.class::isInstance)
.map(TimedRestFilter.class::cast).forEach(TimedRestFilter::onShutdown);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
package com.linkedin.r2.message.timing;

import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;

import com.linkedin.r2.message.RequestContext;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;


/**
Expand All @@ -32,6 +35,7 @@
public class TimingKey
{
private static final Map<String, TimingKey> _pool = new ConcurrentHashMap<>();
private static final ExecutorService _unregisterExecutor = Executors.newFixedThreadPool(1);
zizhong marked this conversation as resolved.
Show resolved Hide resolved

private final String _name;
private final String _type;
Expand Down Expand Up @@ -130,4 +134,26 @@ public static TimingKey registerNewKey(String uniqueName, String type, TimingImp
{
return registerNewKey(new TimingKey(uniqueName, type, timingImportance));
}

/**
* Unregister a TimingKey to reclaim the memory
*
*/
public static void unregisterKey(TimingKey key)
{
_unregisterExecutor.submit(new Callable<Void>() {
public Void call() throws Exception {
_pool.remove(key.getName());
return null;
}
});
}

/**
* Return how many registered keys, for testing purpose.
*/
public static int getCount() {
return _pool.size();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ public void onSuccess(None result)
{
callback.onError(new IllegalStateException("Shutdown has already been requested."));
}
TimingKey.unregisterKey(TIMING_KEY);
}

private void sendStreamRequestAsRestRequest(StreamRequest request, RequestContext requestContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.linkedin.r2.filter.CompressionConfig;
import com.linkedin.r2.filter.FilterChain;
import com.linkedin.r2.filter.FilterChains;
import com.linkedin.r2.filter.TimedRestFilter;
import com.linkedin.r2.filter.TimedStreamFilter;
import com.linkedin.r2.filter.compression.ClientCompressionFilter;
import com.linkedin.r2.filter.compression.ClientCompressionHelper;
import com.linkedin.r2.filter.compression.ClientStreamCompressionFilter;
Expand Down Expand Up @@ -685,6 +687,11 @@ private HttpClientFactory(FilterChain filters,
{
_channelPoolManagerFactory = new ConnectionSharingChannelPoolManagerFactory(_channelPoolManagerFactory);
}

_filters.getStreamFilters().stream().filter(TimedStreamFilter.class::isInstance)
.map(TimedStreamFilter.class::cast).forEach(TimedStreamFilter::setShared);
_filters.getRestFilters().stream().filter(TimedRestFilter.class::isInstance)
.map(TimedRestFilter.class::cast).forEach(TimedRestFilter::setShared);
}

public static class Builder
Expand Down Expand Up @@ -958,7 +965,6 @@ public TransportClient getClient(Map<String, ? extends Object> properties)
properties = new HashMap<String,Object>(properties);
sslContext = coerceAndRemoveFromMap(HTTP_SSL_CONTEXT, properties, SSLContext.class);
sslParameters = coerceAndRemoveFromMap(HTTP_SSL_PARAMS, properties, SSLParameters.class);

return getClient(properties, sslContext, sslParameters);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ public void onSuccess(None result)
_shutdownTimeout);
_jmxManager.onProviderShutdown(_channelPoolManager);
_jmxManager.onProviderShutdown(_sslChannelPoolManager);
TimingKey.unregisterKey(TIMING_KEY);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestRequestBuilder;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.timing.TimingKey;
import com.linkedin.r2.testutils.server.HttpServerBuilder;
import com.linkedin.r2.transport.common.Client;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
Expand Down Expand Up @@ -93,20 +94,25 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion
{
server.start();
List<Client> clients = new ArrayList<>();

int savedTimingKeyCount = TimingKey.getCount();
for (int i = 0; i < 100; i++)
{
HashMap<String, String> properties = new HashMap<>();
properties.put(HttpClientFactory.HTTP_PROTOCOL_VERSION, protocolVersion);
clients.add(new TransportClientAdapter(factory.getClient(properties), restOverStream));
}

int addedTimingKeyCount = TimingKey.getCount() - savedTimingKeyCount;
// In current implementation, one client can have around 30 TimingKeys by default.
Assert.assertTrue(addedTimingKeyCount >= 30 * clients.size());
for (Client c : clients)
{
RestRequest r = new RestRequestBuilder(new URI(URI)).build();
c.restRequest(r).get(30, TimeUnit.SECONDS);
}
Assert.assertEquals(httpServerStatsProvider.requestCount(), expectedRequests);

savedTimingKeyCount = TimingKey.getCount();
for (Client c : clients)
{
FutureCallback<None> callback = new FutureCallback<>();
Expand All @@ -117,6 +123,8 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion
FutureCallback<None> factoryShutdown = new FutureCallback<>();
factory.shutdown(factoryShutdown);
factoryShutdown.get(30, TimeUnit.SECONDS);
int removedTimingKeyCount = savedTimingKeyCount - TimingKey.getCount();
Assert.assertEquals(addedTimingKeyCount, removedTimingKeyCount);
}
finally
{
Expand Down