Skip to content

Commit

Permalink
hparams: add Python layer for hparams and metrics (#2126)
Browse files Browse the repository at this point in the history
Summary:
This change introduces `HParam`, `Metric`, and `Experiment` classes,
which represent their proto counterparts in a more Python-friendly way.
It similarly includes a `Domain` class hierarchy, which does not
correspond to a specific proto message, but rather unifies the domain
variants defined on the `HParamInfo` proto.

The design is roughly as in the original sketch of #1998.

The primary benefit of this change is that having first-class domains
enables clients to reuse the domain information for both the experiment
summary and the underlying tuning algorithm. We don’t provide a method
to do this out of the box, because we don’t actually provide any tuners
at this time, but it’s easy to write (e.g.) a `sample_uniform` function
like the one included in this commit. Then, sampling is as easy as

```python
    hparams = {h: sample_uniform(h.domain, rng) for h in HPARAMS}
```

It is also now more convenient to reference hparam values such that
static analysis can detect potential typos, because the `HParam` objects
themselves can be declared as constants and used as keys in a dict.
Writing `hparams["dropuot"]` fails at runtime, but `hparams[HP_DROPUOT]`
fails at lint time.

As a pleasant bonus, hparam definitions are now more compact, fitting
on one line instead of several. The demo code has net fewer lines.

Manual summary writer management is still required. A future change will
introduce a Keras callback to reduce this overhead.

Test Plan:
Some unit tests included, and the demo still works.

wchargin-branch: hparams-structured-api
  • Loading branch information
wchargin authored Apr 23, 2019
1 parent 2a3ef44 commit 669c420
Show file tree
Hide file tree
Showing 4 changed files with 761 additions and 76 deletions.
31 changes: 30 additions & 1 deletion tensorboard/plugins/hparams/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ py_binary(
srcs = ["hparams_demo.py"],
srcs_version = "PY2AND3",
deps = [
":api",
":protos_all_py_pb2",
":summary",
"//tensorboard:expect_absl_app_installed",
"//tensorboard:expect_absl_flags_installed",
"//tensorboard:expect_numpy_installed",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/plugins/scalar:summary",
"@com_google_protobuf//:protobuf_python",
"@org_pythonhosted_six",
],
)
Expand Down Expand Up @@ -149,6 +149,35 @@ sh_test(
],
)

py_library(
name = "api",
srcs = ["api.py"],
srcs_version = "PY2AND3",
visibility = [
"//visibility:public",
],
deps = [
":protos_all_py_pb2",
":summary",
"@org_pythonhosted_six",
],
)

py_test(
name = "api_test",
size = "small",
srcs = ["api_test.py"],
srcs_version = "PY2AND3",
deps = [
":api",
":metadata",
":protos_all_py_pb2",
"//tensorboard:test",
"@com_google_protobuf//:protobuf_python",
"@org_pythonhosted_six",
],
)

py_library(
name = "summary",
srcs = ["summary.py"],
Expand Down
Loading

0 comments on commit 669c420

Please sign in to comment.