Skip to content

Commit

Permalink
Add more testing
Browse files Browse the repository at this point in the history
  • Loading branch information
dipannita08 committed Nov 4, 2024
1 parent 1dff92c commit 1417133
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 186 deletions.
34 changes: 33 additions & 1 deletion axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
from absl import flags, logging
from ml_goodput_measurement import goodput
from ml_goodput_measurement import monitoring as goodput_monitoring

from axlearn.cloud.common.utils import parse_kv_flags
from axlearn.common import measurement
Expand All @@ -22,7 +23,11 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
"""Converts flags to a recorder.
`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
corresponding to keys will be set to the corresponding values.
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
additionally take in following Tensorboard configs in the recorder_spec:
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
to Tensorboard.
"""
cfg: measurement.Recorder.Config = cls.default_config()
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
Expand All @@ -32,6 +37,7 @@ def __init__(self, cfg):
super().__init__(cfg)
cfg: GoodputRecorder.Config = self.config
self._recorder = None
self._monitor = None

def record(self, event: measurement.Event, *args, **kwargs):
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
Expand Down Expand Up @@ -68,3 +74,29 @@ def record(self, event: measurement.Event, *args, **kwargs):
1,
event,
)

def start_monitoring(self, *args, **kwargs):
# Instantiate ml-goodput-measurement's GoodputMonitor
# to asynchronously calculate goodput and badput at
# the upload_interval and upload to the specified
# tensorboard directory.
if self._monitor is None:
cfg: GoodputRecorder.Config = self.config
self._monitor = goodput_monitoring.GoodputMonitor(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
tensorboard_dir=cfg.upload_dir,
upload_interval=int(cfg.upload_interval),
monitoring_enabled=(jax.process_index() == 0),
include_badput_breakdown=True,
)

if self._monitor:
self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard in the background!")
else:
logging.log_first_n(
logging.WARNING,
"Goodput upload could not be started. Please check GoodputMonitor logs.",
1,
)
41 changes: 38 additions & 3 deletions axlearn/cloud/gcp/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class GoodputRecorderTest(parameterized.TestCase):
"""Tests GoodputRecorder."""

@parameterized.parameters(None, ["name=test-name"])
@parameterized.parameters(
(None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],)
)
def test_from_flags(self, spec):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
Expand All @@ -34,13 +36,46 @@ def test_from_flags(self, spec):
# Recorder is not instantiated until first event.
self.assertIsNone(recorder._recorder)

def test_record(self):
def test_record_and_monitor(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default("recorder_spec", ["name=test-name"])
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
recorder._recorder = mock.MagicMock()
recorder.record(measurement.Event.START_JOB)
self.assertTrue(recorder._recorder.record_job_start_time.called)

def test_start_monitoring(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
recorder._monitor = None # Ensure _monitor is initially None

with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor:
mock_monitor_instance = mock_goodput_monitor.return_value
recorder.start_monitoring()

# Check that GoodputMonitor was instantiated
mock_goodput_monitor.assert_called_once_with(
job_name="test-name",
logger_name="goodput_logger_test-name",
tensorboard_dir="/test/path/to/upload",
upload_interval=15,
monitoring_enabled=True,
include_badput_breakdown=True,
)

# Ensure that start_goodput_uploader is called on the monitor instance
mock_monitor_instance.start_goodput_uploader.assert_called_once()
self.assertIsNotNone(recorder._monitor)
64 changes: 0 additions & 64 deletions axlearn/cloud/gcp/monitoring.py

This file was deleted.

6 changes: 2 additions & 4 deletions axlearn/common/launch_trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@

from absl import app, flags

from axlearn.common import launch, launch_trainer, measurement, monitoring
from axlearn.common import launch, launch_trainer, measurement
from axlearn.common.config import config_for_function


def main(_):
measurement.initialize(flags.FLAGS)
monitoring.initialize(flags.FLAGS)
launch.setup()
trainer_config = launch_trainer.get_trainer_config()
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
monitoring.start_monitoring()
measurement.start_monitoring()
launch_trainer.run_trainer(trainer_config)


if __name__ == "__main__":
measurement.define_flags()
monitoring.define_flags()
app.run(main)
21 changes: 21 additions & 0 deletions axlearn/common/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,13 @@ class Config(Configurable.Config):
Attributes:
name: Name of the recorder.
upload_dir: Directory to store metrics for the monitor.
upload_interval: Time interval (seconds) for monitoring uploads.
"""

name: Required[str] = REQUIRED
upload_dir: Required[str] = REQUIRED
upload_interval: Required[int] = REQUIRED

@classmethod
def from_flags(cls, fv: Optional[flags.FlagValues]) -> "Recorder":
Expand All @@ -59,6 +63,10 @@ def record(self, event: Event, *args, **kwargs):
"""Records an event with the given name."""
raise NotImplementedError(type(self))

def start_monitoring(self, **kwargs):
"""Starts computing and uploading metrics at some configured interval in the background."""
raise NotImplementedError(type(self))


_recorders: dict[str, type] = {}
_T = TypeVar("_T")
Expand Down Expand Up @@ -132,3 +140,16 @@ def record_event(event: Event):
logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1)
else:
global_recorder.record(event)


def start_monitoring():
"""Begins monitoring events as per global monitor functionality."""
if global_recorder is None:
logging.log_first_n(
logging.INFO, "Since recorder is not set up, monitoring cannot be started.", 1
)
else:
global_recorder.start_monitoring()
logging.info(
"Starting monitoring of events using global recorder's monitor: %s", global_recorder
)
7 changes: 7 additions & 0 deletions axlearn/common/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ def test_initialize(self, recorder_type, expected):
with mock.patch.object(measurement.global_recorder, "record") as mock_record:
measurement.record_event(measurement.Event.START_JOB)
self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0])

# Ensure that start_monitoring does not fail.
with mock.patch.object(
measurement.global_recorder, "start_monitoring"
) as mock_start_monitoring:
measurement.start_monitoring()
mock_start_monitoring.assert_called_once()
113 changes: 0 additions & 113 deletions axlearn/common/monitoring.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ class DummyRecorder(measurement.Recorder):
@classmethod
def from_flags(cls, fv) -> measurement.Recorder:
del fv
return cls.default_config().set(name="dummy_recorder").instantiate()
return (
cls.default_config()
.set(name="dummy_recorder", upload_dir="/dummy/upload_dir", upload_interval=15)
.instantiate()
)

0 comments on commit 1417133

Please sign in to comment.