Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Substitute REST path or body parameters in Workflow Steps #536

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.12...2.x)
### Features
### Enhancements
- Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525))

### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
Expand Down Expand Up @@ -75,6 +78,19 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
String workflowId = request.param(WORKFLOW_ID);
String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" });
boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false);
final List<String> validCreateParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW);
// If provisioning, consume all other params and pass to provision transport action
Map<String, String> params = provision
? request.params()
.keySet()
.stream()
.filter(k -> !validCreateParams.contains(k))
.collect(Collectors.toMap(Function.identity(), request::param))
: request.params()
.entrySet()
.stream()
.filter(e -> !validCreateParams.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
if (!flowFrameworkSettings.isFlowFrameworkEnabled()) {
FlowFrameworkException ffe = new FlowFrameworkException(
"This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.",
Expand All @@ -84,12 +100,24 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
);
}
if (!provision && !params.isEmpty()) {
// Consume params and content so custom exception is processed
params.keySet().stream().forEach(request::param);
request.content();
FlowFrameworkException ffe = new FlowFrameworkException(
"Only the parameters " + validCreateParams + " are permitted unless the provision parameter is set to true.",
RestStatus.BAD_REQUEST
);
return channel -> channel.sendResponse(
new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
);
}
try {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Template template = Template.parse(parser);

WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision, params);

return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.ProvisionWorkflowAction;
Expand All @@ -27,7 +28,11 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
Expand Down Expand Up @@ -69,23 +74,19 @@ public List<Route> routes() {
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String workflowId = request.param(WORKFLOW_ID);
try {
Map<String, String> params = parseParamsAndContent(request);
if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) {
throw new FlowFrameworkException(
"This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.",
RestStatus.FORBIDDEN
);
}
// Validate content
if (request.hasContent()) {
// BaseRestHandler will give appropriate error message
return channel -> channel.sendResponse(null);
}
// Validate params
if (workflowId == null) {
throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST);
}
// Create request and provision
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params);
return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
Expand All @@ -108,4 +109,31 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
}

private Map<String, String> parseParamsAndContent(RestRequest request) {
// Get any other params from path
Map<String, String> params = request.params()
.keySet()
.stream()
.filter(k -> !WORKFLOW_ID.equals(k))
.collect(Collectors.toMap(Function.identity(), request::param));
// If body is included get any params from body
if (request.hasContent()) {
try (XContentParser parser = request.contentParser()) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String key = parser.currentName();
if (params.containsKey(key)) {
throw new FlowFrameworkException("Duplicate key " + key, RestStatus.BAD_REQUEST);
}
if (parser.nextToken() != XContentParser.Token.VALUE_STRING) {
throw new FlowFrameworkException("Request body fields must have string values", RestStatus.BAD_REQUEST);
}
params.put(key, parser.text());
}
} catch (IOException e) {
throw new FlowFrameworkException("Request body parsing failed", RestStatus.BAD_REQUEST);
}
}
return params;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.transport.TransportService;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -282,7 +283,7 @@ void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, ActionList

private void validateWorkflows(Template template) throws Exception {
for (Workflow workflow : template.workflows().values()) {
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null);
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null, Collections.emptyMap());
workflowProcessSorter.validate(sortedNodes, pluginsService);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
deprovisionStepId,
workflowStepFactory.createStep(deprovisionStep),
Collections.emptyMap(),
Collections.emptyMap(),
new WorkflowData(Map.of(getResourceByWorkflowStep(stepName), resource.resourceId()), workflowId, deprovisionStepId),
Collections.emptyList(),
this.threadPool,
Expand Down Expand Up @@ -194,6 +195,7 @@
pn.id(),
workflowStepFactory.createStep(pn.workflowStep().getName()),
pn.previousNodeInputs(),
pn.params(),

Check warning on line 198 in src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java#L198

Added line #L198 was not covered by tests
pn.input(),
pn.predecessors(),
this.threadPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work

// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(
provisionWorkflow,
workflowId,
request.getParams()
);
workflowProcessSorter.validate(provisionProcessSequence, pluginsService);

