Skip to content

Commit

Permalink
Change model APIs for multi tenant
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Apr 22, 2024
1 parent 06773b4 commit 1212972
Show file tree
Hide file tree
Showing 16 changed files with 389 additions and 218 deletions.
15 changes: 14 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
public class MLModel implements ToXContentObject {
@Deprecated
public static final String ALGORITHM_FIELD = "algorithm";

public static final String TENANT_ID_FIELD = "tenant_id";
public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
Expand Down Expand Up @@ -133,6 +135,7 @@ public class MLModel implements ToXContentObject {
private Connector connector;
private String connectorId;
private Guardrails guardrails;
private String tenantId;

@Builder(toBuilder = true)
public MLModel(String name,
Expand Down Expand Up @@ -166,7 +169,8 @@ public MLModel(String name,
Boolean isHidden,
Connector connector,
String connectorId,
Guardrails guardrails) {
Guardrails guardrails,
String tenantId) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
Expand Down Expand Up @@ -200,6 +204,7 @@ public MLModel(String name,
this.connector = connector;
this.connectorId = connectorId;
this.guardrails = guardrails;
this.tenantId = tenantId;
}

public MLModel(StreamInput input) throws IOException {
Expand Down Expand Up @@ -442,6 +447,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (guardrails != null) {
builder.field(GUARDRAILS_FIELD, guardrails);
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -486,6 +494,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
Connector connector = null;
String connectorId = null;
Guardrails guardrails = null;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -617,6 +626,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case GUARDRAILS_FIELD:
guardrails = Guardrails.parse(parser);
break;
case TENANT_ID_FIELD:
tenantId = parser.text();
break;
default:
parser.skipChildren();
break;
Expand Down Expand Up @@ -656,6 +668,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.connector(connector)
.connectorId(connectorId)
.guardrails(guardrails)
.tenantId(tenantId)
.build();
}

Expand Down
16 changes: 16 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLModelGroup.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ public class MLModelGroup implements ToXContentObject {

public static final String ACCESS = "access"; //assigned to public, private, or null when model group created
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //unique ID assigned to each model group

public static final String TENANT_ID_FIELD = "tenant_id";
public static final String CREATED_TIME_FIELD = "created_time"; //model group created time stamp
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; //updated whenever a new model version is created


@Setter
private String name;
private String description;
@Setter
private int latestVersion;
private List<String> backendRoles;
private User owner;
Expand All @@ -50,14 +53,18 @@ public class MLModelGroup implements ToXContentObject {

private String modelGroupId;

private String tenantId;

private Instant createdTime;
@Setter
private Instant lastUpdatedTime;


@Builder(toBuilder = true)
public MLModelGroup(String name, String description, int latestVersion,
List<String> backendRoles, User owner, String access,
String modelGroupId,
String tenantId,
Instant createdTime,
Instant lastUpdatedTime) {
this.name = Objects.requireNonNull(name, "model group name must not be null");
Expand All @@ -67,6 +74,7 @@ public MLModelGroup(String name, String description, int latestVersion,
this.owner = owner;
this.access = access;
this.modelGroupId = modelGroupId;
this.tenantId = tenantId;
this.createdTime = createdTime;
this.lastUpdatedTime = lastUpdatedTime;
}
Expand Down Expand Up @@ -132,6 +140,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);
}
if (createdTime != null) {
builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli());
}
Expand All @@ -150,6 +161,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException {
User owner = null;
String access = null;
String modelGroupId = null;
String tenantId = null;
Instant createdTime = null;
Instant lastUpdateTime = null;

Expand Down Expand Up @@ -184,6 +196,9 @@ public static MLModelGroup parse(XContentParser parser) throws IOException {
case MODEL_GROUP_ID_FIELD:
modelGroupId = parser.text();
break;
case TENANT_ID_FIELD:
tenantId = parser.text();
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand All @@ -203,6 +218,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException {
.owner(owner)
.access(access)
.modelGroupId(modelGroupId)
.tenantId(tenantId)
.createdTime(createdTime)
.lastUpdatedTime(lastUpdateTime)
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.common.transport;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.opensearch.action.ActionRequest;
import org.opensearch.core.common.io.stream.StreamInput;

import java.io.IOException;

@NoArgsConstructor
@AllArgsConstructor
public abstract class AbstractGetRequest extends ActionRequest {
@Setter
@Getter
private String tenantId;

public AbstractGetRequest(StreamInput in) throws IOException {
super(in);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.transport.AbstractGetRequest;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -24,11 +25,10 @@

@Getter
@Setter
public class MLConnectorGetRequest extends ActionRequest {
public class MLConnectorGetRequest extends AbstractGetRequest {

String connectorId;

String tenantId;
boolean returnContent;

@Builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.transport.AbstractGetRequest;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -27,7 +28,7 @@
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLModelGetRequest extends ActionRequest {
public class MLModelGetRequest extends AbstractGetRequest {

String modelId;
boolean returnContent;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -96,6 +97,9 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
private Boolean isHidden;
private Guardrails guardrails;

@Setter
private String tenantId;

@Builder(toBuilder = true)
public MLRegisterModelInput(FunctionName functionName,
String modelName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ public void initModelIndexIfAbsent(ActionListener<Boolean> listener) {
initMLIndexIfAbsent(MLIndex.MODEL, listener);
}

public Boolean initModelIndexIfAbsent() throws ExecutionException, InterruptedException {
PlainActionFuture<Boolean> actionFuture = PlainActionFuture.newFuture();
initModelIndexIfAbsent(actionFuture);
return actionFuture.get();
}

public void initMLTaskIndex(ActionListener<Boolean> listener) {
initMLIndexIfAbsent(MLIndex.TASK, listener);
}
Expand Down
Loading

0 comments on commit 1212972

Please sign in to comment.