-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/create model group before register (#192)
* feat: add model group related API Signed-off-by: Lin Wang <wonglam@amazon.com> * feat: call model group register and delete when model register Signed-off-by: Lin Wang <wonglam@amazon.com> * feat: check name unique from model group search Signed-off-by: Lin Wang <wonglam@amazon.com> * fix: register model group call in register model version Signed-off-by: Lin Wang <wonglam@amazon.com> * fix: model delete after model version register failed Signed-off-by: Lin Wang <wonglam@amazon.com> * feat: add model access control related fields for create model group API Signed-off-by: Lin Wang <wonglam@amazon.com> --------- Signed-off-by: Lin Wang <wonglam@amazon.com>
- Loading branch information
Showing
22 changed files
with
665 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
import { MODEL_GROUP_API_ENDPOINT } from '../../server/routes/constants'; | ||
import { InnerHttpProvider } from './inner_http_provider'; | ||
|
||
interface ModelGroupSearchItem { | ||
id: string; | ||
owner: { | ||
backend_roles: string[]; | ||
roles: string[]; | ||
name: string; | ||
}; | ||
latest_version: number; | ||
last_updated_time: number; | ||
name: string; | ||
description?: string; | ||
} | ||
|
||
export interface ModelGroupSearchResponse { | ||
data: ModelGroupSearchItem[]; | ||
total_model_groups: number; | ||
} | ||
|
||
export class ModelGroup { | ||
public register(body: { | ||
name: string; | ||
description?: string; | ||
modelAccessMode: 'public' | 'restricted' | 'private'; | ||
backendRoles?: string[]; | ||
addAllBackendRoles?: boolean; | ||
}) { | ||
return InnerHttpProvider.getHttp().post<{ model_group_id: string; status: 'CREATED' }>( | ||
MODEL_GROUP_API_ENDPOINT, | ||
{ | ||
body: JSON.stringify(body), | ||
} | ||
); | ||
} | ||
|
||
public update({ id, name, description }: { id: string; name?: string; description?: string }) { | ||
return InnerHttpProvider.getHttp().put<{ status: 'success' }>( | ||
`${MODEL_GROUP_API_ENDPOINT}/${id}`, | ||
{ | ||
body: JSON.stringify({ | ||
name, | ||
description, | ||
}), | ||
} | ||
); | ||
} | ||
|
||
public delete(id: string) { | ||
return InnerHttpProvider.getHttp().delete<{ status: 'success' }>( | ||
`${MODEL_GROUP_API_ENDPOINT}/${id}` | ||
); | ||
} | ||
|
||
public search(query: { id?: string; name?: string; from: number; size: number }) { | ||
return InnerHttpProvider.getHttp().get<ModelGroupSearchResponse>(MODEL_GROUP_API_ENDPOINT, { | ||
query, | ||
}); | ||
} | ||
|
||
public getOne = async (id: string) => { | ||
return (await this.search({ id, from: 0, size: 1 })).data[0]; | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
144 changes: 144 additions & 0 deletions
144
public/components/register_model/__tests__/register_model_api.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
import { ModelGroup } from '../../../apis/model_group'; | ||
import { Model } from '../../../apis/model'; | ||
import { submitModelWithFile, submitModelWithURL } from '../register_model_api'; | ||
|
||
describe('register model api', () => { | ||
beforeEach(() => { | ||
jest | ||
.spyOn(ModelGroup.prototype, 'register') | ||
.mockResolvedValue({ model_group_id: 'foo', status: 'success' }); | ||
jest.spyOn(ModelGroup.prototype, 'delete').mockResolvedValue({ status: 'success' }); | ||
jest.spyOn(Model.prototype, 'upload').mockResolvedValue({ task_id: 'foo', model_id: 'bar' }); | ||
}); | ||
|
||
afterEach(() => { | ||
jest.spyOn(ModelGroup.prototype, 'register').mockRestore(); | ||
jest.spyOn(ModelGroup.prototype, 'delete').mockRestore(); | ||
jest.spyOn(Model.prototype, 'upload').mockRestore(); | ||
}); | ||
|
||
it('should not call register model group API if modelId provided', async () => { | ||
expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); | ||
|
||
await submitModelWithFile({ | ||
name: 'foo', | ||
description: 'bar', | ||
configuration: '{}', | ||
modelFileFormat: '', | ||
modelId: 'a-exists-model-id', | ||
modelFile: new File([], 'artifact.zip'), | ||
}); | ||
|
||
expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); | ||
}); | ||
|
||
it('should not call delete model group API if modelId provided and model upload failed', async () => { | ||
const uploadError = new Error(); | ||
const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError); | ||
|
||
try { | ||
await submitModelWithFile({ | ||
name: 'foo', | ||
description: 'bar', | ||
configuration: '{}', | ||
modelFileFormat: '', | ||
modelId: 'a-exists-model-id', | ||
modelFile: new File([], 'artifact.zip'), | ||
}); | ||
} catch (error) { | ||
expect(error).toBe(uploadError); | ||
} | ||
expect(ModelGroup.prototype.delete).not.toHaveBeenCalled(); | ||
|
||
uploadMock.mockRestore(); | ||
}); | ||
|
||
describe('submitModelWithFile', () => { | ||
it('should call register model group API with name and description', async () => { | ||
expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); | ||
|
||
await submitModelWithFile({ | ||
name: 'foo', | ||
description: 'bar', | ||
configuration: '{}', | ||
modelFileFormat: '', | ||
modelFile: new File([], 'artifact.zip'), | ||
}); | ||
|
||
expect(ModelGroup.prototype.register).toHaveBeenCalledWith( | ||
expect.objectContaining({ | ||
name: 'foo', | ||
description: 'bar', | ||
}) | ||
); | ||
}); | ||
|
||
it('should delete created model group API upload failed', async () => { | ||
const uploadError = new Error(); | ||
const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError); | ||
|
||
expect(ModelGroup.prototype.delete).not.toHaveBeenCalled(); | ||
try { | ||
await submitModelWithFile({ | ||
name: 'foo', | ||
description: 'bar', | ||
configuration: '{}', | ||
modelFileFormat: '', | ||
modelFile: new File([], 'artifact.zip'), | ||
}); | ||
} catch (error) { | ||
expect(uploadError).toBe(error); | ||
} | ||
expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('foo'); | ||
|
||
uploadMock.mockRestore(); | ||
}); | ||
}); | ||
|
||
describe('submitModelWithURL', () => { | ||
it('should call register model group API with name and description', async () => { | ||
expect(ModelGroup.prototype.register).not.toHaveBeenCalled(); | ||
|
||
await submitModelWithURL({ | ||
name: 'foo', | ||
description: 'bar', | ||
configuration: '{}', | ||
modelFileFormat: '', | ||
modelURL: 'https://address.to/artifact.zip', | ||
}); | ||
|
||
expect(ModelGroup.prototype.register).toHaveBeenCalledWith( | ||
expect.objectContaining({ | ||
name: 'foo', | ||
description: 'bar', | ||
}) | ||
); | ||
}); | ||
|
||
it('should delete created model group API upload failed', async () => { | ||
const uploadError = new Error(); | ||
const uploadMock = jest.spyOn(Model.prototype, 'upload').mockRejectedValue(uploadError); | ||
|
||
expect(ModelGroup.prototype.delete).not.toHaveBeenCalled(); | ||
try { | ||
await submitModelWithURL({ | ||
name: 'foo', | ||
description: 'bar', | ||
configuration: '{}', | ||
modelFileFormat: '', | ||
modelURL: 'https://address.to/artifact.zip', | ||
}); | ||
} catch (error) { | ||
expect(uploadError).toBe(error); | ||
} | ||
expect(ModelGroup.prototype.delete).toHaveBeenCalledWith('foo'); | ||
|
||
uploadMock.mockRestore(); | ||
}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.