Skip to content

Commit

Permalink
Add ml_monitoring to trace ML events like train_step, eval_step durin…
Browse files Browse the repository at this point in the history
…g the training.

PiperOrigin-RevId: 679779091
  • Loading branch information
The paxml Authors committed Sep 28, 2024
1 parent 98c3daf commit 3dc4a6f
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 30 deletions.
7 changes: 7 additions & 0 deletions paxml/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ pytype_strict_library(
# Implicit etils dependency.
# Implicit fiddle.absl_flags dependency.
# Implicit jax dependency.
"//paxml:ml_monitoring",
"//praxis:pax_fiddle",
"//praxis:py_utils",
# Implicit seqio dependency.
Expand Down Expand Up @@ -842,6 +843,7 @@ pytype_strict_library(
# Implicit etils dependency.
# Implicit jax dependency.
"//paxml:checkpoints",
"//paxml:ml_monitoring",
"//praxis:base_hyperparams",
"//praxis:base_input",
"//praxis:base_layer",
Expand Down Expand Up @@ -888,6 +890,11 @@ pytype_strict_library(
srcs = ["host_callback.py"],
)

pytype_strict_library(
name = "ml_monitoring",
srcs = ["ml_monitoring.py"],
)

pytype_strict_test(
name = "host_callback_test",
srcs = ["host_callback_test.py"],
Expand Down
142 changes: 121 additions & 21 deletions paxml/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from paxml import train_states
from paxml import trainer_lib
from paxml import tuning_lib
from paxml import ml_monitoring
from praxis import base_hyperparams
from praxis import base_input
from praxis import base_layer
Expand Down Expand Up @@ -306,7 +307,7 @@ def start(self):
logging.info('[PAX STATUS]: Executor shutdown complete.')


def _train_and_evaluate_common(
def _prepare_train_and_evaluate(
*,
task: tasks_lib.SingleTask,
partitioner: partitioning.Partitioner,
Expand All @@ -318,14 +319,12 @@ def _train_and_evaluate_common(
eval_programs: Sequence[programs.BaseEvalProgram],
decode_programs: Sequence[decode_programs_lib.SingleTaskDecodeProgram],
total_num_params,
early_stopping_fn: trainer_lib.EarlyStoppingFn | None,
checkpointer,
job_log_dir: epath.Path,
eval_prng_seed,
decode_prng_seed,
is_vars_replicated,
train_prng_seed,
exit_after_ondemand_checkpoint,
enable_summary_writer: bool = True,
):
"""Training loop code common to both pmap and spmd."""
Expand Down Expand Up @@ -366,7 +365,11 @@ def _train_and_evaluate_common(
)
for program in decode_programs:
program.setup(
task, partitioner, job_log_dir, decode_prng_seed, enable_summary_writer
task,
partitioner,
job_log_dir,
decode_prng_seed,
enable_summary_writer,
)
trainer_lib.check_unique_names([p.eval_input for p in eval_programs])
trainer_lib.check_unique_names([p.decode_input for p in decode_programs])
Expand All @@ -385,14 +388,27 @@ def _train_and_evaluate_common(
summary_utils.write_global_batch_size(
train_summary_writer, train_program.train_unpadded_global_batch_size
)
return train_state_metadata, train_input_for_checkpoint, step_i, train_p

# Start the train loop. Make sure all at the same step.
py_utils.sync_global_devices(f'Start training loop from step: {step_i}')
# Collect then freeze GC, so that GC in the training loop will not touch the
# python objects used to initialize the model. Unfreeze at the end of the
# loop.
gc.collect()
gc.freeze()

def _train_loop(
*,
task: tasks_lib.SingleTask,
train_program: programs.BaseTrainProgram,
partitioned_train_state: TrainState,
# TODO(hthu): Take a more generalized form of EvalProgram interface.
eval_programs: Sequence[programs.BaseEvalProgram],
decode_programs: Sequence[decode_programs_lib.SingleTaskDecodeProgram],
total_num_params,
early_stopping_fn: trainer_lib.EarlyStoppingFn | None,
checkpointer,
job_log_dir: epath.Path,
exit_after_ondemand_checkpoint,
train_state_metadata,
train_input_for_checkpoint,
step_i,
train_p,
):
while True:
logging.log_first_n(INFO, '[PAX STATUS]: Beginning step `%d`.', 5, step_i)
checkpointer.save_if_needed(
Expand All @@ -418,12 +434,12 @@ def _train_and_evaluate_common(
train_p.num_train_steps,
)
break
with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.TRAIN_STEP):
partitioned_train_state = train_program.update_state(
partitioned_train_state, step_i
)
program_output = train_program.run(partitioned_train_state, step_i)

partitioned_train_state = train_program.update_state(
partitioned_train_state, step_i
)

program_output = train_program.run(partitioned_train_state, step_i)
partitioned_train_state = program_output.state
train_weighted_scalars = program_output.weighted_scalars
steps_per_sec = program_output.steps_per_sec
Expand All @@ -433,6 +449,7 @@ def _train_and_evaluate_common(
# While the eval ones below are post-model weight updates, hence the step
# counter is incremented in between.
step_i = program_output.new_train_step
ml_monitoring.record_step_number(step_i)

eval_metrics: tuning_lib.EvalMetrics | None = None
# Run eval at regular step interval or the final training step.
Expand All @@ -447,11 +464,12 @@ def _train_and_evaluate_common(
# If we have eval test then also evaluate on test.
if eval_programs:
logging.debug('[PAX STATUS]: Running eval programs.')
eval_metrics, elapsed_secs = eval_lib.run_eval_programs(
eval_programs=eval_programs,
train_state=eval_partitioned_train_state,
step=step_i,
)
with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.EVAL_STEP):
eval_metrics, elapsed_secs = eval_lib.run_eval_programs(
eval_programs=eval_programs,
train_state=eval_partitioned_train_state,
step=step_i,
)
jax.monitoring.record_event_duration_secs(
'/jax/pax/train/interleaved_eval_duration_sec', elapsed_secs
)
Expand Down Expand Up @@ -539,6 +557,88 @@ def _train_and_evaluate_common(
train_p.num_train_steps,
)
break
return (
step_i,
partitioned_train_state,
train_state_metadata,
train_input_for_checkpoint,
)


def _train_and_evaluate_common(
*,
task: tasks_lib.SingleTask,
partitioner: partitioning.Partitioner,
train_program: programs.BaseTrainProgram,
train_input: base_input.BaseInput,
partitioned_train_state: TrainState,
train_state_provenance: TrainStateProvenance,
# TODO(hthu): Take a more generalized form of EvalProgram interface.
eval_programs: Sequence[programs.BaseEvalProgram],
decode_programs: Sequence[decode_programs_lib.SingleTaskDecodeProgram],
total_num_params,
early_stopping_fn: trainer_lib.EarlyStoppingFn | None,
checkpointer,
job_log_dir: epath.Path,
eval_prng_seed,
decode_prng_seed,
is_vars_replicated,
train_prng_seed,
exit_after_ondemand_checkpoint,
enable_summary_writer: bool = True,
):
with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.INITIALIZE_BACKEND):
train_state_metadata, train_input_for_checkpoint, step_i, train_p = (
_prepare_train_and_evaluate(
task=task,
partitioner=partitioner,
train_program=train_program,
train_input=train_input,
partitioned_train_state=partitioned_train_state,
train_state_provenance=train_state_provenance,
eval_programs=eval_programs,
decode_programs=decode_programs,
total_num_params=total_num_params,
checkpointer=checkpointer,
job_log_dir=job_log_dir,
eval_prng_seed=eval_prng_seed,
decode_prng_seed=decode_prng_seed,
is_vars_replicated=is_vars_replicated,
train_prng_seed=train_prng_seed,
enable_summary_writer=enable_summary_writer,
)
)

# Start the train loop. Make sure all at the same step.
py_utils.sync_global_devices(f'Start training loop from step: {step_i}')
# Collect then freeze GC, so that GC in the training loop will not touch the
# python objects used to initialize the model. Unfreeze at the end of the
# loop.
gc.collect()
gc.freeze()

with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.MAIN_LOOP):
(
step_i,
partitioned_train_state,
train_state_metadata,
train_input_for_checkpoint,
) = _train_loop(
task=task,
train_program=train_program,
partitioned_train_state=partitioned_train_state,
eval_programs=eval_programs,
decode_programs=decode_programs,
total_num_params=total_num_params,
early_stopping_fn=early_stopping_fn,
checkpointer=checkpointer,
job_log_dir=job_log_dir,
exit_after_ondemand_checkpoint=exit_after_ondemand_checkpoint,
train_state_metadata=train_state_metadata,
train_input_for_checkpoint=train_input_for_checkpoint,
step_i=step_i,
train_p=train_p,
)
gc.unfreeze()

logging.info('[PAX STATUS]: Saving checkpoint for final step.')
Expand Down
27 changes: 18 additions & 9 deletions paxml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from paxml import train
from paxml import trainer_lib
from paxml import tuning_lib
from paxml import ml_monitoring
from praxis import pax_fiddle
from praxis import py_utils

Expand Down Expand Up @@ -496,15 +497,8 @@ def main(argv: Sequence[str]) -> None:
_main(argv)


@py_utils.benchmark(prefix='[PAX STATUS]: E2E time: ')
def _main(argv: Sequence[str]) -> None:
logging.info('[PAX STATUS]: Program start.')
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')

if tf_data_service_lib.run_tf_data_service(FLAGS.mode):
return

def create_experiment_config():
"""Creates the experiment config from the command line flags."""
if FLAGS.tfds_data_dir is not None:
# seqio import is slow so avoid module-level import
import seqio
Expand Down Expand Up @@ -556,6 +550,21 @@ def _main(argv: Sequence[str]) -> None:
)

experiment_config.validate()
return experiment_config


@py_utils.benchmark(prefix='[PAX STATUS]: E2E time: ')
def _main(argv: Sequence[str]) -> None:
logging.info('[PAX STATUS]: Program start.')
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')

if tf_data_service_lib.run_tf_data_service(FLAGS.mode):
return

with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.INITIALIZE_SETUP):
experiment_config = create_experiment_config()

run(
experiment_config=experiment_config,
enable_checkpoint_saving=FLAGS.enable_checkpoint_saving,
Expand Down
52 changes: 52 additions & 0 deletions paxml/ml_monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2022 The Pax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ML Monitoring for PAX."""

import contextlib
import enum


class MlEvent(enum.Enum):
"""ML events to be recorded."""

INITIALIZE_BACKEND = enum.auto()
INITIALIZE_SETUP = enum.auto()
MAIN_LOOP = enum.auto()
TRAIN_STEP = enum.auto()
EVAL_STEP = enum.auto()
DECODE_STEP = enum.auto()


class EventBoundary(enum.Enum):
"""Event boundary to be recorded."""

START = enum.auto()
END = enum.auto()


def record_step_number(step_number: int):
"""Records the step number."""
pass


def record_event_boundary(event: MlEvent, boundary: EventBoundary, **kwargs):
"""Records the event boundary."""
pass


@contextlib.contextmanager
def ml_event_logger(event: MlEvent, **kwargs):
yield

0 comments on commit 3dc4a6f

Please sign in to comment.