diff --git a/pyproject.toml b/pyproject.toml index 5f33189..186605b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "tensorboard==2.18.0", "tensorboard-plugin-profile==2.18.0", "tf_keras==2.18.0", + "protobuf==4.25.5", ] [project.optional-dependencies] diff --git a/torchprime/metrics/step_duration.py b/torchprime/metrics/step_duration.py new file mode 100644 index 0000000..460773e --- /dev/null +++ b/torchprime/metrics/step_duration.py @@ -0,0 +1,116 @@ +""" +Parse a profile to determine the median duration of a training step. +""" + +import glob +import os +import statistics +import sys + +from torchprime.metrics.xplane_pb2 import XSpace # type: ignore + + +def step_duration_from_latest_profile(profile_dir: str) -> float: + profile_dir = os.path.abspath(profile_dir) + profiles = [ + (f, os.path.getctime(f)) + for f in glob.glob(f"{profile_dir}/**/*.xplane.pb", recursive=True) + ] + newest_profile, _time = max(profiles, key=lambda v: v[1]) + return analyze_step_duration(newest_profile) + + +def analyze_step_duration(file_path: str) -> float: + xspace = XSpace() + + # Read and parse the xplane proto + with open(file_path, "rb") as f: + print(f"Parsing {file_path}", file=sys.stderr) + xspace.ParseFromString(f.read()) + + return analyze_step_duration_from_pb(xspace) + + +def analyze_step_duration_from_pb(xspace: XSpace) -> float: + offsets = [] + unique_names = set() + + for plane in xspace.planes: + # Only consider /device:TPU:0 + if plane.name != "/device:TPU:0": + continue + print(f"Plane ID: {plane.id}, Name: {plane.name}", file=sys.stderr) + + for line in plane.lines: + # Only consider XLA Modules line + if line.name != "XLA Modules": + continue + print(f" Line ID: {line.id}, Name: {line.name}", file=sys.stderr) + + # Collect offsets and event names + for event in line.events: + name = plane.event_metadata[event.metadata_id].name + offset_ps = event.offset_ps + unique_names.add(name) + offsets.append(offset_ps) + print( + f" Event Metadata Name: {name}, " + f"ID: {event.metadata_id}, Offset: {offset_ps / 1e12:.3f} s, " + f"Duration: {event.duration_ps / 1e12:.3f} s", + file=sys.stderr, + ) + + # Make sure we have events at all + if not offsets: + raise ValueError("No events found in the given XSpace data.") + + # Confirm we have exactly one unique event name + if len(unique_names) > 1: + raise ValueError(f"Ambiguous event names found in XSpace: {unique_names}") + + inferred_event_name = list(unique_names)[0] + # Sort offsets to compute consecutive differences + offsets.sort() + + if len(offsets) < 2: + raise ValueError("Not enough events to compute step durations.") + + # Compute durations based on consecutive offset differences + durations = [] + for i in range(len(offsets) - 1): + # Convert picoseconds to seconds + durations.append((offsets[i + 1] - offsets[i]) / 1e12) + + # If we have no intervals, we can't compute durations + event_count = len(durations) + if event_count == 0: + raise ValueError("Not enough events to compute step durations.") + + print( + f"Got {event_count} intervals for event '{inferred_event_name}'", file=sys.stderr + ) + + # If fewer than 3 intervals, compute a simple average + if event_count < 3: + print( + "[Warning] Not enough events found to drop outliers.", + file=sys.stderr, + ) + return sum(durations) / len(durations) + + # Otherwise, use the median + average_duration = statistics.median(durations) + return average_duration + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + proto_file_path = sys.argv[1] + try: + median_duration = analyze_step_duration(proto_file_path) + print(f"Median step duration: {median_duration:.4f}") + except Exception as e: + print(f"Error: {e}") + sys.exit(1) diff --git a/torchprime/metrics/tests/real_profile.pb.gz b/torchprime/metrics/tests/real_profile.pb.gz new file mode 100644 index 0000000..698b0c1 Binary files /dev/null and b/torchprime/metrics/tests/real_profile.pb.gz differ diff --git a/torchprime/metrics/tests/test_step_duration.py b/torchprime/metrics/tests/test_step_duration.py new file mode 100644 index 0000000..074d4ca --- /dev/null +++ b/torchprime/metrics/tests/test_step_duration.py @@ -0,0 +1,105 @@ +import gzip +import os +import tempfile + +import pytest + +from torchprime.metrics.step_duration import ( + analyze_step_duration, + analyze_step_duration_from_pb, +) +from torchprime.metrics.xplane_pb2 import XSpace # type: ignore + + +def test_empty(): + with tempfile.NamedTemporaryFile(delete=True) as temp: + xspace = XSpace() + temp.write(xspace.SerializeToString()) + temp.flush() + with pytest.raises(ValueError): + analyze_step_duration(temp.name) + + +def test_two_events(): + with tempfile.NamedTemporaryFile(delete=True) as temp: + xspace = XSpace() + plane = xspace.planes.add() + plane.name = "/device:TPU:0" + plane.event_metadata[0].name = "SyncTensorsGraph.1" + line = plane.lines.add() + line.name = "XLA Modules" + event = line.events.add() + event.metadata_id = 0 + event.offset_ps = int(1e12) + event.duration_ps = int(1e12) + event = line.events.add() + event.metadata_id = 0 + event.offset_ps = int(2e12) + event.duration_ps = int(2e12) + temp.write(xspace.SerializeToString()) + temp.flush() + assert analyze_step_duration(temp.name) == 1.0 + + +def test_three_events(): + with tempfile.NamedTemporaryFile(delete=True) as temp: + xspace = XSpace() + plane = xspace.planes.add() + plane.name = "/device:TPU:0" + plane.event_metadata[0].name = "SyncTensorsGraph.1" + line = plane.lines.add() + line.name = "XLA Modules" + event = line.events.add() + event.metadata_id = 0 + event.offset_ps = int(1e12) + event.duration_ps = int(1e12) + event = line.events.add() + event.metadata_id = 0 + event.offset_ps = int(2e12) + event.duration_ps = int(2e12) + event = line.events.add() + event.metadata_id = 0 + event.offset_ps = int(4e12) + event.duration_ps = int(1e12) + temp.write(xspace.SerializeToString()) + temp.flush() + assert analyze_step_duration(temp.name) == 1.5 + + +def test_conflicting_step_names(): + """ + There should only ever be one unique step name in the profile. + """ + with tempfile.NamedTemporaryFile(delete=True) as temp: + xspace = XSpace() + plane = xspace.planes.add() + plane.name = "/device:TPU:0" + plane.event_metadata[0].name = "SyncTensorsGraph.1" + plane.event_metadata[1].name = "SyncTensorsGraph.2" + line = plane.lines.add() + line.name = "XLA Modules" + event = line.events.add() + event.metadata_id = 0 + event.offset_ps = int(1e12) + event.duration_ps = int(1e12) + event = line.events.add() + event.metadata_id = 1 + event.offset_ps = int(2e12) + event.duration_ps = int(2e12) + temp.write(xspace.SerializeToString()) + temp.flush() + with pytest.raises(ValueError, match="Ambiguous"): + analyze_step_duration(temp.name) + + +def test_real_profile(): + """ + Tests parsing a real profile generated by a simple CNN. + """ + # Read real_profile.pb.gz relative to this test directory, and decompress it. + script_dir = os.path.dirname(os.path.realpath(__file__)) + compressed_pb = os.path.join(script_dir, "real_profile.pb.gz") + with gzip.open(compressed_pb, "rb") as f: + xspace = XSpace() + xspace.ParseFromString(f.read()) + assert pytest.approx(analyze_step_duration_from_pb(xspace), abs=1e-4) == 0.0206 diff --git a/torchprime/metrics/xplane_pb2.py b/torchprime/metrics/xplane_pb2.py new file mode 100644 index 0000000..ae6962d --- /dev/null +++ b/torchprime/metrics/xplane_pb2.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +# ruff: noqa + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: xplane.proto +"""Generated protocol buffer code.""" + +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0cxplane.proto\x12\x13tensorflow.profiler"j\n\x06XSpace\x12+\n\x06planes\x18\x01 \x03(\x0b\x32\x1b.tensorflow.profiler.XPlane\x12\x0e\n\x06\x65rrors\x18\x02 \x03(\t\x12\x10\n\x08warnings\x18\x03 \x03(\t\x12\x11\n\thostnames\x18\x04 \x03(\t"\xba\x03\n\x06XPlane\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\x12)\n\x05lines\x18\x03 \x03(\x0b\x32\x1a.tensorflow.profiler.XLine\x12\x46\n\x0e\x65vent_metadata\x18\x04 \x03(\x0b\x32..tensorflow.profiler.XPlane.EventMetadataEntry\x12\x44\n\rstat_metadata\x18\x05 \x03(\x0b\x32-.tensorflow.profiler.XPlane.StatMetadataEntry\x12)\n\x05stats\x18\x06 \x03(\x0b\x32\x1a.tensorflow.profiler.XStat\x1aY\n\x12\x45ventMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.tensorflow.profiler.XEventMetadata:\x02\x38\x01\x1aW\n\x11StatMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12\x31\n\x05value\x18\x02 \x01(\x0b\x32".tensorflow.profiler.XStatMetadata:\x02\x38\x01"\xbb\x01\n\x05XLine\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x12\n\ndisplay_id\x18\n \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x14\n\x0c\x64isplay_name\x18\x0b \x01(\t\x12\x14\n\x0ctimestamp_ns\x18\x03 \x01(\x03\x12\x13\n\x0b\x64uration_ps\x18\t \x01(\x03\x12+\n\x06\x65vents\x18\x04 \x03(\x0b\x32\x1b.tensorflow.profiler.XEventJ\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x07\x10\x08J\x04\x08\x08\x10\t"\x95\x01\n\x06XEvent\x12\x13\n\x0bmetadata_id\x18\x01 \x01(\x03\x12\x13\n\toffset_ps\x18\x02 \x01(\x03H\x00\x12\x19\n\x0fnum_occurrences\x18\x05 \x01(\x03H\x00\x12\x13\n\x0b\x64uration_ps\x18\x03 \x01(\x03\x12)\n\x05stats\x18\x04 \x03(\x0b\x32\x1a.tensorflow.profiler.XStatB\x06\n\x04\x64\x61ta"\xad\x01\n\x05XStat\x12\x13\n\x0bmetadata_id\x18\x01 \x01(\x03\x12\x16\n\x0c\x64ouble_value\x18\x02 \x01(\x01H\x00\x12\x16\n\x0cuint64_value\x18\x03 \x01(\x04H\x00\x12\x15\n\x0bint64_value\x18\x04 \x01(\x03H\x00\x12\x13\n\tstr_value\x18\x05 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x06 \x01(\x0cH\x00\x12\x13\n\tref_value\x18\x07 \x01(\x04H\x00\x42\x07\n\x05value"\x8f\x01\n\x0eXEventMetadata\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x14\n\x0c\x64isplay_name\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x03 \x01(\x0c\x12)\n\x05stats\x18\x05 \x03(\x0b\x32\x1a.tensorflow.profiler.XStat\x12\x10\n\x08\x63hild_id\x18\x06 \x03(\x03">\n\rXStatMetadata\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x03 \x01(\tB\x03\xf8\x01\x01\x62\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "xplane_pb2", globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\370\001\001" + _XPLANE_EVENTMETADATAENTRY._options = None + _XPLANE_EVENTMETADATAENTRY._serialized_options = b"8\001" + _XPLANE_STATMETADATAENTRY._options = None + _XPLANE_STATMETADATAENTRY._serialized_options = b"8\001" + _XSPACE._serialized_start = 37 + _XSPACE._serialized_end = 143 + _XPLANE._serialized_start = 146 + _XPLANE._serialized_end = 588 + _XPLANE_EVENTMETADATAENTRY._serialized_start = 410 + _XPLANE_EVENTMETADATAENTRY._serialized_end = 499 + _XPLANE_STATMETADATAENTRY._serialized_start = 501 + _XPLANE_STATMETADATAENTRY._serialized_end = 588 + _XLINE._serialized_start = 591 + _XLINE._serialized_end = 778 + _XEVENT._serialized_start = 781 + _XEVENT._serialized_end = 930 + _XSTAT._serialized_start = 933 + _XSTAT._serialized_end = 1106 + _XEVENTMETADATA._serialized_start = 1109 + _XEVENTMETADATA._serialized_end = 1252 + _XSTATMETADATA._serialized_start = 1254 + _XSTATMETADATA._serialized_end = 1316 +# @@protoc_insertion_point(module_scope) diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 01b933a..65f6e87 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -41,6 +41,8 @@ from transformers.trainer_pt_utils import get_module_class_from_name from transformers.utils import check_min_version +from torchprime.metrics.step_duration import step_duration_from_latest_profile + check_min_version("4.39.3") logger = logging.getLogger(__name__) @@ -207,22 +209,25 @@ def train_loop(self): except StopIteration: break - # For logging step, we explicitly isolate this step from tracing and execution overlapping. - if step % self.config.logging_steps == 0: - xm.wait_device_ops() trace_start_time = timer() - loss = self.train_step(batch) - trace_end_time = timer() + if step % self.config.logging_steps == 0: - xm.wait_device_ops() - execute_end_time = timer() - logger.info( - f"Step: {step}, loss: {loss:0.4f}, trace time: {(trace_end_time - trace_start_time) * 1000:0.2f} ms, step time: {(execute_end_time - trace_end_time) * 1000:0.2f} ms" + + def step_closure(step, loss, trace_start_time, trace_end_time): + logger.info( + f"Step: {step}, loss: {loss:0.4f}, " + f"trace time: {(trace_end_time - trace_start_time) * 1000:0.2f} ms" + ) + if math.isnan(loss): + raise ValueError(f"Loss is NaN at step {step}") + + xm.add_step_closure( + step_closure, + args=(step, loss, trace_start_time, trace_end_time), + run_async=True, ) - if math.isnan(loss): - raise ValueError(f"Loss is NaN at step {step}") # Capture profile at the prefer step if step == self.config.profile_step: @@ -236,8 +241,14 @@ def train_loop(self): self.config.profile_duration, ) + xm.wait_device_ops() logger.info("Finished training run") + # Analyze the step duration from the latest profile + if self.config.profile_step >= 0: + step_duration = step_duration_from_latest_profile(self.config.profile_dir) + logger.info(f"Step duration: {step_duration:.3f} s") + @torch_xla.compile(full_graph=True) def train_step(self, batch): _logits, loss = self.model(**batch)