Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/gcp workload monitoring #1150

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jcyang43
Copy link

@jcyang43 jcyang43 commented Jan 7, 2025

Description

Add option to enable GCP workload monitoring for MaxText workloads.

  • GCP workload monitoring sends performance metrics (heartbeat & training step times) to cloud monarch for monitoring such that if a metric hits its pre-defined threshold, oncalls will be notified to see if any actions are needed. This is ideal for critical workloads sensitive to infrastructure changes.
  • Each metric can be configured to be on or off based on configs. Examples are included in MaxText/configs/base.yml
  • Documentation can be found at getting_started/GCP_Workload_Monitoring.md

Tests

Tested on trillium TPU and confirmed metrics sent to cloud monarch successfully if configs are enabled. No metrics will be sent to cloud monarch if configs are set to False.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file looks independent of maxtext (other than the maxtext logging plubmed in). Is there any opportunity to move this file / APIs to one of the packages we are also depending on (google-cloud-monitoring, google-api-core
google-api-python-client)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Matthew, thanks for the suggestion. We're planning to make this improvement in a subsequent version where we'll build an API and python package for workload observability. We'll send a follow-up PR when the API is ready to use

@@ -456,6 +456,10 @@ monitor_goodput: True
goodput_upload_interval_seconds: 60
enable_pathways_goodput: False

# GCP workload monitoring
report_heartbeat_metric_for_gcp_monitoring: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we default these to false? Most runs of maxtext will not be prod runs that want heartbeating or alerting

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would like to default to true as it would enable us to collect telemetry data. We don't necessarily need prod runs/alerting. The metrics can just be used to monitor performance of maxtext workloads.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping this on by default might lead to errors in internal testing that rely on google3 libraries and some (such as cloud logging and perhaps even monitoring_v3) are not available in google3 yet. MaxText does not support dependency injection in testing so GCP services such as monitoring and logging do not work as intended with tests yet.

I would suggest running some internal tests with the google3 repo as well to validate the above.

This guide provides an overview on how to enable GCP workload monitoring for your MaxText workload.

## Overview
To address Google Cloud's lack of visibility into user workload performance, users now have the option to enable GCP workload monitoring for all of their MaxText workloads. This workload performance monitoring feature is ideal for critical workloads sensitive to infrastructure changes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be worded more positively =)

We offer a monitoring and alerting feature that is well suited for critical workloads sensitive to...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

requirements.txt Outdated
@@ -10,6 +10,9 @@ datasets
gcsfs
google-cloud-aiplatform==1.61.0
google-cloud-storage
google-cloud-monitoring==2.20.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be pinned? Can we use latest?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@gobbleturk gobbleturk requested a review from dipannita08 January 8, 2025 18:13
@jcyang43 jcyang43 force-pushed the feature/gcp-workload-monitoring branch 2 times, most recently from 5e4817b to caef832 Compare January 8, 2025 21:15
@jcyang43 jcyang43 force-pushed the feature/gcp-workload-monitoring branch from caef832 to 25e43fb Compare January 9, 2025 18:26
@@ -852,6 +860,12 @@ def train_loop(config, state=None):
example_batch = None
last_step_completion = datetime.datetime.now()

prof = profiler.Profiler(config)
if gcp_workload_monitor and heartbeat_reporting_stop_event:
max_logging.log("Starting background thread for reporting heartbeat")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we expose an API to start this thread based on configs?

step_time_delta = datetime.datetime.now() - last_step_completion
record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
if gcp_workload_monitor and config.report_performance_metric_for_gcp_monitoring:
gcp_workload_monitor.report_performance(step_time_delta.total_seconds())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should ideally be in its own API.

global_rank = jax.process_index()
while not stop_event.is_set():
self.report_heartbeat(local_rank, str(global_rank))
time.sleep(5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this configurable?

Copy link
Collaborator

@dipannita08 dipannita08 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants