Skip to content

Commit 823ddca

Browse files
authored
Profiler recipe (#1019)
* Profiler recipe Summary: Adding a recipe for profiler Test Plan: make html-noplot
1 parent 90f4771 commit 823ddca

File tree

4 files changed

+223
-0
lines changed

4 files changed

+223
-0
lines changed
34.9 KB
Loading

_static/img/trace_img.png

134 KB
Loading

recipes_source/recipes/profiler.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
"""
2+
PyTorch Profiler
3+
====================================
4+
This recipe explains how to use PyTorch profiler and measure the time and
5+
memory consumption of the model's operators.
6+
7+
Introduction
8+
------------
9+
PyTorch includes a simple profiler API that is useful when user needs
10+
to determine the most expensive operators in the model.
11+
12+
In this recipe, we will use a simple Resnet model to demonstrate how to
13+
use profiler to analyze model performance.
14+
15+
Setup
16+
-----
17+
To install ``torch`` and ``torchvision`` use the following command:
18+
19+
::
20+
21+
pip install torch torchvision
22+
23+
24+
"""
25+
26+
27+
######################################################################
28+
# Steps
29+
# -----
30+
#
31+
# 1. Import all necessary libraries
32+
# 2. Instantiate a simple Resnet model
33+
# 3. Use profiler to analyze execution time
34+
# 4. Use profiler to analyze memory consumption
35+
# 5. Using tracing functionality
36+
#
37+
# 1. Import all necessary libraries
38+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
39+
#
40+
# In this recipe we will use ``torch``, ``torchvision.models``
41+
# and ``profiler`` modules:
42+
#
43+
44+
import torch
45+
import torchvision.models as models
46+
import torch.autograd.profiler as profiler
47+
48+
49+
######################################################################
50+
# 2. Instantiate a simple Resnet model
51+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
52+
#
53+
# Let's create an instance of a Resnet model and prepare an input
54+
# for it:
55+
#
56+
57+
model = models.resnet18()
58+
inputs = torch.randn(5, 3, 224, 224)
59+
60+
######################################################################
61+
# 3. Use profiler to analyze execution time
62+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
63+
#
64+
# PyTorch profiler is enabled through the context manager and accepts
65+
# a number of parameters, some of the most useful are:
66+
#
67+
# - ``record_shapes`` - whether to record shapes of the operator inputs;
68+
# - ``profile_memory`` - whether to report amount of memory consumed by
69+
# model's Tensors;
70+
# - ``use_cuda`` - whether to measure execution time of CUDA kernels.
71+
#
72+
# Let's see how we can use profiler to analyze the execution time:
73+
74+
with profiler.profile(record_shapes=True) as prof:
75+
with profiler.record_function("model_inference"):
76+
model(inputs)
77+
78+
######################################################################
79+
# Note that we can use ``record_function`` context manager to label
80+
# arbitrary code ranges with user provided names
81+
# (``model_inference`` is used as a label in the example above).
82+
# Profiler allows one to check which operators were called during the
83+
# execution of a code range wrapped with a profiler context manager.
84+
# If multiple profiler ranges are active at the same time (e.g. in
85+
# parallel PyTorch threads), each profiling context manager tracks only
86+
# the operators of its corresponding range.
87+
# Profiler also automatically profiles the async tasks launched
88+
# with ``torch.jit._fork`` and (in case of a backward pass)
89+
# the backward pass operators launched with ``backward()`` call.
90+
#
91+
# Let's print out the stats for the execution above:
92+
93+
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
94+
95+
######################################################################
96+
# The output will look like (omitting some columns):
97+
98+
# ------------------------- -------------- ---------- ------------ ---------
99+
# Name Self CPU total CPU total CPU time avg # Calls
100+
# ------------------------- -------------- ---------- ------------ ---------
101+
# model_inference 3.541ms 69.571ms 69.571ms 1
102+
# conv2d 69.122us 40.556ms 2.028ms 20
103+
# convolution 79.100us 40.487ms 2.024ms 20
104+
# _convolution 349.533us 40.408ms 2.020ms 20
105+
# mkldnn_convolution 39.822ms 39.988ms 1.999ms 20
106+
# batch_norm 105.559us 15.523ms 776.134us 20
107+
# _batch_norm_impl_index 103.697us 15.417ms 770.856us 20
108+
# native_batch_norm 9.387ms 15.249ms 762.471us 20
109+
# max_pool2d 29.400us 7.200ms 7.200ms 1
110+
# max_pool2d_with_indices 7.154ms 7.170ms 7.170ms 1
111+
# ------------------------- -------------- ---------- ------------ ---------
112+
113+
######################################################################
114+
# Here we see that, as expected, most of the time is spent in convolution (and specifically in ``mkldnn_convolution``
115+
# for PyTorch compiled with MKL-DNN support).
116+
# Note the difference between self cpu time and cpu time - operators can call other operators, self cpu time exludes time
117+
# spent in children operator calls, while total cpu time includes it.
118+
#
119+
# To get a finer granularity of results and include operator input shapes, pass ``group_by_input_shape=True``:
120+
121+
print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10))
122+
123+
# (omitting some columns)
124+
# ------------------------- ----------- -------- -------------------------------------
125+
# Name CPU total # Calls Input Shapes
126+
# ------------------------- ----------- -------- -------------------------------------
127+
# model_inference 69.571ms 1 []
128+
# conv2d 9.019ms 4 [[5, 64, 56, 56], [64, 64, 3, 3], []]
129+
# convolution 9.006ms 4 [[5, 64, 56, 56], [64, 64, 3, 3], []]
130+
# _convolution 8.982ms 4 [[5, 64, 56, 56], [64, 64, 3, 3], []]
131+
# mkldnn_convolution 8.894ms 4 [[5, 64, 56, 56], [64, 64, 3, 3], []]
132+
# max_pool2d 7.200ms 1 [[5, 64, 112, 112]]
133+
# conv2d 7.189ms 3 [[5, 512, 7, 7], [512, 512, 3, 3], []]
134+
# convolution 7.180ms 3 [[5, 512, 7, 7], [512, 512, 3, 3], []]
135+
# _convolution 7.171ms 3 [[5, 512, 7, 7], [512, 512, 3, 3], []]
136+
# max_pool2d_with_indices 7.170ms 1 [[5, 64, 112, 112]]
137+
# ------------------------- ----------- -------- --------------------------------------
138+
139+
140+
######################################################################
141+
# 4. Use profiler to analyze memory consumption
142+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
143+
#
144+
# PyTorch profiler can also show the amount of memory (used by the model's tensors)
145+
# that was allocated (or released) during the execution of the model's operators.
146+
# In the output below, 'self' memory corresponds to the memory allocated (released)
147+
# by the operator, excluding the children calls to the other operators.
148+
# To enable memory profiling functionality pass ``profile_memory=True``.
149+
150+
with profiler.profile(profile_memory=True, record_shapes=True) as prof:
151+
model(inputs)
152+
153+
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
154+
155+
# (omitting some columns)
156+
# --------------------------- --------------- --------------- ---------------
157+
# Name CPU Mem Self CPU Mem Number of Calls
158+
# --------------------------- --------------- --------------- ---------------
159+
# empty 94.79 Mb 94.79 Mb 123
160+
# resize_ 11.48 Mb 11.48 Mb 2
161+
# addmm 19.53 Kb 19.53 Kb 1
162+
# empty_strided 4 b 4 b 1
163+
# conv2d 47.37 Mb 0 b 20
164+
# --------------------------- --------------- --------------- ---------------
165+
166+
print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))
167+
168+
# (omitting some columns)
169+
# --------------------------- --------------- --------------- ---------------
170+
# Name CPU Mem Self CPU Mem Number of Calls
171+
# --------------------------- --------------- --------------- ---------------
172+
# empty 94.79 Mb 94.79 Mb 123
173+
# batch_norm 47.41 Mb 0 b 20
174+
# _batch_norm_impl_index 47.41 Mb 0 b 20
175+
# native_batch_norm 47.41 Mb 0 b 20
176+
# conv2d 47.37 Mb 0 b 20
177+
# convolution 47.37 Mb 0 b 20
178+
# _convolution 47.37 Mb 0 b 20
179+
# mkldnn_convolution 47.37 Mb 0 b 20
180+
# empty_like 47.37 Mb 0 b 20
181+
# max_pool2d 11.48 Mb 0 b 1
182+
# max_pool2d_with_indices 11.48 Mb 0 b 1
183+
# resize_ 11.48 Mb 11.48 Mb 2
184+
# addmm 19.53 Kb 19.53 Kb 1
185+
# adaptive_avg_pool2d 10.00 Kb 0 b 1
186+
# mean 10.00 Kb 0 b 1
187+
# --------------------------- --------------- --------------- ---------------
188+
189+
######################################################################
190+
# 5. Using tracing functionality
191+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
192+
#
193+
# Profiling results can be outputted as a .json trace file:
194+
195+
with profiler.profile() as prof:
196+
with profiler.record_function("model_inference"):
197+
model(inputs)
198+
199+
prof.export_chrome_trace("trace.json")
200+
201+
######################################################################
202+
# User can examine the sequence of profiled operators after loading the trace file
203+
# in Chrome (``chrome://tracing``):
204+
#
205+
# .. image:: ../../_static/img/trace_img.png
206+
# :scale: 25 %
207+
208+
######################################################################
209+
# Learn More
210+
# ----------
211+
#
212+
# Take a look at the following tutorial to learn how to visualize your model with TensorBoard:
213+
#
214+
# - `Visualizing models, data, and training with TensorBoard <https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html>`_ tutorial
215+
#

recipes_source/recipes_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ Recipes are bite-sized bite-sized, actionable examples of how to use specific Py
9595
:link: ../recipes/recipes/zeroing_out_gradients.html
9696
:tags: Basics
9797

98+
.. customcarditem::
99+
:header: PyTorch Profiler
100+
:card_description: Learn how to use PyTorch's profiler to measure operators time and memory consumption
101+
:image: ../_static/img/thumbnails/cropped/profiler.png
102+
:link: ../recipes/recipes/profiler.html
103+
:tags: Basics
104+
98105
.. Customization
99106
100107
.. customcarditem::
@@ -174,6 +181,7 @@ Recipes are bite-sized bite-sized, actionable examples of how to use specific Py
174181
/recipes/recipes/warmstarting_model_using_parameters_from_a_different_model
175182
/recipes/recipes/save_load_across_devices
176183
/recipes/recipes/zeroing_out_gradients
184+
/recipes/recipes/profiler
177185
/recipes/recipes/custom_dataset_transforms_loader
178186
/recipes/recipes/Captum_Recipe
179187
/recipes/recipes/tensorboard_with_pytorch

0 commit comments

Comments
 (0)