Skip to content

Commit d491c63

Browse files
Fix aggregation memory leak for CCS (#78404) (#78604)
When a CCS search is proxied, the memory for the aggregations on the proxy node would not be freed. Now only use the non-copying byte referencing version on the coordinating node, which itself ensures that memory is freed by calling `consumeAggs`. Relates #72309
1 parent 3487a66 commit d491c63

File tree

7 files changed

+205
-17
lines changed

7 files changed

+205
-17
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
9+
package org.elasticsearch.search.ccs;
10+
11+
import org.elasticsearch.action.ActionFuture;
12+
import org.elasticsearch.action.search.ClearScrollRequest;
13+
import org.elasticsearch.action.search.SearchRequest;
14+
import org.elasticsearch.action.search.SearchResponse;
15+
import org.elasticsearch.client.Client;
16+
import org.elasticsearch.cluster.metadata.IndexMetadata;
17+
import org.elasticsearch.cluster.node.DiscoveryNode;
18+
import org.elasticsearch.common.settings.Settings;
19+
import org.elasticsearch.core.TimeValue;
20+
import org.elasticsearch.index.query.MatchAllQueryBuilder;
21+
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
22+
import org.elasticsearch.search.builder.SearchSourceBuilder;
23+
import org.elasticsearch.test.AbstractMultiClustersTestCase;
24+
import org.elasticsearch.test.InternalTestCluster;
25+
import org.elasticsearch.transport.TransportService;
26+
import org.hamcrest.Matchers;
27+
28+
import java.util.ArrayList;
29+
import java.util.Collection;
30+
import java.util.List;
31+
import java.util.stream.Collectors;
32+
import java.util.stream.StreamSupport;
33+
34+
import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
35+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
36+
import static org.hamcrest.Matchers.equalTo;
37+
38+
public class CrossClusterSearchLeakIT extends AbstractMultiClustersTestCase {
39+
40+
@Override
41+
protected Collection<String> remoteClusterAlias() {
42+
return org.elasticsearch.core.List.of("cluster_a");
43+
}
44+
45+
@Override
46+
protected boolean reuseClusters() {
47+
return false;
48+
}
49+
50+
private int indexDocs(Client client, String field, String index) {
51+
int numDocs = between(1, 200);
52+
for (int i = 0; i < numDocs; i++) {
53+
client.prepareIndex(index, "_doc").setSource(field, "v" + i).get();
54+
}
55+
client.admin().indices().prepareRefresh(index).get();
56+
return numDocs;
57+
}
58+
59+
/**
60+
* This test validates that we do not leak any memory when running CCS in various modes, actual validation is done by test framework
61+
* (leak detection)
62+
* <ul>
63+
* <li>proxy vs non-proxy</li>
64+
* <li>single-phase query-fetch or multi-phase</li>
65+
* <li>minimize roundtrip vs not</li>
66+
* <li>scroll vs no scroll</li>
67+
* </ul>
68+
*/
69+
public void testSearch() throws Exception {
70+
assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate("demo")
71+
.addMapping("_doc", "f", "type=keyword")
72+
.setSettings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3))));
73+
indexDocs(client(LOCAL_CLUSTER), "ignored", "demo");
74+
final InternalTestCluster remoteCluster = cluster("cluster_a");
75+
int minRemotes = between(2, 5);
76+
remoteCluster.ensureAtLeastNumDataNodes(minRemotes);
77+
List<String> remoteDataNodes = StreamSupport.stream(remoteCluster.clusterService().state().nodes().spliterator(), false)
78+
.filter(DiscoveryNode::canContainData)
79+
.map(DiscoveryNode::getName)
80+
.collect(Collectors.toList());
81+
assertThat(remoteDataNodes.size(), Matchers.greaterThanOrEqualTo(minRemotes));
82+
List<String> seedNodes = randomSubsetOf(between(1, remoteDataNodes.size() - 1), remoteDataNodes);
83+
disconnectFromRemoteClusters();
84+
configureRemoteCluster("cluster_a", seedNodes);
85+
final Settings.Builder allocationFilter = Settings.builder();
86+
if (rarely()) {
87+
allocationFilter.put("index.routing.allocation.include._name", String.join(",", seedNodes));
88+
} else {
89+
// Provoke using proxy connections
90+
allocationFilter.put("index.routing.allocation.exclude._name", String.join(",", seedNodes));
91+
}
92+
assertAcked(client("cluster_a").admin().indices().prepareCreate("prod")
93+
.addMapping("_doc", "f", "type=keyword")
94+
.setSettings(Settings.builder().put(allocationFilter.build())
95+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3))));
96+
assertFalse(client("cluster_a").admin().cluster().prepareHealth("prod")
97+
.setWaitForYellowStatus().setTimeout(TimeValue.timeValueSeconds(10)).get().isTimedOut());
98+
int docs = indexDocs(client("cluster_a"), "f", "prod");
99+
100+
List<ActionFuture<SearchResponse>> futures = new ArrayList<>();
101+
for (int i = 0; i < 10; ++i) {
102+
String[] indices = randomBoolean() ? new String[] { "demo", "cluster_a:prod" } : new String[] { "cluster_a:prod" };
103+
final SearchRequest searchRequest = new SearchRequest(indices);
104+
searchRequest.allowPartialSearchResults(false);
105+
boolean scroll = randomBoolean();
106+
searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())
107+
.aggregation(terms("f").field("f").size(docs + between(scroll ? 1 : 0, 10))).size(between(0, 1000)));
108+
if (scroll) {
109+
searchRequest.scroll("30s");
110+
}
111+
searchRequest.setCcsMinimizeRoundtrips(rarely());
112+
futures.add(client(LOCAL_CLUSTER).search(searchRequest));
113+
}
114+
115+
for (ActionFuture<SearchResponse> future : futures) {
116+
SearchResponse searchResponse = future.get();
117+
if (searchResponse.getScrollId() != null) {
118+
ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
119+
clearScrollRequest.scrollIds(org.elasticsearch.core.List.of(searchResponse.getScrollId()));
120+
client(LOCAL_CLUSTER).clearScroll(clearScrollRequest).get();
121+
}
122+
123+
Terms terms = searchResponse.getAggregations().get("f");
124+
assertThat(terms.getBuckets().size(), equalTo(docs));
125+
for (Terms.Bucket bucket : terms.getBuckets()) {
126+
assertThat(bucket.getDocCount(), equalTo(1L));
127+
}
128+
}
129+
}
130+
131+
@Override
132+
protected void configureRemoteCluster(String clusterAlias, Collection<String> seedNodes) throws Exception {
133+
if (rarely()) {
134+
super.configureRemoteCluster(clusterAlias, seedNodes);
135+
} else {
136+
final Settings.Builder settings = Settings.builder();
137+
final String seedNode = randomFrom(seedNodes);
138+
final TransportService transportService = cluster(clusterAlias).getInstance(TransportService.class, seedNode);
139+
final String seedAddress = transportService.boundAddress().publishAddress().toString();
140+
141+
settings.put("cluster.remote." + clusterAlias + ".mode", "proxy");
142+
settings.put("cluster.remote." + clusterAlias + ".proxy_address", seedAddress);
143+
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();
144+
}
145+
}
146+
}

