Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion checkstyle/suppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
files="Murmur3.java"/>

<suppress checks="(NPathComplexity|CyclomaticComplexity)"
files="KStreamSlidingWindowAggregate.java"/>
files="(KStreamSlidingWindowAggregate|RackAwareTaskAssignor).java"/>

<!-- suppress FinalLocalVariable outside of the streams package. -->
<suppress checks="FinalLocalVariable"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ public Edge(final V destination, final int capacity, final int cost, final int r
this(destination, capacity, cost, residualFlow, flow, true);
}

public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow,
final boolean forwardEdge) {
public Edge(final V destination,
final int capacity,
final int cost,
final int residualFlow,
final int flow,
final boolean forwardEdge) {
Objects.requireNonNull(destination);
if (capacity < 0) {
throw new IllegalArgumentException("Edge capacity cannot be negative");
Expand Down Expand Up @@ -72,8 +76,11 @@ public boolean equals(final Object other) {

final Graph<?>.Edge otherEdge = (Graph<?>.Edge) other;

return destination.equals(otherEdge.destination) && capacity == otherEdge.capacity
&& cost == otherEdge.cost && residualFlow == otherEdge.residualFlow && flow == otherEdge.flow
return destination.equals(otherEdge.destination)
&& capacity == otherEdge.capacity
&& cost == otherEdge.cost
&& residualFlow == otherEdge.residualFlow
&& flow == otherEdge.flow
&& forwardEdge == otherEdge.forwardEdge;
}

Expand All @@ -84,8 +91,15 @@ public int hashCode() {

@Override
public String toString() {
return "{destination= " + destination + ", capacity=" + capacity + ", cost=" + cost
+ ", residualFlow=" + residualFlow + ", flow=" + flow + ", forwardEdge=" + forwardEdge;
return "Edge {"
+ "destination= " + destination
+ ", capacity=" + capacity
+ ", cost=" + cost
+ ", residualFlow=" + residualFlow
+ ", flow=" + flow
+ ", forwardEdge=" + forwardEdge
+ "}";

}
}

Expand All @@ -106,12 +120,13 @@ public void addEdge(final V u, final V v, final int capacity, final int cost, fi
addEdge(u, new Edge(v, capacity, cost, capacity - flow, flow));
}

public Set<V> nodes() {
public SortedSet<V> nodes() {
return nodes;
}

public Map<V, Edge> edges(final V node) {
return adjList.get(node);
public SortedMap<V, Edge> edges(final V node) {
final SortedMap<V, Edge> edge = adjList.get(node);
return edge == null ? new TreeMap<>() : edge;
}

public boolean isResidualGraph() {
Expand All @@ -126,12 +141,12 @@ public void setSinkNode(final V node) {
sinkNode = node;
}

public int totalCost() {
int totalCost = 0;
public long totalCost() {
long totalCost = 0;
for (final Map.Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
final SortedMap<V, Edge> edges = nodeEdges.getValue();
for (final Entry<V, Edge> nodeEdge : edges.entrySet()) {
totalCost += nodeEdge.getValue().cost * nodeEdge.getValue().flow;
for (final Edge nodeEdge : edges.values()) {
totalCost += (long) nodeEdge.cost * nodeEdge.flow;
}
}
return totalCost;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
*/
package org.apache.kafka.streams.processor.internals.assignment;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.UUID;
import org.apache.kafka.common.Cluster;
import org.apache.kafka.common.Node;
Expand All @@ -39,28 +44,30 @@
public class RackAwareTaskAssignor {
private static final Logger log = LoggerFactory.getLogger(RackAwareTaskAssignor.class);

private static final int SOURCE_ID = -1;

private final Cluster fullMetadata;
private final Map<TaskId, Set<TopicPartition>> partitionsForTask;
private final Map<UUID, Map<String, Optional<String>>> processRacks;
private final Map<UUID, Map<String, Optional<String>>> racksForProcess;
private final AssignmentConfigs assignmentConfigs;
private final Map<TopicPartition, Set<String>> racksForPartition;
private final InternalTopicManager internalTopicManager;

public RackAwareTaskAssignor(final Cluster fullMetadata,
final Map<TaskId, Set<TopicPartition>> partitionsForTask,
final Map<Subtopology, Set<TaskId>> tasksForTopicGroup,
final Map<UUID, Map<String, Optional<String>>> processRacks,
final Map<UUID, Map<String, Optional<String>>> racksForProcess,
final InternalTopicManager internalTopicManager,
final AssignmentConfigs assignmentConfigs) {
this.fullMetadata = fullMetadata;
this.partitionsForTask = partitionsForTask;
this.processRacks = processRacks;
this.racksForProcess = racksForProcess;
this.internalTopicManager = internalTopicManager;
this.assignmentConfigs = assignmentConfigs;
this.racksForPartition = new HashMap<>();
}

public synchronized boolean canEnableRackAwareAssignorForActiveTasks() {
public synchronized boolean canEnableRackAwareAssignor() {
/*
TODO: enable this after we add the config
if (StreamsConfig.RACK_AWARE_ASSSIGNMENT_STRATEGY_NONE.equals(assignmentConfigs.rackAwareAssignmentStrategy)) {
Expand All @@ -74,11 +81,7 @@ public synchronized boolean canEnableRackAwareAssignorForActiveTasks() {
}

return validateTopicPartitionRack();
}

public boolean canEnableRackAwareAssignorForStandbyTasks() {
// TODO
return false;
// TODO: add changelog topic, standby task validation
}

// Visible for testing. This method also checks if all TopicPartitions exist in cluster
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo one line below: D[e]scribe

Expand Down Expand Up @@ -159,7 +162,7 @@ public boolean validateClientRack() {
* 1. RackId exist for all clients
* 2. Different consumerId for same process should have same rackId
*/
for (final Map.Entry<UUID, Map<String, Optional<String>>> entry : processRacks.entrySet()) {
for (final Map.Entry<UUID, Map<String, Optional<String>>> entry : racksForProcess.entrySet()) {
final UUID processId = entry.getKey();
KeyValue<String, String> previousRackInfo = null;
for (final Map.Entry<String, Optional<String>> rackEntry : entry.getValue().entrySet()) {
Expand All @@ -185,4 +188,213 @@ public boolean validateClientRack() {
}
return true;
}

private int getCost(final TaskId taskId, final UUID processId, final boolean inCurrentAssignment, final int trafficCost, final int nonOverlapCost) {
final Map<String, Optional<String>> clientRacks = racksForProcess.get(processId);
if (clientRacks == null) {
throw new IllegalStateException("Client " + processId + " doesn't exist in processRacks");
}
final Optional<Optional<String>> clientRackOpt = clientRacks.values().stream().filter(Optional::isPresent).findFirst();
if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We filter() already for isPresent -- seems we only seen the first check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. This is to mute some warning in Intellij, Checkstyle or spotBugs

throw new IllegalStateException("Client " + processId + " doesn't have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
}

final String clientRack = clientRackOpt.get().get();
final Set<TopicPartition> topicPartitions = partitionsForTask.get(taskId);
if (topicPartitions == null || topicPartitions.isEmpty()) {
throw new IllegalStateException("Task " + taskId + " has no TopicPartitions");
}

int cost = 0;
for (final TopicPartition tp : topicPartitions) {
final Set<String> tpRacks = racksForPartition.get(tp);
if (tpRacks == null || tpRacks.isEmpty()) {
throw new IllegalStateException("TopicPartition " + tp + " has no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
}
if (!tpRacks.contains(clientRack)) {
cost += trafficCost;
}
}

if (!inCurrentAssignment) {
cost += nonOverlapCost;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I can follow?

My understanding was, that we say cost == 0 if we can assign to the same rack, and otherwise cost is 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above explanation of non_overlap_cost

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left already a comment above about adding a better explanation. Also wondering, if we could find a more descriptive name? Unfortunately, I don't have a good idea either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or moveTaskCost? moveAssignmentCost? Have no strong preference here...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could work.

}

return cost;
}

private static int getSinkID(final List<UUID> clientList, final List<TaskId> taskIdList) {
return clientList.size() + taskIdList.size();
}

// For testing. canEnableRackAwareAssignor must be called first
long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, final SortedSet<TaskId> activeTasks, final int trafficCost, final int nonOverlapCost) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add JavaDocs? It's a little unclear what this method does. Also maybe move activeTasks as first parameter, as they are the main input (all others are metadata)?

/*
Compute the cost for the provided {@code activeTasks}. The passed in active tasks must be contained in {@code clientState}`.
*/

final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems constructActiveTaskGraph is only reading but not modifying clientList -- why do we need to pass a deep-copy? Could we pass a Set instead of a List?

Looking into constructActiveTaskGraph, it seems we access by index and try to make thing deterministic. Is this the reason why we need a list here? If yes, given that we go from keySet to list, is this translation actually deterministic itself (could be given that it's a sorted-map).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. List is trying to make it deterministic. keySet of SortedMap to list should maintain order. That's why line 241 requires sorted map.

final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as for clientList.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also to make it deterministic. activeTasks is a SortedSet

final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, clientList, taskIdList,
clientStates, new HashMap<>(), new HashMap<>(), trafficCost, nonOverlapCost);
return graph.totalCost();
}

/**
* Optimize active task assignment for rack awareness. canEnableRackAwareAssignor must be called first.
* {@code trafficCost} and {@code nonOverlapCost} balance cross rack traffic optimization and task movement.
* If we set {@code trafficCost} to a larger number, we are more likely to compute an assignment with less
* cross rack traffic. However, tasks may be shuffled a lot across clients. If we set {@code nonOverlapCost}
* to a larger number, we are more likely to compute an assignment with similar to input assignment. However,
* cross rack traffic can be higher. In extreme case, if we set {@code nonOverlapCost} to 0 and @{code trafficCost}
* to a positive value, the computed assignment will be minimum for cross rack traffic. If we set {@code trafficCost} to 0,
* and {@code nonOverlapCost} to a positive value, the computed assignment should be the same as input
* @param clientStates Client states
* @param activeTasks Tasks to reassign if needed. They must be assigned already in clientStates
* @param trafficCost Cost of cross rack traffic for each TopicPartition
* @param nonOverlapCost Cost of assign a task to a different client
* @return Total cost after optimization
*/
public long optimizeActiveTasks(final SortedMap<UUID, ClientState> clientStates,
final SortedSet<TaskId> activeTasks,
final int trafficCost,
final int nonOverlapCost) {
if (activeTasks.isEmpty()) {
return 0;
}

final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seem unnecessary to extract and pass expliclity, as we pass clientStates anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for getting index easily later. Since node id in graph is integer (index), it's easier get reference back to UUID and ClientState using index. Otherwise, we need to maintain an index to UUID map I think

Copy link
Member

@mjsax mjsax Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if it's useful for constructStatefulActiveTaskGraph to have such a list, we should construct this list inside constructStatefulActiveTaskGraph but not pass it in? Otherwise, we leak an optimization from constructStatefulActiveTaskGraph to the caller what seems not ideal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also used on line 286.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question as above: why do we need to make a deep copy into a list? Can't we just pass clientStates.keySet() instead? Seems both methods only read but don't modify.

final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
final Map<TaskId, UUID> taskClientMap = new HashMap<>();
final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, clientList, taskIdList,
clientStates, taskClientMap, originalAssignedTaskNumber, trafficCost, nonOverlapCost);

graph.solveMinCostFlow();
final long cost = graph.totalCost();

assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, taskIdList,
clientStates, originalAssignedTaskNumber, taskClientMap);

return cost;
}

private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> activeTasks,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber,
final int trafficCost,
final int nonOverlapCost) {
final Graph<Integer> graph = new Graph<>();

for (final TaskId taskId : activeTasks) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
if (clientState.getValue().hasAssignedTask(taskId)) {
originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
}
}
}

// Make task and client Node id in graph deterministic
for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) {
final TaskId taskId = taskIdList.get(taskNodeId);
for (int j = 0; j < clientList.size(); j++) {
final int clientNodeId = taskIdList.size() + j;
final UUID processId = clientList.get(j);

final int flow = clientStates.get(processId).hasAssignedTask(taskId) ? 1 : 0;
final int cost = getCost(taskId, processId, flow == 1, trafficCost, nonOverlapCost);
if (flow == 1) {
if (taskClientMap.containsKey(taskId)) {
throw new IllegalArgumentException("Task " + taskId + " assigned to multiple clients "
+ processId + ", " + taskClientMap.get(taskId));
}
taskClientMap.put(taskId, processId);
}

graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
}
if (!taskClientMap.containsKey(taskId)) {
throw new IllegalArgumentException("Task " + taskId + " not assigned to any client");
}
}

final int sinkId = getSinkID(clientList, taskIdList);
for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) {
graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
}

// It's possible that some clients have 0 task assign. These clients will have 0 tasks assigned
// even though it may have higher traffic cost. This is to maintain the original assigned task count
for (int i = 0; i < clientList.size(); i++) {
final int clientNodeId = taskIdList.size() + i;
final int capacity = originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
// Flow equals to capacity for edges to sink
graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
}

graph.setSourceNode(SOURCE_ID);
graph.setSinkNode(sinkId);

return graph;
}

private void assignActiveTaskFromMinCostFlow(final Graph<Integer> graph,
final SortedSet<TaskId> activeTasks,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap) {
int tasksAssigned = 0;
for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) {
final TaskId taskId = taskIdList.get(taskNodeId);
final Map<Integer, Graph<Integer>.Edge> edges = graph.edges(taskNodeId);
for (final Graph<Integer>.Edge edge : edges.values()) {
if (edge.flow > 0) {
tasksAssigned++;
final int clientIndex = edge.destination - taskIdList.size();
final UUID processId = clientList.get(clientIndex);
final UUID originalProcessId = taskClientMap.get(taskId);

// Don't need to assign this task to other client
if (processId.equals(originalProcessId)) {
break;
}

clientStates.get(originalProcessId).unassignActive(taskId);
clientStates.get(processId).assignActive(taskId);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we break here (ie, inside the if block)? There should be only one edge with flow 1?

Or even replace the for loop over the edges with a graph.edges(taskNodeId).values().stream().filter(e.flow == 1).findFirst (and throw if we don't find any edge with flow == 1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. There should be only one edge. I didn't break here for the validations below to catch anything wrong

}
}

// Validate task assigned
if (tasksAssigned != activeTasks.size()) {
throw new IllegalStateException("Computed active task assignment number "
+ tasksAssigned + " is different size " + activeTasks.size());
}

// Validate original assigned task number matches
final Map<UUID, Integer> assignedTaskNumber = new HashMap<>();
for (final TaskId taskId : activeTasks) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
if (clientState.getValue().hasAssignedTask(taskId)) {
assignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
}
}
}

if (originalAssignedTaskNumber.size() != assignedTaskNumber.size()) {
throw new IllegalStateException("There are " + originalAssignedTaskNumber.size() + " clients have "
+ " active tasks before assignment, but " + assignedTaskNumber.size() + " clients have"
+ " active tasks after assignment");
}

for (final Entry<UUID, Integer> originalCapacity : originalAssignedTaskNumber.entrySet()) {
final int capacity = assignedTaskNumber.getOrDefault(originalCapacity.getKey(), 0);
if (!Objects.equals(originalCapacity.getValue(), capacity)) {
throw new IllegalStateException("There are " + originalCapacity.getValue() + " tasks assigned to"
+ " client " + originalCapacity.getKey() + " before assignment, but " + capacity + " tasks "
+ " are assigned to it after assignment");
}
}
}
}
Loading