Skip to content

Commit

Permalink
Feature/create model group before register (#192)
Browse files Browse the repository at this point in the history
* feat: add model group related API

Signed-off-by: Lin Wang <wonglam@amazon.com>

* feat: call model group register and delete when model register

Signed-off-by: Lin Wang <wonglam@amazon.com>

* feat: check name unique from model group search

Signed-off-by: Lin Wang <wonglam@amazon.com>

* fix: register model group call in register model version

Signed-off-by: Lin Wang <wonglam@amazon.com>

* fix: model delete after model version register failed

Signed-off-by: Lin Wang <wonglam@amazon.com>

* feat: add model access control related fields for create model group API

Signed-off-by: Lin Wang <wonglam@amazon.com>

---------

Signed-off-by: Lin Wang <wonglam@amazon.com>
  • Loading branch information
wanglam authored May 26, 2023
1 parent 6424c8e commit 4161537
Show file tree
Hide file tree
Showing 22 changed files with 665 additions and 106 deletions.
9 changes: 9 additions & 0 deletions public/apis/api_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import { Model } from './model';
import { ModelAggregate } from './model_aggregate';
import { ModelGroup } from './model_group';
import { ModelRepository } from './model_repository';
import { Profile } from './profile';
import { Security } from './security';
Expand All @@ -17,13 +18,15 @@ const apiInstanceStore: {
security: Security | undefined;
task: Task | undefined;
modelRepository: ModelRepository | undefined;
modelGroup: ModelGroup | undefined;
} = {
model: undefined,
modelAggregate: undefined,
profile: undefined,
security: undefined,
task: undefined,
modelRepository: undefined,
modelGroup: undefined,
};

export class APIProvider {
Expand All @@ -33,6 +36,7 @@ export class APIProvider {
public static getAPI(type: 'profile'): Profile;
public static getAPI(type: 'security'): Security;
public static getAPI(type: 'modelRepository'): ModelRepository;
public static getAPI(type: 'modelGroup'): ModelGroup;
public static getAPI(type: keyof typeof apiInstanceStore) {
if (apiInstanceStore[type]) {
return apiInstanceStore[type]!;
Expand Down Expand Up @@ -68,6 +72,11 @@ export class APIProvider {
apiInstanceStore.modelRepository = newInstance;
return newInstance;
}
case 'modelGroup': {
const newInstance = new ModelGroup();
apiInstanceStore.modelGroup = newInstance;
return newInstance;
}
}
}
public static clear() {
Expand Down
7 changes: 4 additions & 3 deletions public/apis/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { MODEL_STATE, ModelSearchSort } from '../../common';
import { MODEL_STATE } from '../../common';
import {
MODEL_API_ENDPOINT,
MODEL_LOAD_API_ENDPOINT,
Expand Down Expand Up @@ -79,10 +79,11 @@ export interface ModelProfileResponse {

interface UploadModelBase {
name: string;
version: string;
description: string;
version?: string;
description?: string;
modelFormat: string;
modelConfig: Record<string, unknown>;
modelGroupId: string;
}

export interface UploadModelByURL extends UploadModelBase {
Expand Down
70 changes: 70 additions & 0 deletions public/apis/model_group.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { MODEL_GROUP_API_ENDPOINT } from '../../server/routes/constants';
import { InnerHttpProvider } from './inner_http_provider';

interface ModelGroupSearchItem {
id: string;
owner: {
backend_roles: string[];
roles: string[];
name: string;
};
latest_version: number;
last_updated_time: number;
name: string;
description?: string;
}

export interface ModelGroupSearchResponse {
data: ModelGroupSearchItem[];
total_model_groups: number;
}

export class ModelGroup {
public register(body: {
name: string;
description?: string;
modelAccessMode: 'public' | 'restricted' | 'private';
backendRoles?: string[];
addAllBackendRoles?: boolean;
}) {
return InnerHttpProvider.getHttp().post<{ model_group_id: string; status: 'CREATED' }>(
MODEL_GROUP_API_ENDPOINT,
{
body: JSON.stringify(body),
}
);
}

public update({ id, name, description }: { id: string; name?: string; description?: string }) {
return InnerHttpProvider.getHttp().put<{ status: 'success' }>(
`${MODEL_GROUP_API_ENDPOINT}/${id}`,
{
body: JSON.stringify({
name,
description,
}),
}
);
}

public delete(id: string) {
return InnerHttpProvider.getHttp().delete<{ status: 'success' }>(
`${MODEL_GROUP_API_ENDPOINT}/${id}`
);
}

public search(query: { id?: string; name?: string; from: number; size: number }) {
return InnerHttpProvider.getHttp().get<ModelGroupSearchResponse>(MODEL_GROUP_API_ENDPOINT, {
query,
});
}

public getOne = async (id: string) => {
return (await this.search({ id, from: 0, size: 1 })).data[0];
};
}
4 changes: 2 additions & 2 deletions public/components/common/forms/model_name_field.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ interface ModelNameFieldProps<TFieldValues extends ModelNameFormData> {
}

const isDuplicateModelName = async (name: string) => {
const searchResult = await APIProvider.getAPI('model').search({
const searchResult = await APIProvider.getAPI('modelGroup').search({
name,
from: 0,
size: 1,
});
return searchResult.total_models >= 1;
return searchResult.total_model_groups >= 1;
};

export const ModelNameField = <TFieldValues extends ModelNameFormData>({
Expand Down
2 changes: 1 addition & 1 deletion public/components/global_breadcrumbs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ const getModelRegisterBreadcrumbs = (basename: string, matchedParams: {}) => {
staticBreadcrumbs: baseModelRegistryBreadcrumbs,
// TODO: Change to model group API
asyncBreadcrumbsLoader: () =>
APIProvider.getAPI('model')
APIProvider.getAPI('modelGroup')
.getOne(modelId)
.then(
(model) =>
Expand Down
144 changes: 144 additions & 0 deletions public/components/register_model/__tests__/register_model_api.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ModelGroup } from '../../../apis/model_group';
import { Model } from '../../../apis/model';
import { submitModelWithFile, submitModelWithURL } from '../register_model_api';

describe('register model api', () => {
beforeEach(() => {
jest
.spyOn(ModelGroup.prototype, 'register')
.mockResolvedValue({ model_group_id: 'foo', status: 'success' });
jest.spyOn(ModelGroup.prototype, 'delete').mockResolvedValue({ status: 'success' });
jest.spyOn(Model.prototype, 'upload').mockResolvedValue({ task_id: 'foo', model_id: 'bar' });
});

afterEach(() => {
jest.spyOn(ModelGroup.prototype, 'register').mockRestore();
jest.spyOn(ModelGroup.prototype, 'delete').mockRestore();
jest.spyOn(Model.prototype, 'upload').mockRestore();
});

it('should not call register model group API if modelId provided', async () => {
expect(ModelGroup.prototype.register).not.toHaveBeenCalled();

await submitModelWithFile({
name: 'foo',
description: 'bar',
configuration: '{}',
modelFileFormat: '',
modelId: 'a-exists-model-id',
modelFile: new File([], 'artifact.zip'),
});

expect(ModelGroup.prototype.register).not.toHaveBeenCalled();
});

it('should not call delete model group API if modelId provided and model upload failed', async () => {
const uploadError = new Error();
const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError);

try {
await submitModelWithFile({
name: 'foo',
description: 'bar',
configuration: '{}',
modelFileFormat: '',
modelId: 'a-exists-model-id',
modelFile: new File([], 'artifact.zip'),
});
} catch (error) {
expect(error).toBe(uploadError);
}
expect(ModelGroup.prototype.delete).not.toHaveBeenCalled();

uploadMock.mockRestore();
});

describe('submitModelWithFile', () => {
it('should call register model group API with name and description', async () => {
expect(ModelGroup.prototype.register).not.toHaveBeenCalled();

await submitModelWithFile({
name: 'foo',
description: 'bar',
configuration: '{}',
modelFileFormat: '',
modelFile: new File([], 'artifact.zip'),
});

expect(ModelGroup.prototype.register).toHaveBeenCalledWith(
expect.objectContaining({
name: 'foo',
description: 'bar',
})
);
});

it('should delete created model group API upload failed', async () => {
const uploadError = new Error();
const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError);

expect(ModelGroup.prototype.delete).not.toHaveBeenCalled();
try {
await submitModelWithFile({
name: 'foo',
description: 'bar',
configuration: '{}',
modelFileFormat: '',
modelFile: new File([], 'artifact.zip'),
});
} catch (error) {
expect(uploadError).toBe(error);
}
expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('foo');

uploadMock.mockRestore();
});
});

describe('submitModelWithURL', () => {
it('should call register model group API with name and description', async () => {
expect(ModelGroup.prototype.register).not.toHaveBeenCalled();

await submitModelWithURL({
name: 'foo',
description: 'bar',
configuration: '{}',
modelFileFormat: '',
modelURL: 'https://address.to/artifact.zip',
});

expect(ModelGroup.prototype.register).toHaveBeenCalledWith(
expect.objectContaining({
name: 'foo',
description: 'bar',
})
);
});

it('should delete created model group API upload failed', async () => {
const uploadError = new Error();
const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError);

expect(ModelGroup.prototype.delete).not.toHaveBeenCalled();
try {
await submitModelWithURL({
name: 'foo',
description: 'bar',
configuration: '{}',
modelFileFormat: '',
modelURL: 'https://address.to/artifact.zip',
});
} catch (error) {
expect(uploadError).toBe(error);
}
expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('foo');

uploadMock.mockRestore();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import { setup } from './setup';
import * as formAPI from '../register_model_api';
import { Model } from '../../../apis/model';

describe('<RegisterModel /> Details', () => {
const onSubmitMock = jest.fn().mockResolvedValue('model_id');
Expand Down Expand Up @@ -53,13 +52,9 @@ describe('<RegisterModel /> Details', () => {

it('should NOT submit the register model form if model name is duplicated', async () => {
const result = await setup();
jest.spyOn(Model.prototype, 'search').mockResolvedValue({
data: [],
total_models: 1,
});

await result.user.clear(result.nameInput);
await result.user.type(result.nameInput, 'a-duplicated-model-name');
await result.user.type(result.nameInput, 'model1');
await result.user.click(result.submitButton);

expect(result.nameInput).toBeInvalid();
Expand Down
Loading

0 comments on commit 4161537

Please sign in to comment.