Skip to content

Commit

Permalink
feat: det.keras.DeterminedCallback (#10075)
Browse files Browse the repository at this point in the history
It's everything you've ever wanted.
  • Loading branch information
rb-determined-ai committed Oct 23, 2024
1 parent 0210b1c commit 7212d0e
Show file tree
Hide file tree
Showing 29 changed files with 1,850 additions and 370 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ the the following capabilities:
- hyperparameter search
- distributing work across multiple GPUs and/or nodes

These are the same features provided by the higher-level PyTorchTrial, DeepSpeedTrial, and
TFKerasTrial APIs: those APIs are implemented using the Core API.
These features are also available in the higher-level PyTorchTrial and DeepSpeedTrial APIs, both of
which are built on top of the Core API.

This user guide shows you how to get started using the Core API.

Expand Down
200 changes: 114 additions & 86 deletions docs/model-dev-guide/api-guides/apis-howto/api-keras-ug.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,129 +7,157 @@
.. meta::
:description: Learn how to use the Keras API to train a Keras model. This user guide walks you through loading your data, defining the model, customizing how the model.fit function is called, checkpointing, and callbacks.

In this guide, you'll learn how to use the Keras API.
In this guide, you'll learn how to use Determined's ``keras.DeterminedCallback`` while training your
Keras model.

+---------------------------------------------------------------------+
| Visit the API reference |
+=====================================================================+
| :ref:`keras-reference` |
+---------------------------------------------------------------------+

This document guides you through training a Keras model in Determined. You need to implement a trial
class that inherits :class:`~determined.keras.TFKerasTrial` and specify it as the entrypoint in the
:ref:`experiment-configuration`.
This document guides you through training a Keras model in Determined. You will need to update your
``model.fit()`` call to include a :class:`~determined.keras.DeterminedCallback` and submit it to a
Determined cluster.

To learn about this API, you can start by reading the trial definitions in the `Iris categorization
example
To learn about this API, you can start by reading the ``train.py`` script in the `Iris
categorization example
<https://github.com/determined-ai/determined-examples/tree/main/computer_vision/iris_tf_keras>`__.

***********
Load Data
***********
**********************
Configure Entrypoint
**********************

.. note::
Determined requires you to launch training jobs by submitting them with an
:ref:`experiment-configuration`, which tells the Determined master how to start your container. For
Keras training, you should always wrap your training script in Determined's :ref:`TensorFlow
launcher <launch-tensorflow>`:

Before loading data, visit :ref:`load-model-data` to understand how to work with different
sources of data.
.. code:: yaml
Loading data is done by defining :meth:`~determined.keras.TFKerasTrial.build_training_data_loader`
and :meth:`~determined.keras.TFKerasTrial.build_validation_data_loader` methods. Each should return
one of the following data types:
entrypoint: >-
python3 -m determined.launch.tensorflow --
python3 my_train.py --my-arg...
#. A tuple ``(x, y)`` of NumPy arrays. x must be a NumPy array (or array-like), a list of arrays (in
case the model has multiple inputs), or a dict mapping input names to the corresponding array, if
the model has named inputs. y should be a numpy array.
Determined's TensorFlow launcher will automatically configure your training script with the right
``TF_CONFIG`` environment variable for distributed training when distributed resources are
available, and will safely do nothing when they are not.

#. A tuple ``(x, y, sample_weights)`` of NumPy arrays.
****************************************************************
Obtain a ``det.core.Context`` and a ``tf.distribute.Strategy``
****************************************************************

#. A ``tf.data.dataset`` returning a tuple of either (inputs, targets) or (inputs, targets,
sample_weights).
When using distributed training, TensorFlow requires you to create your ``Strategy`` early in the
process lifetime, before creating your model.

#. A ``keras.utils.Sequence`` returning a tuple of either (inputs, targets) or (inputs, targets,
sample weights).
Since you wrapped your training script in Determined's TensorFlow launcher, you can use Determined's
``core.DistributedContext.from_tf_config()`` helper, which will create both a suitable
``DistributedContext`` and ``Strategy`` for the training environment in your training job. Then you
can feed that ``DistributedContext`` to ``det.core.init()`` to get a ``core.Context``, and feed all
of that to your ``main()`` function (or equivalent) in your training script:

If using ``tf.data.Dataset``, users are required to wrap both their training and validation dataset
using :meth:`self.context.wrap_dataset <determined.keras.TFKerasTrialContext.wrap_dataset>`. This
wrapper is used to shard the dataset for distributed training. For optimal performance, users should
wrap a dataset immediately after creating it.
.. code:: python
.. include:: ../../../_shared/note-dtrain-learn-more.txt
if __name__ == "__main__":
distributed, strategy = det.core.DistributedContext.from_tf_config()
with det.core.init(distributed=distributed) as core_context:
main(core_context, strategy)
******************
Define the Model
******************
*****************
Build the Model
*****************

Users are required wrap their model prior to compiling it using :meth:`self.context.wrap_model
<determined.keras.TFKerasTrialContext.wrap_model>`. This is typically done inside
:meth:`~determined.keras.TFKerasTrial.build_model`.
Building a distributed-capable model is easy in Keras; you just need to wrap your model building and
compiling in the ``strategy.scope()``. See the `TensorFlow documentation
<https://www.tensorflow.org/tutorials/distribute/keras#create_the_model_and_instantiate_the_optimizer>`__
for more details

******************************************
Customize Calling Model Fitting Function
******************************************
.. code:: python
The :class:`~determined.keras.TFKerasTrial` interface allows the user to configure how ``model.fit``
is called by calling :meth:`self.context.configure_fit()
<determined.keras.TFKerasTrialContext.configure_fit>`.
def main(core_context, strategy):
with strategy.scope():
model = my_build_model()
model.compile(...)
***********************************
Create the ``DeterminedCallback``
***********************************

The :class:`~determined.keras.DeterminedCallback` automatically integrates your training with the
Determined cluster. It reports both train and test metrics, reports progress, saves checkpoints, and
uploads them to checkpoint storage. Additionally, it manages preemption signals from the Determined
master (for example, when you pause your experiment), gracefully halting training and later resuming
from where it left off.

The ``DeterminedCallback`` has only three required inputs:
- the ``core_context`` you already created
- a ``checkpoint`` UUID to start training from, or ``None``
- a ``continue_id`` used to decide how to treat the checkpoint

In training jobs, an easy value for ``checkpoint`` is ``det.get_cluster_info().latest_checkpoint``,
which will automatically be populated with the latest checkpoint saved by this trial, or ``None``.
If, for example, you wanted to start training from a checkpoint and support pausing and resuming,
you could use ``info.latest_checkpoint or my_starting_checkpoint``.

The ``continue_id`` helps the ``DeterminedCallback`` decide if the provided checkpoint represents
just the starting weights and training should begin at epoch=0, or if the checkpoint represents a
partially complete training that should pick up where it left off (at epoch > 0). The provided
``continue_id`` is saved along with every checkpoint, and when loading the starting checkpoint, if
the ``continue_id`` matches what was in the checkpoint, training state is also loaded from the
checkpoint. In training jobs, an easy value for ``continue_id`` is
``det.get_cluster_info.trial.trial_id``.

See the reference for :class:`~determined.keras.DeterminedCallback` for details on its optional
parameters.

***************
Checkpointing
***************
.. code:: python
A checkpoint includes the model definition (Python source code), experiment configuration file,
network architecture, and the values of the model's parameters (i.e., weights) and hyperparameters.
When using a stateful optimizer during training, checkpoints will also include the state of the
optimizer (i.e., learning rate). You can also embed arbitrary metadata in checkpoints via a
:ref:`Python SDK <store-checkpoint-metadata>`.
info = det.get_cluster_info()
assert info and info.task_type == "TRIAL", "this example only runs as a trial on the cluster"
TensorFlow Keras trials are checkpointed to a file named ``determined-keras-model.h5`` using
``tf.keras.models.save_model``. You can learn more from the `TF Keras docs
<https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras/models/save_model>`__.
det_cb = det.keras.DeterminedCallback(
core_context,
checkpoint=info.latest_checkpoint,
continue_id=info.trial.trial_id,
)
***********
Callbacks
Load Data
***********

To execute arbitrary Python code during the lifecycle of a :class:`~determined.keras.TFKerasTrial`,
implement the :class:`determined.keras.callbacks.Callback` interface (an extension of the
``tf.keras.callbacks.Callbacks`` interface) and supply them to the
:class:`~determined.keras.TFKerasTrial` by implementing
:meth:`~determined.keras.TFKerasTrial.keras_callbacks`.
Loading data is done as usual, though additional considerations may arise if your existing
data-loading code is not container-ready. For more details, see :ref:`load-model-data`.

.. _keras-profiler:
If you want to take advantage Determined's distributed training, you may need to ensure that your
input data is properly sharded. See `TensorFlow documentation
<https://www.tensorflow.org/tutorials/distribute/input#sharding>`__ for details.

***********
Profiling
***********

Determined supports integration with the native TF Keras profiler. Results will automatically be
uploaded to the trial's TensorBoard path and can be viewed in the Determined Web UI.
.. include:: ../../../_shared/note-dtrain-learn-more.txt

The Keras profiler is configured as a callback in the :class:`~determined.keras.TFKerasTrial` class.
The :class:`determined.keras.callbacks.TensorBoard` callback is a thin wrapper around the native
Keras TensorBoard callback, ``tf.keras.callbacks.TensorBoard``. It overrides the ``log_dir``
argument to set the Determined TensorBoard path, while other arguments are passed directly into
``tf.keras.callbacks.TensorBoard``. For a list of accepted arguments, consult the `official Keras
API documentation <https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/TensorBoard>`_.
*************************
TensorBoard Integration
*************************

The following code snippet will configure profiling for batches 5 and 10, and will compute weight
histograms every 1 epochs.
Optionally, you can use Determined's :class:`~determined.keras.TensorBoard` callback, which extends
Keras' ``TensorBoard`` callback with the ability to automatically upload metrics to Determined's
checkpoint storage. Determined's ``TensorBoard`` callback is configured identically to Keras' except
it takes an additional ``core_context`` initial argument:

.. code:: python
from determined import keras
tb_cb = det.keras.TensorBoard(core_context, ...)
def keras_callbacks(self) -> List[tf.keras.callbacks.Callback]:
return [
keras.callbacks.TensorBoard(
update_freq="batch",
profile_batch='5, 10',
histogram_freq=1,
)
]
Then simply include it in your ``model.fit()`` as normal.

.. note::
*************************
Calling ``model.fit()``
*************************

The only remaining step is to pass your callbacks to your ``model.fit()``:

.. code:: python
Though specifying batches to profile with ``profile_batch`` is optional, profiling every batch
may cause a large amount of data to be uploaded to Tensorboard. This may result in long rendering
times for Tensorboard and memory issues. For long-running experiments, it is recommended to
configure profiling only on desired batches.
model.fit(
...,
callbacks=[det_cb, tb_cb],
)
25 changes: 25 additions & 0 deletions docs/model-dev-guide/create-experiment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,31 @@ Use the ``-h`` option to get the latest usage:
python3 -m determined.launch.deepspeed -h
.. _launch-tensorflow:

TensorFlow Launcher
===================

Format:

``determined.launch.tensorflow [--] SCRIPT...``

This launcher configures a ``TF_CONFIG`` environment variable suitable for whichever level of
TensorFlow distributed training is appropriate for the available training resources
(``MultiWorkerMirroredStrategy``, ``MirroredStrategy``, or the default strategy).

Example:

.. code:: bash
python3 -m determined.launch.tensorflow -- python3 ./my_train.py --my-arg=value
Use the ``-h`` option to get the latest usage:

.. code:: bash
python3 -m determined.launch.tensorflow -h
Legacy Launcher
===============

Expand Down
10 changes: 5 additions & 5 deletions docs/model-dev-guide/debug-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ for debugging. See :ref:`pytorch_trainer_ug` for usage details.
#. Create simple tests to verify each ``Trial`` subclass method.

Examples of what these tests might look like for
:class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and :class:`~determined.keras.TFKerasTrial`
can be found in the :meth:`determined.TrialContext.from_config` documentation, but only you can
verify what is reasonable for your test.
:class:`~determined.pytorch.deepspeed.DeepSpeedTrial` can be found in the
:meth:`determined.TrialContext.from_config` documentation, but only you can verify what is
reasonable for your test.

#. Diagnose failures:

Expand Down Expand Up @@ -385,8 +385,8 @@ step only applies if you have multiple GPUs and want to use distributed training
consume too many resources and prevent the experiment from starting.

- Determined is designed to control the details of distributed training for you. If you also try
to control those details, such as by calling ``tf.config.set_visible_devices()`` in a
:class:`~determined.keras.TFKerasTrial`, it is likely to cause issues.
to control those details, such as by calling ``tf.config.set_visible_devices()`` while
training a Keras model, it is likely to cause issues.

- Some classes of metrics must be specially calculated during distributed training. Most
metrics, such as loss or accuracy, can be calculated piecemeal on each worker in a distributed
Expand Down
14 changes: 5 additions & 9 deletions docs/model-dev-guide/dtrain/reproducibility.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ The experiment seed is used as a source of randomness for any hyperparameter sam
The experiment seed is also used to generate a **trial seed** for every trial associated with the
experiment.

In the ``Trial`` interface, the trial seed is accessible within the trial class using
``self.ctx.get_trial_seed()``.
When training on-cluster, the trial seed is accessible via
:class:`det.get_cluster_info().trial.trial_seed <determined.get_cluster_info>`

*******************
Coding Guidelines
Expand All @@ -67,16 +67,12 @@ To achieve reproducible initial conditions in an experiment, please follow these
**************************************

When doing CPU-only training with TensorFlow, it is possible to achieve floating-point
reproducibility throughout optimization. If using the :class:`~determined.keras.TFKerasTrial` API,
implement the optional :meth:`~determined.keras.TFKerasTrial.session_config` method to override the
default session configuration:
reproducibility throughout optimization:

.. code:: python
def session_config(self) -> tf.ConfigProto:
return tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1
)
tf.config.threading.set_intra_op_parallelism_threads(1)
tf.config.threading.set_inter_op_parallelism_threads(1)
.. warning::

