Skip to content

Commit

Permalink
Access row elements with row.x rather than row[x], compatible wit…
Browse files Browse the repository at this point in the history
…h sqlalchemy 2.0. Resolves #993

PiperOrigin-RevId: 586493566
  • Loading branch information
xingyousong authored and copybara-github committed Nov 30, 2023
1 parent aa3d9c0 commit c5ce9e0
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 30 deletions.
8 changes: 0 additions & 8 deletions requirements-client.txt

This file was deleted.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ portpicker>=1.3.1
grpcio>=1.35.0
grpcio-tools>=1.35.0
googleapis-common-protos>=1.56.4
sqlalchemy>=1.4,<=1.4.20
sqlalchemy>=1.4
2 changes: 1 addition & 1 deletion vizier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@

sys.path.append(PROTO_ROOT)

__version__ = "0.1.12"
__version__ = "0.1.13"
31 changes: 11 additions & 20 deletions vizier/_src/service/sql_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import collections
import threading
from typing import Callable, DefaultDict, Iterable, List, Optional
from typing import Callable, Iterable, List, Optional

from absl import logging
import sqlalchemy as sqla
Expand Down Expand Up @@ -124,7 +124,7 @@ def load_study(self, study_name: str) -> study_pb2.Study:
row = result.fetchone()
if not row:
raise NotFoundError('Failed to find study name: %s' % study_name)
return study_pb2.Study.FromString(row['serialized_study'])
return study_pb2.Study.FromString(row.serialized_study)

def update_study(self, study: study_pb2.Study) -> resources.StudyResource:
study_resource = resources.StudyResource.from_name(study.name)
Expand Down Expand Up @@ -190,9 +190,7 @@ def list_studies(self, owner_name: str) -> List[study_pb2.Study]:
raise NotFoundError('Owner name %s does not exist.' % owner_name)
result = self._connection.execute(lq).fetchall()

return [
study_pb2.Study.FromString(row['serialized_study']) for row in result
]
return [study_pb2.Study.FromString(row.serialized_study) for row in result]

def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:
trial_resource = resources.TrialResource.from_name(trial.name)
Expand Down Expand Up @@ -223,7 +221,7 @@ def get_trial(self, trial_name: str) -> study_pb2.Trial:
row = result.fetchone()
if not row:
raise NotFoundError('Failed to find trial name: %s' % trial_name)
return study_pb2.Trial.FromString(row['serialized_trial'])
return study_pb2.Trial.FromString(row.serialized_trial)

def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:
trial_resource = resources.TrialResource.from_name(trial.name)
Expand Down Expand Up @@ -269,9 +267,7 @@ def list_trials(self, study_name: str) -> List[study_pb2.Trial]:
raise NotFoundError('Study name %s does not exist.' % study_name)
result = self._connection.execute(lq)

return [
study_pb2.Trial.FromString(row['serialized_trial']) for row in result
]
return [study_pb2.Trial.FromString(row.serialized_trial) for row in result]

def delete_trial(self, trial_name: str) -> None:
# Exist query
Expand Down Expand Up @@ -347,7 +343,7 @@ def get_suggestion_operation(
row = result.fetchone()
if not row:
raise NotFoundError('Failed to find suggest op name: %s' % operation_name)
return operations_pb2.Operation.FromString(row['serialized_op'])
return operations_pb2.Operation.FromString(row.serialized_op)

def update_suggestion_operation(
self, operation: operations_pb2.Operation
Expand Down Expand Up @@ -407,8 +403,7 @@ def list_suggestion_operations(
result = self._connection.execute(q)

all_ops = [
operations_pb2.Operation.FromString(row['serialized_op'])
for row in result
operations_pb2.Operation.FromString(row.serialized_op) for row in result
]
if filter_fn is not None:
output_list = []
Expand Down Expand Up @@ -495,9 +490,7 @@ def get_early_stopping_operation(
raise NotFoundError(
'Failed to find early stopping op name: %s' % operation_name
)
return vizier_oss_pb2.EarlyStoppingOperation.FromString(
row['serialized_op']
)
return vizier_oss_pb2.EarlyStoppingOperation.FromString(row.serialized_op)

def update_early_stopping_operation(
self, operation: vizier_oss_pb2.EarlyStoppingOperation
Expand Down Expand Up @@ -552,7 +545,7 @@ def update_metadata(
row = study_result.fetchone()
if not row:
raise NotFoundError('No such study:', s_resource.name)
original_study = study_pb2.Study.FromString(row['serialized_study'])
original_study = study_pb2.Study.FromString(row.serialized_study)

# Store Study-related metadata into the database.
vz.metadata_util.merge_study_metadata(
Expand All @@ -565,9 +558,7 @@ def update_metadata(
self._connection.execute(usq)

# Split the trial-related metadata by Trial.
split_metadata: DefaultDict[str, List[datastore.UnitMetadataUpdate]] = (
collections.defaultdict(list)
)
split_metadata = collections.defaultdict(list)
for md in trial_metadata:
split_metadata[md.trial_id].append(md)

Expand All @@ -583,7 +574,7 @@ def update_metadata(
row = trial_result.fetchone()
if not row:
raise NotFoundError('No such trial:', trial_name)
original_trial = study_pb2.Trial.FromString(row['serialized_trial'])
original_trial = study_pb2.Trial.FromString(row.serialized_trial)

# Update Trial.
vz.metadata_util.merge_trial_metadata(original_trial, md_list)
Expand Down

0 comments on commit c5ce9e0

Please sign in to comment.