Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add setting to limit max workflow steps #266

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;

/**
Expand Down Expand Up @@ -106,7 +107,7 @@ public Collection<Object> createComponents(
mlClient,
flowFrameworkIndicesHandler
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler);
}
Expand Down Expand Up @@ -144,6 +145,7 @@ public List<Setting<?>> getSettings() {
List<Setting<?>> settings = ImmutableList.of(
FLOW_FRAMEWORK_ENABLED,
MAX_WORKFLOWS,
MAX_WORKFLOW_STEPS,
WORKFLOW_REQUEST_TIMEOUT,
MAX_GET_TASK_REQUEST_RETRY
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ private FlowFrameworkSettings() {}

/** The upper limit of max workflows that can be created */
public static final int MAX_WORKFLOWS_LIMIT = 10000;
/** The upper limit of max workflow steps that can be in a single workflow */
public static final int MAX_WORKFLOW_STEPS_LIMIT = 500;

/** This setting sets max workflows that can be created */
public static final Setting<Integer> MAX_WORKFLOWS = Setting.intSetting(
Expand All @@ -29,6 +31,16 @@ private FlowFrameworkSettings() {}
Setting.Property.Dynamic
);

/** This setting sets max workflows that can be created */
public static final Setting<Integer> MAX_WORKFLOW_STEPS = Setting.intSetting(
"plugins.flow_framework.max_workflow_steps",
50,
1,
MAX_WORKFLOW_STEPS_LIMIT,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/** This setting sets the timeout for the request */
public static final Setting<TimeValue> WORKFLOW_REQUEST_TIMEOUT = Setting.positiveTimeSetting(
"plugins.flow_framework.request_timeout",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
Expand All @@ -32,6 +34,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE;
import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD;
import static org.opensearch.flowframework.model.WorkflowNode.USER_INPUTS_FIELD;
Expand All @@ -45,16 +48,26 @@ public class WorkflowProcessSorter {

private WorkflowStepFactory workflowStepFactory;
private ThreadPool threadPool;
private Integer maxWorkflowSteps;

/**
* Instantiate this class.
*
* @param workflowStepFactory The factory which matches template step types to instances.
* @param threadPool The OpenSearch Thread pool to pass to process nodes.
* @param clusterService The OpenSearch cluster service.
* @param settings OpenSerch settings
*/
public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) {
public WorkflowProcessSorter(
WorkflowStepFactory workflowStepFactory,
ThreadPool threadPool,
ClusterService clusterService,
Settings settings
) {
this.workflowStepFactory = workflowStepFactory;
this.threadPool = threadPool;
this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it);
}

/**
Expand All @@ -64,6 +77,20 @@ public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool
* @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list.
*/
public List<ProcessNode> sortProcessNodes(Workflow workflow, String workflowId) {
if (workflow.nodes().size() > this.maxWorkflowSteps) {
throw new FlowFrameworkException(
"Workflow "
+ workflowId
+ " has "
+ workflow.nodes().size()
+ " nodes, which exceeds the maximum of "
+ this.maxWorkflowSteps
+ ". Change the setting ["
+ MAX_WORKFLOW_STEPS.getKey()
+ "] to increase this.",
RestStatus.BAD_REQUEST
);
}
List<WorkflowNode> sortedNodes = topologicalSort(workflow.nodes(), workflow.edges());

List<ProcessNode> nodes = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -62,7 +63,7 @@ public void setUp() throws Exception {

final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
Expand All @@ -84,7 +85,7 @@ public void testPlugin() throws IOException {
assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size());
assertEquals(4, ffp.getActions().size());
assertEquals(1, ffp.getExecutorBuilders(settings).size());
assertEquals(4, ffp.getSettings().size());
assertEquals(5, ffp.getSettings().size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ThreadContext;
Expand All @@ -34,18 +37,24 @@

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.mockito.ArgumentCaptor;

import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
Expand All @@ -61,6 +70,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase {
private Template template;
private Client client = mock(Client.class);
private ThreadPool threadPool;
private ClusterSettings clusterSettings;
private ClusterService clusterService;
private ParseUtils parseUtils;
private ThreadContext threadContext;
private Settings settings;
Expand All @@ -73,8 +84,15 @@ public void setUp() throws Exception {
.put("plugins.flow_framework.max_workflows.", 2)
.put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10))
.build();
final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool);
this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool, clusterService, settings);
this.createWorkflowTransportAction = spy(
new CreateWorkflowTransportAction(
mock(TransportService.class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.TemplateTestJsonUtil;
Expand All @@ -32,6 +33,7 @@
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
Expand All @@ -41,6 +43,7 @@
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node;
Expand Down Expand Up @@ -79,11 +82,12 @@ public static void setup() {
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

Settings settings = Settings.builder().put("plugins.flow_framework.max_workflow_steps", 5).build();
final Set<Setting<?>> settingsSet = Stream.concat(
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY)
).collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

when(client.admin()).thenReturn(adminClient);
Expand All @@ -96,7 +100,7 @@ public static void setup() {
mlClient,
flowFrameworkIndicesHandler
);
workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool);
workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, settings);
}

@AfterClass
Expand Down Expand Up @@ -245,6 +249,21 @@ public void testExceptions() throws IOException {
ex = assertThrows(FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("A")), Collections.emptyList())));
assertEquals("Duplicate node id A.", ex.getMessage());
assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus());

ex = assertThrows(
FlowFrameworkException.class,
() -> parse(workflow(List.of(node("A"), node("B"), node("C"), node("D"), node("E"), node("F")), Collections.emptyList()))
);
String message = String.format(
Locale.ROOT,
"Workflow %s has %d nodes, which exceeds the maximum of %d. Change the setting [%s] to increase this.",
"123",
6,
5,
FlowFrameworkSettings.MAX_WORKFLOW_STEPS.getKey()
);
assertEquals(message, ex.getMessage());
assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus());
}

public void testSuccessfulGraphValidation() throws Exception {
Expand Down
Loading