-
Notifications
You must be signed in to change notification settings - Fork 4.1k
/
Copy pathtorch_logs.py
146 lines (123 loc) · 4.5 KB
/
torch_logs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
(beta) Using TORCH_LOGS python API with torch.compile
==========================================================================================
**Author:** `Michael Lazos <https://github.com/mlazos>`_
"""
import logging
######################################################################
#
# This tutorial introduces the ``TORCH_LOGS`` environment variable, as well as the Python API, and
# demonstrates how to apply it to observe the phases of ``torch.compile``.
#
# .. note::
#
# This tutorial requires PyTorch 2.2.0 or later.
#
#
######################################################################
# Setup
# ~~~~~~~~~~~~~~~~~~~~~
# In this example, we'll set up a simple Python function which performs an elementwise
# add and observe the compilation process with ``TORCH_LOGS`` Python API.
#
# .. note::
#
# There is also an environment variable ``TORCH_LOGS``, which can be used to
# change logging settings at the command line. The equivalent environment
# variable setting is shown for each example.
import torch
import sys
def env_setup():
"""Set up the environment to run the example. Exit cleanly if CUDA is not available."""
if not torch.cuda.is_available():
print("CUDA is not available. Exiting.")
sys.exit(0)
if torch.cuda.get_device_capability() < (7, 0):
print("Skipping because torch.compile is not supported on this device.")
sys.exit(0)
def separator(name):
"""Print a separator and reset dynamo between each example"""
print(f"\n{'='*20} {name} {'='*20}")
torch._dynamo.reset()
def run_debugging_suite():
"""Run the complete debugging suite with all logging options"""
env_setup()
@torch.compile()
def fn(x, y):
z = x + y
return z + 2
inputs = (
torch.ones(2, 2, device="cuda"),
torch.zeros(2, 2, device="cuda")
)
logging_scenarios = [
# View dynamo tracing; TORCH_LOGS="+dynamo"
("Dynamo Tracing", {"dynamo": logging.DEBUG}),
# View traced graph; TORCH_LOGS="graph"
("Traced Graph", {"graph": True}),
# View fusion decisions; TORCH_LOGS="fusion"
("Fusion Decisions", {"fusion": True}),
# View output code generated by inductor; TORCH_LOGS="output_code"
("Output Code", {"output_code": True})
]
for name, log_config in logging_scenarios:
separator(name)
torch._logging.set_logs(**log_config)
try:
result = fn(*inputs)
print(f"Function output shape: {result.shape}")
except Exception as e:
print(f"Error during {name}: {str(e)}")
run_debugging_suite()
######################################################################
# Using ``TORCH_TRACE/tlparse`` to produce produce compilation reports (for PyTorch 2)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# In this section, we introduce ``TORCH_TRACE`` and ``tlparse`` to produce reports.
#
#
# 1. Generate the raw trace logs by running the following command:
#
# .. code-block:: bash
#
# TORCH_TRACE="/tmp/tracedir" python script.py`
#
# Ensure you replace ``/tmp/tracedir`` with the path to the directory where you want
# to store the trace logs and replace the script with the name of your script.
#
# 2. Install ``tlparse`` by running:
#
# .. code-block:: bash
#
# pip install tlparse
#
# 3. Pass the trace log to ``tlparse`` to generate compilation reports:
#
# .. code-block: bash
#
# tlparse /tmp/tracedir
#
# This will open your browser with the HTML-like code generated above.
#
# By default, reports generated by ``tlparse`` are stored in the ``tl_out`` directory.
# You can change that by running:
#
# .. code-block:: bash
#
# tlparse /tmp/tracedir -o output_dir/
######################################################################
# Conclusion
# ~~~~~~~~~~
#
# In this tutorial we introduced the TORCH_LOGS environment variable and python API
# by experimenting with a small number of the available logging options.
# To view descriptions of all available options, run any python script
# which imports torch and set TORCH_LOGS to "help".
#
# Alternatively, you can view the `torch._logging documentation`_ to see
# descriptions of all available logging options.
#
# For more information on torch.compile, see the `torch.compile tutorial`_.
#
# .. _torch._logging documentation: https://pytorch.org/docs/main/logging.html
# .. _torch.compile tutorial: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html