diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index 0a917ac7..c0bcb2d5 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -56,7 +56,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ // "searchRegisteredModels", "getLatestVersions", // "createModelVersion", - // "updateModelVersion", + "updateModelVersion", // "transitionModelVersionStage", "deleteModelVersion", // "getModelVersion", diff --git a/mlflow_go/store/model_registry.py b/mlflow_go/store/model_registry.py index f42f6904..01d76e9f 100644 --- a/mlflow_go/store/model_registry.py +++ b/mlflow_go/store/model_registry.py @@ -8,6 +8,7 @@ GetLatestVersions, GetRegisteredModel, RenameRegisteredModel, + UpdateModelVersion, UpdateRegisteredModel, ) @@ -84,6 +85,10 @@ def delete_model_version(self, name, version): request = DeleteModelVersion(name=name, version=str(version)) self.service.call_endpoint(get_lib().ModelRegistryServiceDeleteModelVersion, request) + def update_model_version(self, name, version, description=None): + request = UpdateModelVersion(name=name, version=str(version), description=description) + self.service.call_endpoint(get_lib().ModelRegistryServiceUpdateModelVersion, request) + def ModelRegistryStore(cls): return type(cls.__name__, (_ModelRegistryStore, cls), {}) diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index af202b9d..573383c7 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -15,5 +15,6 @@ type ModelRegistryService interface { DeleteRegisteredModel(ctx context.Context, input *protos.DeleteRegisteredModel) (*protos.DeleteRegisteredModel_Response, *contract.Error) GetRegisteredModel(ctx context.Context, input *protos.GetRegisteredModel) (*protos.GetRegisteredModel_Response, *contract.Error) GetLatestVersions(ctx context.Context, input *protos.GetLatestVersions) (*protos.GetLatestVersions_Response, *contract.Error) + UpdateModelVersion(ctx context.Context, input *protos.UpdateModelVersion) (*protos.UpdateModelVersion_Response, *contract.Error) DeleteModelVersion(ctx context.Context, input *protos.DeleteModelVersion) (*protos.DeleteModelVersion_Response, *contract.Error) } diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index e114f057..9a5bdba8 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -47,6 +47,14 @@ func ModelRegistryServiceGetLatestVersions(serviceID int64, requestData unsafe.P } return invokeServiceMethod(service.GetLatestVersions, new(protos.GetLatestVersions), requestData, requestSize, responseSize) } +//export ModelRegistryServiceUpdateModelVersion +func ModelRegistryServiceUpdateModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := modelRegistryServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.UpdateModelVersion, new(protos.UpdateModelVersion), requestData, requestSize, responseSize) +} //export ModelRegistryServiceDeleteModelVersion func ModelRegistryServiceDeleteModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := modelRegistryServices.Get(serviceID) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index 725c68bc..ac53eca7 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -86,3 +86,16 @@ func (m *ModelRegistryService) DeleteModelVersion( return &protos.DeleteModelVersion_Response{}, nil } + +func (m *ModelRegistryService) UpdateModelVersion( + ctx context.Context, input *protos.UpdateModelVersion, +) (*protos.UpdateModelVersion_Response, *contract.Error) { + modelVersion, err := m.store.UpdateModelVersion(ctx, input.GetName(), input.GetVersion(), input.GetDescription()) + if err != nil { + return nil, err + } + + return &protos.UpdateModelVersion_Response{ + ModelVersion: modelVersion.ToProto(), + }, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 1ecfc586..4b18973e 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -353,3 +353,27 @@ func (m *ModelRegistrySQLStore) DeleteModelVersion(ctx context.Context, name, ve return nil } + +func (m *ModelRegistrySQLStore) UpdateModelVersion( + ctx context.Context, name, version, description string, +) (*entities.ModelVersion, *contract.Error) { + modelVersion, err := m.GetModelVersion(ctx, name, version) + if err != nil { + return nil, err + } + + if err := m.db.WithContext(ctx).Model( + &models.ModelVersion{}, + ).Where( + "name = ?", modelVersion.Name, + ).Where( + "version = ?", modelVersion.Version, + ).Updates(&models.ModelVersion{ + Description: sql.NullString{String: description, Valid: description != ""}, + LastUpdatedTime: time.Now().UnixMilli(), + }).Error; err != nil { + return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "error updating model version", err) + } + + return modelVersion, nil +} diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index fbbffe29..b1245e30 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -16,4 +16,5 @@ type ModelRegistryStore interface { RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error) DeleteRegisteredModel(ctx context.Context, name string) *contract.Error DeleteModelVersion(ctx context.Context, name, version string) *contract.Error + UpdateModelVersion(ctx context.Context, name, version, description string) (*entities.ModelVersion, *contract.Error) } diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index 80618f06..ca0f2489 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -77,6 +77,17 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) + app.Patch("/mlflow/model-versions/update", func(ctx *fiber.Ctx) error { + input := &protos.UpdateModelVersion{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.UpdateModelVersion(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) app.Delete("/mlflow/model-versions/delete", func(ctx *fiber.Ctx) error { input := &protos.DeleteModelVersion{} if err := parser.ParseBody(ctx, input); err != nil {