Skip to content

Commit

Permalink
Move SearchExperiments endpoint (#68)
Browse files Browse the repository at this point in the history
Signed-off-by: Software Developer <7852635+dsuhinin@users.noreply.github.com>
  • Loading branch information
dsuhinin authored Oct 21, 2024
1 parent 7c480fd commit 5c15e67
Show file tree
Hide file tree
Showing 19 changed files with 1,039 additions and 484 deletions.
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ linters:
- gosimple
- lll
disable:
- ireturn
- depguard
- gochecknoglobals # Immutable globals are fine.
- exhaustruct # Often the case for protobuf generated code or gorm structs.
Expand Down
5 changes: 5 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def pytest_configure(config):
(
"tests.store.tracking.test_sqlalchemy_store.test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db",
"tests/override_test_sqlalchemy_store.py",
), # We do not support applying the SQL schema to sqlite like Python does.
# So we do not support sqlite:////:memory: database.
(
"tests.store.tracking.test_sqlalchemy_store.test_search_experiments_max_results_validation",
"tests/override_test_sqlalchemy_store.py",
),
):
func_name = func_to_patch.rsplit(".", 1)[1]
Expand Down
2 changes: 1 addition & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
ImplementedEndpoints: []string{
"getExperimentByName",
"createExperiment",
// "searchExperiments",
"searchExperiments",
"getExperiment",
"deleteExperiment",
"restoreExperiment",
Expand Down
1 change: 1 addition & 0 deletions magefiles/generate/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var validations = map[string]string{
"SetExperimentTag_ExperimentId": "required",
"SetExperimentTag_Key": "required,max=250,validMetricParamOrTagName",
"SetExperimentTag_Value": "max=5000",
"SearchExperiments_MaxResults": "positiveNonZeroInteger,max=50000",
"SetTag_Key": "required,max=1000,validMetricParamOrTagName,pathIsUnique",
"SetTag_Value": "omitempty,truncate=8000",
"LogInputs_RunId": "required,runId",
Expand Down
24 changes: 24 additions & 0 deletions mlflow_go/store/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
LogParam,
RestoreExperiment,
RestoreRun,
SearchExperiments,
SearchRuns,
SetTag,
SetTraceTag,
UpdateExperiment,
UpdateRun,
)
from mlflow.store.entities import PagedList
from mlflow.store.tracking import SEARCH_MAX_RESULTS_DEFAULT
from mlflow.utils.uri import resolve_uri_if_local

from mlflow_go import is_go_enabled
Expand Down Expand Up @@ -201,6 +204,27 @@ def delete_trace_tag(self, request_id: str, key: str):
)
self.service.call_endpoint(get_lib().TrackingServiceDeleteTraceTag, request)

def search_experiments(
self,
view_type=ViewType.ACTIVE_ONLY,
max_results=SEARCH_MAX_RESULTS_DEFAULT,
filter_string=None,
order_by=None,
page_token=None,
):
request = SearchExperiments(
view_type=view_type,
max_results=max_results,
filter=filter_string,
order_by=order_by,
page_token=page_token,
)
response = self.service.call_endpoint(get_lib().TrackingServiceSearchExperiments, request)
experiments = [
Experiment.from_proto(proto_experiment) for proto_experiment in response.experiments
]
return PagedList(experiments, (response.next_page_token or None))

def set_tag(self, run_id, tag):
request = SetTag(run_id=run_id, key=tag.key, value=tag.value)
self.service.call_endpoint(get_lib().TrackingServiceSetTag, request)
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pkg/lib/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/protos/service.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions pkg/server/routes/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions pkg/tracking/service/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,32 @@ func (ts TrackingService) SetExperimentTag(

return &protos.SetExperimentTag_Response{}, nil
}

func (ts TrackingService) SearchExperiments(
ctx context.Context, input *protos.SearchExperiments,
) (*protos.SearchExperiments_Response, *contract.Error) {
experiments, nextPageToken, err := ts.Store.SearchExperiments(
ctx,
input.GetViewType(),
input.GetMaxResults(),
input.GetFilter(),
input.GetOrderBy(),
input.GetPageToken(),
)
if err != nil {
return nil, contract.NewError(protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("error getting experiments: %v", err))
}

response := protos.SearchExperiments_Response{
Experiments: make([]*protos.Experiment, len(experiments)),
}
if nextPageToken != "" {
response.NextPageToken = &nextPageToken
}

for i, experiment := range experiments {
response.Experiments[i] = experiment.ToProto()
}

return &response, nil
}
72 changes: 72 additions & 0 deletions pkg/tracking/store/mock_tracking_store.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 59 additions & 0 deletions pkg/tracking/store/sql/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,65 @@ func (s TrackingSQLStore) GetExperimentByName(
return experiment.ToEntity(), nil
}

func (s TrackingSQLStore) SearchExperiments(
ctx context.Context,
experimentViewType protos.ViewType,
maxResults int64,
filter string,
orderBy []string,
pageToken string,
) ([]*entities.Experiment, string, *contract.Error) {
query := applyExperimentsLifecycleStagesFilter(s.db.WithContext(ctx), experimentViewType)

// apply Limit
query, limit := applyExperimentsLimitFilter(query, maxResults)

// apply Offset
query, offset, err := applyExperimentsOffsetFilter(query, pageToken)
if err != nil {
return nil, "", err
}

// Apply Filter
query, err = applyExperimentsFilter(s.db, query, filter)
if err != nil {
return nil, "", err
}

// OrderBy
query, err = applyExperimentsOrderBy(query, orderBy)
if err != nil {
return nil, "", err
}

// Actual query
var experiments []models.Experiment
if err := query.Preload("Tags").Find(&experiments).Error; err != nil {
return nil, "", contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR,
fmt.Sprintf("failed to get runs %q", err),
err,
)
}

// encode `nextPageToken` value.
nextPageToken, err := createExperimentsNextPageToken(experiments, limit, offset)
if err != nil {
return nil, "", err
}

if len(experiments) > limit {
experiments = experiments[:limit]
}

data := make([]*entities.Experiment, len(experiments))
for i, experiment := range experiments {
data[i] = experiment.ToEntity()
}

return data, nextPageToken, nil
}

func (s TrackingSQLStore) SetExperimentTag(
ctx context.Context, experimentID, key, value string,
) *contract.Error {
Expand Down
Loading

0 comments on commit 5c15e67

Please sign in to comment.