server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ public void sendExecuteQuery(Transport.Connection connection, final ShardSearchR
138138
// we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request
139139
// this used to be the QUERY_AND_FETCH which doesn't exist anymore.
140140
final boolean fetchDocuments = request.numberOfShards() == 1;
141-
Writeable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new;
141+
Writeable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : in -> new QuerySearchResult(in, true);
142142

143143
final ActionListener<? super SearchPhaseResult> handler = responseWrapper.apply(connection, listener);
144144
transportService.sendChildRequest(connection, QUERY_ACTION_NAME, request, task,

server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
package org.elasticsearch.common.io.stream;
1010

1111
import org.elasticsearch.Version;
12+
import org.elasticsearch.common.bytes.BytesReference;
1213
import org.elasticsearch.common.bytes.ReleasableBytesReference;
1314
import org.elasticsearch.core.Releasable;
1415

@@ -50,6 +51,12 @@ public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Read
5051
return new Serialized<>(reader, in.getVersion(), in.namedWriteableRegistry(), in.readReleasableBytesReference());
5152
}
5253

54+
public static <T extends Writeable> DelayableWriteable<T> referencing(Writeable.Reader<T> reader, StreamInput in) throws IOException {
55+
try (ReleasableBytesReference serialized = in.readReleasableBytesReference()) {
56+
return new Referencing<>(deserialize(reader, in.getVersion(), in.namedWriteableRegistry(), serialized));
57+
}
58+
}
59+
5360
private DelayableWriteable() {}
5461

5562
/**
@@ -67,7 +74,7 @@ private DelayableWriteable() {}
6774
* {@code true} if the {@linkplain Writeable} is being stored in
6875
* serialized form, {@code false} otherwise.
6976
*/
70-
abstract boolean isSerialized();
77+
public abstract boolean isSerialized();
7178

7279
/**
7380
* Returns the serialized size of the inner {@link Writeable}.
@@ -104,7 +111,7 @@ public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry regis
104111
}
105112

106113
@Override
107-
boolean isSerialized() {
114+
public boolean isSerialized() {
108115
return false;
109116
}
110117

@@ -169,11 +176,7 @@ public void writeTo(StreamOutput out) throws IOException {
169176
@Override
170177
public T expand() {
171178
try {
172-
try (StreamInput in = registry == null ?
173-
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
174-
in.setVersion(serializedAtVersion);
175-
return reader.read(in);
176-
}
179+
return deserialize(reader, serializedAtVersion, registry, serialized);
177180
} catch (IOException e) {
178181
throw new RuntimeException("unexpected error expanding serialized delayed writeable", e);
179182
}
@@ -185,7 +188,7 @@ public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry regis
185188
}
186189

187190
@Override
188-
boolean isSerialized() {
191+
public boolean isSerialized() {
189192
return true;
190193
}
191194

@@ -214,6 +217,15 @@ public static long getSerializedSize(Writeable ref) {
214217
}
215218
}
216219

220+
private static <T> T deserialize(Reader<T> reader, Version serializedAtVersion, NamedWriteableRegistry registry,
221+
BytesReference serialized) throws IOException {
222+
try (StreamInput in =
223+
registry == null ? serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
224+
in.setVersion(serializedAtVersion);
225+
return reader.read(in);
226+
}
227+
}
228+
217229
private static class CountingStreamOutput extends StreamOutput {
218230
long size = 0;
219231

server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.elasticsearch.search.suggest.Suggest;
3636

3737
public final class QuerySearchResult extends SearchPhaseResult {
38-
3938
private int from;
4039
private int size;
4140
private TopDocsAndMaxScore topDocsAndMaxScore;
@@ -67,6 +66,15 @@ public QuerySearchResult() {
6766
}
6867

6968
public QuerySearchResult(StreamInput in) throws IOException {
69+
this(in, false);
70+
}
71+
72+
/**
73+
* Read the object, but using a delayed aggregations field when delayedAggregations=true. Using this, the caller must ensure that
74+
* either `consumeAggs` or `releaseAggs` is called if `hasAggs() == true`.
75+
* @param delayedAggregations whether to use delayed aggregations or not
76+
*/
77+
public QuerySearchResult(StreamInput in, boolean delayedAggregations) throws IOException {
7078
super(in);
7179
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
7280
isNull = in.readBoolean();
@@ -75,7 +83,7 @@ public QuerySearchResult(StreamInput in) throws IOException {
7583
}
7684
if (isNull == false) {
7785
ShardSearchContextId id = new ShardSearchContextId(in);
78-
readFromWithId(id, in);
86+
readFromWithId(id, in, delayedAggregations);
7987
}
8088
}
8189

@@ -318,6 +326,10 @@ public boolean hasSearchContext() {
318326
}
319327

320328
public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOException {
329+
readFromWithId(id, in, false);
330+
}
331+
332+
private void readFromWithId(ShardSearchContextId id, StreamInput in, boolean delayedAggregations) throws IOException {
321333
this.contextId = id;
322334
from = in.readVInt();
323335
size = in.readVInt();
@@ -344,7 +356,11 @@ public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOExc
344356
}
345357
} else {
346358
if (hasAggs) {
347-
aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in);
359+
if (delayedAggregations) {
360+
aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in);
361+
} else {
362+
aggregations = DelayableWriteable.referencing(InternalAggregations::readFrom, in);
363+
}
348364
}
349365
}
350366
if (in.readBoolean()) {
@@ -371,6 +387,8 @@ public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOExc
371387

372388
@Override
373389
public void writeTo(StreamOutput out) throws IOException {
390+
// we do not know that it is being sent over transport, but this at least protects all writes from happening, including sending.
391+
assert aggregations == null || aggregations.isSerialized() == false : "cannot send serialized version since it will leak";
374392
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
375393
out.writeBoolean(isNull);
376394
}

server/src/test/java/org/elasticsearch/common/io/stream/DelayableWriteableTests.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,12 @@ public void testRoundTripFromDelayedWithNamedWriteable() throws IOException {
134134
public void testRoundTripFromDelayedFromOldVersion() throws IOException {
135135
Example e = new Example(randomAlphaOfLength(5));
136136
DelayableWriteable<Example> original = roundTrip(DelayableWriteable.referencing(e), Example::new, randomOldVersion());
137-
assertTrue(original.isSerialized());
138137
roundTripTestCase(original, Example::new);
139138
}
140139

141140
public void testRoundTripFromDelayedFromOldVersionWithNamedWriteable() throws IOException {
142141
NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5)));
143142
DelayableWriteable<NamedHolder> original = roundTrip(DelayableWriteable.referencing(n), NamedHolder::new, randomOldVersion());
144-
assertTrue(original.isSerialized());
145143
roundTripTestCase(original, NamedHolder::new);
146144
}
147145

