Skip to content

Commit

Permalink
fix update connector API
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn committed Oct 11, 2023
1 parent de8b411 commit 3e75c0a
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,7 @@
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;
Expand Down Expand Up @@ -80,6 +70,7 @@ public interface Connector extends ToXContentObject, Writeable {

void writeTo(StreamOutput out) throws IOException;

void update(MLCreateConnectorInput updateContent, Function<String, String> function);

<T> void parseResponse(T orElse, List<ModelTensor> modelTensors, boolean b) throws IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;

@Log4j2
@NoArgsConstructor
Expand Down Expand Up @@ -248,6 +249,38 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

@Override
public void update(MLCreateConnectorInput updateContent, Function<String, String> function) {
if (updateContent.getName() != null) {
this.name = updateContent.getName();
}
if (updateContent.getDescription() != null) {
this.description = updateContent.getDescription();
}
if (updateContent.getVersion() != null) {
this.version = updateContent.getVersion();
}
if (updateContent.getProtocol() != null) {
this.protocol = updateContent.getProtocol();
}
if (updateContent.getParameters() != null && updateContent.getParameters().size() > 0) {
this.parameters = updateContent.getParameters();
}
if (updateContent.getCredential() != null && updateContent.getCredential().size() > 0) {
this.credential = updateContent.getCredential();
encrypt(function);
}
if (updateContent.getActions() != null) {
this.actions = updateContent.getActions();
}
if (updateContent.getBackendRoles() != null) {
this.backendRoles = updateContent.getBackendRoles();
}
if (updateContent.getAccess() != null) {
this.access = updateContent.getAccess();
}
}

@Override
public <T> T createPredictPayload(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
private Boolean addAllBackendRoles;
private AccessMode access;
private boolean dryRun = false;
private boolean updateConnector = false;

@Builder(toBuilder = true)
public MLCreateConnectorInput(String name,
Expand All @@ -68,9 +69,10 @@ public MLCreateConnectorInput(String name,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode access,
boolean dryRun
boolean dryRun,
boolean updateConnector
) {
if (!dryRun) {
if (!dryRun && !updateConnector) {
if (name == null) {
throw new IllegalArgumentException("Connector name is null");
}
Expand All @@ -92,9 +94,14 @@ public MLCreateConnectorInput(String name,
this.addAllBackendRoles = addAllBackendRoles;
this.access = access;
this.dryRun = dryRun;
this.updateConnector = updateConnector;
}

public static MLCreateConnectorInput parse(XContentParser parser) throws IOException {
return parse(parser, false);
}

public static MLCreateConnectorInput parse(XContentParser parser, boolean updateConnector) throws IOException {
String name = null;
String description = null;
String version = null;
Expand Down Expand Up @@ -159,7 +166,7 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep
break;
}
}
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun);
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun, updateConnector);
}

@Override
Expand Down Expand Up @@ -201,10 +208,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

@Override
public void writeTo(StreamOutput output) throws IOException {
output.writeString(name);
output.writeOptionalString(name);
output.writeOptionalString(description);
output.writeString(version);
output.writeString(protocol);
output.writeOptionalString(version);
output.writeOptionalString(protocol);
if (parameters != null) {
output.writeBoolean(true);
output.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
Expand Down Expand Up @@ -240,13 +247,14 @@ public void writeTo(StreamOutput output) throws IOException {
output.writeBoolean(false);
}
output.writeBoolean(dryRun);
output.writeBoolean(updateConnector);
}

public MLCreateConnectorInput(StreamInput input) throws IOException {
name = input.readString();
name = input.readOptionalString();
description = input.readOptionalString();
version = input.readString();
protocol = input.readString();
version = input.readOptionalString();
protocol = input.readOptionalString();
if (input.readBoolean()) {
parameters = input.readMap(s -> s.readString(), s -> s.readString());
}
Expand All @@ -268,5 +276,6 @@ public MLCreateConnectorInput(StreamInput input) throws IOException {
this.access = input.readEnum(AccessMode.class);
}
dryRun = input.readBoolean();
updateConnector = input.readBoolean();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,31 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
public class MLUpdateConnectorRequest extends ActionRequest {
String connectorId;
Map<String, Object> updateContent;
MLCreateConnectorInput updateContent;

@Builder
public MLUpdateConnectorRequest(String connectorId, Map<String, Object> updateContent) {
public MLUpdateConnectorRequest(String connectorId, MLCreateConnectorInput updateContent) {
this.connectorId = connectorId;
this.updateContent = updateContent;
}

public MLUpdateConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.updateContent = in.readMap();
this.updateContent = new MLCreateConnectorInput(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.connectorId);
out.writeMap(this.getUpdateContent());
this.updateContent.writeTo(out);
}

@Override
Expand All @@ -55,14 +54,17 @@ public ActionRequestValidationException validate() {
exception = addValidationError("ML connector id can't be null", exception);
}

if (updateContent == null) {
exception = addValidationError("Update connector content can't be null", exception);
}

return exception;
}

public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException {
Map<String, Object> dataAsMap = null;
dataAsMap = parser.map();
MLCreateConnectorInput updateContent = MLCreateConnectorInput.parse(parser, true);

return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build();
return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build();
}

public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,37 @@

import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.rest.RestRequest;
import org.opensearch.search.SearchModule;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.Collections;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.when;
import static org.junit.Assert.assertTrue;

public class MLUpdateConnectorRequestTests {
private String connectorId;
private Map<String, Object> updateContent;
private MLCreateConnectorInput updateContent;
private MLUpdateConnectorRequest mlUpdateConnectorRequest;

@Mock
XContentParser parser;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
this.connectorId = "test-connector_id";
this.updateContent = Map.of("description", "new description");
this.updateContent = MLCreateConnectorInput.builder().description("new description").updateConnector(true).build();
mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder()
.connectorId(connectorId)
.updateContent(updateContent)
Expand All @@ -64,18 +63,20 @@ public void validate_Exception_NullConnectorId() {
MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build();
Exception exception = updateConnectorRequest.validate();

assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage());
assertEquals("Validation Failed: 1: ML connector id can't be null;2: Update connector content can't be null;", exception.getMessage());
}

@Test
public void parse_success() throws IOException {
RestRequest.Method method = RestRequest.Method.POST;
final Map<String, Object> updatefields = Map.of("version", "new version", "description", "new description");
when(parser.map()).thenReturn(updatefields);

String jsonStr = "{\"version\":\"new version\",\"description\":\"new description\"}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId);
assertEquals(updateConnectorRequest.getConnectorId(), connectorId);
assertEquals(updateConnectorRequest.getUpdateContent(), updatefields);
assertTrue(updateConnectorRequest.getUpdateContent().isUpdateConnector());
assertEquals("new version", updateConnectorRequest.getUpdateContent().getVersion());
assertEquals("new description", updateConnectorRequest.getUpdateContent().getDescription());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.delete.DeleteRequest;
Expand Down Expand Up @@ -77,11 +81,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
.error(
searchHits.length + " models are still using this connector, please delete or update the models first!"
);
List<String> modelIds = new ArrayList<>();
for (SearchHit hit : searchHits) {
modelIds.add(hit.getId());
}
actionListener
.onFailure(
new MLValidationException(
searchHits.length
+ " models are still using this connector, please delete or update the models first!"
+ " models are still using this connector, please delete or update the models first: "
+ Arrays.toString(modelIds.toArray(new String[0]))
)
);
}
Expand Down
Loading

0 comments on commit 3e75c0a

Please sign in to comment.