Skip to content

Commit

Permalink
[ROOFLINE] Add CUDA support to roofline analysis (#12205)
Browse files Browse the repository at this point in the history
* [ROOFLINE] Add CUDA support to roofline analysis

Add functions to estimate peak flops and bandwidth for CUDA. Add a new
registration mechanism to the roofline analysis to support adding any
target. This mechanism uses generic functions with overrides. New
targets only need to add `estimate_peak_bandwidth` and
`estimate_peak_flops` functions.

Also fix cuda codegen and tensorcore_infer_fragment.cc to support
filling matrix_a and matrix_b fragments.

* formatiing

* move statement back inside loops

* print out report for debugging

* default to avx2

* review comments
  • Loading branch information
Tristan Konolige authored Jul 30, 2022
1 parent e756980 commit 961a7c7
Show file tree
Hide file tree
Showing 10 changed files with 736 additions and 342 deletions.
2 changes: 1 addition & 1 deletion python/tvm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# under the License.
"""Utilities operating at a graph/model or other "high" level"""

from .roofline import estimate_peak_bandwidth, estimate_peak_fma_flops, roofline_analysis
from .roofline import roofline_analysis
266 changes: 28 additions & 238 deletions python/tvm/utils/roofline.py → python/tvm/utils/roofline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
from typing import Dict, Union, Optional
import numpy as np

from .. import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func
from ..target import Target
from ..runtime import profiler_vm, profiling, Device, num_threads
from ..script import tir as T
from ..ir.instrument import pass_instrument
from ..ir.expr import GlobalVar
from ..rpc.base import RPC_SESS_MASK
from ..rpc.client import RPCSession
from ..contrib import utils
from ... import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func
from ...target import Target
from ...runtime import profiler_vm, profiling, Device, num_threads
from ...script import tir as T
from ...ir.instrument import pass_instrument
from ...ir.expr import GlobalVar
from ...rpc.base import RPC_SESS_MASK
from ...rpc.client import RPCSession
from ...contrib import utils

from . import registry, cuda, x86


def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=None):
Expand All @@ -47,231 +49,6 @@ def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=Non
return args


def _detect_vec_width_registers(
target: Target, vec_width: Optional[int], num_vector_registers: Optional[int]
):
"""Get the vector width and number of vector registers for a target.
Parameters
----------
target : Target
Target to detect vector width and registers for.
vec_width : Optional[int]
If None, try and detect vector width from target. Otherwise provided input is used.
num_vector_registers : Optional[int]
If None, try and number of vector registers from target. Otherwise provided input is used.
Returns
-------
vec_width: int
Width of a vector register on `target`.
num_vector_registers: int
Number of vector registers on `target`.
"""
if vec_width is None:
# Only implemented for x86 so far...
if (
str(target.kind) == "llvm"
and target.device_name == ""
and len(target.keys) == 1
and target.keys[0] == "cpu"
):
with target:
vec_width = topi.x86.utils.get_simd_32bit_lanes() # in number of float32s
else:
raise RuntimeError(f"Cannot determine vector width for target {target}")
if num_vector_registers is None:
if target.device_name == "": # indicates x86
num_vector_registers = 16 # Assuming for all platforms, probably wrong on older ones
else:
raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
return vec_width, num_vector_registers


@T.prim_func
def peakflops_fma_tir(
a: T.handle,
vec_width: T.int32,
iters: T.int32,
num_vector_registers: T.int32,
threads: T.int32,
) -> None:
# pylint: disable=invalid-name, missing-function-docstring
A = T.match_buffer(a, [threads, num_vector_registers, vec_width], "float32")
for t in T.parallel(threads):
for _j in range(iters):
for l in T.unroll(num_vector_registers):
# We want to use as few registers as possible, so we perform
# all operations on the same element
for k in T.vectorized(vec_width):
A[t, l, k] = A[t, l, k] * A[t, l, k] + A[t, l, k]


def estimate_peak_fma_flops(
target: Target,
dev: Device,
vec_width: Optional[int] = None,
num_vector_registers: Optional[int] = None,
remote: Optional[RPCSession] = None,
) -> float:
"""
Estimate the maximum number of FLOP/s this target/device combo is capable
of reaching by running a test program. This assumes vectorized f32 FMA
(fused-multiply-add) instructions.
Parameters
----------
target : Target
Target to run on. This should be as specific to the actual hardware as
possible to make sure that LLVM generates the best vector code.
dev : Device
Device to run on.
vec_width : Optional[int]
Vector width of SIMD units on the underlying hardware. Will try to
infer if no value is provided.
num_vector_registers : Optional[int]
Number of vector registers on the underlying hardware. Will try to
infer if no value is provided.
remote : Optional[RPCSession]
Remote session used to upload artifacts for runtime evaluation. Must be
the same session used to create `dev`.
Returns
-------
float
Approximate sustained FLOP/s of this target/device combo assuming
vectorized f32 FMA instructions.
"""
assert str(target.kind) == "llvm", "Only llvm targets are supported"
vec_width, num_vector_registers = _detect_vec_width_registers(
target, vec_width, num_vector_registers
)
iters = 1000000
nthreads = num_threads()
specialized = peakflops_fma_tir.specialize(
{
peakflops_fma_tir.params[1]: vec_width,
peakflops_fma_tir.params[2]: iters,
peakflops_fma_tir.params[3]: num_vector_registers,
peakflops_fma_tir.params[4]: nthreads,
}
)
with transform.PassContext(opt_level=3):
f = build(specialized, target=target)

