From 3dc4a6f904dec54849208a3b64df0379991f5e89 Mon Sep 17 00:00:00 2001 From: The paxml Authors Date: Fri, 27 Sep 2024 17:10:27 -0700 Subject: [PATCH] Add ml_monitoring to trace ML events like train_step, eval_step during the training. PiperOrigin-RevId: 679779091 --- paxml/BUILD | 7 ++ paxml/executors.py | 142 +++++++++++++++++++++++++++++++++++------ paxml/main.py | 27 +++++--- paxml/ml_monitoring.py | 52 +++++++++++++++ 4 files changed, 198 insertions(+), 30 deletions(-) create mode 100644 paxml/ml_monitoring.py diff --git a/paxml/BUILD b/paxml/BUILD index cb7def880..7a23e2559 100644 --- a/paxml/BUILD +++ b/paxml/BUILD @@ -351,6 +351,7 @@ pytype_strict_library( # Implicit etils dependency. # Implicit fiddle.absl_flags dependency. # Implicit jax dependency. + "//paxml:ml_monitoring", "//praxis:pax_fiddle", "//praxis:py_utils", # Implicit seqio dependency. @@ -842,6 +843,7 @@ pytype_strict_library( # Implicit etils dependency. # Implicit jax dependency. "//paxml:checkpoints", + "//paxml:ml_monitoring", "//praxis:base_hyperparams", "//praxis:base_input", "//praxis:base_layer", @@ -888,6 +890,11 @@ pytype_strict_library( srcs = ["host_callback.py"], ) +pytype_strict_library( + name = "ml_monitoring", + srcs = ["ml_monitoring.py"], +) + pytype_strict_test( name = "host_callback_test", srcs = ["host_callback_test.py"], diff --git a/paxml/executors.py b/paxml/executors.py index 26cacbb7b..ec4ca23d1 100644 --- a/paxml/executors.py +++ b/paxml/executors.py @@ -34,6 +34,7 @@ from paxml import train_states from paxml import trainer_lib from paxml import tuning_lib +from paxml import ml_monitoring from praxis import base_hyperparams from praxis import base_input from praxis import base_layer @@ -306,7 +307,7 @@ def start(self): logging.info('[PAX STATUS]: Executor shutdown complete.') -def _train_and_evaluate_common( +def _prepare_train_and_evaluate( *, task: tasks_lib.SingleTask, partitioner: partitioning.Partitioner, @@ -318,14 +319,12 @@ def _train_and_evaluate_common( eval_programs: Sequence[programs.BaseEvalProgram], decode_programs: Sequence[decode_programs_lib.SingleTaskDecodeProgram], total_num_params, - early_stopping_fn: trainer_lib.EarlyStoppingFn | None, checkpointer, job_log_dir: epath.Path, eval_prng_seed, decode_prng_seed, is_vars_replicated, train_prng_seed, - exit_after_ondemand_checkpoint, enable_summary_writer: bool = True, ): """Training loop code common to both pmap and spmd.""" @@ -366,7 +365,11 @@ def _train_and_evaluate_common( ) for program in decode_programs: program.setup( - task, partitioner, job_log_dir, decode_prng_seed, enable_summary_writer + task, + partitioner, + job_log_dir, + decode_prng_seed, + enable_summary_writer, ) trainer_lib.check_unique_names([p.eval_input for p in eval_programs]) trainer_lib.check_unique_names([p.decode_input for p in decode_programs]) @@ -385,14 +388,27 @@ def _train_and_evaluate_common( summary_utils.write_global_batch_size( train_summary_writer, train_program.train_unpadded_global_batch_size ) + return train_state_metadata, train_input_for_checkpoint, step_i, train_p - # Start the train loop. Make sure all at the same step. - py_utils.sync_global_devices(f'Start training loop from step: {step_i}') - # Collect then freeze GC, so that GC in the training loop will not touch the - # python objects used to initialize the model. Unfreeze at the end of the - # loop. - gc.collect() - gc.freeze() + +def _train_loop( + *, + task: tasks_lib.SingleTask, + train_program: programs.BaseTrainProgram, + partitioned_train_state: TrainState, + # TODO(hthu): Take a more generalized form of EvalProgram interface. + eval_programs: Sequence[programs.BaseEvalProgram], + decode_programs: Sequence[decode_programs_lib.SingleTaskDecodeProgram], + total_num_params, + early_stopping_fn: trainer_lib.EarlyStoppingFn | None, + checkpointer, + job_log_dir: epath.Path, + exit_after_ondemand_checkpoint, + train_state_metadata, + train_input_for_checkpoint, + step_i, + train_p, +): while True: logging.log_first_n(INFO, '[PAX STATUS]: Beginning step `%d`.', 5, step_i) checkpointer.save_if_needed( @@ -418,12 +434,12 @@ def _train_and_evaluate_common( train_p.num_train_steps, ) break + with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.TRAIN_STEP): + partitioned_train_state = train_program.update_state( + partitioned_train_state, step_i + ) + program_output = train_program.run(partitioned_train_state, step_i) - partitioned_train_state = train_program.update_state( - partitioned_train_state, step_i - ) - - program_output = train_program.run(partitioned_train_state, step_i) partitioned_train_state = program_output.state train_weighted_scalars = program_output.weighted_scalars steps_per_sec = program_output.steps_per_sec @@ -433,6 +449,7 @@ def _train_and_evaluate_common( # While the eval ones below are post-model weight updates, hence the step # counter is incremented in between. step_i = program_output.new_train_step + ml_monitoring.record_step_number(step_i) eval_metrics: tuning_lib.EvalMetrics | None = None # Run eval at regular step interval or the final training step. @@ -447,11 +464,12 @@ def _train_and_evaluate_common( # If we have eval test then also evaluate on test. if eval_programs: logging.debug('[PAX STATUS]: Running eval programs.') - eval_metrics, elapsed_secs = eval_lib.run_eval_programs( - eval_programs=eval_programs, - train_state=eval_partitioned_train_state, - step=step_i, - ) + with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.EVAL_STEP): + eval_metrics, elapsed_secs = eval_lib.run_eval_programs( + eval_programs=eval_programs, + train_state=eval_partitioned_train_state, + step=step_i, + ) jax.monitoring.record_event_duration_secs( '/jax/pax/train/interleaved_eval_duration_sec', elapsed_secs ) @@ -539,6 +557,88 @@ def _train_and_evaluate_common( train_p.num_train_steps, ) break + return ( + step_i, + partitioned_train_state, + train_state_metadata, + train_input_for_checkpoint, + ) + + +def _train_and_evaluate_common( + *, + task: tasks_lib.SingleTask, + partitioner: partitioning.Partitioner, + train_program: programs.BaseTrainProgram, + train_input: base_input.BaseInput, + partitioned_train_state: TrainState, + train_state_provenance: TrainStateProvenance, + # TODO(hthu): Take a more generalized form of EvalProgram interface. + eval_programs: Sequence[programs.BaseEvalProgram], + decode_programs: Sequence[decode_programs_lib.SingleTaskDecodeProgram], + total_num_params, + early_stopping_fn: trainer_lib.EarlyStoppingFn | None, + checkpointer, + job_log_dir: epath.Path, + eval_prng_seed, + decode_prng_seed, + is_vars_replicated, + train_prng_seed, + exit_after_ondemand_checkpoint, + enable_summary_writer: bool = True, +): + with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.INITIALIZE_BACKEND): + train_state_metadata, train_input_for_checkpoint, step_i, train_p = ( + _prepare_train_and_evaluate( + task=task, + partitioner=partitioner, + train_program=train_program, + train_input=train_input, + partitioned_train_state=partitioned_train_state, + train_state_provenance=train_state_provenance, + eval_programs=eval_programs, + decode_programs=decode_programs, + total_num_params=total_num_params, + checkpointer=checkpointer, + job_log_dir=job_log_dir, + eval_prng_seed=eval_prng_seed, + decode_prng_seed=decode_prng_seed, + is_vars_replicated=is_vars_replicated, + train_prng_seed=train_prng_seed, + enable_summary_writer=enable_summary_writer, + ) + ) + + # Start the train loop. Make sure all at the same step. + py_utils.sync_global_devices(f'Start training loop from step: {step_i}') + # Collect then freeze GC, so that GC in the training loop will not touch the + # python objects used to initialize the model. Unfreeze at the end of the + # loop. + gc.collect() + gc.freeze() + + with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.MAIN_LOOP): + ( + step_i, + partitioned_train_state, + train_state_metadata, + train_input_for_checkpoint, + ) = _train_loop( + task=task, + train_program=train_program, + partitioned_train_state=partitioned_train_state, + eval_programs=eval_programs, + decode_programs=decode_programs, + total_num_params=total_num_params, + early_stopping_fn=early_stopping_fn, + checkpointer=checkpointer, + job_log_dir=job_log_dir, + exit_after_ondemand_checkpoint=exit_after_ondemand_checkpoint, + train_state_metadata=train_state_metadata, + train_input_for_checkpoint=train_input_for_checkpoint, + step_i=step_i, + train_p=train_p, + ) gc.unfreeze() logging.info('[PAX STATUS]: Saving checkpoint for final step.') diff --git a/paxml/main.py b/paxml/main.py index ad5ecf705..f03e2d4a4 100644 --- a/paxml/main.py +++ b/paxml/main.py @@ -51,6 +51,7 @@ from paxml import train from paxml import trainer_lib from paxml import tuning_lib +from paxml import ml_monitoring from praxis import pax_fiddle from praxis import py_utils @@ -496,15 +497,8 @@ def main(argv: Sequence[str]) -> None: _main(argv) -@py_utils.benchmark(prefix='[PAX STATUS]: E2E time: ') -def _main(argv: Sequence[str]) -> None: - logging.info('[PAX STATUS]: Program start.') - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - if tf_data_service_lib.run_tf_data_service(FLAGS.mode): - return - +def create_experiment_config(): + """Creates the experiment config from the command line flags.""" if FLAGS.tfds_data_dir is not None: # seqio import is slow so avoid module-level import import seqio @@ -556,6 +550,21 @@ def _main(argv: Sequence[str]) -> None: ) experiment_config.validate() + return experiment_config + + +@py_utils.benchmark(prefix='[PAX STATUS]: E2E time: ') +def _main(argv: Sequence[str]) -> None: + logging.info('[PAX STATUS]: Program start.') + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + if tf_data_service_lib.run_tf_data_service(FLAGS.mode): + return + + with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.INITIALIZE_SETUP): + experiment_config = create_experiment_config() + run( experiment_config=experiment_config, enable_checkpoint_saving=FLAGS.enable_checkpoint_saving, diff --git a/paxml/ml_monitoring.py b/paxml/ml_monitoring.py new file mode 100644 index 000000000..09c7e0053 --- /dev/null +++ b/paxml/ml_monitoring.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2022 The Pax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ML Monitoring for PAX.""" + +import contextlib +import enum + + +class MlEvent(enum.Enum): + """ML events to be recorded.""" + + INITIALIZE_BACKEND = enum.auto() + INITIALIZE_SETUP = enum.auto() + MAIN_LOOP = enum.auto() + TRAIN_STEP = enum.auto() + EVAL_STEP = enum.auto() + DECODE_STEP = enum.auto() + + +class EventBoundary(enum.Enum): + """Event boundary to be recorded.""" + + START = enum.auto() + END = enum.auto() + + +def record_step_number(step_number: int): + """Records the step number.""" + pass + + +def record_event_boundary(event: MlEvent, boundary: EventBoundary, **kwargs): + """Records the event boundary.""" + pass + + +@contextlib.contextmanager +def ml_event_logger(event: MlEvent, **kwargs): + yield