@@ -160,14 +158,20 @@ public void testAsSerializedIsNoopOnSerialized() throws IOException {
160158

161159
private <T extends Writeable> void roundTripTestCase(DelayableWriteable<T> original, Writeable.Reader<T> reader) throws IOException {
162160
DelayableWriteable<T> roundTripped = roundTrip(original, reader, Version.CURRENT);
163-
assertTrue(roundTripped.isSerialized());
164161
assertThat(roundTripped.expand(), equalTo(original.expand()));
165162
}
166163

167164
private <T extends Writeable> DelayableWriteable<T> roundTrip(DelayableWriteable<T> original,
168165
Writeable.Reader<T> reader, Version version) throws IOException {
169-
return copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
166+
DelayableWriteable<T> delayed = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
170167
in -> DelayableWriteable.delayed(reader, in), version);
168+
assertTrue(delayed.isSerialized());
169+
170+
DelayableWriteable<T> referencing = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
171+
in -> DelayableWriteable.referencing(reader, in), version);
172+
assertFalse(referencing.isSerialized());
173+
174+
return randomFrom(delayed, referencing);
171175
}
172176

173177
@Override

server/src/test/java/org/elasticsearch/search/query/QuerySearchResultTests.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import java.util.Base64;
4040

4141
import static java.util.Collections.emptyList;
42+
import static org.hamcrest.Matchers.is;
43+
import static org.hamcrest.Matchers.nullValue;
4244

