Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use semaphores in TaskUpdateRequestManager to avoid blocking task update threads #685

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ All notable changes to this project will be documented in this file.
- Move task metrics from `TaskUpdateManager` to `TaskService`. (#676)
- Fail fast when tasks are detected past their contribution or final deadline. (#677)
- Mitigate potential race conditions by enforcing `currentStatus` value when updating a task. (#681)
- Use semaphores in `TaskUpdateRequestManager` to avoid blocking task update threads. (#685)

### Quality

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 IEXEC BLOCKCHAIN TECH
* Copyright 2020-2024 IEXEC BLOCKCHAIN TECH
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,13 @@

package com.iexec.core.task.update;

import com.iexec.common.utils.ContextualLockRunner;
import com.iexec.core.task.Task;
import com.iexec.core.task.TaskService;
import lombok.extern.slf4j.Slf4j;
import net.jodah.expiringmap.ExpiringMap;
import org.springframework.stereotype.Component;

import java.util.Optional;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

Expand All @@ -42,8 +42,10 @@ public class TaskUpdateRequestManager {
*/
private static final int TASK_UPDATE_THREADS_POOL_SIZE = Runtime.getRuntime().availableProcessors() * 2;

private final ContextualLockRunner<String> taskExecutionLockRunner =
new ContextualLockRunner<>(LONGEST_TASK_TIMEOUT.getSeconds(), TimeUnit.SECONDS);
// Working with semaphore to guarantee at most 1 item in queue and 1 running thread
private final ExpiringMap<String, Semaphore> taskExecutionLockRunner = ExpiringMap.builder()
.expiration(LONGEST_TASK_TIMEOUT.getSeconds(), TimeUnit.SECONDS)
.build();

final TaskUpdatePriorityBlockingQueue queue = new TaskUpdatePriorityBlockingQueue();
// Both `corePoolSize` and `maximumPoolSize` should be set to `TASK_UPDATE_THREADS_POOL_SIZE`.
Expand Down Expand Up @@ -86,24 +88,32 @@ public synchronized boolean publishRequest(String chainTaskId) {
log.debug("Request already published [chainTaskId:{}]", chainTaskId);
return false;
}
final Optional<Task> oTask = taskService.getTaskByChainTaskId(chainTaskId);
if (oTask.isEmpty()) {
log.warn("No such task. [chainTaskId: {}]", chainTaskId);
final Task task = taskService.getTaskByChainTaskId(chainTaskId).orElse(null);
if (task == null) {
log.warn("No such task [chainTaskId: {}]", chainTaskId);
return false;
}

final Task task = oTask.get();
// Add semaphore to expiring map if missing
taskExecutionLockRunner.putIfAbsent(chainTaskId, new Semaphore(1));

taskUpdateExecutor.execute(new TaskUpdate(task, this::updateTask));
log.debug("Published task update request" +
" [chainTaskId:{}, currentStatus:{}, contributionDeadline:{}, queueSize:{}]",
log.debug("Published task update request [chainTaskId:{}, currentStatus:{}, contributionDeadline:{}, queueSize:{}]",
chainTaskId, task.getCurrentStatus(), task.getContributionDeadline(), queue.size());
return true;
}

private void updateTask(String chainTaskId) {
taskExecutionLockRunner.acceptWithLock(
chainTaskId,
taskUpdateManager::updateTask
);
if (!taskExecutionLockRunner.get(chainTaskId).tryAcquire()) {
log.debug("Could not acquire lock for task update [chainTaskId:{}]", chainTaskId);
return;
}
try {
log.debug("Acquire lock for task update [chainTaskId:{}]", chainTaskId);
taskUpdateManager.updateTask(chainTaskId);
} finally {
log.debug("Release lock for task update [chainTaskId:{}]", chainTaskId);
taskExecutionLockRunner.get(chainTaskId).release();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,31 +1,57 @@
/*
* Copyright 2020-2024 IEXEC BLOCKCHAIN TECH
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.iexec.core.task.update;

import com.iexec.core.task.Task;
import com.iexec.core.task.TaskService;
import com.iexec.core.task.TaskStatus;
import lombok.extern.slf4j.Slf4j;
import net.jodah.expiringmap.ExpiringMap;
import org.assertj.core.api.Assertions;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.boot.test.system.CapturedOutput;
import org.springframework.boot.test.system.OutputCaptureExtension;
import org.springframework.test.util.ReflectionTestUtils;

import java.util.*;
import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.mockito.Mockito.when;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.awaitility.Awaitility.await;
import static org.mockito.Mockito.*;

@Slf4j
@ExtendWith(OutputCaptureExtension.class)
class TaskUpdateRequestManagerTests {

public static final String CHAIN_TASK_ID = "chainTaskId";

@Mock
private TaskService taskService;
@Mock
private TaskUpdateManager taskUpdateManager;

@InjectMocks
private TaskUpdateRequestManager taskUpdateRequestManager;
Expand All @@ -37,43 +63,69 @@ void init() {

// region publishRequest()
@Test
void shouldPublishRequest() throws ExecutionException, InterruptedException {
void shouldPublishRequest(CapturedOutput output) {
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
.thenReturn(Optional.of(Task.builder().chainTaskId(CHAIN_TASK_ID).build()));

boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
await().atMost(5L, TimeUnit.SECONDS)
.until(() -> output.getOut().contains("Acquire lock for task update [chainTaskId:chainTaskId]")
&& output.getOut().contains("Release lock for task update [chainTaskId:chainTaskId]"));

Assertions.assertThat(publishRequestStatus).isTrue();
assertThat(publishRequestStatus).isTrue();
verify(taskUpdateManager).updateTask(CHAIN_TASK_ID);
}

@Test
void shouldNotPublishRequestSinceEmptyTaskId() throws ExecutionException, InterruptedException {
boolean publishRequestStatus = taskUpdateRequestManager.publishRequest("");
void shouldPublishRequestButNotAcquireLock(CapturedOutput output) {
final ExpiringMap<String, Semaphore> locks = ExpiringMap.builder()
.expiration(30L, TimeUnit.SECONDS)
.build();
locks.putIfAbsent(CHAIN_TASK_ID, new Semaphore(1));
assertThat(locks.get(CHAIN_TASK_ID).tryAcquire()).isTrue();
ReflectionTestUtils.setField(taskUpdateRequestManager, "taskExecutionLockRunner", locks);
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
.thenReturn(Optional.of(Task.builder().chainTaskId(CHAIN_TASK_ID).build()));

final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
await().atMost(5L, TimeUnit.SECONDS)
.until(() -> output.getOut().contains("Could not acquire lock for task update [chainTaskId:chainTaskId]"));

Assertions.assertThat(publishRequestStatus).isFalse();
assertThat(publishRequestStatus).isTrue();
verifyNoInteractions(taskUpdateManager);
}

@Test
void shouldNotPublishRequestSinceItemAlreadyAdded() throws ExecutionException, InterruptedException {
void shouldNotPublishRequestSinceEmptyTaskId() {
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest("");

assertThat(publishRequestStatus).isFalse();
verifyNoInteractions(taskService, taskUpdateManager);
}

@Test
void shouldNotPublishRequestSinceItemAlreadyAdded() {
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
.thenReturn(Optional.of(Task.builder().chainTaskId(CHAIN_TASK_ID).build()));
taskUpdateRequestManager.queue.add(
buildTaskUpdate(CHAIN_TASK_ID, null, null, null)
);

boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);

Assertions.assertThat(publishRequestStatus).isFalse();
assertThat(publishRequestStatus).isFalse();
verifyNoInteractions(taskUpdateManager);
}

@Test
void shouldNotPublishRequestSinceTaskDoesNotExist() throws ExecutionException, InterruptedException {
void shouldNotPublishRequestSinceTaskDoesNotExist() {
when(taskService.getTaskByChainTaskId(CHAIN_TASK_ID))
.thenReturn(Optional.empty());

boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);
final boolean publishRequestStatus = taskUpdateRequestManager.publishRequest(CHAIN_TASK_ID);

Assertions.assertThat(publishRequestStatus).isFalse();
assertThat(publishRequestStatus).isFalse();
verifyNoInteractions(taskUpdateManager);
}
// endregion

Expand Down Expand Up @@ -110,20 +162,18 @@ void shouldNotUpdateAtTheSameTime() {
.collect(Collectors.toList());

updates.forEach(taskUpdateRequestManager.taskUpdateExecutor::execute);
Awaitility
.await()
.timeout(30, TimeUnit.SECONDS)
await().timeout(30, TimeUnit.SECONDS)
.until(() -> callsOrder.size() == callsPerUpdate * updates.size());

Assertions.assertThat(callsOrder).hasSize(callsPerUpdate * updates.size());
assertThat(callsOrder).hasSize(callsPerUpdate * updates.size());

// We loop through calls order and see if all calls for a given update have finished
// before another update starts for this task.
// Two updates for different tasks can run at the same time.
Map<String, Map<Integer, Integer>> foundTaskUpdates = new HashMap<>();

for (int updateId : callsOrder) {
System.out.println("[taskId:" + taskForUpdateId.get(updateId) + ", updateId:" + updateId + "]");
log.info("[taskId:{}, updateId:{}]", taskForUpdateId.get(updateId), updateId);
final Map<Integer, Integer> foundOutputsForKeyGroup = foundTaskUpdates.computeIfAbsent(taskForUpdateId.get(updateId), (key) -> new HashMap<>());
for (int alreadyFound : foundOutputsForKeyGroup.keySet()) {
if (!Objects.equals(alreadyFound, updateId) && foundOutputsForKeyGroup.get(alreadyFound) < callsPerUpdate) {
Expand Down Expand Up @@ -161,14 +211,12 @@ void shouldGetInOrderForStatus() throws InterruptedException {
queue.addAll(tasks);

final List<TaskUpdate> prioritizedTasks = queue.takeAll();
Assertions.assertThat(prioritizedTasks)
.containsExactly(
completedTask,
consensusReachedTask,
runningTask,
initializedTask,
initializingTask
);
assertThat(prioritizedTasks).containsExactly(
completedTask,
consensusReachedTask,
runningTask,
initializedTask,
initializingTask);
}

@Test
Expand All @@ -192,14 +240,8 @@ void shouldGetInOrderForContributionDeadline() throws InterruptedException {
queue.addAll(tasks);

final List<TaskUpdate> prioritizedTasks = queue.takeAll();
Assertions.assertThat(prioritizedTasks)
.containsExactly(
t1,
t2,
t3,
t4,
t5
);
assertThat(prioritizedTasks).containsExactly(
t1, t2, t3, t4, t5);
}

@Test
Expand All @@ -219,13 +261,8 @@ void shouldGetInOrderForStatusAndContributionDeadline() throws InterruptedExcept
queue.addAll(tasks);

final List<TaskUpdate> prioritizedTasks = queue.takeAll();
Assertions.assertThat(prioritizedTasks)
.containsExactly(
t3,
t4,
t1,
t2
);
assertThat(prioritizedTasks).containsExactly(
t3, t4, t1, t2);
}
// endregion

Expand All @@ -234,12 +271,11 @@ private TaskUpdate buildTaskUpdate(String chainTaskId,
Date contributionDeadline,
Consumer<String> taskUpdater) {
return new TaskUpdate(
Task
.builder()
Task.builder()
.chainTaskId(chainTaskId)
.currentStatus(status).
contributionDeadline(contributionDeadline).
build(),
.currentStatus(status)
.contributionDeadline(contributionDeadline)
.build(),
taskUpdater
);
}
Expand Down