Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize TaskBatcher behavior in case of a datacenter failure. #41407

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.LinkedHashSet;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand All @@ -46,7 +49,7 @@ public abstract class TaskBatcher {
private final Logger logger;
private final PrioritizedEsThreadPoolExecutor threadExecutor;
// package visible for tests
final Map<Object, LinkedHashSet<BatchedTask>> tasksPerBatchingKey = new HashMap<>();
final ConcurrentMap<Object, Map<IdentityWrapper, BatchedTask>> tasksPerBatchingKey = new ConcurrentHashMap<>();

public TaskBatcher(Logger logger, PrioritizedEsThreadPoolExecutor threadExecutor) {
this.logger = logger;
Expand All @@ -61,25 +64,33 @@ public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue t
assert tasks.stream().allMatch(t -> t.batchingKey == firstTask.batchingKey) :
"tasks submitted in a batch should share the same batching key: " + tasks;
// convert to an identity map to check for dups based on task identity
final Map<Object, BatchedTask> tasksIdentity = tasks.stream().collect(Collectors.toMap(
BatchedTask::getTask,
final Map<IdentityWrapper, BatchedTask> toAdd = tasks.stream().collect(Collectors.toMap(
t -> new IdentityWrapper(t.getTask()),
Function.identity(),
(a, b) -> { throw new IllegalStateException("cannot add duplicate task: " + a); },
IdentityHashMap::new));

synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.computeIfAbsent(firstTask.batchingKey,
k -> new LinkedHashSet<>(tasks.size()));
for (BatchedTask existing : existingTasks) {
// check that there won't be two tasks with the same identity for the same batching key
BatchedTask duplicateTask = tasksIdentity.get(existing.getTask());
if (duplicateTask != null) {
throw new IllegalStateException("task [" + duplicateTask.describeTasks(
Collections.singletonList(existing)) + "] with source [" + duplicateTask.source + "] is already queued");
LinkedHashMap::new));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why LinkedHashMap? Do you care about insertion order?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been trying to understand that and I am not 100% sure that we should or should not care.
The insertion order was maintained in the original version with the help of LinkedHashSet though and I was not brave enough to relax the semantics.


tasksPerBatchingKey.merge(
firstTask.batchingKey,
toAdd,
(oldValue, newValue) -> {
final Map<IdentityWrapper, BatchedTask> merged =
new LinkedHashMap<>(oldValue.size() + newValue.size());
merged.putAll(oldValue);
merged.putAll(newValue);

if (merged.size() != oldValue.size() + newValue.size()) {
// Find the duplicate
oldValue.forEach((k, existing) -> {
final BatchedTask duplicateTask = newValue.get(k);
if (duplicateTask != null) {
throw new IllegalStateException("task [" + duplicateTask.describeTasks(
Collections.singletonList(existing)) + "] with source [" + duplicateTask.source + "] is already queued");
}
});
}
}
existingTasks.addAll(tasks);
}
return merged;
});

if (timeout != null) {
threadExecutor.execute(firstTask, timeout, () -> onTimeoutInternal(tasks, timeout));
Expand All @@ -89,29 +100,42 @@ public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue t
}

private void onTimeoutInternal(List<? extends BatchedTask> tasks, TimeValue timeout) {
final ArrayList<BatchedTask> toRemove = new ArrayList<>();
final Set<IdentityWrapper> ids = new HashSet<>(tasks.size());
final List<BatchedTask> toRemove = new ArrayList<>(tasks.size());
for (BatchedTask task : tasks) {
if (task.processed.getAndSet(true) == false) {
logger.debug("task [{}] timed out after [{}]", task.source, timeout);
ids.add(new IdentityWrapper(task.getTask()));
toRemove.add(task);
}
}
if (toRemove.isEmpty() == false) {
BatchedTask firstTask = toRemove.get(0);
Object batchingKey = firstTask.batchingKey;
assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey) :
"tasks submitted in a batch should share the same batching key: " + tasks;
synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.get(batchingKey);
if (existingTasks != null) {
existingTasks.removeAll(toRemove);
if (existingTasks.isEmpty()) {
tasksPerBatchingKey.remove(batchingKey);
}
}
}
onTimeout(toRemove, timeout);

if (toRemove.isEmpty()) {
return;
}

final BatchedTask firstTask = toRemove.get(0);
final Object batchingKey = firstTask.batchingKey;
assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey) :
"tasks submitted in a batch should share the same batching key: " + tasks;
tasksPerBatchingKey.computeIfPresent(
batchingKey,
(k, v) -> {
if (v.size() == ids.size() && ids.containsAll(v.keySet())) {
// Special case when all the tasks timed out
return null;
} else {
final Map<IdentityWrapper, BatchedTask> merged = new LinkedHashMap<>(v.size());
v.forEach((id, task) -> {
if (ids.contains(id) == false) {
merged.put(id, task);
}
});
return merged;
}
});

onTimeout(toRemove, timeout);
}

/**
Expand All @@ -120,23 +144,21 @@ private void onTimeoutInternal(List<? extends BatchedTask> tasks, TimeValue time
*/
protected abstract void onTimeout(List<? extends BatchedTask> tasks, TimeValue timeout);

void runIfNotProcessed(BatchedTask updateTask) {
private void runIfNotProcessed(BatchedTask updateTask) {
// if this task is already processed, it shouldn't execute other tasks with same batching key that arrived later,
// to give other tasks with different batching key a chance to execute.
if (updateTask.processed.get() == false) {
final List<BatchedTask> toExecute = new ArrayList<>();
final Map<String, List<BatchedTask>> processTasksBySource = new HashMap<>();
synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
if (pending != null) {
for (BatchedTask task : pending) {
if (task.processed.getAndSet(true) == false) {
logger.trace("will process {}", task);
toExecute.add(task);
processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task);
} else {
logger.trace("skipping {}, already processed", task);
}
final Map<IdentityWrapper, BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
if (pending != null) {
for (BatchedTask task : pending.values()) {
if (task.processed.getAndSet(true) == false) {
logger.trace("will process {}", task);
toExecute.add(task);
processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task);
} else {
logger.trace("skipping {}, already processed", task);
}
}
}
Expand Down Expand Up @@ -204,4 +226,27 @@ public Object getTask() {
return task;
}
}

/**
* Uses wrapped {@link Object} identity for {@link #equals(Object)} and {@link #hashCode()}.
*/
private static final class IdentityWrapper {
private final Object object;

private IdentityWrapper(final Object object) {
this.object = object;
}

@Override
public boolean equals(final Object o) {
assert o instanceof IdentityWrapper;
final IdentityWrapper that = (IdentityWrapper) o;
return object == that.object;
}

@Override
public int hashCode() {
return System.identityHashCode(object);
}
}
}