4345
public class QuerySearchResultTests extends ESTestCase {
4446

@@ -74,7 +76,9 @@ private static QuerySearchResult createTestInstance() throws Exception {
7476

7577
public void testSerialization() throws Exception {
7678
QuerySearchResult querySearchResult = createTestInstance();
77-
QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry, QuerySearchResult::new);
79+
boolean delayed = randomBoolean();
80+
QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry,
81+
delayed ? in -> new QuerySearchResult(in, true) : QuerySearchResult::new);
7882
assertEquals(querySearchResult.getContextId().getId(), deserialized.getContextId().getId());
7983
assertNull(deserialized.getSearchShardTarget());
8084
assertEquals(querySearchResult.topDocs().maxScore, deserialized.topDocs().maxScore, 0f);
@@ -83,9 +87,11 @@ public void testSerialization() throws Exception {
8387
assertEquals(querySearchResult.size(), deserialized.size());
8488
assertEquals(querySearchResult.hasAggs(), deserialized.hasAggs());
8589
if (deserialized.hasAggs()) {
90+
assertThat(deserialized.aggregations().isSerialized(), is(delayed));
8691
Aggregations aggs = querySearchResult.consumeAggs();
8792
Aggregations deserializedAggs = deserialized.consumeAggs();
8893
assertEquals(aggs.asList(), deserializedAggs.asList());
94+
assertThat(deserialized.aggregations(), is(nullValue()));
8995
}
9096
assertEquals(querySearchResult.terminatedEarly(), deserialized.terminatedEarly());
9197
}

test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ protected void disconnectFromRemoteClusters() throws Exception {
133133
for (String clusterAlias : clusterAliases) {
134134
if (clusterAlias.equals(LOCAL_CLUSTER) == false) {
135135
settings.putNull("cluster.remote." + clusterAlias + ".seeds");
136+
settings.putNull("cluster.remote." + clusterAlias + ".mode");
137+
settings.putNull("cluster.remote." + clusterAlias + ".proxy_address");
136138
}
137139
}
138140
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();

0 commit comments

Comments
 (0)