Skip to content

Commit

Permalink
Use semaphores in TaskUpdateRequestManager to avoid blocking task u…
Browse files Browse the repository at this point in the history
…pdate threads (#685)
  • Loading branch information
jbern0rd authored Mar 28, 2024
1 parent 2811256 commit 55e32c5
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ All notable changes to this project will be documented in this file.
- Move task metrics from `TaskUpdateManager` to `TaskService`. (#676)
- Fail fast when tasks are detected past their contribution or final deadline. (#677)
- Mitigate potential race conditions by enforcing `currentStatus` value when updating a task. (#681)
- Use semaphores in `TaskUpdateRequestManager` to avoid blocking task update threads. (#685)

### Quality

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 IEXEC BLOCKCHAIN TECH
* Copyright 2020-2024 IEXEC BLOCKCHAIN TECH
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,13 @@

package com.iexec.core.task.update;

import com.iexec.common.utils.ContextualLockRunner;
import com.iexec.core.task.Task;
import com.iexec.core.task.TaskService;
import lombok.extern.slf4j.Slf4j;
import net.jodah.expiringmap.ExpiringMap;
import org.springframework.stereotype.Component;

import java.util.Optional;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

Expand All @@ -42,8 +42,10 @@ public class TaskUpdateRequestManager {
*/
private static final int TASK_UPDATE_THREADS_POOL_SIZE = Runtime.getRuntime().availableProcessors() * 2;

private final ContextualLockRunner<String> taskExecutionLockRunner =
new ContextualLockRunner<>(LONGEST_TASK_TIMEOUT.getSeconds(), TimeUnit.SECONDS);
// Working with semaphore to guarantee at most 1 item in queue and 1 running thread
private final ExpiringMap<String, Semaphore> taskExecutionLockRunner = ExpiringMap.builder()
.expiration(LONGEST_TASK_TIMEOUT.getSeconds(), TimeUnit.SECONDS)
.build();

final TaskUpdatePriorityBlockingQueue queue = new TaskUpdatePriorityBlockingQueue();
// Both `corePoolSize` and `maximumPoolSize` should be set to `TASK_UPDATE_THREADS_POOL_SIZE`.
Expand Down Expand Up @@ -86,24 +88,32 @@ public synchronized boolean publishRequest(String chainTaskId) {
log.debug("Request already published [chainTaskId:{}]", chainTaskId);
return false;
}
final Optional<Task> oTask = taskService.getTaskByChainTaskId(chainTaskId);
if (oTask.isEmpty()) {
log.warn("No such task. [chainTaskId: {}]", chainTaskId);
final Task task = taskService.getTaskByChainTaskId(chainTaskId).orElse(null);
if (task == null) {
log.warn("No such task [chainTaskId: {}]", chainTaskId);
return false;
}

final Task task = oTask.get();
// Add semaphore to expiring map if missing
taskExecutionLockRunner.putIfAbsent(chainTaskId, new Semaphore(1));

taskUpdateExecutor.execute(new TaskUpdate(task, this::updateTask));
log.debug("Published task update request" +
" [chainTaskId:{}, currentStatus:{}, contributionDeadline:{}, queueSize:{}]",
log.debug("Published task update request [chainTaskId:{}, currentStatus:{}, contributionDeadline:{}, queueSize:{}]",
chainTaskId, task.getCurrentStatus(), task.getContributionDeadline(), queue.size());
return true;
}

private void updateTask(String chainTaskId) {
taskExecutionLockRunner.acceptWithLock(
chainTaskId,
taskUpdateManager::updateTask
);
if (!taskExecutionLockRunner.get(chainTaskId).tryAcquire()) {
log.debug("Could not acquire lock for task update [chainTaskId:{}]", chainTaskId);
return;
}
try {
log.debug("Acquire lock for task update [chainTaskId:{}]", chainTaskId);
taskUpdateManager.updateTask(chainTaskId);
} finally {
log.debug("Release lock for task update [chainTaskId:{}]", chainTaskId);
taskExecutionLockRunner.get(chainTaskId).release();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,31 +1,57 @@
/*
* Copyright 2020-2024 IEXEC BLOCKCHAIN TECH
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.iexec.core.task.update;

import com.iexec.core.task.Task;
import com.iexec.core.task.TaskService;
import com.iexec.core.task.TaskStatus;
import lombok.extern.slf4j.Slf4j;
import net.jodah.expiringmap.ExpiringMap;
import org.assertj.core.api.Assertions;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.boot.test.system.CapturedOutput;
import org.springframework.boot.test.system.OutputCaptureExtension;
import org.springframework.test.util.ReflectionTestUtils;

import java.util.*;
import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.mockito.Mockito.when;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.awaitility.Awaitility.await;
import static org.mockito.Mockito.*;

@Slf4j
@ExtendWith(OutputCaptureExtension.class)
class TaskUpdateRequestManagerTests {

public static final String CHAIN_TASK_ID = "chainTaskId";

@Mock
private TaskService taskService;
@Mock
private TaskUpdateManager taskUpdateManager;

@InjectMocks
private TaskUpdateRequestManager taskUpdateRequestManager;
Expand All @@ -37,43 +63,69 @@ void init() {

// region publishRequest()
@Test
void shouldPublishRequest() throws ExecutionException, InterruptedException {
void shouldPublishRequest(CapturedOutput output) {
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
.thenReturn(Optional.of(Task.builder().chainTaskId(CHAIN_TASK_ID).build()));

boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
await().atMost(5L, TimeUnit.SECONDS)
.until(() -> output.getOut().contains("Acquire lock for task update [chainTaskId:chainTaskId]")
&& output.getOut().contains("Release lock for task update [chainTaskId:chainTaskId]"));

Assertions.assertThat(publishRequestStatus).isTrue();
assertThat(publishRequestStatus).isTrue();
verify(taskUpdateManager).updateTask(CHAIN_TASK_ID);
}

@Test
void shouldNotPublishRequestSinceEmptyTaskId() throws ExecutionException, InterruptedException {
boolean publishRequestStatus = taskUpdateRequestManager.publishRequest("");
void shouldPublishRequestButNotAcquireLock(CapturedOutput output) {
final ExpiringMap<String, Semaphore> locks = ExpiringMap.builder()
.expiration(30L, TimeUnit.SECONDS)
.build();
locks.putIfAbsent(CHAIN_TASK_ID, new Semaphore(1));
assertThat(locks.get(CHAIN_TASK_ID).tryAcquire()).isTrue();
ReflectionTestUtils.setField(taskUpdateRequestManager, "taskExecutionLockRunner", locks);
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
.thenReturn(Optional.of(Task.builder().chainTaskId(CHAIN_TASK_ID).build()));

final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
await().atMost(5L, TimeUnit.SECONDS)
.until(() -> output.getOut().contains("Could not acquire lock for task update [chainTaskId:chainTaskId]"));

Assertions.assertThat(publishRequestStatus).isFalse();
assertThat(publishRequestStatus).isTrue();
verifyNoInteractions(taskUpdateManager);
}

@Test
void shouldNotPublishRequestSinceItemAlreadyAdded() throws ExecutionException, InterruptedException {
void shouldNotPublishRequestSinceEmptyTaskId() {
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest("");

assertThat(publishRequestStatus).isFalse();
verifyNoInteractions(taskService, taskUpdateManager);
}

@Test
void shouldNotPublishRequestSinceItemAlreadyAdded() {
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
.thenReturn(Optional.of(Task.builder().chainTaskId(CHAIN_TASK_ID).build()));
taskUpdateRequestManager.queue.add(
buildTaskUpdate(CHAIN_TASK_ID, null, null, null)
);

boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);

Assertions.assertThat(publishRequestStatus).isFalse();
assertThat(publishRequestStatus).isFalse();
verifyNoInteractions(taskUpdateManager);
}

@Test
void shouldNotPublishRequestSinceTaskDoesNotExist() throws ExecutionException, InterruptedException {
void shouldNotPublishRequestSinceTaskDoesNotExist() {
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
.thenReturn(Optional.empty());

boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);

Assertions.assertThat(publishRequestStatus).isFalse();
assertThat(publishRequestStatus).isFalse();
verifyNoInteractions(taskUpdateManager);
}
// endregion

Expand Down Expand Up @@ -110,20 +162,18 @@ void shouldNotUpdateAtTheSameTime() {
.collect(Collectors.toList());

updates.forEach(taskUpdateRequestManager.taskUpdateExecutor::execute);
Awaitility
.await()
.timeout(30, TimeUnit.SECONDS)
await().timeout(30, TimeUnit.SECONDS)
.until(() -> callsOrder.size() == callsPerUpdate * updates.size());

Assertions.assertThat(callsOrder).hasSize(callsPerUpdate * updates.size());
assertThat(callsOrder).hasSize(callsPerUpdate * updates.size());

// We loop through calls order and see if all calls for a given update have finished
// before another update starts for this task.
// Two updates for different tasks can run at the same time.
Map<String, Map<Integer, Integer>> foundTaskUpdates = new HashMap<>();

for (int updateId : callsOrder) {
System.out.println("[taskId:" + taskForUpdateId.get(updateId) + ", updateId:" + updateId + "]");
log.info("[taskId:{}, updateId:{}]", taskForUpdateId.get(updateId), updateId);
final Map<Integer, Integer> foundOutputsForKeyGroup = foundTaskUpdates.computeIfAbsent(taskForUpdateId.get(updateId), (key) -> new HashMap<>());
for (int alreadyFound : foundOutputsForKeyGroup.keySet()) {
if (!Objects.equals(alreadyFound, updateId) && foundOutputsForKeyGroup.get(alreadyFound) < callsPerUpdate) {
Expand Down Expand Up @@ -161,14 +211,12 @@ void shouldGetInOrderForStatus() throws InterruptedException {
queue.addAll(tasks);

final List<TaskUpdate> prioritizedTasks = queue.takeAll();
Assertions.assertThat(prioritizedTasks)
.containsExactly(
completedTask,
consensusReachedTask,
runningTask,
initializedTask,
initializingTask
);
assertThat(prioritizedTasks).containsExactly(
completedTask,
consensusReachedTask,
runningTask,
initializedTask,
initializingTask);
}

@Test
Expand All @@ -192,14 +240,8 @@ void shouldGetInOrderForContributionDeadline() throws InterruptedException {
queue.addAll(tasks);

final List<TaskUpdate> prioritizedTasks = queue.takeAll();
Assertions.assertThat(prioritizedTasks)
.containsExactly(
t1,
t2,
t3,
t4,
t5
);
assertThat(prioritizedTasks).containsExactly(
t1, t2, t3, t4, t5);
}

@Test
Expand All @@ -219,13 +261,8 @@ void shouldGetInOrderForStatusAndContributionDeadline() throws InterruptedExcept
queue.addAll(tasks);

final List<TaskUpdate> prioritizedTasks = queue.takeAll();
Assertions.assertThat(prioritizedTasks)
.containsExactly(
t3,
t4,
t1,
t2
);
assertThat(prioritizedTasks).containsExactly(
t3, t4, t1, t2);
}
// endregion

Expand All @@ -234,12 +271,11 @@ private TaskUpdate buildTaskUpdate(String chainTaskId,
Date contributionDeadline,
Consumer<String> taskUpdater) {
return new TaskUpdate(
Task
.builder()
Task.builder()
.chainTaskId(chainTaskId)
.currentStatus(status).
contributionDeadline(contributionDeadline).
build(),
.currentStatus(status)
.contributionDeadline(contributionDeadline)
.build(),
taskUpdater
);
}
Expand Down

0 comments on commit 55e32c5

Please sign in to comment.