Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mitigate potential race conditions by enforcing currentStatus value when updating a task #681

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file.
- Always provide a `WorkerpoolAuthorization` to a worker during its recovery. (#674)
- Move task metrics from `TaskUpdateManager` to `TaskService`. (#676)
- Fail fast when tasks are detected past their contribution or final deadline. (#677)
- Mitigate potential race conditions by enforcing `currentStatus` value when updating a task. (#681)

### Quality

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public void detect() {
.push("dateStatusList").each(
TaskStatusChange.builder().status(TaskStatus.CONTRIBUTION_TIMEOUT).build(),
TaskStatusChange.builder().status(TaskStatus.FAILED).build());
taskService.failMultipleTasksByQuery(update, query)
taskService.updateMultipleTasksByQuery(query, update)
.forEach(id -> applicationEventPublisher.publishEvent(new ContributionTimeoutEvent(id)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public void detect() {
.push("dateStatusList").each(
TaskStatusChange.builder().status(TaskStatus.FINAL_DEADLINE_REACHED).build(),
TaskStatusChange.builder().status(TaskStatus.FAILED).build());
taskService.failMultipleTasksByQuery(update, query)
taskService.updateMultipleTasksByQuery(query, update)
.forEach(id -> applicationEventPublisher.publishEvent(new TaskFailedEvent(id)));
}
}
2 changes: 2 additions & 0 deletions src/main/java/com/iexec/core/task/Task.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@
unique = true)
public class Task {

public static final String CHAIN_TASK_ID_FIELD_NAME = "chainTaskId";
public static final String CURRENT_STATUS_FIELD_NAME = "currentStatus";
public static final String CONTRIBUTION_DEADLINE_FIELD_NAME = "contributionDeadline";
public static final String DATE_STATUS_LIST_FIELD_NAME = "dateStatusList";

@Id
private String id;
Expand Down
3 changes: 0 additions & 3 deletions src/main/java/com/iexec/core/task/TaskRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ public interface TaskRepository extends MongoRepository<Task, String> {
@Query("{ 'currentStatus': {$in: ?0} }")
List<Task> findByCurrentStatus(List<TaskStatus> statuses);

@Query("{ 'currentStatus': {$in: ?0} }")
List<Task> findByCurrentStatus(List<TaskStatus> statuses, Sort sort);

/**
* Retrieves the prioritized task matching with given criteria:
* <ul>
Expand Down
139 changes: 89 additions & 50 deletions src/main/java/com/iexec/core/task/TaskService.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Metrics;
import lombok.extern.slf4j.Slf4j;
import org.bson.Document;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.event.EventListener;
import org.springframework.dao.DuplicateKeyException;
Expand All @@ -45,13 +46,13 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.iexec.core.task.Task.*;
import static com.iexec.core.task.TaskStatus.*;

@Slf4j
@Service
public class TaskService {

private static final String CHAIN_TASK_ID_FIELD = "chainTaskId";
public static final String METRIC_TASKS_STATUSES_COUNT = "iexec.core.tasks.count";
private final MongoTemplate mongoTemplate;
private final TaskRepository taskRepository;
Expand Down Expand Up @@ -106,21 +107,19 @@ private void initializeCurrentTaskStatusesCount() {
}

/**
* Save task in database if it does not
* already exist.
* Save task in database if it does not already exist.
*
* @param chainDealId
* @param taskIndex
* @param dealBlockNumber
* @param imageName
* @param commandLine
* @param trust
* @param maxExecutionTime
* @param tag
* @param contributionDeadline
* @param finalDeadline
* @return optional containing the saved
* task, {@link Optional#empty()} otherwise.
* @param chainDealId on-chain deal id
* @param taskIndex task index in deal
* @param dealBlockNumber block number when orders were matched to produce the current deal
* @param imageName OCI image to use for replicates computation
* @param commandLine command that will be executed during replicates computation
* @param trust trust level, impacts replication
* @param maxExecutionTime execution time
* @param tag deal tag describing additional features like TEE framework
* @param contributionDeadline date after which a worker cannot contribute
* @param finalDeadline date after which a task cannot be updated
* @return {@code Optional} containing the saved task, {@link Optional#empty()} otherwise.
*/
public Optional<Task> addTask(
String chainDealId,
Expand All @@ -134,40 +133,73 @@ public Optional<Task> addTask(
Date contributionDeadline,
Date finalDeadline
) {
Task newTask = new Task(chainDealId, taskIndex, imageName,
commandLine, trust, maxExecutionTime, tag);
Task newTask = new Task(chainDealId, taskIndex, imageName, commandLine, trust, maxExecutionTime, tag);
newTask.setDealBlockNumber(dealBlockNumber);
newTask.setFinalDeadline(finalDeadline);
newTask.setContributionDeadline(contributionDeadline);
final String taskLogDetails = String.format("chainDealId:%s, taskIndex:%s, imageName:%s, commandLine:%s, trust:%s, contributionDeadline:%s, finalDeadline:%s",
chainDealId, taskIndex, imageName, commandLine, trust, contributionDeadline, finalDeadline);
try {
newTask = taskRepository.save(newTask);
log.info("Added new task [chainDealId:{}, taskIndex:{}, imageName:{}, commandLine:{}, trust:{}, chainTaskId:{}]",
chainDealId, taskIndex, imageName, commandLine, trust, newTask.getChainTaskId());
log.info("Added new task [{}}, chainTaskId:{}]", taskLogDetails, newTask.getChainTaskId());
return Optional.of(newTask);
} catch (DuplicateKeyException e) {
log.info("Task already added [chainDealId:{}, taskIndex:{}, imageName:{}, commandLine:{}, trust:{}]",
chainDealId, taskIndex, imageName, commandLine, trust);
log.info("Task already added [{}]", taskLogDetails);
return Optional.empty();
}
}

public long updateTaskStatus(Task task, TaskStatus currentStatus, List<TaskStatusChange> statusChanges) {
Update update = Update.update("currentStatus", task.getCurrentStatus());
update.push("dateStatusList").each(statusChanges);
UpdateResult result = mongoTemplate.updateFirst(
Query.query(Criteria.where(CHAIN_TASK_ID_FIELD).is(task.getChainTaskId())),
update,
Task.class);
log.debug("Updated chainTaskId [chainTaskId:{}, result:{}]", task.getChainTaskId(), result);
updateMetricsAfterStatusUpdate(currentStatus, task.getCurrentStatus());
/**
* Updates the status of a single task in the collection
*
* @param chainTaskId On-chain ID of the task to update
* @param currentStatus Expected {@code currentStatus} of the task when executing the update
* @param targetStatus Wished {@code currentStatus} the task should be updated to
* @param statusChanges List of {@code TaskStatusChange} to append to the {@code dateStatusList} field
* @return The number of updated documents in the task collection, should be {@literal 0} or {@literal 1} due to task ID uniqueness
*/
public long updateTaskStatus(String chainTaskId, TaskStatus currentStatus, TaskStatus targetStatus, List<TaskStatusChange> statusChanges) {
final Update update = Update.update(CURRENT_STATUS_FIELD_NAME, targetStatus)
.push(DATE_STATUS_LIST_FIELD_NAME).each(statusChanges);
final UpdateResult result = updateTask(chainTaskId, currentStatus, update);
return result.getModifiedCount();
}

public void updateTask(String chainTaskId, Update update) {
UpdateResult result = mongoTemplate.updateFirst(
Query.query(Criteria.where(CHAIN_TASK_ID_FIELD).is(chainTaskId)),
update, Task.class);
log.debug("Updated chainTaskId [chainTaskId:{}, result{}]", chainTaskId, result);
/**
* Update a single task in the collection
*
* @param chainTaskId On-chain ID of the task to update
* @param currentStatus Expected {@code currentStatus} of the task when executing the update
* @param update Update to execute on the task if criteria are respected
* @return The result of the update execution on the task collection
*/
public UpdateResult updateTask(String chainTaskId, TaskStatus currentStatus, Update update) {
final Criteria criteria = Criteria.where(CHAIN_TASK_ID_FIELD_NAME).is(chainTaskId)
.and(CURRENT_STATUS_FIELD_NAME).is(currentStatus);
// chainTaskId and currentStatus are part of the criteria, no need to add them explicitly
log.debug("Update request [criteria:{}, update:{}]",
criteria.getCriteriaObject(), update.getUpdateObject());
final UpdateResult result = mongoTemplate.updateFirst(Query.query(criteria), update, Task.class);
log.debug("Update execution result [chainTaskId:{}, result:{}]", chainTaskId, result);
if (result.getModifiedCount() == 0L) {
log.warn("The task was not updated [chainTaskId:{}]", chainTaskId);
} else if (isTaskCurrentStatusUpdated(update)) {
// A single document has been updated (chainTaskId uniqueness) and the currentStatus has been modified
updateMetricsAfterStatusUpdate(currentStatus, update.getUpdateObject().get("$set", Document.class)
.get(CURRENT_STATUS_FIELD_NAME, TaskStatus.class));
}
return result;
}

/**
* Checks if provided MongoDB update has modified the {@code currentStatus} field of the task
*
* @param update The MongoDB request to check
* @return {@literal true} if the {@code currentStatus} was updated, {@literal false} otherwise
*/
private boolean isTaskCurrentStatusUpdated(Update update) {
return update.getUpdateObject().containsKey("$set")
&& update.getUpdateObject().get("$set", Document.class).containsKey(CURRENT_STATUS_FIELD_NAME);
}

public Optional<Task> getTaskByChainTaskId(String chainTaskId) {
Expand Down Expand Up @@ -214,8 +246,8 @@ public Optional<Task> getPrioritizedInitializedOrRunningTask(
Arrays.asList(INITIALIZED, RUNNING),
excludedTags,
excludedChainTaskIds,
Sort.by(Sort.Order.desc(Task.CURRENT_STATUS_FIELD_NAME),
Sort.Order.asc(Task.CONTRIBUTION_DEADLINE_FIELD_NAME)));
Sort.by(Sort.Order.desc(CURRENT_STATUS_FIELD_NAME),
Sort.Order.asc(CONTRIBUTION_DEADLINE_FIELD_NAME)));
}

/**
Expand Down Expand Up @@ -249,23 +281,29 @@ private Optional<Task> findPrioritizedTask(List<TaskStatus> statuses,
);
}

public List<String> failMultipleTasksByQuery(Update update, Query query) {
/**
* Updates task on a given MongoDB query.
*
* @param query The query to perform to lookup for tasks in the collection
* @param update The update to execute on the tasks returned by the query
* @return The list of modified chain task ids
*/
public List<String> updateMultipleTasksByQuery(Query query, Update update) {
return mongoTemplate.find(query, Task.class).stream()
.map(task -> failSingleTask(update, task))
.map(task -> updateSingleTask(task, update))
.collect(Collectors.toList());
}

private String failSingleTask(Update update, Task task) {
final TaskStatus beforeUpdate = task.getCurrentStatus();
final UpdateResult updateResult = mongoTemplate.updateFirst(
Query.query(Criteria.where(CHAIN_TASK_ID_FIELD).is(task.getChainTaskId())), update, Task.class);
if (updateResult.getModifiedCount() == 0) {
log.warn("The task was not updated [chainTaskId:{}]", task.getChainTaskId());
return "";
} else {
updateMetricsAfterStatusUpdate(beforeUpdate, FAILED);
return task.getChainDealId();
}
/**
* Updates a single task in the task collection
*
* @param task The task to update
* @param update The update to perform
* @return The chain task id of the task
*/
private String updateSingleTask(Task task, Update update) {
final UpdateResult updateResult = updateTask(task.getChainTaskId(), task.getCurrentStatus(), update);
return updateResult.getModifiedCount() == 0L ? "" : task.getChainTaskId();
}

public List<String> getChainTaskIdsOfTasksExpiredBefore(Date expirationDate) {
Expand Down Expand Up @@ -325,6 +363,7 @@ public long countByCurrentStatus(TaskStatus status) {
}

void updateMetricsAfterStatusUpdate(TaskStatus previousStatus, TaskStatus newStatus) {
log.debug("updateMetricsAfterStatusUpdate [prev:{}, next:{}]", previousStatus, newStatus);
currentTaskStatusesCount.get(previousStatus).decrementAndGet();
currentTaskStatusesCount.get(newStatus).incrementAndGet();
publishTaskStatusesCountUpdate();
Expand Down
34 changes: 19 additions & 15 deletions src/main/java/com/iexec/core/task/update/TaskUpdateManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,47 +162,50 @@ void updateTask(String chainTaskId) {
* @param statuses List of statuses to append to the task {@code dateStatusList}
*/
void updateTaskStatusesAndSave(Task task, TaskStatus... statuses) {
TaskStatus currentStatus = task.getCurrentStatus();
List<TaskStatusChange> statusChanges = new ArrayList<>();
final TaskStatus currentStatus = task.getCurrentStatus();
final List<TaskStatusChange> statusChanges = new ArrayList<>();
for (TaskStatus newStatus : statuses) {
log.info("Create TaskStatusChange succeeded [chainTaskId:{}, currentStatus:{}, newStatus:{}]",
task.getChainTaskId(), task.getCurrentStatus(), newStatus);
final TaskStatusChange statusChange = TaskStatusChange.builder().status(newStatus).build();
// task update required by tests
task.setCurrentStatus(newStatus);
task.getDateStatusList().add(statusChange);
statusChanges.add(statusChange);
}
saveTask(task, currentStatus, statusChanges);
saveTask(task.getChainTaskId(), currentStatus, statuses[statuses.length - 1], statusChanges);
}

void updateTaskStatusAndSave(Task task, TaskStatus newStatus) {
updateTaskStatusAndSave(task, newStatus, null);
}

void updateTaskStatusAndSave(Task task, TaskStatus newStatus, ChainReceipt chainReceipt) {
TaskStatus currentStatus = task.getCurrentStatus();
TaskStatusChange statusChange = TaskStatusChange.builder().status(newStatus).chainReceipt(chainReceipt).build();
final TaskStatus currentStatus = task.getCurrentStatus();
final TaskStatusChange statusChange = TaskStatusChange.builder().status(newStatus).chainReceipt(chainReceipt).build();
// task update required by tests
task.setCurrentStatus(newStatus);
task.getDateStatusList().add(statusChange);
saveTask(task, currentStatus, List.of(statusChange));
saveTask(task.getChainTaskId(), currentStatus, newStatus, List.of(statusChange));
}

/**
* Saves the task to the database.
*
* @param task The task
* @param chainTaskId ID of the task
* @param currentStatus The current status in database
* @param wishedStatus The status the task should have after the update
* @param statusChanges List of changes
*/
void saveTask(Task task, TaskStatus currentStatus, List<TaskStatusChange> statusChanges) {
long updatedTaskCount = taskService.updateTaskStatus(task, currentStatus, statusChanges);
void saveTask(String chainTaskId, TaskStatus currentStatus, TaskStatus wishedStatus, List<TaskStatusChange> statusChanges) {
long updatedTaskCount = taskService.updateTaskStatus(chainTaskId, currentStatus, wishedStatus, statusChanges);
// `savedTask.isPresent()` should always be true if the task exists in the repository.
if (updatedTaskCount != 0L) {
log.info("UpdateTaskStatus succeeded [chainTaskId:{}, currentStatus:{}, newStatus:{}]",
task.getChainTaskId(), currentStatus, task.getCurrentStatus());
chainTaskId, currentStatus, wishedStatus);
} else {
log.warn("UpdateTaskStatus failed. Chain Task is probably unknown [chainTaskId:{}, currentStatus:{}, wishedStatus:{}]",
task.getChainTaskId(), currentStatus, task.getCurrentStatus());
chainTaskId, currentStatus, wishedStatus);
}
}
// endregion
Expand Down Expand Up @@ -244,7 +247,7 @@ void received2Initializing(Task task) {
}
task.setEnclaveChallenge(enclaveChallenge.get());
update.set("enclaveChallenge", enclaveChallenge.get());
taskService.updateTask(task.getChainTaskId(), update);
taskService.updateTask(task.getChainTaskId(), task.getCurrentStatus(), update);

blockchainAdapterService
.requestInitialize(task.getChainDealId(), task.getTaskIndex())
Expand Down Expand Up @@ -392,7 +395,7 @@ private void running2ConsensusReached(ChainTask chainTask, Task task, Replicates
task.setConsensus(chainTask.getConsensusValue());
long consensusBlockNumber = iexecHubService.getConsensusBlock(chainTaskId, task.getInitializationBlockNumber()).getBlockNumber();
task.setConsensusReachedBlockNumber(consensusBlockNumber);
taskService.updateTask(task.getChainTaskId(),
taskService.updateTask(task.getChainTaskId(), task.getCurrentStatus(),
Update.update("revealDeadline", task.getRevealDeadline())
.set("consensus", task.getConsensus())
.set("consensusReachedBlockNumber", task.getConsensusReachedBlockNumber()));
Expand Down Expand Up @@ -570,7 +573,7 @@ void resultUploading2Uploaded(ChainTask chainTask, Task task) {
if (uploadedReplicate != null) {
task.setResultLink(uploadedReplicate.getResultLink());
task.setChainCallbackData(uploadedReplicate.getChainCallbackData());
taskService.updateTask(task.getChainTaskId(),
taskService.updateTask(task.getChainTaskId(), task.getCurrentStatus(),
Update.update("resultLink", uploadedReplicate.getResultLink())
.set("chainCallbackData", uploadedReplicate.getChainCallbackData()));
updateTaskStatusAndSave(task, RESULT_UPLOADED);
Expand Down Expand Up @@ -616,7 +619,8 @@ void requestUpload(Task task) {
replicatesService.getRandomReplicateWithRevealStatus(task.getChainTaskId()).ifPresent(replicate -> {
// save in the task the workerWallet that is in charge of uploading the result
task.setUploadingWorkerWalletAddress(replicate.getWalletAddress());
taskService.updateTask(task.getChainTaskId(), Update.update("uploadingWorkerWalletAddress", replicate.getWalletAddress()));
taskService.updateTask(task.getChainTaskId(), task.getCurrentStatus(),
Update.update("uploadingWorkerWalletAddress", replicate.getWalletAddress()));
updateTaskStatusAndSave(task, RESULT_UPLOADING);
replicatesService.updateReplicateStatus(
task.getChainTaskId(), replicate.getWalletAddress(),
Expand Down
Loading