-
Notifications
You must be signed in to change notification settings - Fork 4.1k
/
torch_logs.py
96 lines (79 loc) · 2.88 KB
/
torch_logs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
(beta) Using TORCH_LOGS python API with torch.compile
==========================================================================================
**Author:** `Michael Lazos <https://github.com/mlazos>`_
"""
import logging
######################################################################
#
# This tutorial introduces the ``TORCH_LOGS`` environment variable, as well as the Python API, and
# demonstrates how to apply it to observe the phases of ``torch.compile``.
#
# .. note::
#
# This tutorial requires PyTorch 2.2.0 or later.
#
#
######################################################################
# Setup
# ~~~~~~~~~~~~~~~~~~~~~
# In this example, we'll set up a simple Python function which performs an elementwise
# add and observe the compilation process with ``TORCH_LOGS`` Python API.
#
# .. note::
#
# There is also an environment variable ``TORCH_LOGS``, which can be used to
# change logging settings at the command line. The equivalent environment
# variable setting is shown for each example.
import torch
# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
print("Skipping because torch.compile is not supported on this device.")
else:
@torch.compile()
def fn(x, y):
z = x + y
return z + 2
inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))
# print separator and reset dynamo
# between each example
def separator(name):
print(f"==================={name}=========================")
torch._dynamo.reset()
separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
torch._logging.set_logs(dynamo=logging.DEBUG)
fn(*inputs)
separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
torch._logging.set_logs(graph=True)
fn(*inputs)
separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
torch._logging.set_logs(fusion=True)
fn(*inputs)
separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
torch._logging.set_logs(output_code=True)
fn(*inputs)
separator("")
######################################################################
# Conclusion
# ~~~~~~~~~~
#
# In this tutorial we introduced the TORCH_LOGS environment variable and python API
# by experimenting with a small number of the available logging options.
# To view descriptions of all available options, run any python script
# which imports torch and set TORCH_LOGS to "help".
#
# Alternatively, you can view the `torch._logging documentation`_ to see
# descriptions of all available logging options.
#
# For more information on torch.compile, see the `torch.compile tutorial`_.
#
# .. _torch._logging documentation: https://pytorch.org/docs/main/logging.html
# .. _torch.compile tutorial: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html