Skip to content

Commit

Permalink
[Profiler] Defer recording startup python events (take 2) (pytorch#91684
Browse files Browse the repository at this point in the history
)

This is my commandeer of pytorch#82154 with a couple extra fixes.

The high level idea is that when we start profiling we see python frames which are currently executing, but we don't know what system TID created them. So instead we defer the TID assignment, and then during post processing we peer into the future and use the system TID *of the next* call on that Python TID.

As an aside, it turns out that CPython does some bookkeeping (https://github.com/python/cpython/blob/ee821dcd3961efc47262322848267fe398faa4e4/Include/cpython/pystate.h#L159-L165, thanks @dzhulgakov for the pointer), but you'd have to do some extra work at runtime to know how to map their TID to ours so for now I'm going to stick to what I can glean from post processing alone.

As we start observing more threads it becomes more important to be principled about how we start up and shut down. (Since threads may die while the profiler is running.) pytorch#82154 had various troubles with segfaults that wound up being related to accessing Python thread pointers which were no longer alive. I've tweaked the startup and shutdown interaction with the CPython interpreter and it should be safer now.

Differential Revision: [D42336292](https://our.internmc.facebook.com/intern/diff/D42336292/)
Pull Request resolved: pytorch#91684
Approved by: https://github.com/chaekit
  • Loading branch information
Taylor Robie authored and pytorchmergebot committed Feb 11, 2023
1 parent 8d45f55 commit d09cd15
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 35 deletions.
159 changes: 159 additions & 0 deletions test/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import tempfile
import textwrap
import threading
import unittest
from unittest.mock import patch
import weakref
Expand Down Expand Up @@ -57,6 +58,8 @@
from torch.testing._internal.common_device_type import skipCUDAVersionIn
from torch.testing._internal.common_utils import (
IS_WINDOWS,
instantiate_parametrized_tests,
parametrize,
run_tests,
TemporaryDirectoryName,
TemporaryFileName,
Expand Down Expand Up @@ -478,6 +481,7 @@ def test_execution_graph_no_capture(self):
assert found_root_node


@instantiate_parametrized_tests
class TestProfiler(TestCase):

@unittest.skipIf(TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite.")
Expand Down Expand Up @@ -549,6 +553,161 @@ def extract(pattern: str):

torch._C._set_graph_executor_optimize(prev_opt)

@parametrize(
"name,thread_spec",
{
"basic": ((False, False),),
"multiple_preexisting": ((False, False), ) * 2,
"open_in_scope": ((True, False),),
"close_in_scope": ((False, True),),
"complex": (
# Large number of background threads
(False, False),
(False, False),
(False, False),
(False, False),
# some of which finish during profiling
(False, True),
(False, True),
# And the profiled section is also multithreaded
(True, False),
(True, True),
),
}.items(),
name_fn=lambda name, thread_spec: name
)
@parametrize("work_in_main_thread", [True, False])
def test_source_multithreaded(self, name, thread_spec, work_in_main_thread):
"""Test various threading configurations.
`thread_spec` is a Tuple[Tuple[bool, bool], ...] where each pair is a
thread. The first bool indicates if the thread should be started under
the profiler context and the second is if it should be joined under the
profiler context.
"""

timeout = 15
num_threads = len(thread_spec) + 1 # Main thread
start_barrier = threading.Barrier(num_threads, timeout=timeout)
end_barrier = threading.Barrier(num_threads, timeout=timeout)

class Task(threading.Thread):

def __init__(self):
self._end_gate = threading.Event()
super().__init__(daemon=True)
self.start()
self.finished = False

def run(self):
self._run(self._end_gate)

def release(self):
self._end_gate.set()

@staticmethod
def _run(end_gate=None):

def known_preexisting_function():
start_barrier.wait()

# Fixed point that we can use to test capture of functions
# which are already running when profiling is enabled.
known_preexisting_function()

model = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
)

def invoked_during_run():
pass

invoked_during_run()

_ = model(torch.rand(4, 10))
end_barrier.wait()

if end_gate is not None:
end_gate.wait(timeout=timeout)

threads = {}

def add_threads(context: bool):
for idx, (start_under_profiler, _) in enumerate(thread_spec):
if start_under_profiler == context:
assert idx not in threads
threads[idx] = Task()

def join_threads(context: bool):
for idx, (_, end_under_profiler) in enumerate(thread_spec):
if end_under_profiler == context:
threads[idx].release()

for idx, (_, end_under_profiler) in enumerate(thread_spec):
t = threads[idx]
if end_under_profiler == context:
t.join(timeout=timeout)

try:
add_threads(False)
with torch.profiler.profile(with_stack=True) as prof:
# Threads added while the profiler are running will not be observed
# since there is no way to hook into Python's thread start call to
# register the observer. These are here purely to verify safety.
add_threads(True)

if work_in_main_thread:
Task._run()
else:
start_barrier.wait()
end_barrier.wait()

join_threads(True)
join_threads(False)

finally:
# It is very important that we clean up everything because the
# Python tracer will detect ALL active threads. (Even orphans from
# prior failed tests.) If we don't clean up properly we can
# contaminate subsequent tests.
start_barrier.abort()
end_barrier.abort()
for t in threads.values():
t.release()

for t in threads.values():
t.join(timeout=timeout)

for t in threads.values():
self.assertFalse(t.is_alive())

roots = prof.profiler.kineto_results.experimental_event_tree()
nodes = [node for node in _utils.traverse_dfs(roots) if isinstance(node.extra_fields, _ExtraFields_PyCall)]
tid_counts = collections.Counter([node.start_tid for node in nodes])

prior_threads = sum(not start_under_profiler for start_under_profiler, _ in thread_spec)
expected_threads = prior_threads + 1
self.assertEqual(len(tid_counts), expected_threads, f"{expected_threads}, {tid_counts}")
self.assertEqual(len(nodes), sum(tid_counts.values()))

# Profiler uses uint64_t max as a placeholder until TID can be determined.
no_tid = 2 ** 64 - 1
self.assertFalse(no_tid in tid_counts)

worker_threads = prior_threads + (1 if work_in_main_thread else 0)

observed_preexisting = [node.start_tid for node in nodes if "known_preexisting_function" in node.name]
self.assertEqual(len(observed_preexisting), worker_threads)
self.assertEqual(len(observed_preexisting), len(set(observed_preexisting)))

observed_during_run = [node.start_tid for node in nodes if "invoked_during_run" in node.name]
self.assertEqual(len(observed_during_run), worker_threads)
self.assertEqual(len(observed_during_run), len(set(observed_during_run)))

def payload(self, use_cuda=False):
x = torch.randn(10, 10)
if use_cuda:
Expand Down
Loading

0 comments on commit d09cd15

Please sign in to comment.