Skip to content

Commit

Permalink
Mitigate potential race conditions by enforcing currentStatus value…
Browse files Browse the repository at this point in the history
… when updating a task (#681)
  • Loading branch information
jbern0rd authored Mar 20, 2024
1 parent 163d5d9 commit 66c946b
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 184 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file.
- Always provide a `WorkerpoolAuthorization` to a worker during its recovery. (#674)
- Move task metrics from `TaskUpdateManager` to `TaskService`. (#676)
- Fail fast when tasks are detected past their contribution or final deadline. (#677)
- Mitigate potential race conditions by enforcing `currentStatus` value when updating a task. (#681)

### Quality

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public void detect() {
.push("dateStatusList").each(
TaskStatusChange.builder().status(TaskStatus.CONTRIBUTION_TIMEOUT).build(),
TaskStatusChange.builder().status(TaskStatus.FAILED).build());
taskService.failMultipleTasksByQuery(update, query)
taskService.updateMultipleTasksByQuery(query, update)
.forEach(id -> applicationEventPublisher.publishEvent(new ContributionTimeoutEvent(id)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public void detect() {
.push("dateStatusList").each(
TaskStatusChange.builder().status(TaskStatus.FINAL_DEADLINE_REACHED).build(),
TaskStatusChange.builder().status(TaskStatus.FAILED).build());
taskService.failMultipleTasksByQuery(update, query)
taskService.updateMultipleTasksByQuery(query, update)
.forEach(id -> applicationEventPublisher.publishEvent(new TaskFailedEvent(id)));
}
}
2 changes: 2 additions & 0 deletions src/main/java/com/iexec/core/task/Task.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@
unique = true)
public class Task {

public static final String CHAIN_TASK_ID_FIELD_NAME = "chainTaskId";
public static final String CURRENT_STATUS_FIELD_NAME = "currentStatus";
public static final String CONTRIBUTION_DEADLINE_FIELD_NAME = "contributionDeadline";
public static final String DATE_STATUS_LIST_FIELD_NAME = "dateStatusList";

@Id
private String id;
Expand Down
3 changes: 0 additions & 3 deletions src/main/java/com/iexec/core/task/TaskRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ public interface TaskRepository extends MongoRepository<Task, String> {
@Query("{ 'currentStatus': {$in: ?0} }")
List<Task> findByCurrentStatus(List<TaskStatus> statuses);

@Query("{ 'currentStatus': {$in: ?0} }")
List<Task> findByCurrentStatus(List<TaskStatus> statuses, Sort sort);

/**
* Retrieves the prioritized task matching with given criteria:
* <ul>
Expand Down
139 changes: 89 additions & 50 deletions src/main/java/com/iexec/core/task/TaskService.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Metrics;
import lombok.extern.slf4j.Slf4j;
import org.bson.Document;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.event.EventListener;
import org.springframework.dao.DuplicateKeyException;
Expand All @@ -45,13 +46,13 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.iexec.core.task.Task.*;
import static com.iexec.core.task.TaskStatus.*;

@Slf4j
@Service
public class TaskService {

private static final String CHAIN_TASK_ID_FIELD = "chainTaskId";
public static final String METRIC_TASKS_STATUSES_COUNT = "iexec.core.tasks.count";
private final MongoTemplate mongoTemplate;
private final TaskRepository taskRepository;
Expand Down Expand Up @@ -106,21 +107,19 @@ private void initializeCurrentTaskStatusesCount() {
}

/**
* Save task in database if it does not
* already exist.
* Save task in database if it does not already exist.
*
* @param chainDealId
* @param taskIndex
* @param dealBlockNumber
* @param imageName
* @param commandLine
* @param trust
* @param maxExecutionTime
* @param tag
* @param contributionDeadline
* @param finalDeadline
* @return optional containing the saved
* task, {@link Optional#empty()} otherwise.
* @param chainDealId on-chain deal id
* @param taskIndex task index in deal
* @param dealBlockNumber block number when orders were matched to produce the current deal
* @param imageName OCI image to use for replicates computation
* @param commandLine command that will be executed during replicates computation
* @param trust trust level, impacts replication
* @param maxExecutionTime execution time
* @param tag deal tag describing additional features like TEE framework
* @param contributionDeadline date after which a worker cannot contribute
* @param finalDeadline date after which a task cannot be updated
* @return {@code Optional} containing the saved task, {@link Optional#empty()} otherwise.
*/
public Optional<Task> addTask(
String chainDealId,
Expand All @@ -134,40 +133,73 @@ public Optional<Task> addTask(
Date contributionDeadline,
Date finalDeadline
) {
Task newTask = new Task(chainDealId, taskIndex, imageName,
commandLine, trust, maxExecutionTime, tag);
Task newTask = new Task(chainDealId, taskIndex, imageName, commandLine, trust, maxExecutionTime, tag);
newTask.setDealBlockNumber(dealBlockNumber);
newTask.setFinalDeadline(finalDeadline);
newTask.setContributionDeadline(contributionDeadline);
final String taskLogDetails = String.format("chainDealId:%s, taskIndex:%s, imageName:%s, commandLine:%s, trust:%s, contributionDeadline:%s, finalDeadline:%s",
chainDealId, taskIndex, imageName, commandLine, trust, contributionDeadline, finalDeadline);
try {
newTask = taskRepository.save(newTask);
log.info("Added new task [chainDealId:{}, taskIndex:{}, imageName:{}, commandLine:{}, trust:{}, chainTaskId:{}]",
chainDealId, taskIndex, imageName, commandLine, trust, newTask.getChainTaskId());
log.info("Added new task [{}}, chainTaskId:{}]", taskLogDetails, newTask.getChainTaskId());
return Optional.of(newTask);
} catch (DuplicateKeyException e) {
log.info("Task already added [chainDealId:{}, taskIndex:{}, imageName:{}, commandLine:{}, trust:{}]",
chainDealId, taskIndex, imageName, commandLine, trust);
log.info("Task already added [{}]", taskLogDetails);
return Optional.empty();
}
}

public long updateTaskStatus(Task task, TaskStatus currentStatus, List<TaskStatusChange> statusChanges) {
Update update = Update.update("currentStatus", task.getCurrentStatus());
update.push("dateStatusList").each(statusChanges);
UpdateResult result = mongoTemplate.updateFirst(
Query.query(Criteria.where(CHAIN_TASK_ID_FIELD).is(task.getChainTaskId())),
update,
Task.class);
log.debug("Updated chainTaskId [chainTaskId:{}, result:{}]", task.getChainTaskId(), result);
updateMetricsAfterStatusUpdate(currentStatus, task.getCurrentStatus());
/**
* Updates the status of a single task in the collection
*
* @param chainTaskId On-chain ID of the task to update
* @param currentStatus Expected {@code currentStatus} of the task when executing the update
* @param targetStatus Wished {@code currentStatus} the task should be updated to
* @param statusChanges List of {@code TaskStatusChange} to append to the {@code dateStatusList} field
* @return The number of updated documents in the task collection, should be {@literal 0} or {@literal 1} due to task ID uniqueness
*/
public long updateTaskStatus(String chainTaskId, TaskStatus currentStatus, TaskStatus targetStatus, List<TaskStatusChange> statusChanges) {
final Update update = Update.update(CURRENT_STATUS_FIELD_NAME, targetStatus)
.push(DATE_STATUS_LIST_FIELD_NAME).each(statusChanges);
final UpdateResult result = updateTask(chainTaskId, currentStatus, update);
return result.getModifiedCount();
}

public void updateTask(String chainTaskId, Update update) {
UpdateResult result = mongoTemplate.updateFirst(
Query.query(Criteria.where(CHAIN_TASK_ID_FIELD).is(chainTaskId)),
update, Task.class);
log.debug("Updated chainTaskId [chainTaskId:{}, result{}]", chainTaskId, result);
/**
* Update a single task in the collection
*
* @param chainTaskId On-chain ID of the task to update
* @param currentStatus Expected {@code currentStatus} of the task when executing the update
* @param update Update to execute on the task if criteria are respected
* @return The result of the update execution on the task collection
*/
public UpdateResult updateTask(String chainTaskId, TaskStatus currentStatus, Update update) {
final Criteria criteria = Criteria.where(CHAIN_TASK_ID_FIELD_NAME).is(chainTaskId)
.and(CURRENT_STATUS_FIELD_NAME).is(currentStatus);
// chainTaskId and currentStatus are part of the criteria, no need to add them explicitly
log.debug("Update request [criteria:{}, update:{}]",
criteria.getCriteriaObject(), update.getUpdateObject());
final UpdateResult result = mongoTemplate.updateFirst(Query.query(criteria), update, Task.class);
log.debug("Update execution result [chainTaskId:{}, result:{}]", chainTaskId, result);
if (result.getModifiedCount() == 0L) {
log.warn("The task was not updated [chainTaskId:{}]", chainTaskId);
} else if (isTaskCurrentStatusUpdated(update)) {
// A single document has been updated (chainTaskId uniqueness) and the currentStatus has been modified
updateMetricsAfterStatusUpdate(currentStatus, update.getUpdateObject().get("$set", Document.class)
.get(CURRENT_STATUS_FIELD_NAME, TaskStatus.class));
}
return result;
}

/**
* Checks if provided MongoDB update has modified the {@code currentStatus} field of the task
*
* @param update The MongoDB request to check
* @return {@literal true} if the {@code currentStatus} was updated, {@literal false} otherwise
*/
private boolean isTaskCurrentStatusUpdated(Update update) {
return update.getUpdateObject().containsKey("$set")
&& update.getUpdateObject().get("$set", Document.class).containsKey(CURRENT_STATUS_FIELD_NAME);
}

public Optional<Task> getTaskByChainTaskId(String chainTaskId) {
Expand Down Expand Up @@ -214,8 +246,8 @@ public Optional<Task> getPrioritizedInitializedOrRunningTask(
Arrays.asList(INITIALIZED, RUNNING),
excludedTags,
excludedChainTaskIds,
Sort.by(Sort.Order.desc(Task.CURRENT_STATUS_FIELD_NAME),
Sort.Order.asc(Task.CONTRIBUTION_DEADLINE_FIELD_NAME)));
Sort.by(Sort.Order.desc(CURRENT_STATUS_FIELD_NAME),
Sort.Order.asc(CONTRIBUTION_DEADLINE_FIELD_NAME)));
}

/**
Expand Down Expand Up @@ -249,23 +281,29 @@ private Optional<Task> findPrioritizedTask(List<TaskStatus> statuses,
);
}

public List<String> failMultipleTasksByQuery(Update update, Query query) {
/**
* Updates task on a given MongoDB query.
*
* @param query The query to perform to lookup for tasks in the collection
* @param update The update to execute on the tasks returned by the query
* @return The list of modified chain task ids
*/
public List<String> updateMultipleTasksByQuery(Query query, Update update) {
return mongoTemplate.find(query, Task.class).stream()
.map(task -> failSingleTask(update, task))
.map(task -> updateSingleTask(task, update))
.collect(Collectors.toList());
}

private String failSingleTask(Update update, Task task) {
final TaskStatus beforeUpdate = task.getCurrentStatus();
final UpdateResult updateResult = mongoTemplate.updateFirst(
Query.query(Criteria.where(CHAIN_TASK_ID_FIELD).is(task.getChainTaskId())), update, Task.class);
if (updateResult.getModifiedCount() == 0) {
log.warn("The task was not updated [chainTaskId:{}]", task.getChainTaskId());
return "";
} else {
updateMetricsAfterStatusUpdate(beforeUpdate, FAILED);
return task.getChainDealId();
}
/**
* Updates a single task in the task collection
*
* @param task The task to update
* @param update The update to perform
* @return The chain task id of the task
*/
private String updateSingleTask(Task task, Update update) {
final UpdateResult updateResult = updateTask(task.getChainTaskId(), task.getCurrentStatus(), update);
return updateResult.getModifiedCount() == 0L ? "" : task.getChainTaskId();
}

public List<String> getChainTaskIdsOfTasksExpiredBefore(Date expirationDate) {
Expand Down Expand Up @@ -325,6 +363,7 @@ public long countByCurrentStatus(TaskStatus status) {
}

void updateMetricsAfterStatusUpdate(TaskStatus previousStatus, TaskStatus newStatus) {
log.debug("updateMetricsAfterStatusUpdate [prev:{}, next:{}]", previousStatus, newStatus);
currentTaskStatusesCount.get(previousStatus).decrementAndGet();
currentTaskStatusesCount.get(newStatus).incrementAndGet();
publishTaskStatusesCountUpdate();
Expand Down
34 changes: 19 additions & 15 deletions src/main/java/com/iexec/core/task/update/TaskUpdateManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,47 +162,50 @@ void updateTask(String chainTaskId) {
* @param statuses List of statuses to append to the task {@code dateStatusList}
*/
void updateTaskStatusesAndSave(Task task, TaskStatus... statuses) {
TaskStatus currentStatus = task.getCurrentStatus();
List<TaskStatusChange> statusChanges = new ArrayList<>();
final TaskStatus currentStatus = task.getCurrentStatus();
final List<TaskStatusChange> statusChanges = new ArrayList<>();
for (TaskStatus newStatus : statuses) {
log.info("Create TaskStatusChange succeeded [chainTaskId:{}, currentStatus:{}, newStatus:{}]",
task.getChainTaskId(), task.getCurrentStatus(), newStatus);
final TaskStatusChange statusChange = TaskStatusChange.builder().status(newStatus).build();
// task update required by tests
task.setCurrentStatus(newStatus);
task.getDateStatusList().add(statusChange);
statusChanges.add(statusChange);
}
saveTask(task, currentStatus, statusChanges);
saveTask(task.getChainTaskId(), currentStatus, statuses[statuses.length - 1], statusChanges);
}

void updateTaskStatusAndSave(Task task, TaskStatus newStatus) {
updateTaskStatusAndSave(task, newStatus, null);
}

void updateTaskStatusAndSave(Task task, TaskStatus newStatus, ChainReceipt chainReceipt) {
TaskStatus currentStatus = task.getCurrentStatus();
TaskStatusChange statusChange = TaskStatusChange.builder().status(newStatus).chainReceipt(chainReceipt).build();
final TaskStatus currentStatus = task.getCurrentStatus();
final TaskStatusChange statusChange = TaskStatusChange.builder().status(newStatus).chainReceipt(chainReceipt).build();
// task update required by tests
task.setCurrentStatus(newStatus);
task.getDateStatusList().add(statusChange);
saveTask(task, currentStatus, List.of(statusChange));
saveTask(task.getChainTaskId(), currentStatus, newStatus, List.of(statusChange));
}

/**
* Saves the task to the database.
*
* @param task The task
* @param chainTaskId ID of the task
* @param currentStatus The current status in database
* @param wishedStatus The status the task should have after the update
* @param statusChanges List of changes
*/
void saveTask(Task task, TaskStatus currentStatus, List<TaskStatusChange> statusChanges) {
long updatedTaskCount = taskService.updateTaskStatus(task, currentStatus, statusChanges);
void saveTask(String chainTaskId, TaskStatus currentStatus, TaskStatus wishedStatus, List<TaskStatusChange> statusChanges) {
long updatedTaskCount = taskService.updateTaskStatus(chainTaskId, currentStatus, wishedStatus, statusChanges);
// `savedTask.isPresent()` should always be true if the task exists in the repository.
if (updatedTaskCount != 0L) {
log.info("UpdateTaskStatus succeeded [chainTaskId:{}, currentStatus:{}, newStatus:{}]",
task.getChainTaskId(), currentStatus, task.getCurrentStatus());
chainTaskId, currentStatus, wishedStatus);
} else {
log.warn("UpdateTaskStatus failed. Chain Task is probably unknown [chainTaskId:{}, currentStatus:{}, wishedStatus:{}]",
task.getChainTaskId(), currentStatus, task.getCurrentStatus());
chainTaskId, currentStatus, wishedStatus);
}
}
// endregion
Expand Down Expand Up @@ -244,7 +247,7 @@ void received2Initializing(Task task) {
}
task.setEnclaveChallenge(enclaveChallenge.get());
update.set("enclaveChallenge", enclaveChallenge.get());
taskService.updateTask(task.getChainTaskId(), update);
taskService.updateTask(task.getChainTaskId(), task.getCurrentStatus(), update);

blockchainAdapterService
.requestInitialize(task.getChainDealId(), task.getTaskIndex())
Expand Down Expand Up @@ -392,7 +395,7 @@ private void running2ConsensusReached(ChainTask chainTask, Task task, Replicates
task.setConsensus(chainTask.getConsensusValue());
long consensusBlockNumber = iexecHubService.getConsensusBlock(chainTaskId, task.getInitializationBlockNumber()).getBlockNumber();
task.setConsensusReachedBlockNumber(consensusBlockNumber);
taskService.updateTask(task.getChainTaskId(),
taskService.updateTask(task.getChainTaskId(), task.getCurrentStatus(),
Update.update("revealDeadline", task.getRevealDeadline())
.set("consensus", task.getConsensus())
.set("consensusReachedBlockNumber", task.getConsensusReachedBlockNumber()));
Expand Down Expand Up @@ -570,7 +573,7 @@ void resultUploading2Uploaded(ChainTask chainTask, Task task) {
if (uploadedReplicate != null) {
task.setResultLink(uploadedReplicate.getResultLink());
task.setChainCallbackData(uploadedReplicate.getChainCallbackData());
taskService.updateTask(task.getChainTaskId(),
taskService.updateTask(task.getChainTaskId(), task.getCurrentStatus(),
Update.update("resultLink", uploadedReplicate.getResultLink())
.set("chainCallbackData", uploadedReplicate.getChainCallbackData()));
updateTaskStatusAndSave(task, RESULT_UPLOADED);
Expand Down Expand Up @@ -616,7 +619,8 @@ void requestUpload(Task task) {
replicatesService.getRandomReplicateWithRevealStatus(task.getChainTaskId()).ifPresent(replicate -> {
// save in the task the workerWallet that is in charge of uploading the result
task.setUploadingWorkerWalletAddress(replicate.getWalletAddress());
taskService.updateTask(task.getChainTaskId(), Update.update("uploadingWorkerWalletAddress", replicate.getWalletAddress()));
taskService.updateTask(task.getChainTaskId(), task.getCurrentStatus(),
Update.update("uploadingWorkerWalletAddress", replicate.getWalletAddress()));
updateTaskStatusAndSave(task, RESULT_UPLOADING);
replicatesService.updateReplicateStatus(
task.getChainTaskId(), replicate.getWalletAddress(),
Expand Down
Loading

0 comments on commit 66c946b

Please sign in to comment.