Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script to compute step duration #70

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"tensorboard==2.18.0",
"tensorboard-plugin-profile==2.18.0",
"tf_keras==2.18.0",
"protobuf==4.25.5",
]

[project.optional-dependencies]
Expand Down
116 changes: 116 additions & 0 deletions torchprime/metrics/step_duration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Parse a profile to determine the median duration of a training step.
"""

import glob
import os
import statistics
import sys

from torchprime.metrics.xplane_pb2 import XSpace # type: ignore


def step_duration_from_latest_profile(profile_dir: str) -> float:
profile_dir = os.path.abspath(profile_dir)
profiles = [
(f, os.path.getctime(f))
for f in glob.glob(f"{profile_dir}/**/*.xplane.pb", recursive=True)
]
newest_profile, _time = max(profiles, key=lambda v: v[1])
return analyze_step_duration(newest_profile)


def analyze_step_duration(file_path: str) -> float:
xspace = XSpace()

# Read and parse the xplane proto
with open(file_path, "rb") as f:
print(f"Parsing {file_path}", file=sys.stderr)
xspace.ParseFromString(f.read())

return analyze_step_duration_from_pb(xspace)


def analyze_step_duration_from_pb(xspace: XSpace) -> float:
offsets = []
unique_names = set()

for plane in xspace.planes:
# Only consider /device:TPU:0
if plane.name != "/device:TPU:0":
continue
print(f"Plane ID: {plane.id}, Name: {plane.name}", file=sys.stderr)

for line in plane.lines:
# Only consider XLA Modules line
if line.name != "XLA Modules":
continue
print(f" Line ID: {line.id}, Name: {line.name}", file=sys.stderr)

# Collect offsets and event names
for event in line.events:
name = plane.event_metadata[event.metadata_id].name
offset_ps = event.offset_ps
unique_names.add(name)
offsets.append(offset_ps)
print(
f" Event Metadata Name: {name}, "
f"ID: {event.metadata_id}, Offset: {offset_ps / 1e12:.3f} s, "
f"Duration: {event.duration_ps / 1e12:.3f} s",
file=sys.stderr,
)

# Make sure we have events at all
if not offsets:
raise ValueError("No events found in the given XSpace data.")

# Confirm we have exactly one unique event name
if len(unique_names) > 1:
raise ValueError(f"Ambiguous event names found in XSpace: {unique_names}")

inferred_event_name = list(unique_names)[0]
# Sort offsets to compute consecutive differences
offsets.sort()

if len(offsets) < 2:
raise ValueError("Not enough events to compute step durations.")

# Compute durations based on consecutive offset differences
durations = []
for i in range(len(offsets) - 1):
# Convert picoseconds to seconds
durations.append((offsets[i + 1] - offsets[i]) / 1e12)

# If we have no intervals, we can't compute durations
event_count = len(durations)
if event_count == 0:
raise ValueError("Not enough events to compute step durations.")

print(
f"Got {event_count} intervals for event '{inferred_event_name}'", file=sys.stderr
)

# If fewer than 3 intervals, compute a simple average
if event_count < 3:
print(
"[Warning] Not enough events found to drop outliers.",
file=sys.stderr,
)
return sum(durations) / len(durations)

# Otherwise, use the median
average_duration = statistics.median(durations)
return average_duration


if __name__ == "__main__":
if len(sys.argv) != 2:
print(f"Usage: {sys.argv[0]} <path_to_proto_file>")
sys.exit(1)
proto_file_path = sys.argv[1]
try:
median_duration = analyze_step_duration(proto_file_path)
print(f"Median step duration: {median_duration:.4f}")
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
Binary file added torchprime/metrics/tests/real_profile.pb.gz
Binary file not shown.
105 changes: 105 additions & 0 deletions torchprime/metrics/tests/test_step_duration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import gzip
import os
import tempfile

import pytest

from torchprime.metrics.step_duration import (
analyze_step_duration,
analyze_step_duration_from_pb,
)
from torchprime.metrics.xplane_pb2 import XSpace # type: ignore


def test_empty():
with tempfile.NamedTemporaryFile(delete=True) as temp:
xspace = XSpace()
temp.write(xspace.SerializeToString())
temp.flush()
with pytest.raises(ValueError):
analyze_step_duration(temp.name)


def test_two_events():
with tempfile.NamedTemporaryFile(delete=True) as temp:
xspace = XSpace()
plane = xspace.planes.add()
plane.name = "/device:TPU:0"
plane.event_metadata[0].name = "SyncTensorsGraph.1"
line = plane.lines.add()
line.name = "XLA Modules"
event = line.events.add()
event.metadata_id = 0
event.offset_ps = int(1e12)
event.duration_ps = int(1e12)
event = line.events.add()
event.metadata_id = 0
event.offset_ps = int(2e12)
event.duration_ps = int(2e12)
temp.write(xspace.SerializeToString())
temp.flush()
assert analyze_step_duration(temp.name) == 1.0


def test_three_events():
with tempfile.NamedTemporaryFile(delete=True) as temp:
xspace = XSpace()
plane = xspace.planes.add()
plane.name = "/device:TPU:0"
plane.event_metadata[0].name = "SyncTensorsGraph.1"
line = plane.lines.add()
line.name = "XLA Modules"
event = line.events.add()
event.metadata_id = 0
event.offset_ps = int(1e12)
event.duration_ps = int(1e12)
event = line.events.add()
event.metadata_id = 0
event.offset_ps = int(2e12)
event.duration_ps = int(2e12)
event = line.events.add()
event.metadata_id = 0
event.offset_ps = int(4e12)
event.duration_ps = int(1e12)
temp.write(xspace.SerializeToString())
temp.flush()
assert analyze_step_duration(temp.name) == 1.5


def test_conflicting_step_names():
"""
There should only ever be one unique step name in the profile.
"""
with tempfile.NamedTemporaryFile(delete=True) as temp:
xspace = XSpace()
plane = xspace.planes.add()
plane.name = "/device:TPU:0"
plane.event_metadata[0].name = "SyncTensorsGraph.1"
plane.event_metadata[1].name = "SyncTensorsGraph.2"
line = plane.lines.add()
line.name = "XLA Modules"
event = line.events.add()
event.metadata_id = 0
event.offset_ps = int(1e12)
event.duration_ps = int(1e12)
event = line.events.add()
event.metadata_id = 1
event.offset_ps = int(2e12)
event.duration_ps = int(2e12)
temp.write(xspace.SerializeToString())
temp.flush()
with pytest.raises(ValueError, match="Ambiguous"):
analyze_step_duration(temp.name)


def test_real_profile():
"""
Tests parsing a real profile generated by a simple CNN.
"""
# Read real_profile.pb.gz relative to this test directory, and decompress it.
script_dir = os.path.dirname(os.path.realpath(__file__))
compressed_pb = os.path.join(script_dir, "real_profile.pb.gz")
with gzip.open(compressed_pb, "rb") as f:
xspace = XSpace()
xspace.ParseFromString(f.read())
assert pytest.approx(analyze_step_duration_from_pb(xspace), abs=1e-4) == 0.0206
49 changes: 49 additions & 0 deletions torchprime/metrics/xplane_pb2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-

# ruff: noqa

# Generated by the protocol buffer compiler. DO NOT EDIT!
tengyifei marked this conversation as resolved.
Show resolved Hide resolved
# source: xplane.proto
"""Generated protocol buffer code."""

from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0cxplane.proto\x12\x13tensorflow.profiler"j\n\x06XSpace\x12+\n\x06planes\x18\x01 \x03(\x0b\x32\x1b.tensorflow.profiler.XPlane\x12\x0e\n\x06\x65rrors\x18\x02 \x03(\t\x12\x10\n\x08warnings\x18\x03 \x03(\t\x12\x11\n\thostnames\x18\x04 \x03(\t"\xba\x03\n\x06XPlane\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\x12)\n\x05lines\x18\x03 \x03(\x0b\x32\x1a.tensorflow.profiler.XLine\x12\x46\n\x0e\x65vent_metadata\x18\x04 \x03(\x0b\x32..tensorflow.profiler.XPlane.EventMetadataEntry\x12\x44\n\rstat_metadata\x18\x05 \x03(\x0b\x32-.tensorflow.profiler.XPlane.StatMetadataEntry\x12)\n\x05stats\x18\x06 \x03(\x0b\x32\x1a.tensorflow.profiler.XStat\x1aY\n\x12\x45ventMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.tensorflow.profiler.XEventMetadata:\x02\x38\x01\x1aW\n\x11StatMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12\x31\n\x05value\x18\x02 \x01(\x0b\x32".tensorflow.profiler.XStatMetadata:\x02\x38\x01"\xbb\x01\n\x05XLine\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x12\n\ndisplay_id\x18\n \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x14\n\x0c\x64isplay_name\x18\x0b \x01(\t\x12\x14\n\x0ctimestamp_ns\x18\x03 \x01(\x03\x12\x13\n\x0b\x64uration_ps\x18\t \x01(\x03\x12+\n\x06\x65vents\x18\x04 \x03(\x0b\x32\x1b.tensorflow.profiler.XEventJ\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x07\x10\x08J\x04\x08\x08\x10\t"\x95\x01\n\x06XEvent\x12\x13\n\x0bmetadata_id\x18\x01 \x01(\x03\x12\x13\n\toffset_ps\x18\x02 \x01(\x03H\x00\x12\x19\n\x0fnum_occurrences\x18\x05 \x01(\x03H\x00\x12\x13\n\x0b\x64uration_ps\x18\x03 \x01(\x03\x12)\n\x05stats\x18\x04 \x03(\x0b\x32\x1a.tensorflow.profiler.XStatB\x06\n\x04\x64\x61ta"\xad\x01\n\x05XStat\x12\x13\n\x0bmetadata_id\x18\x01 \x01(\x03\x12\x16\n\x0c\x64ouble_value\x18\x02 \x01(\x01H\x00\x12\x16\n\x0cuint64_value\x18\x03 \x01(\x04H\x00\x12\x15\n\x0bint64_value\x18\x04 \x01(\x03H\x00\x12\x13\n\tstr_value\x18\x05 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x06 \x01(\x0cH\x00\x12\x13\n\tref_value\x18\x07 \x01(\x04H\x00\x42\x07\n\x05value"\x8f\x01\n\x0eXEventMetadata\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x14\n\x0c\x64isplay_name\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x03 \x01(\x0c\x12)\n\x05stats\x18\x05 \x03(\x0b\x32\x1a.tensorflow.profiler.XStat\x12\x10\n\x08\x63hild_id\x18\x06 \x03(\x03">\n\rXStatMetadata\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x03 \x01(\tB\x03\xf8\x01\x01\x62\x06proto3'
)

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "xplane_pb2", globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\370\001\001"
_XPLANE_EVENTMETADATAENTRY._options = None
_XPLANE_EVENTMETADATAENTRY._serialized_options = b"8\001"
_XPLANE_STATMETADATAENTRY._options = None
_XPLANE_STATMETADATAENTRY._serialized_options = b"8\001"
_XSPACE._serialized_start = 37
_XSPACE._serialized_end = 143
_XPLANE._serialized_start = 146
_XPLANE._serialized_end = 588
_XPLANE_EVENTMETADATAENTRY._serialized_start = 410
_XPLANE_EVENTMETADATAENTRY._serialized_end = 499
_XPLANE_STATMETADATAENTRY._serialized_start = 501
_XPLANE_STATMETADATAENTRY._serialized_end = 588
_XLINE._serialized_start = 591
_XLINE._serialized_end = 778
_XEVENT._serialized_start = 781
_XEVENT._serialized_end = 930
_XSTAT._serialized_start = 933
_XSTAT._serialized_end = 1106
_XEVENTMETADATA._serialized_start = 1109
_XEVENTMETADATA._serialized_end = 1252
_XSTATMETADATA._serialized_start = 1254
_XSTATMETADATA._serialized_end = 1316
# @@protoc_insertion_point(module_scope)
33 changes: 22 additions & 11 deletions torchprime/torch_xla_models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from transformers.trainer_pt_utils import get_module_class_from_name
from transformers.utils import check_min_version

from torchprime.metrics.step_duration import step_duration_from_latest_profile

check_min_version("4.39.3")
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -207,22 +209,25 @@ def train_loop(self):
except StopIteration:
break

# For logging step, we explicitly isolate this step from tracing and execution overlapping.
if step % self.config.logging_steps == 0:
xm.wait_device_ops()
trace_start_time = timer()

loss = self.train_step(batch)

trace_end_time = timer()

if step % self.config.logging_steps == 0:
xm.wait_device_ops()
execute_end_time = timer()
logger.info(
f"Step: {step}, loss: {loss:0.4f}, trace time: {(trace_end_time - trace_start_time) * 1000:0.2f} ms, step time: {(execute_end_time - trace_end_time) * 1000:0.2f} ms"

def step_closure(step, loss, trace_start_time, trace_end_time):
logger.info(
f"Step: {step}, loss: {loss:0.4f}, "
f"trace time: {(trace_end_time - trace_start_time) * 1000:0.2f} ms"
)
if math.isnan(loss):
raise ValueError(f"Loss is NaN at step {step}")

xm.add_step_closure(
step_closure,
args=(step, loss, trace_start_time, trace_end_time),
run_async=True,
)
if math.isnan(loss):
raise ValueError(f"Loss is NaN at step {step}")

# Capture profile at the prefer step
if step == self.config.profile_step:
Expand All @@ -236,8 +241,14 @@ def train_loop(self):
self.config.profile_duration,
)

xm.wait_device_ops()
logger.info("Finished training run")

# Analyze the step duration from the latest profile
if self.config.profile_step >= 0:
step_duration = step_duration_from_latest_profile(self.config.profile_dir)
logger.info(f"Step duration: {step_duration:.3f} s")

@torch_xla.compile(full_graph=True)
def train_step(self, batch):
_logits, loss = self.model(**batch)
Expand Down