Skip to content

Commit

Permalink
Add LEAST_WASTED task memory killer
Browse files Browse the repository at this point in the history
Add task low memory killer which picks ups a task which has highest
used-memory/task-runtime ratio.
  • Loading branch information
losipiuk committed May 12, 2022
1 parent d8a965a commit a58fc50
Show file tree
Hide file tree
Showing 6 changed files with 417 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.trino.memory;

import com.google.common.collect.ImmutableSet;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.operator.RetryPolicy;
import io.trino.operator.TaskStats;
import io.trino.spi.QueryId;
import io.trino.spi.memory.MemoryPoolInfo;

import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.AbstractMap.SimpleEntry;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Comparator.comparing;

public class LeastWastedEffortTaskLowMemoryKiller
implements LowMemoryKiller
{
private static final long MIN_WALL_TIME = Duration.of(30, ChronoUnit.SECONDS).toMillis();

@Override
public Optional<KillTarget> chooseTargetToKill(List<RunningQueryInfo> runningQueries, List<MemoryInfo> nodes)
{
Set<QueryId> queriesWithTaskRetryPolicy = runningQueries.stream()
.filter(query -> query.getRetryPolicy() == RetryPolicy.TASK)
.map(RunningQueryInfo::getQueryId)
.collect(toImmutableSet());

if (queriesWithTaskRetryPolicy.isEmpty()) {
return Optional.empty();
}

ImmutableSet.Builder<TaskId> tasksToKillBuilder = ImmutableSet.builder();

Map<TaskId, TaskInfo> taskInfos = runningQueries.stream()
.filter(queryInfo -> queriesWithTaskRetryPolicy.contains(queryInfo.getQueryId()))
.flatMap(queryInfo -> queryInfo.getTaskInfos().entrySet().stream())
.collect(toImmutableMap(
Map.Entry::getKey,
Map.Entry::getValue));

for (MemoryInfo node : nodes) {
MemoryPoolInfo memoryPool = node.getPool();
if (memoryPool == null) {
continue;
}
if (memoryPool.getFreeBytes() + memoryPool.getReservedRevocableBytes() > 0) {
continue;
}

memoryPool.getTaskMemoryReservations().entrySet().stream()
.map(entry -> new SimpleEntry<>(TaskId.valueOf(entry.getKey()), entry.getValue()))
.filter(entry -> queriesWithTaskRetryPolicy.contains(entry.getKey().getQueryId()))
.max(comparing(entry -> {
TaskId taskId = entry.getKey();
Long memoryUsed = entry.getValue();
long wallTime = 0;
if (taskInfos.containsKey(taskId)) {
TaskStats stats = taskInfos.get(taskId).getStats();
wallTime = stats.getTotalScheduledTime().toMillis() + stats.getTotalBlockedTime().toMillis();
}
wallTime = Math.max(wallTime, MIN_WALL_TIME); // only look at memory consumption for fairly short-lived tasks
return (double) memoryUsed / wallTime;
}))
.map(SimpleEntry::getKey)
.ifPresent(tasksToKillBuilder::add);
}
Set<TaskId> tasksToKill = tasksToKillBuilder.build();
if (tasksToKill.isEmpty()) {
return Optional.empty();
}
return Optional.of(KillTarget.selectedTasks(tasksToKill));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ public enum LowMemoryTaskKillerPolicy
{
NONE,
TOTAL_RESERVATION_ON_BLOCKED_NODES,
LEAST_WASTED,
/**/;

public static LowMemoryTaskKillerPolicy fromString(String value)
Expand All @@ -212,6 +213,8 @@ public static LowMemoryTaskKillerPolicy fromString(String value)
return NONE;
case "total-reservation-on-blocked-nodes":
return TOTAL_RESERVATION_ON_BLOCKED_NODES;
case "least-wasted":
return LEAST_WASTED;
}

throw new IllegalArgumentException(format("Unrecognized value: '%s'", value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import io.trino.failuredetector.FailureDetectorModule;
import io.trino.memory.ClusterMemoryManager;
import io.trino.memory.ForMemoryManager;
import io.trino.memory.LeastWastedEffortTaskLowMemoryKiller;
import io.trino.memory.LowMemoryKiller;
import io.trino.memory.LowMemoryKiller.ForQueryLowMemoryKiller;
import io.trino.memory.LowMemoryKiller.ForTaskLowMemoryKiller;
Expand Down Expand Up @@ -220,6 +221,7 @@ protected void setup(Binder binder)

bindLowMemoryTaskKiller(LowMemoryTaskKillerPolicy.NONE, NoneLowMemoryKiller.class);
bindLowMemoryTaskKiller(LowMemoryTaskKillerPolicy.TOTAL_RESERVATION_ON_BLOCKED_NODES, TotalReservationOnBlockedNodesTaskLowMemoryKiller.class);
bindLowMemoryTaskKiller(LowMemoryTaskKillerPolicy.LEAST_WASTED, LeastWastedEffortTaskLowMemoryKiller.class);
bindLowMemoryQueryKiller(LowMemoryQueryKillerPolicy.NONE, NoneLowMemoryKiller.class);
bindLowMemoryQueryKiller(LowMemoryQueryKillerPolicy.TOTAL_RESERVATION, TotalReservationLowMemoryKiller.class);
bindLowMemoryQueryKiller(LowMemoryQueryKillerPolicy.TOTAL_RESERVATION_ON_BLOCKED_NODES, TotalReservationOnBlockedNodesQueryLowMemoryKiller.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.client.NodeVersion;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.metadata.InternalNode;
import io.trino.operator.RetryPolicy;
import io.trino.spi.QueryId;
Expand All @@ -31,6 +32,8 @@
import java.util.Map;
import java.util.Set;

import static com.google.common.collect.ImmutableMap.toImmutableMap;

public final class LowMemoryKillerTestingUtils
{
private LowMemoryKillerTestingUtils() {}
Expand Down Expand Up @@ -103,17 +106,28 @@ static List<LowMemoryKiller.RunningQueryInfo> toRunningQueryInfoList(Map<String,
}

static List<LowMemoryKiller.RunningQueryInfo> toRunningQueryInfoList(Map<String, Map<String, Long>> queries, Set<String> queriesWithTaskLevelRetries)
{
return toRunningQueryInfoList(queries, queriesWithTaskLevelRetries, ImmutableMap.of());
}

static List<LowMemoryKiller.RunningQueryInfo> toRunningQueryInfoList(Map<String, Map<String, Long>> queries, Set<String> queriesWithTaskLevelRetries, Map<String, Map<Integer, TaskInfo>> taskInfos)
{
ImmutableList.Builder<LowMemoryKiller.RunningQueryInfo> result = ImmutableList.builder();
for (Map.Entry<String, Map<String, Long>> entry : queries.entrySet()) {
String queryId = entry.getKey();
long totalReservation = entry.getValue().values().stream()
.mapToLong(x -> x)
.sum();

Map<TaskId, TaskInfo> queryTaskInfos = taskInfos.getOrDefault(queryId, ImmutableMap.of()).entrySet().stream()
.collect(toImmutableMap(
taskEntry -> taskId(queryId, taskEntry.getKey()),
Map.Entry::getValue));

result.add(new LowMemoryKiller.RunningQueryInfo(
new QueryId(queryId),
totalReservation,
ImmutableMap.of(),
queryTaskInfos,
queriesWithTaskLevelRetries.contains(queryId) ? RetryPolicy.TASK : RetryPolicy.NONE));
}
return result.build();
Expand Down
Loading

0 comments on commit a58fc50

Please sign in to comment.