# upload to remote if running over rpc
if dev.device_type >= RPC_SESS_MASK:
if remote is None:
raise RuntimeError("A RPCSession must be provided when using a remote device.")
temp = utils.tempdir()
path = temp.relpath("peak_fma_flops.tar")
f.export_library(path)
remote.upload(path)
f = remote.load_module("peak_fma_flops.tar")
random_fill = remote.get_function("tvm.contrib.random.random_fill")
else:
random_fill = get_global_func("tvm.contrib.random.random_fill")
assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"

a = nd.empty((nthreads, num_vector_registers, vec_width), dtype="float32", device=dev)
random_fill(a)
times = f.time_evaluator(f.entry_name, dev, repeat=100, number=1)(a)
flops = 2 * vec_width * num_vector_registers * nthreads * iters # fma is two flops
flop_s = flops / times.min
return flop_s


@T.prim_func
def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.int32) -> None:
# pylint: disable=invalid-name, missing-function-docstring
N = T.var("int32")
A = T.match_buffer(a, [threads, N, 4, vec_width], "float32")
B = T.match_buffer(b, [threads, vec_width, 4], "float32")
# Parallelism is necessary to hit all cores/nodes
for i in T.parallel(threads):
for k in T.serial(N):
for l in T.unroll(4):
# vectorized load is necessary to hit peak bandwidth
for j in T.vectorized(vec_width):
# += is necessary to introduce a data dependency for all
# elements of A, preventing the backend from removing the
# `k` loop and setting `k` to the loop extent.
B[i, l, j] += A[i, k, l, j]


def estimate_peak_bandwidth(
target: Target,
dev: Device,
vec_width: Optional[int] = None,
remote: Optional[RPCSession] = None,
) -> float:
"""Estimate peak memory bandwidth of a target/device combo.
Peak bandwidth is estimated by running a small experiment on the underlying
hardware. The peak bandwidth measurement assumes that vector instructions
are being used to load the data.
Parameters
----------
target : Target
Target to use for measurement. This target should be as specific to the
underlying hardware as possible.
dev : Device
Device to measure peak bandwidth on.
vec_width : Optional[int]
Vector unit width, determined from target if not supplied.
remote : Optional[RPCSession]
Remote session used to upload artifacts for runtime evaluation. Must be
the same session used to create `dev`.
Returns
-------
float
Peak memory bandwidth in bytes/seconds.
"""
# Ideally we'd be able to use this code to measure peak bandwidth of the
# different cache levels. If we could just generate load commands, then we
# could use those in a tight loop. Instead we need some code that is
# limited on the cache bandwidth. With the L1 cache we need an operation
# that has a very low arithmetic intensity and we haven't come up with one
# yet.
vec_width, _ = _detect_vec_width_registers(target, vec_width, 1)
specialized = peak_bandwidth_tir.specialize(
{
peak_bandwidth_tir.params[3]: vec_width,
}
)
with transform.PassContext(opt_level=3):
f = build(specialized, target=target)

# upload to remote if running over rpc
if dev.device_type >= RPC_SESS_MASK:
if remote is None:
raise RuntimeError("A RPCSession must be provided when using a remote device.")
temp = utils.tempdir()
path = temp.relpath("peak_bandwidth.tar")
f.export_library(path)
remote.upload(path)
f = remote.load_module("peak_bandwidth.tar")
random_fill = remote.get_function("tvm.contrib.random.random_fill")
else:
random_fill = get_global_func("tvm.contrib.random.random_fill")
assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"

threads = num_threads()
# Data size needs to be larger than last level of cache. We don't have a
# way of getting cache sizes, so this number should give us a large enough
# size.
size = 10**8 // (4 * threads * vec_width)
a = nd.empty((threads, size, 4, vec_width), dtype="float32", device=dev)
random_fill(a)
b = nd.empty((threads, vec_width, 4), dtype="float32", device=dev)
random_fill(b)
times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b, threads)
return a.numpy().size * 4 / times.min # 4 bytes per float32


@pass_instrument
class SaveLoweredTIR:
"""Save TIR functions from right before final lowering. Right now this
Expand Down Expand Up @@ -357,8 +134,9 @@ def roofline_from_existing(
:py:func:`roofline_analysis` for more information on which metrics
are included.
"""
peak_bandwidth = estimate_peak_bandwidth(target, dev, remote=remote)
peak_flops = estimate_peak_fma_flops(target, dev, remote=remote)
with target:
peak_bandwidth = registry.estimate_peak_bandwidth(target, dev, remote)
peak_flops = registry.estimate_peak_flops(target, dev, remote)

ridge_point = peak_flops / peak_bandwidth

Expand All @@ -377,7 +155,19 @@ def roofline_from_existing(
loaded_bytes = 0.0
# assume no more than 100 buffers
for i in range(100):
key = f"B{i}.bytes"
if str(target.kind) == "cuda":
# autoscheduler features do not take into account that 1.
# global and shared memory have very different performance
# characteristics -- both are included in the same bytes
# touched count 2. multiple threads accessing the same byte
# of memory does not use the same amount of bandwidth as
# multiple threads accessing different bytes of memory. We
# use unique bytes accessed here to avoid these two issues,
# but this does bias results towards being more compute
# bound.
key = f"B{i}.unique_bytes"
else:
key = f"B{i}.bytes"
if not key in features.keys():
break
loaded_bytes += np.sum(features[key])
Expand All @@ -401,7 +191,7 @@ def roofline_from_existing(
else:
new_calls.append(call)
new_configuration = dict(report.configuration.items())
new_configuration["Estimated Peak FMA FLOP/s"] = profiling.Ratio(peak_flops)
new_configuration["Estimated Peak FLOP/s"] = profiling.Ratio(peak_flops)
new_configuration["Estimated Peak Bandwidth (byte/second)"] = profiling.Ratio(peak_bandwidth)
return profiling.Report(new_calls, report.device_metrics, new_configuration)

Expand Down
Loading

0 comments on commit 961a7c7

Please sign in to comment.