Expand Down
6 changes: 3 additions & 3 deletions docs/model-dev-guide/profiling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ training code. Identifying inefficiencies in individual training operations or s
fine-grained context than generic system metrics can provide. For this level of profiling,
Determined supports integration with training profilers that are native to their frameworks:

- PyTorch Profiler (:ref:`PyTorch API <pytorch_profiler>`)
- DeepSpeed Profiler (:ref:`DeepSpeed API <deepspeed-profiler>`)
- TensorFlow Keras Profiler (:ref:`Keras API <keras-profiler>`)
- :ref:`PyTorch Profiler <pytorch_profiler>`
- :ref:`DeepSpeed Profiler <deepspeed-profiler>`
- :class:`Keras TensorBoard callback <determined.keras.TensorBoard>`

Please see your framework's profiler documentation and the Determined Training API guide for usage
details.
Expand Down
9 changes: 5 additions & 4 deletions docs/reference/experiment-config-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,31 @@ field is empty.
Arbitrary Script
----------------

Required. An arbitrary entrypoint script name.
Required. An arbitrary entrypoint script with args.

Example:

.. code:: yaml
entrypoint: ./hello.sh
entrypoint: ./hello.sh args...
Preconfigured Launch Module with Script
---------------------------------------

Required. The name of a preconfigured launch module and script name.
Required. The name of a preconfigured launch module and script with args.

Example:

.. code:: yaml
entrypoint: python3 -m (LAUNCH_MODULE) train.py
entrypoint: python3 -m (LAUNCH_MODULE) train.py args...
``LAUNCH_MODULE`` options:

- Horovod (determined.launch.horovod)
- PyTorch (determined.launch.torch_distributed)
- Deepspeed (determined.launch.deepspeed)
- TensorFlow (determined.launch.tensorflow)

Preconfigured Launch Module with Legacy Trial Definition
--------------------------------------------------------
Expand Down
Loading

0 comments on commit 7212d0e

Please sign in to comment.