Skip to content

Commit

Permalink
if model version fails to register, update model group accordingly (#…
Browse files Browse the repository at this point in the history
…1463)

* if model version fails to register, update model group accordingly

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
(cherry picked from commit e2d2778)
  • Loading branch information
rbhavna authored and github-actions[bot] committed Oct 6, 2023
1 parent 7f548a6 commit 52fb87e
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String ACCESS_MODE_FIELD = "access_mode";
public static final String BACKEND_ROLES_FIELD = "backend_roles";
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";

public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";
private FunctionName functionName;
private String modelName;
private String modelGroupId;
Expand All @@ -73,6 +73,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
private List<String> backendRoles;
private Boolean addAllBackendRoles;
private AccessMode accessMode;
private Boolean doesVersionCreateModelGroup;

@Builder(toBuilder = true)
public MLRegisterModelInput(FunctionName functionName,
Expand All @@ -90,7 +91,8 @@ public MLRegisterModelInput(FunctionName functionName,
String connectorId,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode accessMode
AccessMode accessMode,
Boolean doesVersionCreateModelGroup
) {
if (functionName == null) {
this.functionName = FunctionName.TEXT_EMBEDDING;
Expand Down Expand Up @@ -123,6 +125,7 @@ public MLRegisterModelInput(FunctionName functionName,
this.backendRoles = backendRoles;
this.addAllBackendRoles = addAllBackendRoles;
this.accessMode = accessMode;
this.doesVersionCreateModelGroup = doesVersionCreateModelGroup;
}


Expand Down Expand Up @@ -157,6 +160,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.accessMode = in.readEnum(AccessMode.class);
}
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}

@Override
Expand Down Expand Up @@ -202,6 +206,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}

@Override
Expand Down Expand Up @@ -249,6 +254,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (accessMode != null) {
builder.field(ACCESS_MODE_FIELD, accessMode);
}
if (doesVersionCreateModelGroup != null) {
builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup);
}
builder.endObject();
return builder;
}
Expand All @@ -267,6 +275,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
List<String> backendRoles = new ArrayList<>();
Boolean addAllBackendRoles = null;
AccessMode accessMode = null;
Boolean doesVersionCreateModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -318,12 +327,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
break;
case DOES_VERSION_CREATE_MODEL_GROUP:
doesVersionCreateModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
}

public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException {
Expand All @@ -342,6 +354,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
List<String> backendRoles = new ArrayList<>();
AccessMode accessMode = null;
Boolean addAllBackendRoles = null;
Boolean doesVersionCreateModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -400,11 +413,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
break;
case DOES_VERSION_CREATE_MODEL_GROUP:
doesVersionCreateModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
public static final String ACCESS_MODE = "access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";


private FunctionName functionName;
private String name;
Expand All @@ -65,11 +67,13 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
private List<String> backendRoles;
private AccessMode accessMode;
private Boolean isAddAllBackendRoles;
private Boolean doesVersionCreateModelGroup;

@Builder(toBuilder = true)
public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List<String> backendRoles,
AccessMode accessMode,
Boolean isAddAllBackendRoles) {
Boolean isAddAllBackendRoles,
Boolean doesVersionCreateModelGroup) {
if (name == null) {
throw new IllegalArgumentException("model name is null");
}
Expand Down Expand Up @@ -103,6 +107,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
this.backendRoles = backendRoles;
this.accessMode = accessMode;
this.isAddAllBackendRoles = isAddAllBackendRoles;
this.doesVersionCreateModelGroup = doesVersionCreateModelGroup;
}

public MLRegisterModelMetaInput(StreamInput in) throws IOException{
Expand All @@ -128,6 +133,7 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{
accessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}

@Override
Expand Down Expand Up @@ -171,6 +177,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isAddAllBackendRoles);
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}

@Override
Expand Down Expand Up @@ -206,6 +213,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (isAddAllBackendRoles != null) {
builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles);
}
if (doesVersionCreateModelGroup != null) {
builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup);
}
builder.endObject();
return builder;
}
Expand All @@ -225,6 +235,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
List<String> backendRoles = null;
AccessMode accessMode = null;
Boolean isAddAllBackendRoles = null;
Boolean doesVersionCreateModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -277,12 +288,15 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = parser.booleanValue();
break;
case DOES_VERSION_CREATE_MODEL_GROUP:
doesVersionCreateModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles);
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void setup() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public void setUp() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,14 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(registerModelInput);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
registerModelInput.setModelGroupId(modelGroupId);
registerModelInput.setDoesVersionCreateModelGroup(true);
registerModel(registerModelInput, listener);
}, e -> {
logException("Failed to create Model Group", e, log);
listener.onFailure(e);
}));
} else {
registerModelInput.setDoesVersionCreateModelGroup(false);
registerModel(registerModelInput, listener);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionList
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
mlUploadInput.setModelGroupId(modelGroupId);
mlUploadInput.setDoesVersionCreateModelGroup(true);
registerModelMeta(mlUploadInput, listener);
}, e -> {
logException("Failed to create Model Group", e, log);
listener.onFailure(e);
}));
} else {
mlUploadInput.setDoesVersionCreateModelGroup(false);
registerModelMeta(mlUploadInput, listener);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,13 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us
AccessMode modelAccessMode = input.getModelAccessMode();
Boolean isAddAllBackendRoles = input.getIsAddAllBackendRoles();
if (modelAccessMode == null) {
if (modelAccessMode == null) {
if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) {
throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time.");
} else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) {
input.setModelAccessMode(AccessMode.RESTRICTED);
modelAccessMode = AccessMode.RESTRICTED;
} else {
input.setModelAccessMode(AccessMode.PRIVATE);
}
if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) {
throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time.");
} else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) {
input.setModelAccessMode(AccessMode.RESTRICTED);
modelAccessMode = AccessMode.RESTRICTED;
} else {
input.setModelAccessMode(AccessMode.PRIVATE);
}
}
if ((AccessMode.PUBLIC == modelAccessMode || AccessMode.PRIVATE == modelAccessMode)
Expand Down Expand Up @@ -184,20 +182,29 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us
}

public void validateUniqueModelGroupName(String name, ActionListener<SearchResponse> listener) throws IllegalArgumentException {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder);

client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> {
if (e instanceof IndexNotFoundException) {
listener.onResponse(null);
} else {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}
}));
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder);

client
.search(
searchRequest,
ActionListener.runBefore(ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> {
if (e instanceof IndexNotFoundException) {
listener.onResponse(null);
} else {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}
}), () -> context.restore())
);
} catch (Exception e) {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}
}

private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) {
Expand Down
Loading

0 comments on commit 52fb87e

Please sign in to comment.