flowFrameworkIndicesHandler.isWorkflowNotStarted(workflowId, workflowIsNotStarted -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.opensearch.flowframework.model.Template;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;

/**
* Transport Request to create, provision, and deprovision a workflow
Expand Down Expand Up @@ -43,12 +45,27 @@ public class WorkflowRequest extends ActionRequest {
private boolean provision;

/**
* Instantiates a new WorkflowRequest, set validation to false and set requestTimeout and maxWorkflows to null
* Params map
*/
private Map<String, String> params;

/**
* Instantiates a new WorkflowRequest, set validation to all, no provisioning
* @param workflowId the documentId of the workflow
* @param template the use case template which describes the workflow
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) {
this(workflowId, template, new String[] { "all" }, false);
this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap());
}

/**
* Instantiates a new WorkflowRequest with params map, set validation to all, provisioning to true
* @param workflowId the documentId of the workflow
* @param template the use case template which describes the workflow
* @param params The parameters from the REST path
*/
public WorkflowRequest(String workflowId, @Nullable Template template, Map<String, String> params) {
this(workflowId, template, new String[] { "all" }, true, params);
}

/**
Expand All @@ -57,12 +74,23 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template)
* @param template the use case template which describes the workflow
* @param validation flag to indicate if validation is necessary
* @param provision flag to indicate if provision is necessary
* @param params map of REST path params. If provision is false, must be an empty map.
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, String[] validation, boolean provision) {
public WorkflowRequest(
@Nullable String workflowId,
@Nullable Template template,
String[] validation,
boolean provision,
Map<String, String> params
) {
this.workflowId = workflowId;
this.template = template;
this.validation = validation;
this.provision = provision;
if (!provision && !params.isEmpty()) {
throw new IllegalArgumentException("Params may only be included when provisioning.");
}
this.params = params;
}

/**
Expand All @@ -77,6 +105,7 @@ public WorkflowRequest(StreamInput in) throws IOException {
this.template = templateJson == null ? null : Template.parse(templateJson);
this.validation = in.readStringArray();
this.provision = in.readBoolean();
this.params = this.provision ? in.readMap(StreamInput::readString, StreamInput::readString) : Collections.emptyMap();
}

/**
Expand Down Expand Up @@ -113,13 +142,24 @@ public boolean isProvision() {
return this.provision;
}

/**
* Gets the params map
* @return the params map
*/
public Map<String, String> getParams() {
return Map.copyOf(this.params);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(workflowId);
out.writeOptionalString(template == null ? null : template.toJson());
out.writeStringArray(validation);
out.writeBoolean(provision);
if (provision) {
out.writeMap(params, StreamOutput::writeString, StreamOutput::writeString);
}
}

@Override
Expand Down
18 changes: 13 additions & 5 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ public static Map<String, String> getStringToStringMap(Object map, String fieldN
* @param currentNodeInputs Input params and content for this node, from workflow parsing
* @param outputs WorkflowData content of previous steps
* @param previousNodeInputs Input params for this node that come from previous steps
* @param params Params that came from REST path
* @return A map containing the requiredInputKeys with their corresponding values,
* and optionalInputKeys with their corresponding values if present.
* Throws a {@link FlowFrameworkException} if a required key is not present.
Expand All @@ -257,7 +258,8 @@ public static Map<String, Object> getInputsFromPreviousSteps(
Set<String> optionalInputKeys,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
// Mutable set to ensure all required keys are used
Set<String> requiredKeys = new HashSet<>(requiredInputKeys);
Expand Down Expand Up @@ -308,11 +310,11 @@ public static Map<String, Object> getInputsFromPreviousSteps(
Map<String, Object> valueMap = (Map<String, Object>) value;
value = valueMap.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> conditionallySubstitute(e.getValue(), outputs)));
.collect(Collectors.toMap(Map.Entry::getKey, e -> conditionallySubstitute(e.getValue(), outputs, params)));
} else if (value instanceof List) {
value = ((List<?>) value).stream().map(v -> conditionallySubstitute(v, outputs)).collect(Collectors.toList());
value = ((List<?>) value).stream().map(v -> conditionallySubstitute(v, outputs, params)).collect(Collectors.toList());
} else {
value = conditionallySubstitute(value, outputs);
value = conditionallySubstitute(value, outputs, params);
}
// Add value to inputs and mark that a required key was present
inputs.put(key, value);
Expand All @@ -336,15 +338,21 @@ public static Map<String, Object> getInputsFromPreviousSteps(
return inputs;
}

private static Object conditionallySubstitute(Object value, Map<String, WorkflowData> outputs) {
private static Object conditionallySubstitute(Object value, Map<String, WorkflowData> outputs, Map<String, String> params) {
if (value instanceof String) {
Matcher m = SUBSTITUTION_PATTERN.matcher((String) value);
if (m.matches()) {
// Try matching a previous step+value pair
WorkflowData data = outputs.get(m.group(1));
if (data != null && data.getContent().containsKey(m.group(2))) {
return data.getContent().get(m.group(2));
}
}
// Replace all params if present
for (Entry<String, String> e : params.entrySet()) {
String regex = "\\$\\{\\{\\s*" + Pattern.quote(e.getKey()) + "\\s*\\}\\}";
value = ((String) value).replaceAll(regex, e.getValue());
}
}
return value;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ public PlainActionFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
Map<String, String> previousNodeInputs,
Map<String, String> params
) {

PlainActionFuture<WorkflowData> registerLocalModelFuture = PlainActionFuture.newFuture();
Expand All @@ -90,7 +91,8 @@ public PlainActionFuture<WorkflowData> execute(
getOptionalKeys(),
currentNodeInputs,
outputs,
previousNodeInputs
previousNodeInputs,
params
);

// Extract common fields of OS provided text-embedding, sparse encoding and custom models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ public PlainActionFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
PlainActionFuture<WorkflowData> createConnectorFuture = PlainActionFuture.newFuture();

Expand Down Expand Up @@ -138,7 +139,8 @@ public void onFailure(Exception e) {
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
previousNodeInputs,
params
);

String name = (String) inputs.get(NAME_FIELD);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public PlainActionFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
PlainActionFuture<WorkflowData> createIndexFuture = PlainActionFuture.newFuture();
ActionListener<CreateIndexResponse> actionListener = new ActionListener<>() {
Expand Down
Loading
Loading