Skip to content

Commit

Permalink
Optimize warning header de-duplication (elastic#37725)
Browse files Browse the repository at this point in the history
Now that warning headers no longer contain a timestamp of when the
warning was generated, we no longer need to extract the warning value
from the warning to determine whether or not the warning value is
duplicated. Instead, we can compare strings directly.

Further, when de-duplicating warning headers, are constantly rebuilding
sets. Instead of doing that, we can carry about the set with us and
rebuild it if we find a new warning value.

This commit applies both of these optimizations.
  • Loading branch information
jasontedor authored Jan 24, 2019
1 parent feab59d commit 7517e3a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ void deprecated(final Set<ThreadContext> threadContexts, final String message, f
while (iterator.hasNext()) {
try {
final ThreadContext next = iterator.next();
next.addResponseHeader("Warning", warningHeaderValue, DeprecationLogger::extractWarningValueFromWarningHeader);
next.addResponseHeader("Warning", warningHeaderValue);
} catch (final IllegalStateException e) {
// ignored; it should be removed shortly
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Collector;
import java.util.stream.Stream;

import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_COUNT;
Expand Down Expand Up @@ -258,11 +262,11 @@ public Map<String, String> getHeaders() {
* @return Never {@code null}.
*/
public Map<String, List<String>> getResponseHeaders() {
Map<String, List<String>> responseHeaders = threadLocal.get().responseHeaders;
Map<String, Set<String>> responseHeaders = threadLocal.get().responseHeaders;
HashMap<String, List<String>> map = new HashMap<>(responseHeaders.size());

for (Map.Entry<String, List<String>> entry : responseHeaders.entrySet()) {
map.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
for (Map.Entry<String, Set<String>> entry : responseHeaders.entrySet()) {
map.put(entry.getKey(), Collections.unmodifiableList(new ArrayList<>(entry.getValue())));
}

return Collections.unmodifiableMap(map);
Expand Down Expand Up @@ -405,7 +409,7 @@ default void restore() {
private static final class ThreadContextStruct {
private final Map<String, String> requestHeaders;
private final Map<String, Object> transientHeaders;
private final Map<String, List<String>> responseHeaders;
private final Map<String, Set<String>> responseHeaders;
private final boolean isSystemContext;
private long warningHeadersSize; //saving current warning headers' size not to recalculate the size with every new warning header
private ThreadContextStruct(StreamInput in) throws IOException {
Expand All @@ -416,7 +420,23 @@ private ThreadContextStruct(StreamInput in) throws IOException {
}

this.requestHeaders = requestHeaders;
this.responseHeaders = in.readMapOfLists(StreamInput::readString, StreamInput::readString);
this.responseHeaders = in.readMap(StreamInput::readString, input -> {
final int size = input.readVInt();
if (size == 0) {
return Collections.emptySet();
} else if (size == 1) {
return Collections.singleton(input.readString());
} else {
// use a linked hash set to preserve order
final LinkedHashSet<String> values = new LinkedHashSet<>(size);
for (int i = 0; i < size; i++) {
final String value = input.readString();
final boolean added = values.add(value);
assert added : value;
}
return values;
}
});
this.transientHeaders = Collections.emptyMap();
isSystemContext = false; // we never serialize this it's a transient flag
this.warningHeadersSize = 0L;
Expand All @@ -430,7 +450,7 @@ private ThreadContextStruct setSystemContext() {
}

private ThreadContextStruct(Map<String, String> requestHeaders,
Map<String, List<String>> responseHeaders,
Map<String, Set<String>> responseHeaders,
Map<String, Object> transientHeaders, boolean isSystemContext) {
this.requestHeaders = requestHeaders;
this.responseHeaders = responseHeaders;
Expand All @@ -440,7 +460,7 @@ private ThreadContextStruct(Map<String, String> requestHeaders,
}

private ThreadContextStruct(Map<String, String> requestHeaders,
Map<String, List<String>> responseHeaders,
Map<String, Set<String>> responseHeaders,
Map<String, Object> transientHeaders, boolean isSystemContext,
long warningHeadersSize) {
this.requestHeaders = requestHeaders;
Expand Down Expand Up @@ -481,19 +501,19 @@ private ThreadContextStruct putHeaders(Map<String, String> headers) {
}
}

private ThreadContextStruct putResponseHeaders(Map<String, List<String>> headers) {
private ThreadContextStruct putResponseHeaders(Map<String, Set<String>> headers) {
assert headers != null;
if (headers.isEmpty()) {
return this;
}
final Map<String, List<String>> newResponseHeaders = new HashMap<>(this.responseHeaders);
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
final Map<String, Set<String>> newResponseHeaders = new HashMap<>(this.responseHeaders);
for (Map.Entry<String, Set<String>> entry : headers.entrySet()) {
String key = entry.getKey();
final List<String> existingValues = newResponseHeaders.get(key);
final Set<String> existingValues = newResponseHeaders.get(key);
if (existingValues != null) {
List<String> newValues = Stream.concat(entry.getValue().stream(),
existingValues.stream()).distinct().collect(Collectors.toList());
newResponseHeaders.put(key, Collections.unmodifiableList(newValues));
final Set<String> newValues =
Stream.concat(entry.getValue().stream(), existingValues.stream()).collect(LINKED_HASH_SET_COLLECTOR);
newResponseHeaders.put(key, Collections.unmodifiableSet(newValues));
} else {
newResponseHeaders.put(key, entry.getValue());
}
Expand Down Expand Up @@ -523,20 +543,19 @@ private ThreadContextStruct putResponse(final String key, final String value, fi
}
}

final Map<String, List<String>> newResponseHeaders = new HashMap<>(this.responseHeaders);
final List<String> existingValues = newResponseHeaders.get(key);
final Map<String, Set<String>> newResponseHeaders;
final Set<String> existingValues = responseHeaders.get(key);
if (existingValues != null) {
final Set<String> existingUniqueValues = existingValues.stream().map(uniqueValue).collect(Collectors.toSet());
assert existingValues.size() == existingUniqueValues.size() :
"existing values: [" + existingValues + "], existing unique values [" + existingUniqueValues + "]";
if (existingUniqueValues.contains(uniqueValue.apply(value))) {
if (existingValues.contains(uniqueValue.apply(value))) {
return this;
}
final List<String> newValues = new ArrayList<>(existingValues);
newValues.add(value);
newResponseHeaders.put(key, Collections.unmodifiableList(newValues));
// preserve insertion order
final Set<String> newValues = Stream.concat(existingValues.stream(), Stream.of(value)).collect(LINKED_HASH_SET_COLLECTOR);
newResponseHeaders = new HashMap<>(responseHeaders);
newResponseHeaders.put(key, Collections.unmodifiableSet(newValues));
} else {
newResponseHeaders.put(key, Collections.singletonList(value));
newResponseHeaders = new HashMap<>(responseHeaders);
newResponseHeaders.put(key, Collections.singleton(value));
}

//check if we can add another warning header - if max count within limits
Expand Down Expand Up @@ -588,7 +607,7 @@ private void writeTo(StreamOutput out, Map<String, String> defaultHeaders) throw
out.writeString(entry.getValue());
}

out.writeMapOfLists(responseHeaders, StreamOutput::writeString, StreamOutput::writeString);
out.writeMap(responseHeaders, StreamOutput::writeString, StreamOutput::writeStringCollection);
}
}

Expand Down Expand Up @@ -751,4 +770,40 @@ public AbstractRunnable unwrap() {
return in;
}
}

private static final Collector<String, Set<String>, Set<String>> LINKED_HASH_SET_COLLECTOR = new LinkedHashSetCollector<>();

private static class LinkedHashSetCollector<T> implements Collector<T, Set<T>, Set<T>> {
@Override
public Supplier<Set<T>> supplier() {
return LinkedHashSet::new;
}

@Override
public BiConsumer<Set<T>, T> accumulator() {
return Set::add;
}

@Override
public BinaryOperator<Set<T>> combiner() {
return (left, right) -> {
left.addAll(right);
return left;
};
}

@Override
public Function<Set<T>, Set<T>> finisher() {
return Function.identity();
}

private static final Set<Characteristics> CHARACTERISTICS =
Collections.unmodifiableSet(EnumSet.of(Collector.Characteristics.IDENTITY_FINISH));

@Override
public Set<Characteristics> characteristics() {
return CHARACTERISTICS;
}
}

}

0 comments on commit 7517e3a

Please sign in to comment.