Skip to content

Commit

Permalink
[Improve][Zeta] Spilt the classloader of task group
Browse files Browse the repository at this point in the history
  • Loading branch information
Hisoka-X committed Sep 4, 2024
1 parent 9a603ea commit 700f2ac
Show file tree
Hide file tree
Showing 21 changed files with 258 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.seatunnel.engine.server.execution.TaskGroup;
import org.apache.seatunnel.engine.server.execution.TaskGroupContext;
import org.apache.seatunnel.engine.server.execution.TaskGroupLocation;
import org.apache.seatunnel.engine.server.execution.TaskGroupUtils;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
import org.apache.seatunnel.engine.server.execution.TaskTracker;
import org.apache.seatunnel.engine.server.metrics.SeaTunnelMetricsContext;
Expand Down Expand Up @@ -324,33 +325,50 @@ public TaskDeployState deployTask(@NonNull TaskGroupImmutableInformation taskImm
taskImmutableInfo.getExecutionId()));
TaskGroup taskGroup = null;
try {
Set<ConnectorJarIdentifier> connectorJarIdentifiers =
List<Set<ConnectorJarIdentifier>> connectorJarIdentifiersList =
taskImmutableInfo.getConnectorJarIdentifiers();
Set<URL> jars = new HashSet<>();
ClassLoader classLoader;
if (!CollectionUtils.isEmpty(connectorJarIdentifiers)) {
// Prioritize obtaining the jar package file required for the current task execution
// from the local, if it does not exist locally, it will be downloaded from the
// master node.
jars =
serverConnectorPackageClient.getConnectorJarFromLocal(
connectorJarIdentifiers);
} else if (!CollectionUtils.isEmpty(taskImmutableInfo.getJars())) {
jars = taskImmutableInfo.getJars();
}
classLoader =
classLoaderService.getClassLoader(
taskImmutableInfo.getJobId(), Lists.newArrayList(jars));
if (jars.isEmpty()) {
taskGroup =
nodeEngine.getSerializationService().toObject(taskImmutableInfo.getGroup());
} else {
taskGroup =
CustomClassLoadedObject.deserializeWithCustomClassLoader(
nodeEngine.getSerializationService(),
classLoader,
taskImmutableInfo.getGroup());
List<Data> taskData = taskImmutableInfo.getTasksData();
ConcurrentHashMap<Long, ClassLoader> classLoaders = new ConcurrentHashMap<>();
List<Task> tasks = new ArrayList<>();
ConcurrentHashMap<Long, Collection<URL>> taskJars = new ConcurrentHashMap<>();
for (int i = 0; i < taskData.size(); i++) {
Set<URL> jars = new HashSet<>();
Set<ConnectorJarIdentifier> connectorJarIdentifiers =
connectorJarIdentifiersList.get(i);
if (!CollectionUtils.isEmpty(connectorJarIdentifiers)) {
// Prioritize obtaining the jar package file required for the current task
// execution
// from the local, if it does not exist locally, it will be downloaded from the
// master node.
jars =
serverConnectorPackageClient.getConnectorJarFromLocal(
connectorJarIdentifiers);
} else if (!CollectionUtils.isEmpty(taskImmutableInfo.getJars().get(i))) {
jars = taskImmutableInfo.getJars().get(i);
}
ClassLoader classLoader =
classLoaderService.getClassLoader(
taskImmutableInfo.getJobId(), Lists.newArrayList(jars));
Task task;
if (jars.isEmpty()) {
task = nodeEngine.getSerializationService().toObject(taskData.get(i));
} else {
task =
CustomClassLoadedObject.deserializeWithCustomClassLoader(
nodeEngine.getSerializationService(),
classLoader,
taskData.get(i));
}
tasks.add(task);
classLoaders.put(task.getTaskID(), classLoader);
taskJars.put(task.getTaskID(), jars);
}
taskGroup =
TaskGroupUtils.createTaskGroup(
taskImmutableInfo.getTaskGroupType(),
taskImmutableInfo.getTaskGroupLocation(),
taskImmutableInfo.getTaskGroupName(),
tasks);

logger.info(
String.format(
Expand All @@ -364,7 +382,7 @@ public TaskDeployState deployTask(@NonNull TaskGroupImmutableInformation taskImm
"TaskGroupLocation: %s already exists",
taskGroup.getTaskGroupLocation()));
}
deployLocalTask(taskGroup, classLoader, jars);
deployLocalTask(taskGroup, classLoaders, taskJars);
return TaskDeployState.success();
}
} catch (Throwable t) {
Expand All @@ -382,12 +400,16 @@ public TaskDeployState deployTask(@NonNull TaskGroupImmutableInformation taskImm
@Deprecated
public PassiveCompletableFuture<TaskExecutionState> deployLocalTask(
@NonNull TaskGroup taskGroup) {
return deployLocalTask(
taskGroup, Thread.currentThread().getContextClassLoader(), emptyList());
Long taskId = taskGroup.getTasks().iterator().next().getTaskID();
ConcurrentHashMap<Long, ClassLoader> classLoaders = new ConcurrentHashMap<>();
classLoaders.put(taskId, Thread.currentThread().getContextClassLoader());
return deployLocalTask(taskGroup, classLoaders, new ConcurrentHashMap<>());
}

public PassiveCompletableFuture<TaskExecutionState> deployLocalTask(
@NonNull TaskGroup taskGroup, @NonNull ClassLoader classLoader, Collection<URL> jars) {
@NonNull TaskGroup taskGroup,
@NonNull ConcurrentHashMap<Long, ClassLoader> classLoaders,
ConcurrentHashMap<Long, Collection<URL>> jars) {
CompletableFuture<TaskExecutionState> resultFuture = new CompletableFuture<>();
try {
taskGroup.init();
Expand Down Expand Up @@ -426,7 +448,7 @@ public PassiveCompletableFuture<TaskExecutionState> deployLocalTask(
}));
executionContexts.put(
taskGroup.getTaskGroupLocation(),
new TaskGroupContext(taskGroup, classLoader, jars));
new TaskGroupContext(taskGroup, classLoaders, jars));
cancellationFutures.put(taskGroup.getTaskGroupLocation(), cancellationFuture);
submitThreadShareTask(executionTracker, byCooperation.get(true));
submitBlockingTask(executionTracker, byCooperation.get(false));
Expand Down Expand Up @@ -709,7 +731,8 @@ public void run() {
ClassLoader classLoader =
executionContexts
.get(taskGroupExecutionTracker.taskGroup.getTaskGroupLocation())
.getClassLoader();
.getClassLoaders()
.get(tracker.task.getTaskID());
ClassLoader oldClassLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(classLoader);
final Task t = tracker.task;
Expand Down Expand Up @@ -817,7 +840,8 @@ public void run() {
myThread.setContextClassLoader(
executionContexts
.get(taskGroupExecutionTracker.taskGroup.getTaskGroupLocation())
.getClassLoader());
.getClassLoaders()
.get(taskTracker.task.getTaskID()));
call = taskTracker.task.call();
synchronized (timer) {
timer.timerStop();
Expand Down Expand Up @@ -1012,8 +1036,10 @@ void taskDone(Task task) {

private void recycleClassLoader(TaskGroupLocation taskGroupLocation) {
TaskGroupContext context = executionContexts.get(taskGroupLocation);
executionContexts.get(taskGroupLocation).setClassLoader(null);
classLoaderService.releaseClassLoader(taskGroupLocation.getJobId(), context.getJars());
executionContexts.get(taskGroupLocation).setClassLoaders(null);
for (Collection<URL> jars : context.getJars().values()) {
classLoaderService.releaseClassLoader(taskGroupLocation.getJobId(), jars);
}
}

boolean executionCompletedExceptionally() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ public void runInternal() throws Exception {
.getExecutionContext(taskLocation.getTaskGroupLocation());
Task task = groupContext.getTaskGroup().getTask(taskLocation.getTaskID());
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(groupContext.getClassLoader());
Thread.currentThread()
.setContextClassLoader(
groupContext.getClassLoader(task.getTaskID()));

task.notifyCheckpointEnd(checkpointId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ public void runInternal() throws Exception {
.getExecutionContext(taskLocation.getTaskGroupLocation());
Task task = groupContext.getTaskGroup().getTask(taskLocation.getTaskID());
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(groupContext.getClassLoader());
Thread.currentThread()
.setContextClassLoader(
groupContext.getClassLoader(task.getTaskID()));
if (successful) {
task.notifyCheckpointComplete(checkpointId);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public void runInternal() throws Exception {
() -> {
Thread.currentThread()
.setContextClassLoader(
groupContext.getClassLoader());
groupContext.getClassLoader(
task.getTaskID()));
try {
log.debug(
"NotifyTaskRestoreOperation.restoreState "
Expand Down
Loading

0 comments on commit 700f2ac

Please sign in to comment.