Skip to content

Commit

Permalink
feat(bff): create endpoint to list all model versions (#707)
Browse files Browse the repository at this point in the history
Signed-off-by: Eder Ignatowicz <ignatowicz@gmail.com>
  • Loading branch information
ederign authored Jan 16, 2025
1 parent a7af392 commit f9f78c3
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 8 deletions.
4 changes: 4 additions & 0 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/a
}}'
```
```
# GET /api/v1/model_registry/{model_registry_id}/model_versions
curl -i -H "kubeflow-userid: user@example.com" "http://localhost:4000/api/v1/model_registry/model-registry/model_versions?namespace=kubeflow"
```
```
# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}
curl -i -H "kubeflow-userid: user@example.com" "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow"
```
Expand Down
3 changes: 2 additions & 1 deletion clients/ui/bff/internal/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ func (app *App) Routes() http.Handler {
apiRouter.PATCH(RegisteredModelPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.UpdateRegisteredModelHandler))))
apiRouter.GET(RegisteredModelVersionsPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelVersionsForRegisteredModelHandler))))
apiRouter.POST(RegisteredModelVersionsPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelVersionForRegisteredModelHandler))))
apiRouter.GET(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient((app.GetModelVersionHandler)))))
apiRouter.POST(ModelVersionListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelVersionHandler))))
apiRouter.GET(ModelVersionListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelVersionHandler))))
apiRouter.GET(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetModelVersionHandler))))
apiRouter.PATCH(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.UpdateModelVersionHandler))))
apiRouter.GET(ModelVersionArtifactListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler))))
apiRouter.POST(ModelVersionArtifactListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelArtifactByModelVersionHandler))))
Expand Down
24 changes: 24 additions & 0 deletions clients/ui/bff/internal/api/model_versions_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,30 @@ type ModelVersionUpdateEnvelope Envelope[*openapi.ModelVersionUpdate, None]
type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None]
type ModelArtifactEnvelope Envelope[*openapi.ModelArtifact, None]

func (app *App) GetAllModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

versionList, err := app.repositories.ModelRegistryClient.GetAllModelVersions(client)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}

responseBody := ModelVersionListEnvelope{
Data: versionList,
}

err = app.WriteJSON(w, http.StatusOK, responseBody, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}

}

func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface)
if !ok {
Expand Down
12 changes: 12 additions & 0 deletions clients/ui/bff/internal/api/model_versions_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ import (
var _ = Describe("TestGetModelVersionHandler", func() {
Context("testing Model Version Handler", Ordered, func() {

It("should retrieve all model versions", func() {
By("fetching all model versions")
data := mocks.GetModelVersionListMock()
expected := ModelVersionListEnvelope{Data: &data}
actual, rs, err := setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions?namespace=kubeflow", nil, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow")
Expect(err).NotTo(HaveOccurred())
By("should match the expected model versions")
Expect(rs.StatusCode).To(Equal(http.StatusOK))
Expect(actual.Data.Size).To(Equal(expected.Data.Size))
Expect(actual.Data.Items).To(Equal(expected.Data.Items))
})

It("should retrieve a model version", func() {
By("fetching a model version")
data := mocks.GetModelVersionMocks()[0]
Expand Down
2 changes: 1 addition & 1 deletion clients/ui/bff/internal/api/registered_models_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWrit
return
}

versionList, err := app.repositories.ModelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId), r.URL.Query())
versionList, err := app.repositories.ModelRegistryClient.GetAllModelVersionsForRegisteredModel(client, ps.ByName(RegisteredModelId), r.URL.Query())

if err != nil {
app.serverErrorResponse(w, r, err)
Expand Down
7 changes: 6 additions & 1 deletion clients/ui/bff/internal/mocks/model_registry_client_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ func (m *ModelRegistryClientMock) UpdateRegisteredModel(_ integrations.HTTPClien
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface) (*openapi.ModelVersionList, error) {
mockData := GetModelVersionListMock()
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetModelVersion(_ integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) {
if id == "3" {
mockData := GetModelVersionMocks()[2]
Expand All @@ -61,7 +66,7 @@ func (m *ModelRegistryClientMock) UpdateModelVersion(_ integrations.HTTPClientIn
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) {
func (m *ModelRegistryClientMock) GetAllModelVersionsForRegisteredModel(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) {
mockData := GetModelVersionListMock()
return &mockData, nil
}
Expand Down
16 changes: 16 additions & 0 deletions clients/ui/bff/internal/repositories/model_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const modelVersionPath = "/model_versions"
const artifactsByModelVersionPath = "/artifacts"

type ModelVersionInterface interface {
GetAllModelVersions(client integrations.HTTPClientInterface) (*openapi.ModelVersionList, error)
GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error)
CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error)
UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
Expand All @@ -24,6 +25,21 @@ type ModelVersion struct {
ModelVersionInterface
}

func (v ModelVersion) GetAllModelVersions(client integrations.HTTPClientInterface) (*openapi.ModelVersionList, error) {
response, err := client.GET(modelVersionPath)

if err != nil {
return nil, fmt.Errorf("error fetching model versions: %w", err)
}

var models openapi.ModelVersionList
if err := json.Unmarshal(response, &models); err != nil {
return nil, fmt.Errorf("error decoding response data: %w", err)
}

return &models, nil
}

func (v ModelVersion) GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) {
path, err := url.JoinPath(modelVersionPath, id)
if err != nil {
Expand Down
26 changes: 26 additions & 0 deletions clients/ui/bff/internal/repositories/model_version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@ func TestGetModelVersion(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllModelVersions(t *testing.T) {
_ = gofakeit.Seed(0)

expected := mocks.GenerateMockModelVersionList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

modelVersion := ModelVersion{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", modelVersionPath).Return(mockData, nil)

actual, err := modelVersion.GetAllModelVersions(mockClient)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.NextPageToken, actual.NextPageToken)
assert.Equal(t, expected.PageSize, actual.PageSize)
assert.Equal(t, expected.Size, actual.Size)
assert.Equal(t, len(expected.Items), len(actual.Items))

mockClient.AssertExpectations(t)
}

func TestCreateModelVersion(t *testing.T) {
_ = gofakeit.Seed(0)

Expand Down
4 changes: 2 additions & 2 deletions clients/ui/bff/internal/repositories/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type RegisteredModelInterface interface {
CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error)
GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error)
UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error)
GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error)
GetAllModelVersionsForRegisteredModel(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error)
CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
}

Expand Down Expand Up @@ -94,7 +94,7 @@ func (m RegisteredModel) UpdateRegisteredModel(client integrations.HTTPClientInt
return &model, nil
}

func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) {
func (m RegisteredModel) GetAllModelVersionsForRegisteredModel(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) {
path, err := url.JoinPath(registeredModelPath, id, versionsPath)

if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions clients/ui/bff/internal/repositories/registered_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestUpdateRegisteredModel(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllModelVersions(t *testing.T) {
func TestGetAllModelVersionsByRegisteredModel(t *testing.T) {
_ = gofakeit.Seed(0)

expected := mocks.GenerateMockModelVersionList()
Expand All @@ -149,7 +149,7 @@ func TestGetAllModelVersions(t *testing.T) {
assert.NoError(t, err)
mockClient.On("GET", path).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1", nil)
actual, err := registeredModel.GetAllModelVersionsForRegisteredModel(mockClient, "1", nil)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.NoError(t, err)
Expand Down Expand Up @@ -180,7 +180,7 @@ func TestGetAllModelVersionsWithPageParams(t *testing.T) {

mockClient.On("GET", reqUrl).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1", pageValues)
actual, err := registeredModel.GetAllModelVersionsForRegisteredModel(mockClient, "1", pageValues)
assert.NoError(t, err)
assert.NotNil(t, actual)

Expand Down

0 comments on commit f9f78c3

Please sign in to comment.