diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java index 395359435..cf65420c5 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java @@ -32,7 +32,6 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; @@ -64,6 +63,7 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import org.opensearch.timeseries.util.SecurityUtil; import com.google.common.base.Throwables; @@ -509,7 +509,7 @@ private void stopAdJobForEndRunException( executionStartTime, error, true, - ADTaskState.STOPPED.name(), + TaskState.STOPPED.name(), recorder, detector ) diff --git a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java b/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java index fbe9787e9..19bb3b43d 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java +++ b/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java @@ -15,9 +15,9 @@ import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; -import static org.opensearch.ad.model.ADTaskType.taskTypeToString; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_DETECTOR_UPPER_LIMIT; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.model.TaskType.taskTypeToString; import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; @@ -39,7 +39,6 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorInternalState; @@ -59,6 +58,7 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import org.opensearch.timeseries.util.ExceptionUtil; /** @@ -245,7 +245,7 @@ private void createRealtimeADTask(Job job, String error, ConcurrentLinkedQueue { private AnomalyDetector detector = null; - private String state = null; - private Float taskProgress = null; - private Float initProgress = null; - private Instant currentPiece = null; - private Instant executionStartTime = null; - private Instant executionEndTime = null; - private Boolean isLatest = null; - private String error = null; - private String checkpointId = null; - private Instant lastUpdateTime = null; - private String startedBy = null; - private String stoppedBy = null; - private String coordinatingNode = null; - private String workerNode = null; private DateRange detectionDateRange = null; - private Entity entity = null; - private String parentTaskId; - private Integer estimatedMinutesLeft; - private User user = null; public Builder() {} - public Builder taskId(String taskId) { - this.taskId = taskId; - return this; - } - - public Builder lastUpdateTime(Instant lastUpdateTime) { - this.lastUpdateTime = lastUpdateTime; - return this; - } - - public Builder startedBy(String startedBy) { - this.startedBy = startedBy; - return this; - } - - public Builder stoppedBy(String stoppedBy) { - this.stoppedBy = stoppedBy; - return this; - } - - public Builder error(String error) { - this.error = error; - return this; - } - - public Builder state(String state) { - this.state = state; - return this; - } - - public Builder detectorId(String detectorId) { - this.detectorId = detectorId; - return this; - } - - public Builder taskProgress(Float taskProgress) { - this.taskProgress = taskProgress; - return this; - } - - public Builder initProgress(Float initProgress) { - this.initProgress = initProgress; - return this; - } - - public Builder currentPiece(Instant currentPiece) { - this.currentPiece = currentPiece; - return this; - } - - public Builder executionStartTime(Instant executionStartTime) { - this.executionStartTime = executionStartTime; - return this; - } - - public Builder executionEndTime(Instant executionEndTime) { - this.executionEndTime = executionEndTime; - return this; - } - - public Builder isLatest(Boolean isLatest) { - this.isLatest = isLatest; - return this; - } - - public Builder taskType(String taskType) { - this.taskType = taskType; - return this; - } - - public Builder checkpointId(String checkpointId) { - this.checkpointId = checkpointId; - return this; - } - public Builder detector(AnomalyDetector detector) { this.detector = detector; return this; } - public Builder coordinatingNode(String coordinatingNode) { - this.coordinatingNode = coordinatingNode; - return this; - } - - public Builder workerNode(String workerNode) { - this.workerNode = workerNode; - return this; - } - public Builder detectionDateRange(DateRange detectionDateRange) { this.detectionDateRange = detectionDateRange; return this; } - public Builder entity(Entity entity) { - this.entity = entity; - return this; - } - - public Builder parentTaskId(String parentTaskId) { - this.parentTaskId = parentTaskId; - return this; - } - - public Builder estimatedMinutesLeft(Integer estimatedMinutesLeft) { - this.estimatedMinutesLeft = estimatedMinutesLeft; - return this; - } - - public Builder user(User user) { - this.user = user; - return this; - } - public ADTask build() { ADTask adTask = new ADTask(); adTask.taskId = this.taskId; adTask.lastUpdateTime = this.lastUpdateTime; adTask.error = this.error; adTask.state = this.state; - adTask.detectorId = this.detectorId; + adTask.configId = this.configId; adTask.taskProgress = this.taskProgress; adTask.initProgress = this.initProgress; adTask.currentPiece = this.currentPiece; @@ -381,56 +195,9 @@ public ADTask build() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { XContentBuilder xContentBuilder = builder.startObject(); - if (taskId != null) { - xContentBuilder.field(TASK_ID_FIELD, taskId); - } - if (lastUpdateTime != null) { - xContentBuilder.field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()); - } - if (startedBy != null) { - xContentBuilder.field(STARTED_BY_FIELD, startedBy); - } - if (stoppedBy != null) { - xContentBuilder.field(STOPPED_BY_FIELD, stoppedBy); - } - if (error != null) { - xContentBuilder.field(ERROR_FIELD, error); - } - if (state != null) { - xContentBuilder.field(STATE_FIELD, state); - } - if (detectorId != null) { - xContentBuilder.field(DETECTOR_ID_FIELD, detectorId); - } - if (taskProgress != null) { - xContentBuilder.field(TASK_PROGRESS_FIELD, taskProgress); - } - if (initProgress != null) { - xContentBuilder.field(INIT_PROGRESS_FIELD, initProgress); - } - if (currentPiece != null) { - xContentBuilder.field(CURRENT_PIECE_FIELD, currentPiece.toEpochMilli()); - } - if (executionStartTime != null) { - xContentBuilder.field(EXECUTION_START_TIME_FIELD, executionStartTime.toEpochMilli()); - } - if (executionEndTime != null) { - xContentBuilder.field(EXECUTION_END_TIME_FIELD, executionEndTime.toEpochMilli()); - } - if (isLatest != null) { - xContentBuilder.field(IS_LATEST_FIELD, isLatest); - } - if (taskType != null) { - xContentBuilder.field(TASK_TYPE_FIELD, taskType); - } - if (checkpointId != null) { - xContentBuilder.field(CHECKPOINT_ID_FIELD, checkpointId); - } - if (coordinatingNode != null) { - xContentBuilder.field(COORDINATING_NODE_FIELD, coordinatingNode); - } - if (workerNode != null) { - xContentBuilder.field(WORKER_NODE_FIELD, workerNode); + xContentBuilder = super.toXContent(xContentBuilder, params); + if (configId != null) { + xContentBuilder.field(DETECTOR_ID_FIELD, configId); } if (detector != null) { xContentBuilder.field(DETECTOR_FIELD, detector); @@ -438,18 +205,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (detectionDateRange != null) { xContentBuilder.field(DETECTION_DATE_RANGE_FIELD, detectionDateRange); } - if (entity != null) { - xContentBuilder.field(ENTITY_FIELD, entity); - } - if (parentTaskId != null) { - xContentBuilder.field(PARENT_TASK_ID_FIELD, parentTaskId); - } - if (estimatedMinutesLeft != null) { - xContentBuilder.field(ESTIMATED_MINUTES_LEFT_FIELD, estimatedMinutesLeft); - } - if (user != null) { - xContentBuilder.field(USER_FIELD, user); - } return xContentBuilder.endObject(); } @@ -488,73 +243,73 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept parser.nextToken(); switch (fieldName) { - case LAST_UPDATE_TIME_FIELD: + case TimeSeriesTask.LAST_UPDATE_TIME_FIELD: lastUpdateTime = ParseUtils.toInstant(parser); break; - case STARTED_BY_FIELD: + case TimeSeriesTask.STARTED_BY_FIELD: startedBy = parser.text(); break; - case STOPPED_BY_FIELD: + case TimeSeriesTask.STOPPED_BY_FIELD: stoppedBy = parser.text(); break; - case ERROR_FIELD: + case TimeSeriesTask.ERROR_FIELD: error = parser.text(); break; - case STATE_FIELD: + case TimeSeriesTask.STATE_FIELD: state = parser.text(); break; case DETECTOR_ID_FIELD: detectorId = parser.text(); break; - case TASK_PROGRESS_FIELD: + case TimeSeriesTask.TASK_PROGRESS_FIELD: taskProgress = parser.floatValue(); break; - case INIT_PROGRESS_FIELD: + case TimeSeriesTask.INIT_PROGRESS_FIELD: initProgress = parser.floatValue(); break; - case CURRENT_PIECE_FIELD: + case TimeSeriesTask.CURRENT_PIECE_FIELD: currentPiece = ParseUtils.toInstant(parser); break; - case EXECUTION_START_TIME_FIELD: + case TimeSeriesTask.EXECUTION_START_TIME_FIELD: executionStartTime = ParseUtils.toInstant(parser); break; - case EXECUTION_END_TIME_FIELD: + case TimeSeriesTask.EXECUTION_END_TIME_FIELD: executionEndTime = ParseUtils.toInstant(parser); break; - case IS_LATEST_FIELD: + case TimeSeriesTask.IS_LATEST_FIELD: isLatest = parser.booleanValue(); break; - case TASK_TYPE_FIELD: + case TimeSeriesTask.TASK_TYPE_FIELD: taskType = parser.text(); break; - case CHECKPOINT_ID_FIELD: + case TimeSeriesTask.CHECKPOINT_ID_FIELD: checkpointId = parser.text(); break; case DETECTOR_FIELD: detector = AnomalyDetector.parse(parser); break; - case TASK_ID_FIELD: + case TimeSeriesTask.TASK_ID_FIELD: parsedTaskId = parser.text(); break; - case COORDINATING_NODE_FIELD: + case TimeSeriesTask.COORDINATING_NODE_FIELD: coordinatingNode = parser.text(); break; - case WORKER_NODE_FIELD: + case TimeSeriesTask.WORKER_NODE_FIELD: workerNode = parser.text(); break; case DETECTION_DATE_RANGE_FIELD: detectionDateRange = DateRange.parse(parser); break; - case ENTITY_FIELD: + case TimeSeriesTask.ENTITY_FIELD: entity = Entity.parse(parser); break; - case PARENT_TASK_ID_FIELD: + case TimeSeriesTask.PARENT_TASK_ID_FIELD: parentTaskId = parser.text(); break; - case ESTIMATED_MINUTES_LEFT_FIELD: + case TimeSeriesTask.ESTIMATED_MINUTES_LEFT_FIELD: estimatedMinutesLeft = parser.intValue(); break; - case USER_FIELD: + case TimeSeriesTask.USER_FIELD: user = User.parse(parser); break; default: @@ -591,7 +346,7 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept .stoppedBy(stoppedBy) .error(error) .state(state) - .detectorId(detectorId) + .configId(detectorId) .taskProgress(taskProgress) .initProgress(initProgress) .currentPiece(currentPiece) @@ -613,185 +368,35 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept @Generated @Override - public boolean equals(Object o) { - if (this == o) + public boolean equals(Object other) { + if (this == other) return true; - if (o == null || getClass() != o.getClass()) + if (other == null || getClass() != other.getClass()) return false; - ADTask that = (ADTask) o; - return Objects.equal(getTaskId(), that.getTaskId()) - && Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) - && Objects.equal(getStartedBy(), that.getStartedBy()) - && Objects.equal(getStoppedBy(), that.getStoppedBy()) - && Objects.equal(getError(), that.getError()) - && Objects.equal(getState(), that.getState()) - && Objects.equal(getId(), that.getId()) - && Objects.equal(getTaskProgress(), that.getTaskProgress()) - && Objects.equal(getInitProgress(), that.getInitProgress()) - && Objects.equal(getCurrentPiece(), that.getCurrentPiece()) - && Objects.equal(getExecutionStartTime(), that.getExecutionStartTime()) - && Objects.equal(getExecutionEndTime(), that.getExecutionEndTime()) - && Objects.equal(getLatest(), that.getLatest()) - && Objects.equal(getTaskType(), that.getTaskType()) - && Objects.equal(getCheckpointId(), that.getCheckpointId()) - && Objects.equal(getCoordinatingNode(), that.getCoordinatingNode()) - && Objects.equal(getWorkerNode(), that.getWorkerNode()) + ADTask that = (ADTask) other; + return super.equals(that) && Objects.equal(getDetector(), that.getDetector()) - && Objects.equal(getDetectionDateRange(), that.getDetectionDateRange()) - && Objects.equal(getEntity(), that.getEntity()) - && Objects.equal(getParentTaskId(), that.getParentTaskId()) - && Objects.equal(getEstimatedMinutesLeft(), that.getEstimatedMinutesLeft()) - && Objects.equal(getUser(), that.getUser()); + && Objects.equal(getDetectionDateRange(), that.getDetectionDateRange()); } @Generated @Override public int hashCode() { - return Objects - .hashCode( - taskId, - lastUpdateTime, - startedBy, - stoppedBy, - error, - state, - detectorId, - taskProgress, - initProgress, - currentPiece, - executionStartTime, - executionEndTime, - isLatest, - taskType, - checkpointId, - coordinatingNode, - workerNode, - detector, - detectionDateRange, - entity, - parentTaskId, - estimatedMinutesLeft, - user - ); - } - - public String getTaskId() { - return taskId; - } - - public void setTaskId(String taskId) { - this.taskId = taskId; - } - - public Instant getLastUpdateTime() { - return lastUpdateTime; - } - - public String getStartedBy() { - return startedBy; - } - - public String getStoppedBy() { - return stoppedBy; - } - - public String getError() { - return error; - } - - public void setError(String error) { - this.error = error; - } - - public String getState() { - return state; - } - - public void setState(String state) { - this.state = state; - } - - public String getId() { - return detectorId; - } - - public Float getTaskProgress() { - return taskProgress; - } - - public Float getInitProgress() { - return initProgress; - } - - public Instant getCurrentPiece() { - return currentPiece; - } - - public Instant getExecutionStartTime() { - return executionStartTime; - } - - public Instant getExecutionEndTime() { - return executionEndTime; - } - - public Boolean getLatest() { - return isLatest; - } - - public String getTaskType() { - return taskType; - } - - public String getCheckpointId() { - return checkpointId; + int superHashCode = super.hashCode(); + int hash = Objects.hashCode(configId, detector, detectionDateRange); + hash += 89 * superHashCode; + return hash; } public AnomalyDetector getDetector() { return detector; } - public String getCoordinatingNode() { - return coordinatingNode; - } - - public String getWorkerNode() { - return workerNode; - } - public DateRange getDetectionDateRange() { return detectionDateRange; } - public Entity getEntity() { - return entity; - } - - public String getEntityModelId() { - return entity == null ? null : entity.getModelId(getId()).orElse(null); - } - - public String getParentTaskId() { - return parentTaskId; - } - - public Integer getEstimatedMinutesLeft() { - return estimatedMinutesLeft; - } - - public User getUser() { - return user; - } - public void setDetectionDateRange(DateRange detectionDateRange) { this.detectionDateRange = detectionDateRange; } - - public void setLatest(Boolean latest) { - isLatest = latest; - } - - public void setLastUpdateTime(Instant lastUpdateTime) { - this.lastUpdateTime = lastUpdateTime; - } } diff --git a/src/main/java/org/opensearch/ad/model/ADTaskType.java b/src/main/java/org/opensearch/ad/model/ADTaskType.java index b4e06aefc..d235bad7e 100644 --- a/src/main/java/org/opensearch/ad/model/ADTaskType.java +++ b/src/main/java/org/opensearch/ad/model/ADTaskType.java @@ -12,11 +12,12 @@ package org.opensearch.ad.model; import java.util.List; -import java.util.stream.Collectors; + +import org.opensearch.timeseries.model.TaskType; import com.google.common.collect.ImmutableList; -public enum ADTaskType { +public enum ADTaskType implements TaskType { @Deprecated HISTORICAL, REALTIME_SINGLE_ENTITY, @@ -41,8 +42,4 @@ public enum ADTaskType { ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL ); - - public static List taskTypeToString(List adTaskTypes) { - return adTaskTypes.stream().map(type -> type.name()).collect(Collectors.toList()); - } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java index fac26992a..c33abe55b 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java @@ -32,7 +32,6 @@ import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyDetectorJobResponse; @@ -53,6 +52,7 @@ import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; @@ -346,7 +346,7 @@ public void stopAnomalyDetectorJob(String detectorId, ActionListener listener) { boolean isHCDetector = adTask.getDetector().isHighCardinality(); - if (isHCDetector && !adTaskCacheManager.topEntityInited(adTask.getId())) { + if (isHCDetector && !adTaskCacheManager.topEntityInited(adTask.getConfigId())) { // Initialize top entities for HC detector threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { ActionListener hcDelegatedListener = getInternalHCDelegatedListener(adTask); @@ -262,7 +262,7 @@ private ActionListener getTopEntitiesListener( ActionListener listener ) { String taskId = adTask.getTaskId(); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); ActionListener actionListener = ActionListener.wrap(response -> { adTaskCacheManager.setTopEntityInited(detectorId); int totalEntities = adTaskCacheManager.getPendingEntityCount(detectorId); @@ -390,16 +390,16 @@ private void searchTopEntitiesForMultiCategoryHC( logger.debug("finish searching top entities at " + System.currentTimeMillis()); List topNEntities = priorityTracker.getTopNEntities(maxTopEntitiesPerHcDetector); if (topNEntities.size() == 0) { - logger.error("There is no entity found for detector " + adTask.getId()); - internalHCListener.onFailure(new ResourceNotFoundException(adTask.getId(), "No entity found")); + logger.error("There is no entity found for detector " + adTask.getConfigId()); + internalHCListener.onFailure(new ResourceNotFoundException(adTask.getConfigId(), "No entity found")); return; } - adTaskCacheManager.addPendingEntities(adTask.getId(), topNEntities); - adTaskCacheManager.setTopEntityCount(adTask.getId(), topNEntities.size()); + adTaskCacheManager.addPendingEntities(adTask.getConfigId(), topNEntities); + adTaskCacheManager.setTopEntityCount(adTask.getConfigId(), topNEntities.size()); internalHCListener.onResponse("Get top entities done"); } }, e -> { - logger.error("Failed to get top entities for detector " + adTask.getId(), e); + logger.error("Failed to get top entities for detector " + adTask.getConfigId(), e); internalHCListener.onFailure(e); }); int minimumDocCount = Math.max((int) (bucketInterval / adTask.getDetector().getIntervalInMilliseconds()) / 2, 1); @@ -467,16 +467,16 @@ private void searchTopEntitiesForSingleCategoryHC( logger.debug("finish searching top entities at " + System.currentTimeMillis()); List topNEntities = priorityTracker.getTopNEntities(maxTopEntitiesPerHcDetector); if (topNEntities.size() == 0) { - logger.error("There is no entity found for detector " + adTask.getId()); - internalHCListener.onFailure(new ResourceNotFoundException(adTask.getId(), "No entity found")); + logger.error("There is no entity found for detector " + adTask.getConfigId()); + internalHCListener.onFailure(new ResourceNotFoundException(adTask.getConfigId(), "No entity found")); return; } - adTaskCacheManager.addPendingEntities(adTask.getId(), topNEntities); - adTaskCacheManager.setTopEntityCount(adTask.getId(), topNEntities.size()); + adTaskCacheManager.addPendingEntities(adTask.getConfigId(), topNEntities); + adTaskCacheManager.setTopEntityCount(adTask.getConfigId(), topNEntities.size()); internalHCListener.onResponse("Get top entities done"); } }, e -> { - logger.error("Failed to get top entities for detector " + adTask.getId(), e); + logger.error("Failed to get top entities for detector " + adTask.getConfigId(), e); internalHCListener.onFailure(e); }); // using the original context in listener as user roles have no permissions for internal operations like fetching a @@ -512,7 +512,7 @@ public void forwardOrExecuteADTask( ) { try { checkIfADTaskCancelledAndCleanupCache(adTask); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); AnomalyDetector detector = adTask.getDetector(); boolean isHCDetector = detector.isHighCardinality(); if (isHCDetector) { @@ -561,14 +561,14 @@ public void forwardOrExecuteADTask( logger.info("Create entity task for entity:{}", entityString); Instant now = Instant.now(); ADTask adEntityTask = new ADTask.Builder() - .detectorId(adTask.getId()) + .configId(adTask.getConfigId()) .detector(detector) .isLatest(true) .taskType(ADTaskType.HISTORICAL_HC_ENTITY.name()) .executionStartTime(now) .taskProgress(0.0f) .initProgress(0.0f) - .state(ADTaskState.INIT.name()) + .state(TaskState.INIT.name()) .initProgress(0.0f) .lastUpdateTime(now) .startedBy(adTask.getStartedBy()) @@ -595,7 +595,7 @@ public void forwardOrExecuteADTask( ); } else { Map updatedFields = new HashMap<>(); - updatedFields.put(STATE_FIELD, ADTaskState.INIT.name()); + updatedFields.put(STATE_FIELD, TaskState.INIT.name()); updatedFields.put(INIT_PROGRESS_FIELD, 0.0f); ActionListener workerNodeResponseListener = workerNodeResponseListener( adTask, @@ -639,7 +639,7 @@ private ActionListener workerNodeResponseListener( if (adTask.isEntityTask()) { // When reach this line, the entity task already been put into worker node's cache. // Then it's safe to move entity from temp entities queue to running entities queue. - adTaskCacheManager.moveToRunningEntity(adTask.getId(), adTaskManager.convertEntityToString(adTask)); + adTaskCacheManager.moveToRunningEntity(adTask.getConfigId(), adTaskManager.convertEntityToString(adTask)); } startNewEntityTaskLane(adTask, transportService); }, e -> { @@ -650,7 +650,7 @@ private ActionListener workerNodeResponseListener( if (adTask.getDetector().isHighCardinality()) { // Entity task done on worker node. Send entity task done message to coordinating node to poll next entity. adTaskManager.entityTaskDone(adTask, e, transportService); - if (adTaskCacheManager.getAvailableNewEntityTaskLanes(adTask.getId()) > 0) { + if (adTaskCacheManager.getAvailableNewEntityTaskLanes(adTask.getConfigId()) > 0) { // When reach this line, it means entity task failed to start on worker node // Sleep some time before starting new task lane. threadPool @@ -699,8 +699,8 @@ private void forwardOrExecuteEntityTask( // start new entity task lane private synchronized void startNewEntityTaskLane(ADTask adTask, TransportService transportService) { - if (adTask.getDetector().isHighCardinality() && adTaskCacheManager.getAndDecreaseEntityTaskLanes(adTask.getId()) > 0) { - logger.debug("start new task lane for detector {}", adTask.getId()); + if (adTask.getDetector().isHighCardinality() && adTaskCacheManager.getAndDecreaseEntityTaskLanes(adTask.getConfigId()) > 0) { + logger.debug("start new task lane for detector {}", adTask.getConfigId()); forwardOrExecuteADTask(adTask, transportService, getInternalHCDelegatedListener(adTask)); } } @@ -722,10 +722,10 @@ private void dispatchTask(ADTask adTask, ActionListener listener) .append(DEFAULT_JVM_HEAP_USAGE_THRESHOLD) .append("%. ") .append(NO_ELIGIBLE_NODE_TO_RUN_DETECTOR) - .append(adTask.getId()); + .append(adTask.getConfigId()); String errorMessage = errorMessageBuilder.toString(); logger.warn(errorMessage + ", task id " + adTask.getTaskId() + ", " + adTask.getTaskType()); - listener.onFailure(new LimitExceededException(adTask.getId(), errorMessage)); + listener.onFailure(new LimitExceededException(adTask.getConfigId(), errorMessage)); return; } candidateNodeResponse = candidateNodeResponse @@ -735,10 +735,10 @@ private void dispatchTask(ADTask adTask, ActionListener listener) if (candidateNodeResponse.size() == 0) { StringBuilder errorMessageBuilder = new StringBuilder("All nodes' executing batch tasks exceeds limitation ") .append(NO_ELIGIBLE_NODE_TO_RUN_DETECTOR) - .append(adTask.getId()); + .append(adTask.getConfigId()); String errorMessage = errorMessageBuilder.toString(); logger.warn(errorMessage + ", task id " + adTask.getTaskId() + ", " + adTask.getTaskType()); - listener.onFailure(new LimitExceededException(adTask.getId(), errorMessage)); + listener.onFailure(new LimitExceededException(adTask.getConfigId(), errorMessage)); return; } Optional targetNode = candidateNodeResponse @@ -798,8 +798,8 @@ public void startADBatchTaskOnWorkerNode( private ActionListener internalBatchTaskListener(ADTask adTask, TransportService transportService) { String taskId = adTask.getTaskId(); - String detectorTaskId = adTask.getDetectorLevelTaskId(); - String detectorId = adTask.getId(); + String detectorTaskId = adTask.getConfigLevelTaskId(); + String detectorId = adTask.getConfigId(); ActionListener listener = ActionListener.wrap(response -> { // If batch task finished normally, remove task from cache and decrease executing task count by 1. adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); @@ -810,11 +810,11 @@ private ActionListener internalBatchTaskListener(ADTask adTask, Transpor .cleanDetectorCache( adTask, transportService, - () -> adTaskManager.updateADTask(taskId, ImmutableMap.of(STATE_FIELD, ADTaskState.FINISHED.name())) + () -> adTaskManager.updateADTask(taskId, ImmutableMap.of(STATE_FIELD, TaskState.FINISHED.name())) ); } else { // Set entity task as FINISHED here - adTaskManager.updateADTask(adTask.getTaskId(), ImmutableMap.of(STATE_FIELD, ADTaskState.FINISHED.name())); + adTaskManager.updateADTask(adTask.getTaskId(), ImmutableMap.of(STATE_FIELD, TaskState.FINISHED.name())); adTaskManager.entityTaskDone(adTask, null, transportService); } }, e -> { @@ -866,7 +866,7 @@ private void executeADBatchTaskOnWorkerNode(ADTask adTask, ActionListener cons double maxValue = maxAgg.getValue(); // If time field not exist or there is no value, will return infinity value if (minValue == Double.POSITIVE_INFINITY) { - internalListener.onFailure(new ResourceNotFoundException(adTask.getId(), "There is no data in the time field")); + internalListener.onFailure(new ResourceNotFoundException(adTask.getConfigId(), "There is no data in the time field")); return; } long interval = ((IntervalTimeConfiguration) adTask.getDetector().getInterval()).toDuration().toMillis(); @@ -983,7 +983,8 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons long maxDate = (long) maxValue; if (minDate >= dataEndTime || maxDate <= dataStartTime) { - internalListener.onFailure(new ResourceNotFoundException(adTask.getId(), "There is no data in the detection date range")); + internalListener + .onFailure(new ResourceNotFoundException(adTask.getConfigId(), "There is no data in the detection date range")); return; } if (minDate > dataStartTime) { @@ -1099,8 +1100,8 @@ private void detectAnomaly( ? "No full shingle in current detection window" : "No data in current detection window"; AnomalyResult anomalyResult = new AnomalyResult( - adTask.getId(), - adTask.getDetectorLevelTaskId(), + adTask.getConfigId(), + adTask.getConfigLevelTaskId(), featureData, Instant.ofEpochMilli(intervalEndTime - interval), Instant.ofEpochMilli(intervalEndTime), @@ -1125,9 +1126,9 @@ private void detectAnomaly( AnomalyResult anomalyResult = AnomalyResult .fromRawTRCFResult( - adTask.getId(), + adTask.getConfigId(), adTask.getDetector().getIntervalInMilliseconds(), - adTask.getDetectorLevelTaskId(), + adTask.getConfigLevelTaskId(), score, descriptor.getAnomalyGrade(), descriptor.getDataConfidence(), @@ -1247,14 +1248,14 @@ private void runNextPiece( ActionListener internalListener ) { String taskId = adTask.getTaskId(); - String detectorId = adTask.getId(); - String detectorTaskId = adTask.getDetectorLevelTaskId(); + String detectorId = adTask.getConfigId(); + String detectorTaskId = adTask.getConfigLevelTaskId(); float initProgress = calculateInitProgress(taskId); - String taskState = initProgress >= 1.0f ? ADTaskState.RUNNING.name() : ADTaskState.INIT.name(); + String taskState = initProgress >= 1.0f ? TaskState.RUNNING.name() : TaskState.INIT.name(); logger.debug("Init progress: {}, taskState:{}, task id: {}", initProgress, taskState, taskId); if (initProgress >= 1.0f && adTask.isEntityTask()) { - updateDetectorLevelTaskState(detectorId, adTask.getParentTaskId(), ADTaskState.RUNNING.name()); + updateDetectorLevelTaskState(detectorId, adTask.getParentTaskId(), TaskState.RUNNING.name()); } if (pieceStartTime < dataEndTime) { @@ -1320,7 +1321,7 @@ private void runNextPiece( INIT_PROGRESS_FIELD, initProgress, STATE_FIELD, - ADTaskState.FINISHED + TaskState.FINISHED ), ActionListener.wrap(r -> internalListener.onResponse("task execution done"), e -> internalListener.onFailure(e)) ); @@ -1360,8 +1361,8 @@ private float calculateInitProgress(String taskId) { private void checkIfADTaskCancelledAndCleanupCache(ADTask adTask) { String taskId = adTask.getTaskId(); - String detectorId = adTask.getId(); - String detectorTaskId = adTask.getDetectorLevelTaskId(); + String detectorId = adTask.getConfigId(); + String detectorTaskId = adTask.getConfigLevelTaskId(); // refresh latest HC task run time adTaskCacheManager.refreshLatestHCTaskRunTime(detectorId); if (adTask.getDetector().isHighCardinality() diff --git a/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java b/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java index 91f00b4cd..7f4c70f81 100644 --- a/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java +++ b/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java @@ -13,7 +13,7 @@ import java.time.Instant; -import org.opensearch.ad.model.ADTaskState; +import org.opensearch.timeseries.model.TaskState; /** * Cache HC batch task running state on coordinating and worker node. @@ -32,7 +32,7 @@ public class ADHCBatchTaskRunState { private Long cancelledTimeInMillis; public ADHCBatchTaskRunState() { - this.detectorTaskState = ADTaskState.INIT.name(); + this.detectorTaskState = TaskState.INIT.name(); } public String getDetectorTaskState() { diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java index 0df994963..682a70a1b 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java @@ -38,7 +38,6 @@ import org.opensearch.action.ActionListener; import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -47,6 +46,7 @@ import org.opensearch.timeseries.common.exception.DuplicateTaskException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.TaskState; import org.opensearch.transport.TransportService; import com.amazon.randomcutforest.RandomCutForest; @@ -171,7 +171,7 @@ public ADTaskCacheManager(Settings settings, ClusterService clusterService, Memo */ public synchronized void add(ADTask adTask) { String taskId = adTask.getTaskId(); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); if (contains(taskId)) { throw new DuplicateTaskException(DETECTOR_IS_RUNNING); } @@ -189,7 +189,7 @@ public synchronized void add(ADTask adTask) { taskCache.getCacheMemorySize().set(neededCacheSize); batchTaskCaches.put(taskId, taskCache); if (adTask.isEntityTask()) { - ADHCBatchTaskRunState hcBatchTaskRunState = getHCBatchTaskRunState(detectorId, adTask.getDetectorLevelTaskId()); + ADHCBatchTaskRunState hcBatchTaskRunState = getHCBatchTaskRunState(detectorId, adTask.getConfigLevelTaskId()); if (hcBatchTaskRunState != null) { hcBatchTaskRunState.setLastTaskRunTimeInMillis(Instant.now().toEpochMilli()); } @@ -1035,7 +1035,7 @@ public boolean isRealtimeTaskChangeNeeded(String detectorId, String newState, Fl String oldState = realtimeTaskCache.getState(); if (newState != null && !newState.equals(oldState) - && !(ADTaskState.INIT.name().equals(newState) && ADTaskState.RUNNING.name().equals(oldState))) { + && !(TaskState.INIT.name().equals(newState) && TaskState.RUNNING.name().equals(oldState))) { stateChangeNeeded = true; } boolean initProgressChangeNeeded = false; @@ -1084,7 +1084,7 @@ public void updateRealtimeTaskCache(String detectorId, String newState, Float ne if (newError != null) { realtimeTaskCache.setError(newError); } - if (newState != null && !ADTaskState.NOT_ENDED_STATES.contains(newState)) { + if (newState != null && !TaskState.NOT_ENDED_STATES.contains(newState)) { // If task is done, will remove its realtime task cache. logger.info("Realtime task done with state {}, remove RT task cache for detector ", newState, detectorId); removeRealtimeTaskCache(detectorId); diff --git a/src/main/java/org/opensearch/ad/task/ADTaskManager.java b/src/main/java/org/opensearch/ad/task/ADTaskManager.java index 9cd21b6f1..c705f835e 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskManager.java @@ -33,11 +33,9 @@ import static org.opensearch.ad.model.ADTask.STOPPED_BY_FIELD; import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; -import static org.opensearch.ad.model.ADTaskState.NOT_ENDED_STATES; import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES; import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES; -import static org.opensearch.ad.model.ADTaskType.taskTypeToString; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DELETE_AD_RESULT_WHEN_DELETE_DETECTOR; @@ -53,6 +51,8 @@ import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import static org.opensearch.timeseries.constant.CommonName.TASK_ID_FIELD; +import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES; +import static org.opensearch.timeseries.model.TaskType.taskTypeToString; import static org.opensearch.timeseries.util.ExceptionUtil.getErrorMessage; import static org.opensearch.timeseries.util.ExceptionUtil.getShardsFailure; import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; @@ -106,7 +106,6 @@ import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.ad.model.ADTaskProfile; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfile; @@ -166,6 +165,7 @@ import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportRequestOptions; @@ -669,7 +669,7 @@ private void forwardToCoordinatingNode( .info( "There are {} task slots available now to scale historical analysis task lane for detector {}", approvedTaskSlots, - adTask.getId() + adTask.getConfigId() ); scaleTaskLaneOnCoordinatingNode(adTask, approvedTaskSlots, transportService, wrappedActionListener); break; @@ -707,7 +707,7 @@ private DiscoveryNode getCoordinatingNode(ADTask adTask) { } } if (targetNode == null) { - throw new ResourceNotFoundException(adTask.getId(), "AD task coordinating node not found"); + throw new ResourceNotFoundException(adTask.getConfigId(), "AD task coordinating node not found"); } return targetNode; } @@ -1086,7 +1086,7 @@ private void resetRealtimeDetectorTaskState( return; } ADTask adTask = runningRealtimeTasks.get(0); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); GetRequest getJobRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); client.get(getJobRequest, ActionListener.wrap(r -> { if (r.isExists()) { @@ -1143,7 +1143,7 @@ private void resetHistoricalDetectorTaskState( && !isNullOrEmpty(taskProfile.getEntityTaskProfiles())) { // If coordinating node restarted, HC detector cache on it will be gone. But worker node still // runs entity tasks, we'd better stop these entity tasks to clean up resource earlier. - stopHistoricalAnalysis(adTask.getId(), Optional.of(adTask), null, ActionListener.wrap(r -> { + stopHistoricalAnalysis(adTask.getConfigId(), Optional.of(adTask), null, ActionListener.wrap(r -> { logger.debug("Restop detector successfully"); resetTaskStateAsStopped(adTask, function, transportService, listener); }, e -> { @@ -1267,9 +1267,9 @@ private void resetTaskStateAsStopped( ) { cleanDetectorCache(adTask, transportService, () -> { String taskId = adTask.getTaskId(); - Map updatedFields = ImmutableMap.of(STATE_FIELD, ADTaskState.STOPPED.name()); + Map updatedFields = ImmutableMap.of(STATE_FIELD, TaskState.STOPPED.name()); updateADTask(taskId, updatedFields, ActionListener.wrap(r -> { - adTask.setState(ADTaskState.STOPPED.name()); + adTask.setState(TaskState.STOPPED.name()); if (function != null) { function.execute(); } @@ -1294,7 +1294,7 @@ private void resetEntityTasksAsStopped(String detectorTaskId) { query.filter(new TermsQueryBuilder(STATE_FIELD, NOT_ENDED_STATES)); updateByQueryRequest.setQuery(query); updateByQueryRequest.setRefresh(true); - String script = String.format(Locale.ROOT, "ctx._source.%s='%s';", STATE_FIELD, ADTaskState.STOPPED.name()); + String script = String.format(Locale.ROOT, "ctx._source.%s='%s';", STATE_FIELD, TaskState.STOPPED.name()); updateByQueryRequest.setScript(new Script(script)); client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { @@ -1329,7 +1329,7 @@ public void cleanDetectorCache( ActionListener listener ) { String coordinatingNode = adTask.getCoordinatingNode(); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); String taskId = adTask.getTaskId(); try { forwardADTaskToCoordinatingNode( @@ -1357,7 +1357,7 @@ public void cleanDetectorCache( } protected void cleanDetectorCache(ADTask adTask, TransportService transportService, ExecutorFunction function) { - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); String taskId = adTask.getTaskId(); cleanDetectorCache( adTask, @@ -1411,7 +1411,7 @@ public void getLatestHistoricalTaskProfile( * @param listener action listener */ private void getADTaskProfile(ADTask adDetectorLevelTask, ActionListener listener) { - String detectorId = adDetectorLevelTask.getId(); + String detectorId = adDetectorLevelTask.getConfigId(); hashRing.getAllEligibleDataNodesWithKnownAdVersion(dataNodes -> { ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes); @@ -1519,14 +1519,14 @@ private void createNewADTask( Instant now = Instant.now(); String taskType = getADTaskType(detector, detectionDateRange).name(); ADTask adTask = new ADTask.Builder() - .detectorId(detector.getId()) + .configId(detector.getId()) .detector(detector) .isLatest(true) .taskType(taskType) .executionStartTime(now) .taskProgress(0.0f) .initProgress(0.0f) - .state(ADTaskState.CREATED.name()) + .state(TaskState.CREATED.name()) .lastUpdateTime(now) .startedBy(userName) .coordinatingNode(coordinatingNode) @@ -1562,11 +1562,11 @@ public void createADTaskDirectly(ADTask adTask, Consumer func .source(adTask.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(request, ActionListener.wrap(r -> function.accept(r), e -> { - logger.error("Failed to create AD task for detector " + adTask.getId(), e); + logger.error("Failed to create AD task for detector " + adTask.getConfigId(), e); listener.onFailure(e); })); } catch (Exception e) { - logger.error("Failed to create AD task for detector " + adTask.getId(), e); + logger.error("Failed to create AD task for detector " + adTask.getConfigId(), e); listener.onFailure(e); } } @@ -1593,7 +1593,7 @@ private void onIndexADTaskResponse( // ADTaskManager#initRealtimeTaskCacheAndCleanupStaleCache for details. Here the // realtime task cache not inited yet when create AD task, so no need to cleanup. if (adTask.isHistoricalTask()) { - adTaskCacheManager.removeHistoricalTaskCache(adTask.getId()); + adTaskCacheManager.removeHistoricalTaskCache(adTask.getConfigId()); } listener.onFailure(e); } @@ -1603,7 +1603,7 @@ private void onIndexADTaskResponse( // DuplicateTaskException. This is to solve race condition when user send // multiple start request for one historical detector. if (adTask.isHistoricalTask()) { - adTaskCacheManager.add(adTask.getId(), adTask); + adTaskCacheManager.add(adTask.getConfigId(), adTask); } } catch (Exception e) { delegatedListener.onFailure(e); @@ -1616,7 +1616,7 @@ private void onIndexADTaskResponse( private void cleanOldAdTaskDocs(IndexResponse response, ADTask adTask, ActionListener delegatedListener) { BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, adTask.getId())); + query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, adTask.getConfigId())); query.filter(new TermQueryBuilder(IS_LATEST_FIELD, false)); if (adTask.isHistoricalTask()) { @@ -1637,7 +1637,7 @@ private void cleanOldAdTaskDocs(IndexResponse response, ADTask adTask, ActionLis .from(maxOldAdTaskDocsPerDetector) .size(MAX_OLD_AD_TASK_DOCS); searchRequest.source(sourceBuilder).indices(DETECTION_STATE_INDEX); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); deleteTaskDocs(detectorId, searchRequest, () -> { if (adTask.isHistoricalTask()) { @@ -1672,7 +1672,7 @@ protected void deleteTaskDocs( try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); ADTask adTask = ADTask.parse(parser, searchHit.getId()); - logger.debug("Delete old task: {} of detector: {}", adTask.getTaskId(), adTask.getId()); + logger.debug("Delete old task: {} of detector: {}", adTask.getTaskId(), adTask.getConfigId()); bulkRequest.add(new DeleteRequest(DETECTION_STATE_INDEX).id(adTask.getTaskId())); } catch (Exception e) { listener.onFailure(e); @@ -1746,7 +1746,7 @@ private void runBatchResultAction(IndexResponse response, ADTask adTask, ActionL .info( "AD task {} of detector {} dispatched to {} node {}", adTask.getTaskId(), - adTask.getId(), + adTask.getConfigId(), remoteOrLocal, r.getNodeId() ); @@ -1769,7 +1769,7 @@ private void runBatchResultAction(IndexResponse response, ADTask adTask, ActionL */ public void handleADTaskException(ADTask adTask, Exception e) { // TODO: handle timeout exception - String state = ADTaskState.FAILED.name(); + String state = TaskState.FAILED.name(); Map updatedFields = new HashMap<>(); if (e instanceof DuplicateTaskException) { // If user send multiple start detector request, we will meet race condition. @@ -1778,7 +1778,7 @@ public void handleADTaskException(ADTask adTask, Exception e) { logger .warn( "There is already one running task for detector, detectorId:" - + adTask.getId() + + adTask.getConfigId() + ". Will delete task " + adTask.getTaskId() ); @@ -1786,14 +1786,14 @@ public void handleADTaskException(ADTask adTask, Exception e) { return; } if (e instanceof TaskCancelledException) { - logger.info("AD task cancelled, taskId: {}, detectorId: {}", adTask.getTaskId(), adTask.getId()); - state = ADTaskState.STOPPED.name(); + logger.info("AD task cancelled, taskId: {}, detectorId: {}", adTask.getTaskId(), adTask.getConfigId()); + state = TaskState.STOPPED.name(); String stoppedBy = ((TaskCancelledException) e).getCancelledBy(); if (stoppedBy != null) { updatedFields.put(STOPPED_BY_FIELD, stoppedBy); } } else { - logger.error("Failed to execute AD batch task, task id: " + adTask.getTaskId() + ", detector id: " + adTask.getId(), e); + logger.error("Failed to execute AD batch task, task id: " + adTask.getTaskId() + ", detector id: " + adTask.getConfigId(), e); } updatedFields.put(ERROR_FIELD, getErrorMessage(e)); updatedFields.put(STATE_FIELD, state); @@ -1981,7 +1981,7 @@ public void updateLatestADTask( */ public void stopLatestRealtimeTask( String detectorId, - ADTaskState state, + TaskState state, Exception error, TransportService transportService, ActionListener listener @@ -2035,11 +2035,11 @@ public void updateLatestRealtimeTaskOnCoordinatingNode( String newState = null; // calculate init progress and task state with RCF total updates if (detectorIntervalInMinutes != null && rcfTotalUpdates != null) { - newState = ADTaskState.INIT.name(); + newState = TaskState.INIT.name(); if (rcfTotalUpdates < NUM_MIN_SAMPLES) { initProgress = (float) rcfTotalUpdates / NUM_MIN_SAMPLES; } else { - newState = ADTaskState.RUNNING.name(); + newState = TaskState.RUNNING.name(); initProgress = 1.0f; } } @@ -2289,10 +2289,10 @@ public boolean isRetryableError(String error) { * @param state AD task state * @param listener action listener */ - public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListener listener) { - String detectorId = adTask.getId(); + public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener listener) { + String detectorId = adTask.getConfigId(); String taskId = adTask.isEntityTask() ? adTask.getParentTaskId() : adTask.getTaskId(); - String detectorTaskId = adTask.getDetectorLevelTaskId(); + String detectorTaskId = adTask.getConfigLevelTaskId(); ActionListener wrappedListener = ActionListener.wrap(response -> { logger.info("Historical HC detector done with state: {}. Remove from cache, detector id:{}", state.name(), detectorId); @@ -2309,11 +2309,11 @@ public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListen }); long timeoutInMillis = 2000;// wait for 2 seconds to acquire updating HC detector task semaphore - if (state == ADTaskState.FINISHED) { - this.countEntityTasksByState(detectorTaskId, ImmutableList.of(ADTaskState.FINISHED), ActionListener.wrap(r -> { - logger.info("number of finished entity tasks: {}, for detector {}", r, adTask.getId()); + if (state == TaskState.FINISHED) { + this.countEntityTasksByState(detectorTaskId, ImmutableList.of(TaskState.FINISHED), ActionListener.wrap(r -> { + logger.info("number of finished entity tasks: {}, for detector {}", r, adTask.getConfigId()); // Set task as FAILED if no finished entity task; otherwise set as FINISHED - ADTaskState hcDetectorTaskState = r == 0 ? ADTaskState.FAILED : ADTaskState.FINISHED; + TaskState hcDetectorTaskState = r == 0 ? TaskState.FAILED : TaskState.FINISHED; // execute in AD batch task thread pool in case waiting for semaphore waste any shared OpenSearch thread pool threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { updateADHCDetectorTask( @@ -2343,7 +2343,7 @@ public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListen ImmutableMap .of( STATE_FIELD, - ADTaskState.FAILED.name(),// set as FAILED if fail to get finished entity tasks. + TaskState.FAILED.name(),// set as FAILED if fail to get finished entity tasks. TASK_PROGRESS_FIELD, 1.0, ERROR_FIELD, @@ -2387,7 +2387,7 @@ public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListen * @param taskStates task states * @param listener action listener */ - public void countEntityTasksByState(String detectorTaskId, List taskStates, ActionListener listener) { + public void countEntityTasksByState(String detectorTaskId, List taskStates, ActionListener listener) { BoolQueryBuilder queryBuilder = new BoolQueryBuilder(); queryBuilder.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, detectorTaskId)); if (taskStates != null && taskStates.size() > 0) { @@ -2496,7 +2496,7 @@ public void runNextEntityForHCADHistorical( TransportService transportService, ActionListener listener ) { - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); int scaleDelta = scaleTaskSlots( adTask, transportService, @@ -2547,7 +2547,7 @@ protected int scaleTaskSlots( TransportService transportService, ActionListener scaleUpListener ) { - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); if (!scaleEntityTaskLane.tryAcquire()) { logger.debug("Can't get scaleEntityTaskLane semaphore"); return 0; @@ -2792,14 +2792,14 @@ public synchronized void removeStaleRunningEntity( TransportService transportService, ActionListener listener ) { - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); boolean removed = adTaskCacheManager.removeRunningEntity(detectorId, entity); if (removed && adTaskCacheManager.getPendingEntityCount(detectorId) > 0) { logger.debug("kick off next pending entities"); this.runNextEntityForHCADHistorical(adTask, transportService, listener); } else { if (!adTaskCacheManager.hasEntity(detectorId)) { - setHCDetectorTaskDone(adTask, ADTaskState.STOPPED, listener); + setHCDetectorTaskDone(adTask, TaskState.STOPPED, listener); } } } @@ -3053,7 +3053,12 @@ private void maintainRunningHistoricalTask(ConcurrentLinkedQueue taskQue ActionListener .wrap( r -> { - logger.debug("Reset historical task state done for task {}, detector {}", adTask.getTaskId(), adTask.getId()); + logger + .debug( + "Reset historical task state done for task {}, detector {}", + adTask.getTaskId(), + adTask.getConfigId() + ); }, e -> { logger.error("Failed to reset historical task state for task " + adTask.getTaskId(), e); } ) diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java index 337e71e6e..addc6fa20 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java @@ -27,7 +27,6 @@ import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; @@ -37,6 +36,7 @@ import org.opensearch.tasks.Task; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.TaskState; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; @@ -121,7 +121,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener adTaskCacheManager.setDetectorTaskSlots(detectorId, 0); logger.info("Historical HC detector done, will remove from cache, detector id:{}", detectorId); listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK)); - ADTaskState state = !adTask.isEntityTask() && adTask.getError() != null ? ADTaskState.FAILED : ADTaskState.FINISHED; + TaskState state = !adTask.isEntityTask() && adTask.getError() != null ? TaskState.FAILED : TaskState.FINISHED; adTaskManager.setHCDetectorTaskDone(adTask, state, listener); } else { logger.debug("Run next entity for detector " + detectorId); @@ -133,7 +133,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener ImmutableMap .of( STATE_FIELD, - ADTaskState.RUNNING.name(), + TaskState.RUNNING.name(), TASK_PROGRESS_FIELD, adTaskManager.hcDetectorProgress(detectorId), ERROR_FIELD, @@ -157,18 +157,18 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener if (adTask.isEntityTask()) { // AD task must be entity level task. adTaskCacheManager.removeRunningEntity(detectorId, entityValue); if (adTaskManager.isRetryableError(adTask.getError()) - && !adTaskCacheManager.exceedRetryLimit(adTask.getId(), adTask.getTaskId())) { + && !adTaskCacheManager.exceedRetryLimit(adTask.getConfigId(), adTask.getTaskId())) { // If retryable exception happens when run entity task, will push back entity to the end // of pending entities queue, then we can retry it later. - adTaskCacheManager.pushBackEntity(adTask.getTaskId(), adTask.getId(), entityValue); + adTaskCacheManager.pushBackEntity(adTask.getTaskId(), adTask.getConfigId(), entityValue); } else { // If exception is not retryable or exceeds retry limit, will remove this entity. - adTaskCacheManager.removeEntity(adTask.getId(), entityValue); + adTaskCacheManager.removeEntity(adTask.getConfigId(), entityValue); logger.warn("Entity task failed, task id: {}, entity: {}", adTask.getTaskId(), adTask.getEntity().toString()); } if (!adTaskCacheManager.hasEntity(detectorId)) { adTaskCacheManager.setDetectorTaskSlots(detectorId, 0); - adTaskManager.setHCDetectorTaskDone(adTask, ADTaskState.FINISHED, listener); + adTaskManager.setHCDetectorTaskDone(adTask, TaskState.FINISHED, listener); } else { logger.debug("scale task slots for PUSH_BACK_ENTITY, detector {} task {}", detectorId, adTask.getTaskId()); int taskSlots = adTaskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, 1); @@ -204,7 +204,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener adTaskCacheManager.clearPendingEntities(detectorId); adTaskCacheManager.removeRunningEntity(detectorId, entityValue); if (!adTaskCacheManager.hasEntity(detectorId) || !adTask.isEntityTask()) { - adTaskManager.setHCDetectorTaskDone(adTask, ADTaskState.STOPPED, listener); + adTaskManager.setHCDetectorTaskDone(adTask, TaskState.STOPPED, listener); } listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.OK)); } else { diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTask.java b/src/main/java/org/opensearch/forecast/model/ForecastTask.java new file mode 100644 index 000000000..4d7e889d7 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecastTask.java @@ -0,0 +1,389 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +public class ForecastTask extends TimeSeriesTask { + public static final String FORECASTER_ID_FIELD = "forecaster_id"; + public static final String FORECASTER_FIELD = "forecaster"; + public static final String DATE_RANGE_FIELD = "date_range"; + + private Forecaster forecaster = null; + private DateRange dateRange = null; + + private ForecastTask() {} + + public ForecastTask(StreamInput input) throws IOException { + this.taskId = input.readOptionalString(); + this.taskType = input.readOptionalString(); + this.configId = input.readOptionalString(); + if (input.readBoolean()) { + this.forecaster = new Forecaster(input); + } else { + this.forecaster = null; + } + this.state = input.readOptionalString(); + this.taskProgress = input.readOptionalFloat(); + this.initProgress = input.readOptionalFloat(); + this.currentPiece = input.readOptionalInstant(); + this.executionStartTime = input.readOptionalInstant(); + this.executionEndTime = input.readOptionalInstant(); + this.isLatest = input.readOptionalBoolean(); + this.error = input.readOptionalString(); + this.checkpointId = input.readOptionalString(); + this.lastUpdateTime = input.readOptionalInstant(); + this.startedBy = input.readOptionalString(); + this.stoppedBy = input.readOptionalString(); + this.coordinatingNode = input.readOptionalString(); + this.workerNode = input.readOptionalString(); + if (input.readBoolean()) { + this.user = new User(input); + } else { + user = null; + } + if (input.readBoolean()) { + this.dateRange = new DateRange(input); + } else { + this.dateRange = null; + } + if (input.readBoolean()) { + this.entity = new Entity(input); + } else { + this.entity = null; + } + this.parentTaskId = input.readOptionalString(); + this.estimatedMinutesLeft = input.readOptionalInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(taskId); + out.writeOptionalString(taskType); + out.writeOptionalString(configId); + if (forecaster != null) { + out.writeBoolean(true); + forecaster.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(state); + out.writeOptionalFloat(taskProgress); + out.writeOptionalFloat(initProgress); + out.writeOptionalInstant(currentPiece); + out.writeOptionalInstant(executionStartTime); + out.writeOptionalInstant(executionEndTime); + out.writeOptionalBoolean(isLatest); + out.writeOptionalString(error); + out.writeOptionalString(checkpointId); + out.writeOptionalInstant(lastUpdateTime); + out.writeOptionalString(startedBy); + out.writeOptionalString(stoppedBy); + out.writeOptionalString(coordinatingNode); + out.writeOptionalString(workerNode); + if (user != null) { + out.writeBoolean(true); // user exists + user.writeTo(out); + } else { + out.writeBoolean(false); // user does not exist + } + // Only forward forecast task to nodes with same version, so it's ok to write these new fields. + if (dateRange != null) { + out.writeBoolean(true); + dateRange.writeTo(out); + } else { + out.writeBoolean(false); + } + if (entity != null) { + out.writeBoolean(true); + entity.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(parentTaskId); + out.writeOptionalInt(estimatedMinutesLeft); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public boolean isEntityTask() { + return ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY.name().equals(taskType); + } + + public static class Builder extends TimeSeriesTask.Builder { + private Forecaster forecaster = null; + private DateRange dateRange = null; + + public Builder() {} + + public Builder forecaster(Forecaster forecaster) { + this.forecaster = forecaster; + return this; + } + + public Builder dateRange(DateRange dateRange) { + this.dateRange = dateRange; + return this; + } + + public ForecastTask build() { + ForecastTask forecastTask = new ForecastTask(); + forecastTask.taskId = this.taskId; + forecastTask.lastUpdateTime = this.lastUpdateTime; + forecastTask.error = this.error; + forecastTask.state = this.state; + forecastTask.configId = this.configId; + forecastTask.taskProgress = this.taskProgress; + forecastTask.initProgress = this.initProgress; + forecastTask.currentPiece = this.currentPiece; + forecastTask.executionStartTime = this.executionStartTime; + forecastTask.executionEndTime = this.executionEndTime; + forecastTask.isLatest = this.isLatest; + forecastTask.taskType = this.taskType; + forecastTask.checkpointId = this.checkpointId; + forecastTask.forecaster = this.forecaster; + forecastTask.startedBy = this.startedBy; + forecastTask.stoppedBy = this.stoppedBy; + forecastTask.coordinatingNode = this.coordinatingNode; + forecastTask.workerNode = this.workerNode; + forecastTask.dateRange = this.dateRange; + forecastTask.entity = this.entity; + forecastTask.parentTaskId = this.parentTaskId; + forecastTask.estimatedMinutesLeft = this.estimatedMinutesLeft; + forecastTask.user = this.user; + + return forecastTask; + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder = super.toXContent(xContentBuilder, params); + if (configId != null) { + xContentBuilder.field(FORECASTER_ID_FIELD, configId); + } + if (forecaster != null) { + xContentBuilder.field(FORECASTER_FIELD, forecaster); + } + if (dateRange != null) { + xContentBuilder.field(DATE_RANGE_FIELD, dateRange); + } + return xContentBuilder.endObject(); + } + + public static ForecastTask parse(XContentParser parser) throws IOException { + return parse(parser, null); + } + + public static ForecastTask parse(XContentParser parser, String taskId) throws IOException { + Instant lastUpdateTime = null; + String startedBy = null; + String stoppedBy = null; + String error = null; + String state = null; + String configId = null; + Float taskProgress = null; + Float initProgress = null; + Instant currentPiece = null; + Instant executionStartTime = null; + Instant executionEndTime = null; + Boolean isLatest = null; + String taskType = null; + String checkpointId = null; + Forecaster forecaster = null; + String parsedTaskId = taskId; + String coordinatingNode = null; + String workerNode = null; + DateRange dateRange = null; + Entity entity = null; + String parentTaskId = null; + Integer estimatedMinutesLeft = null; + User user = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case LAST_UPDATE_TIME_FIELD: + lastUpdateTime = ParseUtils.toInstant(parser); + break; + case STARTED_BY_FIELD: + startedBy = parser.text(); + break; + case STOPPED_BY_FIELD: + stoppedBy = parser.text(); + break; + case ERROR_FIELD: + error = parser.text(); + break; + case STATE_FIELD: + state = parser.text(); + break; + case FORECASTER_ID_FIELD: + configId = parser.text(); + break; + case TASK_PROGRESS_FIELD: + taskProgress = parser.floatValue(); + break; + case INIT_PROGRESS_FIELD: + initProgress = parser.floatValue(); + break; + case CURRENT_PIECE_FIELD: + currentPiece = ParseUtils.toInstant(parser); + break; + case EXECUTION_START_TIME_FIELD: + executionStartTime = ParseUtils.toInstant(parser); + break; + case EXECUTION_END_TIME_FIELD: + executionEndTime = ParseUtils.toInstant(parser); + break; + case IS_LATEST_FIELD: + isLatest = parser.booleanValue(); + break; + case TASK_TYPE_FIELD: + taskType = parser.text(); + break; + case CHECKPOINT_ID_FIELD: + checkpointId = parser.text(); + break; + case FORECASTER_FIELD: + forecaster = Forecaster.parse(parser); + break; + case TASK_ID_FIELD: + parsedTaskId = parser.text(); + break; + case COORDINATING_NODE_FIELD: + coordinatingNode = parser.text(); + break; + case WORKER_NODE_FIELD: + workerNode = parser.text(); + break; + case DATE_RANGE_FIELD: + dateRange = DateRange.parse(parser); + break; + case ENTITY_FIELD: + entity = Entity.parse(parser); + break; + case PARENT_TASK_ID_FIELD: + parentTaskId = parser.text(); + break; + case ESTIMATED_MINUTES_LEFT_FIELD: + estimatedMinutesLeft = parser.intValue(); + break; + case USER_FIELD: + user = User.parse(parser); + break; + default: + parser.skipChildren(); + break; + } + } + Forecaster copyForecaster = forecaster == null + ? null + : new Forecaster( + configId, + forecaster.getVersion(), + forecaster.getName(), + forecaster.getDescription(), + forecaster.getTimeField(), + forecaster.getIndices(), + forecaster.getFeatureAttributes(), + forecaster.getFilterQuery(), + forecaster.getInterval(), + forecaster.getWindowDelay(), + forecaster.getShingleSize(), + forecaster.getUiMetadata(), + forecaster.getSchemaVersion(), + forecaster.getLastUpdateTime(), + forecaster.getCategoryFields(), + forecaster.getUser(), + forecaster.getCustomResultIndex(), + forecaster.getHorizon(), + forecaster.getImputationOption() + ); + return new Builder() + .taskId(parsedTaskId) + .lastUpdateTime(lastUpdateTime) + .startedBy(startedBy) + .stoppedBy(stoppedBy) + .error(error) + .state(state) + .configId(configId) + .taskProgress(taskProgress) + .initProgress(initProgress) + .currentPiece(currentPiece) + .executionStartTime(executionStartTime) + .executionEndTime(executionEndTime) + .isLatest(isLatest) + .taskType(taskType) + .checkpointId(checkpointId) + .coordinatingNode(coordinatingNode) + .workerNode(workerNode) + .forecaster(copyForecaster) + .dateRange(dateRange) + .entity(entity) + .parentTaskId(parentTaskId) + .estimatedMinutesLeft(estimatedMinutesLeft) + .user(user) + .build(); + } + + @Generated + @Override + public boolean equals(Object other) { + if (this == other) + return true; + if (other == null || getClass() != other.getClass()) + return false; + ForecastTask that = (ForecastTask) other; + return super.equals(that) + && Objects.equal(getForecaster(), that.getForecaster()) + && Objects.equal(getDateRange(), that.getDateRange()); + } + + @Generated + @Override + public int hashCode() { + int superHashCode = super.hashCode(); + int hash = Objects.hashCode(configId, forecaster, dateRange); + hash += 89 * superHashCode; + return hash; + } + + public Forecaster getForecaster() { + return forecaster; + } + + public DateRange getDateRange() { + return dateRange; + } + + public void setDateRange(DateRange dateRange) { + this.dateRange = dateRange; + } +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java b/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java new file mode 100644 index 000000000..76e1aac88 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java @@ -0,0 +1,69 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.model; + +import java.util.List; + +import org.opensearch.timeseries.model.TaskType; + +import com.google.common.collect.ImmutableList; + +/** + * The ForecastTaskType enum defines different task types for forecasting, categorized into real-time and historical settings. + * In real-time forecasting, we monitor states at the forecaster level, resulting in two distinct task types: one for + * single-stream forecasting and another for high cardinality (HC). In the historical setting, state tracking is more nuanced, + * encompassing both entity and forecaster levels. This leads to three specific task types: a forecaster-level task dedicated + * to single-stream forecasting, and two tasks for HC, one at the forecaster level and another at the entity level. + * + * Real-time forecasting: + * - FORECAST_REALTIME_SINGLE_STREAM: Represents a task type for single-stream forecasting. Ideal for scenarios where a single + * time series is processed in real-time. + * - FORECAST_REALTIME_HC_FORECASTER: Represents a task type for high cardinality (HC) forecasting. Used when dealing with a + * large number of distinct entities in real-time. + * + * Historical forecasting: + * - FORECAST_HISTORICAL_SINGLE_STREAM: Represents a forecaster-level task for single-stream historical forecasting. + * Suitable for analyzing a single time series in a sequential manner. + * - FORECAST_HISTORICAL_HC_FORECASTER: A forecaster-level task to track overall state, initialization progress, errors, etc., + * for HC forecasting. Central to managing multiple historical time series with high cardinality. + * - FORECAST_HISTORICAL_HC_ENTITY: An entity-level task to track the state, initialization progress, errors, etc., of a + * specific entity within HC historical forecasting. Allows for fine-grained information recording at the entity level. + * + */ +public enum ForecastTaskType implements TaskType { + FORECAST_REALTIME_SINGLE_STREAM, + FORECAST_REALTIME_HC_FORECASTER, + FORECAST_HISTORICAL_SINGLE_STREAM, + // forecaster level task to track overall state, init progress, error etc. for HC forecaster + FORECAST_HISTORICAL_HC_FORECASTER, + // entity level task to track just one specific entity's state, init progress, error etc. + FORECAST_HISTORICAL_HC_ENTITY; + + public static List HISTORICAL_FORECASTER_TASK_TYPES = ImmutableList + .of(ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM); + public static List ALL_HISTORICAL_TASK_TYPES = ImmutableList + .of( + ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, + ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, + ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + ); + public static List REALTIME_TASK_TYPES = ImmutableList + .of(ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER); + public static List ALL_FORECAST_TASK_TYPES = ImmutableList + .of( + ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, + ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER, + ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, + ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, + ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + ); +} diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java index a9d033b78..6b4078ad4 100644 --- a/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java +++ b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java @@ -39,6 +39,17 @@ public final class ForecastSettings { Setting.Property.Dynamic ); + // ====================================== + // cleanup resouce setting + // ====================================== + public static final Setting DELETE_FORECAST_RESULT_WHEN_DELETE_FORECASTER = Setting + .boolSetting( + "plugins.forecast.delete_forecast_result_when_delete_forecaster", + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + // ====================================== // resource constraint // ====================================== diff --git a/src/main/java/org/opensearch/ad/model/ADTaskState.java b/src/main/java/org/opensearch/timeseries/model/TaskState.java similarity index 68% rename from src/main/java/org/opensearch/ad/model/ADTaskState.java rename to src/main/java/org/opensearch/timeseries/model/TaskState.java index 68462f816..2b5c4240e 100644 --- a/src/main/java/org/opensearch/ad/model/ADTaskState.java +++ b/src/main/java/org/opensearch/timeseries/model/TaskState.java @@ -9,47 +9,47 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.util.List; import com.google.common.collect.ImmutableList; /** - * AD task states. + * AD and forecasting task states. *
    *
  • CREATED: - * When user start a historical detector, we will create one task to track the detector + * AD: When user start a historical detector, we will create one task to track the detector * execution and set its state as CREATED * *
  • INIT: - * After task created, coordinate node will gather all eligible node’s state and dispatch + * AD: After task created, coordinate node will gather all eligible node’s state and dispatch * task to the worker node with lowest load. When the worker node receives the request, * it will set the task state as INIT immediately, then start to run cold start to train * RCF model. We will track the initialization progress in task. * Init_Progress=ModelUpdates/MinSampleSize * *
  • RUNNING: - * If RCF model gets enough data points and passed training, it will start to detect data + * AD: If RCF model gets enough data points and passed training, it will start to detect data * normally and output positive anomaly scores. Once the RCF model starts to output positive * anomaly score, we will set the task state as RUNNING and init progress as 100%. We will * track task running progress in task: Task_Progress=DetectedPieces/AllPieces * *
  • FINISHED: - * When all historical data detected, we set the task state as FINISHED and task progress + * AD: When all historical data detected, we set the task state as FINISHED and task progress * as 100%. * *
  • STOPPED: - * User can cancel a running task by stopping detector, for example, user want to tune + * AD: User can cancel a running task by stopping detector, for example, user want to tune * feature and reran and don’t want current task run any more. When a historical detector * stopped, we will mark the task flag cancelled as true, when run next piece, we will * check this flag and stop the task. Then task stopped, will set its state as STOPPED * *
  • FAILED: - * If any exception happen, we will set task state as FAILED + * AD: If any exception happen, we will set task state as FAILED *
*/ -public enum ADTaskState { +public enum TaskState { CREATED, INIT, RUNNING, @@ -58,5 +58,5 @@ public enum ADTaskState { FINISHED; public static List NOT_ENDED_STATES = ImmutableList - .of(ADTaskState.CREATED.name(), ADTaskState.INIT.name(), ADTaskState.RUNNING.name()); + .of(TaskState.CREATED.name(), TaskState.INIT.name(), TaskState.RUNNING.name()); } diff --git a/src/main/java/org/opensearch/timeseries/model/TaskType.java b/src/main/java/org/opensearch/timeseries/model/TaskType.java new file mode 100644 index 000000000..74481871d --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/TaskType.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.model; + +import java.util.List; +import java.util.stream.Collectors; + +public interface TaskType { + String name(); + + public static List taskTypeToString(List adTaskTypes) { + return adTaskTypes.stream().map(type -> type.name()).collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java b/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java new file mode 100644 index 000000000..fd57de7cd --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java @@ -0,0 +1,448 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.model; + +import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.annotation.Generated; + +import com.google.common.base.Objects; + +public abstract class TimeSeriesTask implements ToXContentObject, Writeable { + + public static final String TASK_ID_FIELD = "task_id"; + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + public static final String STARTED_BY_FIELD = "started_by"; + public static final String STOPPED_BY_FIELD = "stopped_by"; + public static final String ERROR_FIELD = "error"; + public static final String STATE_FIELD = "state"; + public static final String TASK_PROGRESS_FIELD = "task_progress"; + public static final String INIT_PROGRESS_FIELD = "init_progress"; + public static final String CURRENT_PIECE_FIELD = "current_piece"; + public static final String EXECUTION_START_TIME_FIELD = "execution_start_time"; + public static final String EXECUTION_END_TIME_FIELD = "execution_end_time"; + public static final String IS_LATEST_FIELD = "is_latest"; + public static final String TASK_TYPE_FIELD = "task_type"; + public static final String CHECKPOINT_ID_FIELD = "checkpoint_id"; + public static final String COORDINATING_NODE_FIELD = "coordinating_node"; + public static final String WORKER_NODE_FIELD = "worker_node"; + public static final String ENTITY_FIELD = "entity"; + public static final String PARENT_TASK_ID_FIELD = "parent_task_id"; + public static final String ESTIMATED_MINUTES_LEFT_FIELD = "estimated_minutes_left"; + public static final String USER_FIELD = "user"; + public static final String HISTORICAL_TASK_PREFIX = "HISTORICAL"; + + protected String configId = null; + protected String taskId = null; + protected Instant lastUpdateTime = null; + protected String startedBy = null; + protected String stoppedBy = null; + protected String error = null; + protected String state = null; + protected Float taskProgress = null; + protected Float initProgress = null; + protected Instant currentPiece = null; + protected Instant executionStartTime = null; + protected Instant executionEndTime = null; + protected Boolean isLatest = null; + protected String taskType = null; + protected String checkpointId = null; + protected String coordinatingNode = null; + protected String workerNode = null; + protected Entity entity = null; + protected String parentTaskId = null; + protected Integer estimatedMinutesLeft = null; + protected User user = null; + + @SuppressWarnings("unchecked") + public abstract static class Builder> { + protected String configId = null; + protected String taskId = null; + protected String taskType = null; + protected String state = null; + protected Float taskProgress = null; + protected Float initProgress = null; + protected Instant currentPiece = null; + protected Instant executionStartTime = null; + protected Instant executionEndTime = null; + protected Boolean isLatest = null; + protected String error = null; + protected String checkpointId = null; + protected Instant lastUpdateTime = null; + protected String startedBy = null; + protected String stoppedBy = null; + protected String coordinatingNode = null; + protected String workerNode = null; + protected Entity entity = null; + protected String parentTaskId; + protected Integer estimatedMinutesLeft; + protected User user = null; + + public Builder() {} + + public T configId(String configId) { + this.configId = configId; + return (T) this; + } + + public T taskId(String taskId) { + this.taskId = taskId; + return (T) this; + } + + public T lastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + return (T) this; + } + + public T startedBy(String startedBy) { + this.startedBy = startedBy; + return (T) this; + } + + public T stoppedBy(String stoppedBy) { + this.stoppedBy = stoppedBy; + return (T) this; + } + + public T error(String error) { + this.error = error; + return (T) this; + } + + public T state(String state) { + this.state = state; + return (T) this; + } + + public T taskProgress(Float taskProgress) { + this.taskProgress = taskProgress; + return (T) this; + } + + public T initProgress(Float initProgress) { + this.initProgress = initProgress; + return (T) this; + } + + public T currentPiece(Instant currentPiece) { + this.currentPiece = currentPiece; + return (T) this; + } + + public T executionStartTime(Instant executionStartTime) { + this.executionStartTime = executionStartTime; + return (T) this; + } + + public T executionEndTime(Instant executionEndTime) { + this.executionEndTime = executionEndTime; + return (T) this; + } + + public T isLatest(Boolean isLatest) { + this.isLatest = isLatest; + return (T) this; + } + + public T taskType(String taskType) { + this.taskType = taskType; + return (T) this; + } + + public T checkpointId(String checkpointId) { + this.checkpointId = checkpointId; + return (T) this; + } + + public T coordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + return (T) this; + } + + public T workerNode(String workerNode) { + this.workerNode = workerNode; + return (T) this; + } + + public T entity(Entity entity) { + this.entity = entity; + return (T) this; + } + + public T parentTaskId(String parentTaskId) { + this.parentTaskId = parentTaskId; + return (T) this; + } + + public T estimatedMinutesLeft(Integer estimatedMinutesLeft) { + this.estimatedMinutesLeft = estimatedMinutesLeft; + return (T) this; + } + + public T user(User user) { + this.user = user; + return (T) this; + } + } + + public boolean isHistoricalTask() { + return taskType.startsWith(TimeSeriesTask.HISTORICAL_TASK_PREFIX); + } + + /** + * Get config level task id. If a task has no parent task, the task is config level task. + * @return config level task id + */ + public String getConfigLevelTaskId() { + return getParentTaskId() != null ? getParentTaskId() : getTaskId(); + } + + public String getTaskId() { + return taskId; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public Instant getLastUpdateTime() { + return lastUpdateTime; + } + + public String getStartedBy() { + return startedBy; + } + + public String getStoppedBy() { + return stoppedBy; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public Float getTaskProgress() { + return taskProgress; + } + + public Float getInitProgress() { + return initProgress; + } + + public Instant getCurrentPiece() { + return currentPiece; + } + + public Instant getExecutionStartTime() { + return executionStartTime; + } + + public Instant getExecutionEndTime() { + return executionEndTime; + } + + public Boolean isLatest() { + return isLatest; + } + + public String getTaskType() { + return taskType; + } + + public String getCheckpointId() { + return checkpointId; + } + + public String getCoordinatingNode() { + return coordinatingNode; + } + + public String getWorkerNode() { + return workerNode; + } + + public Entity getEntity() { + return entity; + } + + public String getParentTaskId() { + return parentTaskId; + } + + public Integer getEstimatedMinutesLeft() { + return estimatedMinutesLeft; + } + + public User getUser() { + return user; + } + + public String getConfigId() { + return configId; + } + + public void setLatest(Boolean latest) { + isLatest = latest; + } + + public void setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + } + + public boolean isDone() { + return !NOT_ENDED_STATES.contains(this.getState()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (taskId != null) { + builder.field(TimeSeriesTask.TASK_ID_FIELD, taskId); + } + if (lastUpdateTime != null) { + builder.field(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } + if (startedBy != null) { + builder.field(TimeSeriesTask.STARTED_BY_FIELD, startedBy); + } + if (stoppedBy != null) { + builder.field(TimeSeriesTask.STOPPED_BY_FIELD, stoppedBy); + } + if (error != null) { + builder.field(TimeSeriesTask.ERROR_FIELD, error); + } + if (state != null) { + builder.field(TimeSeriesTask.STATE_FIELD, state); + } + if (taskProgress != null) { + builder.field(TimeSeriesTask.TASK_PROGRESS_FIELD, taskProgress); + } + if (initProgress != null) { + builder.field(TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress); + } + if (currentPiece != null) { + builder.field(TimeSeriesTask.CURRENT_PIECE_FIELD, currentPiece.toEpochMilli()); + } + if (executionStartTime != null) { + builder.field(TimeSeriesTask.EXECUTION_START_TIME_FIELD, executionStartTime.toEpochMilli()); + } + if (executionEndTime != null) { + builder.field(TimeSeriesTask.EXECUTION_END_TIME_FIELD, executionEndTime.toEpochMilli()); + } + if (isLatest != null) { + builder.field(TimeSeriesTask.IS_LATEST_FIELD, isLatest); + } + if (taskType != null) { + builder.field(TimeSeriesTask.TASK_TYPE_FIELD, taskType); + } + if (checkpointId != null) { + builder.field(TimeSeriesTask.CHECKPOINT_ID_FIELD, checkpointId); + } + if (coordinatingNode != null) { + builder.field(TimeSeriesTask.COORDINATING_NODE_FIELD, coordinatingNode); + } + if (workerNode != null) { + builder.field(TimeSeriesTask.WORKER_NODE_FIELD, workerNode); + } + if (entity != null) { + builder.field(TimeSeriesTask.ENTITY_FIELD, entity); + } + if (parentTaskId != null) { + builder.field(TimeSeriesTask.PARENT_TASK_ID_FIELD, parentTaskId); + } + if (estimatedMinutesLeft != null) { + builder.field(TimeSeriesTask.ESTIMATED_MINUTES_LEFT_FIELD, estimatedMinutesLeft); + } + if (user != null) { + builder.field(TimeSeriesTask.USER_FIELD, user); + } + return builder; + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + TimeSeriesTask that = (TimeSeriesTask) o; + return Objects.equal(getConfigId(), that.getConfigId()) + && Objects.equal(getTaskId(), that.getTaskId()) + && Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) + && Objects.equal(getStartedBy(), that.getStartedBy()) + && Objects.equal(getStoppedBy(), that.getStoppedBy()) + && Objects.equal(getError(), that.getError()) + && Objects.equal(getState(), that.getState()) + && Objects.equal(getTaskProgress(), that.getTaskProgress()) + && Objects.equal(getInitProgress(), that.getInitProgress()) + && Objects.equal(getCurrentPiece(), that.getCurrentPiece()) + && Objects.equal(getExecutionStartTime(), that.getExecutionStartTime()) + && Objects.equal(getExecutionEndTime(), that.getExecutionEndTime()) + && Objects.equal(isLatest(), that.isLatest()) + && Objects.equal(getTaskType(), that.getTaskType()) + && Objects.equal(getCheckpointId(), that.getCheckpointId()) + && Objects.equal(getCoordinatingNode(), that.getCoordinatingNode()) + && Objects.equal(getWorkerNode(), that.getWorkerNode()) + && Objects.equal(getEntity(), that.getEntity()) + && Objects.equal(getParentTaskId(), that.getParentTaskId()) + && Objects.equal(getEstimatedMinutesLeft(), that.getEstimatedMinutesLeft()) + && Objects.equal(getUser(), that.getUser()); + } + + @Generated + @Override + public int hashCode() { + return Objects + .hashCode( + taskId, + lastUpdateTime, + startedBy, + stoppedBy, + error, + state, + taskProgress, + initProgress, + currentPiece, + executionStartTime, + executionEndTime, + isLatest, + taskType, + checkpointId, + coordinatingNode, + workerNode, + entity, + parentTaskId, + estimatedMinutesLeft, + user + ); + } + + public abstract boolean isEntityTask(); + + public String getEntityModelId() { + return entity == null ? null : entity.getModelId(configId).orElse(null); + } +} diff --git a/src/test/java/org/opensearch/StreamInputOutputTests.java b/src/test/java/org/opensearch/StreamInputOutputTests.java index a1906c43f..82ff5cc24 100644 --- a/src/test/java/org/opensearch/StreamInputOutputTests.java +++ b/src/test/java/org/opensearch/StreamInputOutputTests.java @@ -39,8 +39,8 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.model.Entity; diff --git a/src/test/java/org/opensearch/ad/ADUnitTestCase.java b/src/test/java/org/opensearch/ad/ADUnitTestCase.java index 232c5dcdc..6adf8e6f9 100644 --- a/src/test/java/org/opensearch/ad/ADUnitTestCase.java +++ b/src/test/java/org/opensearch/ad/ADUnitTestCase.java @@ -28,7 +28,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.test.OpenSearchTestCase; import com.google.common.collect.ImmutableMap; diff --git a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java index 2353f69b9..a98eef88d 100644 --- a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java @@ -41,7 +41,7 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java index 7c619fda1..a1f8296e2 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -55,8 +55,8 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.index.IndexNotFoundException; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java index a419b8719..d40fa84f8 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java @@ -37,7 +37,6 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.mock.plugin.MockReindexPlugin; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.transport.AnomalyDetectorJobAction; @@ -57,6 +56,7 @@ import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import com.google.common.collect.ImmutableList; @@ -120,6 +120,7 @@ public void ingestTestData( } } + @Override public Feature maxValueFeature() throws IOException { AggregationBuilder aggregationBuilder = TestHelpers.parseAggregation("{\"test\":{\"max\":{\"field\":\"" + valueField + "\"}}}"); return new Feature(randomAlphaOfLength(5), randomAlphaOfLength(10), true, aggregationBuilder); @@ -135,21 +136,15 @@ public ADTask randomCreatedADTask(String taskId, AnomalyDetector detector, DateR } public ADTask randomCreatedADTask(String taskId, AnomalyDetector detector, String detectorId, DateRange detectionDateRange) { - return randomADTask(taskId, detector, detectorId, detectionDateRange, ADTaskState.CREATED); + return randomADTask(taskId, detector, detectorId, detectionDateRange, TaskState.CREATED); } - public ADTask randomADTask( - String taskId, - AnomalyDetector detector, - String detectorId, - DateRange detectionDateRange, - ADTaskState state - ) { + public ADTask randomADTask(String taskId, AnomalyDetector detector, String detectorId, DateRange detectionDateRange, TaskState state) { ADTask.Builder builder = ADTask .builder() .taskId(taskId) .taskType(ADTaskType.HISTORICAL_SINGLE_ENTITY.name()) - .detectorId(detectorId) + .configId(detectorId) .detectionDateRange(detectionDateRange) .detector(detector) .state(state.name()) @@ -158,12 +153,12 @@ public ADTask randomADTask( .isLatest(true) .startedBy(randomAlphaOfLength(5)) .executionStartTime(Instant.now().minus(randomLongBetween(10, 100), ChronoUnit.MINUTES)); - if (ADTaskState.FINISHED == state) { + if (TaskState.FINISHED == state) { setPropertyForNotRunningTask(builder); - } else if (ADTaskState.FAILED == state) { + } else if (TaskState.FAILED == state) { setPropertyForNotRunningTask(builder); builder.error(randomAlphaOfLength(5)); - } else if (ADTaskState.STOPPED == state) { + } else if (TaskState.STOPPED == state) { setPropertyForNotRunningTask(builder); builder.error(randomAlphaOfLength(5)); builder.stoppedBy(randomAlphaOfLength(5)); diff --git a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java index f21b74b11..9e2fda3e1 100644 --- a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java +++ b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java @@ -24,7 +24,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.monitor.jvm.JvmInfo.Mem; import org.opensearch.monitor.jvm.JvmService; diff --git a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java index e90783d0f..173e9d4b0 100644 --- a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java @@ -61,7 +61,7 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java index 7774fb314..f43457318 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java @@ -57,7 +57,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.monitor.jvm.JvmInfo.Mem; import org.opensearch.monitor.jvm.JvmService; diff --git a/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java index 125d58d2c..2ba8f391a 100644 --- a/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java @@ -32,7 +32,7 @@ import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.component.LifecycleListener; +import org.opensearch.common.lifecycle.LifecycleListener; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; diff --git a/src/test/java/org/opensearch/ad/model/ADTaskTests.java b/src/test/java/org/opensearch/ad/model/ADTaskTests.java index 1cd2e6cc8..d97dc15dd 100644 --- a/src/test/java/org/opensearch/ad/model/ADTaskTests.java +++ b/src/test/java/org/opensearch/ad/model/ADTaskTests.java @@ -25,6 +25,7 @@ import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.TaskState; public class ADTaskTests extends OpenSearchSingleNodeTestCase { @@ -39,7 +40,7 @@ protected NamedWriteableRegistry writableRegistry() { } public void testAdTaskSerialization() throws IOException { - ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), ADTaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), true); + ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), TaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), true); BytesStreamOutput output = new BytesStreamOutput(); adTask.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); @@ -48,7 +49,7 @@ public void testAdTaskSerialization() throws IOException { } public void testAdTaskSerializationWithNullDetector() throws IOException { - ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), ADTaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), false); + ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), TaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), false); BytesStreamOutput output = new BytesStreamOutput(); adTask.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); @@ -58,7 +59,7 @@ public void testAdTaskSerializationWithNullDetector() throws IOException { public void testParseADTask() throws IOException { ADTask adTask = TestHelpers - .randomAdTask(null, ADTaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); + .randomAdTask(null, TaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); String taskId = randomAlphaOfLength(5); adTask.setTaskId(taskId); String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); @@ -69,7 +70,7 @@ public void testParseADTask() throws IOException { public void testParseADTaskWithoutTaskId() throws IOException { String taskId = null; ADTask adTask = TestHelpers - .randomAdTask(taskId, ADTaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); + .randomAdTask(taskId, TaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); ADTask parsedADTask = ADTask.parse(TestHelpers.parser(adTaskString)); assertEquals("Parsing AD task doesn't work", adTask, parsedADTask); @@ -78,7 +79,7 @@ public void testParseADTaskWithoutTaskId() throws IOException { public void testParseADTaskWithNullDetector() throws IOException { String taskId = randomAlphaOfLength(5); ADTask adTask = TestHelpers - .randomAdTask(taskId, ADTaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), false); + .randomAdTask(taskId, TaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), false); String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); ADTask parsedADTask = ADTask.parse(TestHelpers.parser(adTaskString), taskId); assertEquals("Parsing AD task doesn't work", adTask, parsedADTask); diff --git a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java index e9923a24e..3411f37ac 100644 --- a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java +++ b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java @@ -270,7 +270,7 @@ public static List searchLatestAdTaskOfDetector(RestClient client, Strin .builder() .taskId(id) .state(state) - .detectorId(parsedDetectorId) + .configId(parsedDetectorId) .taskProgress(taskProgress.floatValue()) .initProgress(initProgress.floatValue()) .taskType(parsedTaskType) @@ -398,7 +398,7 @@ private static ADTask parseAdTask(Map taskMap) { .builder() .taskId(id) .state(state) - .detectorId(parsedDetectorId) + .configId(parsedDetectorId) .taskProgress(taskProgress.floatValue()) .initProgress(initProgress.floatValue()) .taskType(parsedTaskType) diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java index c083ea25f..e3881c968 100644 --- a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java @@ -34,7 +34,6 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskProfile; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; @@ -42,6 +41,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -118,8 +118,8 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul // get task profile ADTaskProfile adTaskProfile = waitUntilGetTaskProfile(detectorId); if (categoryFieldSize > 0) { - if (!ADTaskState.RUNNING.name().equals(adTaskProfile.getAdTask().getState())) { - adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(ADTaskState.RUNNING.name())).get(0); + if (!TaskState.RUNNING.name().equals(adTaskProfile.getAdTask().getState())) { + adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(TaskState.RUNNING.name())).get(0); } assertEquals((int) Math.pow(categoryFieldDocCount, categoryFieldSize), adTaskProfile.getTotalEntitiesCount().intValue()); assertTrue(adTaskProfile.getPendingEntitiesCount() > 0); @@ -172,7 +172,7 @@ public void testStopHistoricalAnalysis() throws Exception { assertEquals(RestStatus.OK, TestHelpers.restStatus(stopDetectorResponse)); // get task profile - checkIfTaskCanFinishCorrectly(detectorId, taskId, ImmutableSet.of(ADTaskState.STOPPED.name())); + checkIfTaskCanFinishCorrectly(detectorId, taskId, ImmutableSet.of(TaskState.STOPPED.name())); updateClusterSettings(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1); waitUntilTaskDone(detectorId); diff --git a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java index ba9698d6a..baff7155f 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java @@ -35,7 +35,6 @@ import org.junit.Before; import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -46,6 +45,7 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.DuplicateTaskException; import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.model.TaskState; import com.google.common.collect.ImmutableList; @@ -94,7 +94,7 @@ public void testPutTask() throws IOException { adTaskCacheManager.add(adTask); assertEquals(1, adTaskCacheManager.size()); assertTrue(adTaskCacheManager.contains(adTask.getTaskId())); - assertTrue(adTaskCacheManager.containsTaskOfDetector(adTask.getId())); + assertTrue(adTaskCacheManager.containsTaskOfDetector(adTask.getConfigId())); assertNotNull(adTaskCacheManager.getTRcfModel(adTask.getTaskId())); assertNotNull(adTaskCacheManager.getShingle(adTask.getTaskId())); assertFalse(adTaskCacheManager.isThresholdModelTrained(adTask.getTaskId())); @@ -113,10 +113,10 @@ public void testPutDuplicateTask() throws IOException { ADTask adTask2 = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, adTask1.getExecutionEndTime(), adTask1.getStoppedBy(), - adTask1.getId(), + adTask1.getConfigId(), adTask1.getDetector(), ADTaskType.HISTORICAL_SINGLE_ENTITY ); @@ -137,7 +137,7 @@ public void testPutMultipleEntityTasks() throws IOException { ADTask adTask1 = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.CREATED, + TaskState.CREATED, Instant.now(), null, detector.getId(), @@ -147,7 +147,7 @@ public void testPutMultipleEntityTasks() throws IOException { ADTask adTask2 = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.CREATED, + TaskState.CREATED, Instant.now(), null, detector.getId(), @@ -223,8 +223,8 @@ public void testCancelByDetectorId() throws IOException { when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.add(adTask); - String detectorId = adTask.getId(); - String detectorTaskId = adTask.getId(); + String detectorId = adTask.getConfigId(); + String detectorTaskId = adTask.getConfigId(); String reason = randomAlphaOfLength(10); String userName = randomAlphaOfLength(5); ADTaskCancellationState state = adTaskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); @@ -310,7 +310,7 @@ public void testPushBackEntity() throws IOException { public void testRealtimeTaskCache() { String detectorId1 = randomAlphaOfLength(10); - String newState = ADTaskState.INIT.name(); + String newState = TaskState.INIT.name(); Float newInitProgress = 0.0f; String newError = randomAlphaOfLength(5); assertTrue(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); @@ -328,7 +328,7 @@ public void testRealtimeTaskCache() { adTaskCacheManager.updateRealtimeTaskCache(detectorId2, newState, newInitProgress, newError); assertEquals(2, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); - newState = ADTaskState.RUNNING.name(); + newState = TaskState.RUNNING.name(); newInitProgress = 1.0f; newError = "test error"; assertTrue(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); @@ -354,7 +354,7 @@ public void testUpdateRealtimeTaskCache() { assertNull(realtimeTaskCache.getError()); assertNull(realtimeTaskCache.getInitProgress()); - String state = ADTaskState.RUNNING.name(); + String state = TaskState.RUNNING.name(); Float initProgress = 0.1f; String error = randomAlphaOfLength(5); adTaskCacheManager.updateRealtimeTaskCache(detectorId, state, initProgress, error); @@ -363,7 +363,7 @@ public void testUpdateRealtimeTaskCache() { assertEquals(error, realtimeTaskCache.getError()); assertEquals(initProgress, realtimeTaskCache.getInitProgress()); - state = ADTaskState.STOPPED.name(); + state = TaskState.STOPPED.name(); adTaskCacheManager.updateRealtimeTaskCache(detectorId, state, initProgress, error); realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); assertNull(realtimeTaskCache); @@ -434,7 +434,7 @@ private List addHCDetectorCache() throws IOException { ADTask adDetectorTask = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.CREATED, + TaskState.CREATED, Instant.now(), null, detectorId, @@ -444,7 +444,7 @@ private List addHCDetectorCache() throws IOException { ADTask adEntityTask = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.CREATED, + TaskState.CREATED, Instant.now(), null, detectorId, @@ -621,11 +621,11 @@ public void testADHCBatchTaskRunStateCacheWithCancel() { ADHCBatchTaskRunState state = adTaskCacheManager.getOrCreateHCDetectorTaskStateCache(detectorId, detectorTaskId); assertTrue(adTaskCacheManager.detectorTaskStateExists(detectorId, detectorTaskId)); - assertEquals(ADTaskState.INIT.name(), state.getDetectorTaskState()); + assertEquals(TaskState.INIT.name(), state.getDetectorTaskState()); assertFalse(state.expired()); - state.setDetectorTaskState(ADTaskState.RUNNING.name()); - assertEquals(ADTaskState.RUNNING.name(), adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId)); + state.setDetectorTaskState(TaskState.RUNNING.name()); + assertEquals(TaskState.RUNNING.name(), adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId)); String cancelReason = randomAlphaOfLength(5); String cancelledBy = randomAlphaOfLength(5); @@ -647,7 +647,7 @@ public void testADHCBatchTaskRunStateCacheWithCancel() { public void testUpdateDetectorTaskState() { String detectorId = randomAlphaOfLength(5); String detectorTaskId = randomAlphaOfLength(5); - String newState = ADTaskState.RUNNING.name(); + String newState = TaskState.RUNNING.name(); adTaskCacheManager.updateDetectorTaskState(detectorId, detectorTaskId, newState); assertEquals(newState, adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId)); diff --git a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java index cbdb72646..160b84aaa 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java @@ -87,7 +87,6 @@ import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.ad.model.ADTaskProfile; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; @@ -104,11 +103,11 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.ToXContent; @@ -129,6 +128,7 @@ import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; @@ -450,7 +450,7 @@ private void setupTaskSlots(int node1UsedTaskSlots, int node1AssignedTaskSLots, public void testCheckTaskSlotsWithNoAvailableTaskSlots() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -475,7 +475,7 @@ private void setupSearchTopEntities(int entitySize) { public void testCheckTaskSlotsWithAvailableTaskSlotsForHC() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -494,7 +494,7 @@ public void testCheckTaskSlotsWithAvailableTaskSlotsForHC() throws IOException { public void testCheckTaskSlotsWithAvailableTaskSlotsForSingleEntityDetector() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of()) @@ -512,7 +512,7 @@ public void testCheckTaskSlotsWithAvailableTaskSlotsForSingleEntityDetector() th public void testCheckTaskSlotsWithAvailableTaskSlotsAndNoEntity() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -530,7 +530,7 @@ public void testCheckTaskSlotsWithAvailableTaskSlotsAndNoEntity() throws IOExcep public void testCheckTaskSlotsWithAvailableTaskSlotsForScale() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -562,7 +562,7 @@ public void testDeleteDuplicateTasks() throws IOException { public void testParseEntityForSingleCategoryHC() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -575,7 +575,7 @@ public void testParseEntityForSingleCategoryHC() throws IOException { public void testParseEntityForMultiCategoryHC() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers @@ -715,7 +715,7 @@ public void testGetADTaskWithExistingTask() { @SuppressWarnings("unchecked") public void testUpdateLatestRealtimeTaskOnCoordinatingNode() { String detectorId = randomAlphaOfLength(5); - String state = ADTaskState.RUNNING.name(); + String state = TaskState.RUNNING.name(); Long rcfTotalUpdates = randomLongBetween(200, 1000); Long detectorIntervalInMinutes = 1L; String error = randomAlphaOfLength(5); @@ -1031,10 +1031,10 @@ public void testGetAndExecuteOnLatestADTasksWithRunningRealtimeTaskWithTaskStopp .builder() .taskId(randomAlphaOfLength(5)) .taskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()) - .detectorId(randomAlphaOfLength(5)) + .configId(randomAlphaOfLength(5)) .detector(detector) .entity(null) - .state(ADTaskState.RUNNING.name()) + .state(TaskState.RUNNING.name()) .taskProgress(0.5f) .initProgress(1.0f) .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) @@ -1097,10 +1097,10 @@ public void testGetAndExecuteOnLatestADTasksWithRunningHistoricalTask() throws I .builder() .taskId(historicalTaskId) .taskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()) - .detectorId(randomAlphaOfLength(5)) + .configId(randomAlphaOfLength(5)) .detector(detector) .entity(null) - .state(ADTaskState.RUNNING.name()) + .state(TaskState.RUNNING.name()) .taskProgress(0.5f) .initProgress(1.0f) .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java index b9595e2e7..c23f26575 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java @@ -40,8 +40,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java index 87a5853f9..e3f241e02 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java @@ -46,7 +46,6 @@ import org.opensearch.ad.mock.transport.MockAnomalyDetectorJobAction; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskProfile; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.client.Client; @@ -58,6 +57,7 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import org.opensearch.timeseries.stats.StatNames; import com.google.common.collect.ImmutableList; @@ -180,7 +180,7 @@ public void testStartHistoricalAnalysisForSingleCategoryHCWithUser() throws IOEx assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(adTask.getState())); assertEquals(categoryField, adTask.getDetector().getCategoryFields().get(0)); - if (ADTaskState.FINISHED.name().equals(adTask.getState())) { + if (TaskState.FINISHED.name().equals(adTask.getState())) { List adTasks = searchADTasks(detectorId, true, 100); assertEquals(4, adTasks.size()); List entityTasks = adTasks @@ -236,7 +236,7 @@ public void testStartHistoricalAnalysisForMultiCategoryHCWithUser() throws IOExc assertEquals(categoryField, adTask.getDetector().getCategoryFields().get(0)); assertEquals(ipField, adTask.getDetector().getCategoryFields().get(1)); - if (ADTaskState.FINISHED.name().equals(adTask.getState())) { + if (TaskState.FINISHED.name().equals(adTask.getState())) { List adTasks = searchADTasks(detectorId, taskId, true, 100); assertEquals(5, adTasks.size()); List entityTasks = adTasks @@ -297,8 +297,8 @@ public void testRaceConditionByStartingMultipleTasks() throws IOException, Inter List adTasks = searchADTasks(detectorId, null, 100); assertEquals(1, adTasks.size()); - assertTrue(adTasks.get(0).getLatest()); - assertNotEquals(ADTaskState.FAILED.name(), adTasks.get(0).getState()); + assertTrue(adTasks.get(0).isLatest()); + assertNotEquals(TaskState.FAILED.name(), adTasks.get(0).getState()); } // TODO: fix this flaky test case @@ -309,8 +309,8 @@ public void testCleanOldTaskDocs() throws InterruptedException, IOException { String detectorId = createDetector(detector); createDetectionStateIndex(); - List states = ImmutableList.of(ADTaskState.FAILED, ADTaskState.FINISHED, ADTaskState.STOPPED); - for (ADTaskState state : states) { + List states = ImmutableList.of(TaskState.FAILED, TaskState.FINISHED, TaskState.STOPPED); + for (TaskState state : states) { ADTask task = randomADTask(randomAlphaOfLength(5), detector, detectorId, dateRange, state); createADTask(task); } @@ -431,13 +431,13 @@ public void testStopRealtimeDetector() throws IOException { assertEquals(1, adTasks.size()); assertEquals(ADTaskType.REALTIME_SINGLE_ENTITY.name(), adTasks.get(0).getTaskType()); assertNotEquals(jobId, adTasks.get(0).getTaskId()); - assertEquals(ADTaskState.STOPPED.name(), adTasks.get(0).getState()); + assertEquals(TaskState.STOPPED.name(), adTasks.get(0).getState()); } public void testStopHistoricalDetector() throws IOException, InterruptedException { updateTransientSettings(ImmutableMap.of(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 5)); ADTask adTask = startHistoricalAnalysis(startTime, endTime); - assertEquals(ADTaskState.INIT.name(), adTask.getState()); + assertEquals(TaskState.INIT.name(), adTask.getState()); assertNull(adTask.getStartedBy()); assertNull(adTask.getUser()); waitUntil(() -> { @@ -447,7 +447,7 @@ public void testStopHistoricalDetector() throws IOException, InterruptedExceptio if (taskRunning) { // It's possible that the task not started on worker node yet. Recancel it to make sure // task cancelled. - AnomalyDetectorJobRequest request = stopDetectorJobRequest(adTask.getId(), true); + AnomalyDetectorJobRequest request = stopDetectorJobRequest(adTask.getConfigId(), true); client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); } return !taskRunning; @@ -456,13 +456,13 @@ public void testStopHistoricalDetector() throws IOException, InterruptedExceptio } }, 20, TimeUnit.SECONDS); ADTask stoppedTask = getADTask(adTask.getTaskId()); - assertEquals(ADTaskState.STOPPED.name(), stoppedTask.getState()); + assertEquals(TaskState.STOPPED.name(), stoppedTask.getState()); assertEquals(0, getExecutingADTask()); } public void testProfileHistoricalDetector() throws IOException, InterruptedException { ADTask adTask = startHistoricalAnalysis(startTime, endTime); - GetAnomalyDetectorRequest request = taskProfileRequest(adTask.getId()); + GetAnomalyDetectorRequest request = taskProfileRequest(adTask.getConfigId()); GetAnomalyDetectorResponse response = client().execute(GetAnomalyDetectorAction.INSTANCE, request).actionGet(10000); assertTrue(response.getDetectorProfile().getAdTaskProfile() != null); @@ -479,7 +479,7 @@ public void testProfileHistoricalDetector() throws IOException, InterruptedExcep assertNull(response.getDetectorProfile().getAdTaskProfile().getNodeId()); ADTask profileAdTask = response.getDetectorProfile().getAdTaskProfile().getAdTask(); assertEquals(finishedTask.getTaskId(), profileAdTask.getTaskId()); - assertEquals(finishedTask.getId(), profileAdTask.getId()); + assertEquals(finishedTask.getConfigId(), profileAdTask.getConfigId()); assertEquals(finishedTask.getDetector(), profileAdTask.getDetector()); assertEquals(finishedTask.getState(), profileAdTask.getState()); } @@ -488,8 +488,8 @@ public void testProfileWithMultipleRunningTask() throws IOException { ADTask adTask1 = startHistoricalAnalysis(startTime, endTime); ADTask adTask2 = startHistoricalAnalysis(startTime, endTime); - GetAnomalyDetectorRequest request1 = taskProfileRequest(adTask1.getId()); - GetAnomalyDetectorRequest request2 = taskProfileRequest(adTask2.getId()); + GetAnomalyDetectorRequest request1 = taskProfileRequest(adTask1.getConfigId()); + GetAnomalyDetectorRequest request2 = taskProfileRequest(adTask2.getConfigId()); GetAnomalyDetectorResponse response1 = client().execute(GetAnomalyDetectorAction.INSTANCE, request1).actionGet(10000); GetAnomalyDetectorResponse response2 = client().execute(GetAnomalyDetectorAction.INSTANCE, request2).actionGet(10000); ADTaskProfile taskProfile1 = response1.getDetectorProfile().getAdTaskProfile(); diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index 6c8634959..a92704b33 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -93,12 +93,13 @@ import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.transport.TransportResponse; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -126,7 +127,6 @@ import org.opensearch.transport.TransportInterceptor; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestOptions; -import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; diff --git a/src/test/java/org/opensearch/ad/transport/DeleteTests.java b/src/test/java/org/opensearch/ad/transport/DeleteTests.java index 619ee6bb2..de531c9b0 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteTests.java @@ -50,9 +50,9 @@ import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.reindex.BulkByScrollResponse; diff --git a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java b/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java index 4a1cfc718..32b3226b4 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java @@ -43,8 +43,9 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.transport.TransportResponse; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.tasks.Task; import org.opensearch.timeseries.AbstractTimeSeriesTest; @@ -58,7 +59,6 @@ import org.opensearch.transport.TransportInterceptor; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestOptions; -import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index fe8877c2a..6e9dda0f0 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -98,6 +98,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.transport.TransportResponse; import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHits; @@ -128,7 +129,6 @@ import org.opensearch.transport.TransportInterceptor; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestOptions; -import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTests.java index ff33ed277..ca006de2e 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTests.java @@ -37,8 +37,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; diff --git a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java index 0ed3fd1ee..6a8056e02 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java @@ -36,8 +36,9 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.transport.TransportResponse; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.tasks.Task; import org.opensearch.timeseries.AbstractTimeSeriesTest; @@ -50,7 +51,6 @@ import org.opensearch.transport.TransportInterceptor; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestOptions; -import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; diff --git a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java index d6ed84d2d..2f854f5fd 100644 --- a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java @@ -21,9 +21,9 @@ import org.opensearch.action.ActionResponse; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchIntegTestCase; @@ -91,7 +91,7 @@ public void writeTo(StreamOutput streamOutput) throws IOException { @Test public void toXContentTest() throws IOException { StopDetectorResponse stopDetectorResponse = new StopDetectorResponse(true); - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); stopDetectorResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = Strings.toString(builder); diff --git a/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java b/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java new file mode 100644 index 000000000..28ec18bab --- /dev/null +++ b/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import java.io.IOException; +import java.util.Collection; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public class ForecastTaskSerializationTests extends OpenSearchSingleNodeTestCase { + private BytesStreamOutput output; + + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + + output = new BytesStreamOutput(); + } + + public void testConstructor_allFieldsPresent() throws IOException { + // Set up a StreamInput that contains all fields + ForecastTask originalTask = TestHelpers.ForecastTaskBuilder.newInstance().build(); + + originalTask.writeTo(output); + // required by AggregationBuilder in Feature's constructor for named writeable + NamedWriteableAwareStreamInput streamInput = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + + ForecastTask readTask = new ForecastTask(streamInput); + + assertEquals("task123", readTask.getTaskId()); + assertEquals("FORECAST_HISTORICAL_HC_ENTITY", readTask.getTaskType()); + assertTrue(readTask.isEntityTask()); + assertEquals("config123", readTask.getConfigId()); + assertEquals(originalTask.getForecaster(), readTask.getForecaster()); + assertEquals("Running", readTask.getState()); + assertEquals(Float.valueOf(0.5f), readTask.getTaskProgress()); + assertEquals(Float.valueOf(0.1f), readTask.getInitProgress()); + assertEquals(originalTask.getCurrentPiece(), readTask.getCurrentPiece()); + assertEquals(originalTask.getExecutionStartTime(), readTask.getExecutionStartTime()); + assertEquals(originalTask.getExecutionEndTime(), readTask.getExecutionEndTime()); + assertEquals(Boolean.TRUE, readTask.isLatest()); + assertEquals("No errors", readTask.getError()); + assertEquals("checkpoint1", readTask.getCheckpointId()); + assertEquals(originalTask.getLastUpdateTime(), readTask.getLastUpdateTime()); + assertEquals("user1", readTask.getStartedBy()); + assertEquals("user2", readTask.getStoppedBy()); + assertEquals("node1", readTask.getCoordinatingNode()); + assertEquals("node2", readTask.getWorkerNode()); + assertEquals(originalTask.getUser(), readTask.getUser()); + assertEquals(originalTask.getDateRange(), readTask.getDateRange()); + assertEquals(originalTask.getEntity(), readTask.getEntity()); + // since entity attributes are random, we cannot have a fixed model id to verify + assertTrue(readTask.getEntityModelId().startsWith("config123_entity_")); + assertEquals("parentTask1", readTask.getParentTaskId()); + assertEquals(Integer.valueOf(10), readTask.getEstimatedMinutesLeft()); + } + + public void testConstructor_missingOptionalFields() throws IOException { + // Set up a StreamInput that contains all fields + ForecastTask originalTask = TestHelpers.ForecastTaskBuilder + .newInstance() + .setForecaster(null) + .setUser(null) + .setDateRange(null) + .setEntity(null) + .build(); + + originalTask.writeTo(output); + // required by AggregationBuilder in Feature's constructor for named writeable + NamedWriteableAwareStreamInput streamInput = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + + ForecastTask readTask = new ForecastTask(streamInput); + + assertEquals("task123", readTask.getTaskId()); + assertEquals("FORECAST_HISTORICAL_HC_ENTITY", readTask.getTaskType()); + assertTrue(readTask.isEntityTask()); + assertEquals("config123", readTask.getConfigId()); + assertEquals(null, readTask.getForecaster()); + assertEquals("Running", readTask.getState()); + assertEquals(Float.valueOf(0.5f), readTask.getTaskProgress()); + assertEquals(Float.valueOf(0.1f), readTask.getInitProgress()); + assertEquals(originalTask.getCurrentPiece(), readTask.getCurrentPiece()); + assertEquals(originalTask.getExecutionStartTime(), readTask.getExecutionStartTime()); + assertEquals(originalTask.getExecutionEndTime(), readTask.getExecutionEndTime()); + assertEquals(Boolean.TRUE, readTask.isLatest()); + assertEquals("No errors", readTask.getError()); + assertEquals("checkpoint1", readTask.getCheckpointId()); + assertEquals(originalTask.getLastUpdateTime(), readTask.getLastUpdateTime()); + assertEquals("user1", readTask.getStartedBy()); + assertEquals("user2", readTask.getStoppedBy()); + assertEquals("node1", readTask.getCoordinatingNode()); + assertEquals("node2", readTask.getWorkerNode()); + assertEquals(null, readTask.getUser()); + assertEquals(null, readTask.getDateRange()); + assertEquals(null, readTask.getEntity()); + assertEquals(null, readTask.getEntityModelId()); + assertEquals("parentTask1", readTask.getParentTaskId()); + assertEquals(Integer.valueOf(10), readTask.getEstimatedMinutesLeft()); + } + +} diff --git a/src/test/java/org/opensearch/forecast/model/ForecastTaskTests.java b/src/test/java/org/opensearch/forecast/model/ForecastTaskTests.java new file mode 100644 index 000000000..bce749757 --- /dev/null +++ b/src/test/java/org/opensearch/forecast/model/ForecastTaskTests.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import java.io.IOException; + +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; + +public class ForecastTaskTests extends OpenSearchTestCase { + public void testParse() throws IOException { + ForecastTask originalTask = TestHelpers.ForecastTaskBuilder.newInstance().build(); + String forecastTaskString = TestHelpers + .xContentBuilderToString(originalTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ForecastTask parsedForecastTask = ForecastTask.parse(TestHelpers.parser(forecastTaskString)); + assertEquals("Parsing forecast task doesn't work", originalTask, parsedForecastTask); + } + + public void testParseEmptyForecaster() throws IOException { + ForecastTask originalTask = TestHelpers.ForecastTaskBuilder.newInstance().setForecaster(null).build(); + String forecastTaskString = TestHelpers + .xContentBuilderToString(originalTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ForecastTask parsedForecastTask = ForecastTask.parse(TestHelpers.parser(forecastTaskString)); + assertEquals("Parsing forecast task doesn't work", originalTask, parsedForecastTask); + } + + public void testParseEmptyForecasterRange() throws IOException { + ForecastTask originalTask = TestHelpers.ForecastTaskBuilder.newInstance().setForecaster(null).setDateRange(null).build(); + String forecastTaskString = TestHelpers + .xContentBuilderToString(originalTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ForecastTask parsedForecastTask = ForecastTask.parse(TestHelpers.parser(forecastTaskString)); + assertEquals("Parsing forecast task doesn't work", originalTask, parsedForecastTask); + } +} diff --git a/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java b/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java new file mode 100644 index 000000000..4ee403a0e --- /dev/null +++ b/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import java.util.Arrays; + +import org.opensearch.test.OpenSearchTestCase; + +public class ForecastTaskTypeTests extends OpenSearchTestCase { + + public void testHistoricalForecasterTaskTypes() { + assertEquals( + Arrays.asList(ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM), + ForecastTaskType.HISTORICAL_FORECASTER_TASK_TYPES + ); + } + + public void testAllHistoricalTaskTypes() { + assertEquals( + Arrays + .asList( + ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, + ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, + ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + ), + ForecastTaskType.ALL_HISTORICAL_TASK_TYPES + ); + } + + public void testRealtimeTaskTypes() { + assertEquals( + Arrays.asList(ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER), + ForecastTaskType.REALTIME_TASK_TYPES + ); + } + + public void testAllForecastTaskTypes() { + assertEquals( + Arrays + .asList( + ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, + ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER, + ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, + ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, + ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + ), + ForecastTaskType.ALL_FORECAST_TASK_TYPES + ); + } +} diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index 3287f4118..d929ce1e6 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -67,7 +67,6 @@ import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorExecutionInput; @@ -113,6 +112,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.ForecastTask; import org.opensearch.forecast.model.Forecaster; import org.opensearch.index.get.GetResult; import org.opensearch.index.query.BoolQueryBuilder; @@ -136,6 +136,7 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.ImputationMethod; import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.DataByFeatureId; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; @@ -143,6 +144,7 @@ import org.opensearch.timeseries.model.FeatureData; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import org.opensearch.timeseries.model.TimeConfiguration; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; @@ -160,12 +162,12 @@ public class TestHelpers { public static final String AD_BASE_PREVIEW_URI = AD_BASE_DETECTORS_URI + "/%s/_preview"; public static final String AD_BASE_STATS_URI = "/_plugins/_anomaly_detection/stats"; public static ImmutableSet HISTORICAL_ANALYSIS_RUNNING_STATS = ImmutableSet - .of(ADTaskState.CREATED.name(), ADTaskState.INIT.name(), ADTaskState.RUNNING.name()); + .of(TaskState.CREATED.name(), TaskState.INIT.name(), TaskState.RUNNING.name()); // Task may fail if memory circuit breaker triggered. public static final Set HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS = ImmutableSet - .of(ADTaskState.FINISHED.name(), ADTaskState.FAILED.name()); + .of(TaskState.FINISHED.name(), TaskState.FAILED.name()); public static ImmutableSet HISTORICAL_ANALYSIS_DONE_STATS = ImmutableSet - .of(ADTaskState.FAILED.name(), ADTaskState.FINISHED.name(), ADTaskState.STOPPED.name()); + .of(TaskState.FAILED.name(), TaskState.FINISHED.name(), TaskState.STOPPED.name()); private static final Logger logger = LogManager.getLogger(TestHelpers.class); public static final Random random = new Random(42); @@ -1261,7 +1263,7 @@ public static Map attrMap = new HashMap<>(); - detector.getCategoryFields().stream().forEach(f -> attrMap.put(f, randomAlphaOfLength(5))); - entity = Entity.createEntityByReordering(attrMap); - } else if (detector.isHighCardinality()) { - entity = Entity.createEntityByReordering(ImmutableMap.of(detector.getCategoryFields().get(0), randomAlphaOfLength(5))); - } - } + Entity entity = randomEntity(detector); String taskType = entity == null ? ADTaskType.HISTORICAL_SINGLE_ENTITY.name() : ADTaskType.HISTORICAL_HC_ENTITY.name(); ADTask task = ADTask .builder() .taskId(taskId) .taskType(taskType) - .detectorId(randomAlphaOfLength(5)) + .configId(randomAlphaOfLength(5)) .detector(detector) .state(state.name()) .taskProgress(0.5f) @@ -1427,6 +1420,33 @@ public static ADTask randomAdTask( return task; } + /** + * Generates a random Entity based on the provided configuration. + * + * If the configuration has multiple categories, a new Entity is created with attributes + * populated with random alphanumeric strings of length 5. + * + * If the configuration is marked as high cardinality and does not have multiple categories, + * a new Entity is created with a single attribute using the first category field and a random + * alphanumeric string of length 5. + * + * @param config The configuration object containing information about a time series analysis. + * @return A randomly generated Entity based on the configuration, or null if the config is null. + */ + public static Entity randomEntity(Config config) { + Entity entity = null; + if (config != null) { + if (config.hasMultipleCategories()) { + Map attrMap = new HashMap<>(); + config.getCategoryFields().stream().forEach(f -> attrMap.put(f, randomAlphaOfLength(5))); + entity = Entity.createEntityByReordering(attrMap); + } else if (config.isHighCardinality()) { + entity = Entity.createEntityByReordering(ImmutableMap.of(config.getCategoryFields().get(0), randomAlphaOfLength(5))); + } + } + return entity; + } + public static HttpEntity toHttpEntity(ToXContentObject object) throws IOException { return new StringEntity(toJsonString(object), APPLICATION_JSON); } @@ -1765,4 +1785,87 @@ public static Forecaster randomForecaster() throws IOException { randomImputationOption() ); } + + public static class ForecastTaskBuilder { + private String configId = "config123"; + private String taskId = "task123"; + private String taskType = "FORECAST_HISTORICAL_HC_ENTITY"; + private String state = "Running"; + private Float taskProgress = 0.5f; + private Float initProgress = 0.1f; + private Instant currentPiece = Instant.now().truncatedTo(ChronoUnit.SECONDS); + private Instant executionStartTime = Instant.now().truncatedTo(ChronoUnit.SECONDS); + private Instant executionEndTime = Instant.now().truncatedTo(ChronoUnit.SECONDS); + private Boolean isLatest = true; + private String error = "No errors"; + private String checkpointId = "checkpoint1"; + private Instant lastUpdateTime = Instant.now().truncatedTo(ChronoUnit.SECONDS); + private String startedBy = "user1"; + private String stoppedBy = "user2"; + private String coordinatingNode = "node1"; + private String workerNode = "node2"; + private Forecaster forecaster = TestHelpers.randomForecaster(); + private Entity entity = TestHelpers.randomEntity(forecaster); + private String parentTaskId = "parentTask1"; + private Integer estimatedMinutesLeft = 10; + protected User user = TestHelpers.randomUser(); + + private DateRange dateRange = new DateRange(Instant.ofEpochMilli(123), Instant.ofEpochMilli(456)); + + public ForecastTaskBuilder() throws IOException { + forecaster = TestHelpers.randomForecaster(); + } + + public static ForecastTaskBuilder newInstance() throws IOException { + return new ForecastTaskBuilder(); + } + + public ForecastTaskBuilder setForecaster(Forecaster associatedForecaster) { + this.forecaster = associatedForecaster; + return this; + } + + public ForecastTaskBuilder setUser(User associatedUser) { + this.user = associatedUser; + return this; + } + + public ForecastTaskBuilder setDateRange(DateRange associatedRange) { + this.dateRange = associatedRange; + return this; + } + + public ForecastTaskBuilder setEntity(Entity associatedEntity) { + this.entity = associatedEntity; + return this; + } + + public ForecastTask build() { + return new ForecastTask.Builder() + .configId(configId) + .taskId(taskId) + .lastUpdateTime(lastUpdateTime) + .startedBy(startedBy) + .stoppedBy(stoppedBy) + .error(error) + .state(state) + .taskProgress(taskProgress) + .initProgress(initProgress) + .currentPiece(currentPiece) + .executionStartTime(executionStartTime) + .executionEndTime(executionEndTime) + .isLatest(isLatest) + .taskType(taskType) + .checkpointId(checkpointId) + .coordinatingNode(coordinatingNode) + .workerNode(workerNode) + .entity(entity) + .parentTaskId(parentTaskId) + .estimatedMinutesLeft(estimatedMinutesLeft) + .user(user) + .forecaster(forecaster) + .dateRange(dateRange) + .build(); + } + } } diff --git a/src/test/java/org/opensearch/timeseries/feature/NoPowermockSearchFeatureDaoTests.java b/src/test/java/org/opensearch/timeseries/feature/NoPowermockSearchFeatureDaoTests.java index c689d3679..d2aec688e 100644 --- a/src/test/java/org/opensearch/timeseries/feature/NoPowermockSearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/timeseries/feature/NoPowermockSearchFeatureDaoTests.java @@ -67,9 +67,9 @@ import org.opensearch.common.util.BitMixer; import org.opensearch.common.util.MockBigArrays; import org.opensearch.common.util.MockPageCacheRecycler; +import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.indices.breaker.NoneCircuitBreakerService; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; diff --git a/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java b/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java index 3eb4fa80a..5c66d8f54 100644 --- a/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java +++ b/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java @@ -28,7 +28,7 @@ import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; -import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.transport.TransportAddress; public class ClusterCreation { /** diff --git a/src/test/java/test/org/opensearch/ad/util/FakeNode.java b/src/test/java/test/org/opensearch/ad/util/FakeNode.java index 6cacecc95..58f3f14bb 100644 --- a/src/test/java/test/org/opensearch/ad/util/FakeNode.java +++ b/src/test/java/test/org/opensearch/ad/util/FakeNode.java @@ -41,11 +41,11 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.BoundTransportAddress; -import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.core.common.transport.BoundTransportAddress; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; import org.opensearch.tasks.TaskManager; import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase;