Skip to content

Commit

Permalink
VizierConverter.to_dna to pick up DNA metadata from global namespac…
Browse files Browse the repository at this point in the history
…e if not present under pyglove namespace.

PiperOrigin-RevId: 596084062
  • Loading branch information
daiyip authored and copybara-github committed Jan 5, 2024
1 parent 4a90bbe commit 0bc5ee7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
22 changes: 17 additions & 5 deletions vizier/_src/pyglove/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,14 +420,26 @@ def to_dna(self, trial: vz.Trial) -> pg.DNA:
decision_dict, self.dna_spec, use_ints_as_literals=True)

# Restore DNA metadata if present
dna_metada = trial.metadata.ns(constants.METADATA_NAMESPACE).get(
constants.TRIAL_METADATA_KEY_DNA_METADATA, None)
dna_metadata = trial.metadata.ns(constants.METADATA_NAMESPACE).get(
constants.TRIAL_METADATA_KEY_DNA_METADATA, None
)

if dna_metadata is None:
# NOTE(daiyip): To be compatible with V1 pipeline for transfer learning,
# we also try to read DNA_METADATA stored under the global (empty)
# namespace.
dna_metadata = trial.metadata.get(
constants.TRIAL_METADATA_KEY_DNA_METADATA, None
)

if dna_metada is not None:
if dna_metadata is not None:
dna.rebind(
metadata=pg.from_json_str(dna_metada),
metadata=pg.from_json_str(dna_metadata),
skip_notification=True,
raise_on_no_change=False)
raise_on_no_change=False,
)
else:
logging.warn('DNA metadata is None for trial: %s', trial)
return dna

def to_trial(self, dna: pg.DNA, *,
Expand Down
11 changes: 8 additions & 3 deletions vizier/_src/pyglove/converters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_vizier_trial_to_tuner_trial(self):
),
)

@parameterized.parameters((7,), (7.,))
def test_trial_to_dna(self, discrete_int_value):
@parameterized.parameters((7, None), (7.0, 'pyglove'))
def test_trial_to_dna(self, discrete_int_value, metadata_ns):
vc = converters.VizierConverter.from_problem(
vz.ProblemStatement(search_space=self._search_space()))
trial = vz.Trial()
Expand All @@ -108,7 +108,12 @@ def test_trial_to_dna(self, discrete_int_value):
trial.parameters['discrete_int'] = discrete_int_value
trial.parameters['discrete_double'] = 4.1
trial.parameters['categorical'] = 'a'
self.assertEqual(vc.to_dna(trial), pg.DNA([-.5, 1, 2, 1, 0]))

metadata = trial.metadata.ns(metadata_ns) if metadata_ns else trial.metadata
metadata[constants.TRIAL_METADATA_KEY_DNA_METADATA] = '{"log_prob": 1.0}'
dna = vc.to_dna(trial)
self.assertEqual(dna, pg.DNA([-0.5, 1, 2, 1, 0]))
self.assertEqual(dna.metadata, dict(log_prob=1.0))


class PyGloveCreatedSearchSpaceTest(parameterized.TestCase):
Expand Down

0 comments on commit 0bc5ee7

Please sign in to comment.