Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ab_testing.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# A/B Testing Documentation
# TritonBench A/B Testing

## Overview

Expand Down
33 changes: 33 additions & 0 deletions docs/accuracy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Numerics Checking with TritonBench

TritonBench supports numerics checking and export for accuracy checking. Every operator backend can export its output with a given input.
In the forward mode, it will compare the forward pass output tensors.
In the backward mode, it will compare the gradients of input tensors that requires gradients.

## Compare numerics of different compiler backends with `--metrics accuracy`

`--metrics accuracy` requires the operator declares a backend as the baseline and will compare other backend numerics against it.
Users can use `--baseline <BACKEND_NAME>` to specify the baseline backend. If unspecified, TritonBench will use the backend decorated by `@register_benchmark(baseline=True)`.

By default, TritonBench uses `torch.testing.assert_close()` API [link](https://docs.pytorch.org/docs/stable/testing.html),
which will set different `rtol` and `atol` thresholds. For example, for `bfloat16` dtype, `rtol` is `1.6e2` and `atol` is `1e-5`.
`--metrics accuracy` will return `1` when the numeric matches the baseline, and `0` when it does not.
We provide CLI options `--rtol` and `--atol` for users to tune these thresholds, they are both `None` by default, which will use the default values used by PyTorch.

If users want to create their own numeric checking methods, they can override the accuracy checking metric like in this [code example](https://github.com/meta-pytorch/tritonbench/blob/9a4bbc7070b134fb274114018ac02b38fcfd4ba7/tritonbench/operators/vector_exp/operator.py#L88).

We force all backends of one operator to comply to the same numeric checking criteria.

## Compare numerics on different hardware platforms with export output

When comparing numerics on different devices, we provide `--export [input | output | both]` and `--export-dir <DIR>` options.
Users need to first export the tensor outputs to one directory on one device, then run the same command to export the output on the second device,
and finally copy the two directories under the same filesystem for comparison.

We provide a simple script to compare two directories:

```
python benchmarks/numeric_check/run.py --a <DIR_ON_DEVICE_A> --b <DIR_ON_DEVICE_B>
```

For cross-device numeric checking, we only support the default threshold using `torch.testing.all_close()`.
17 changes: 17 additions & 0 deletions docs/data.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# TritonBench Input Data

In TritonBench, users can customize the input data to run. Here is an overview of the CLI options related to inputs.

| Option | Usage |
|-----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `--input-id` | Input ID to run, starting from 0. Default is 0. |
| `--num-inputs` | Number of inputs to run. By default, run all available inputs. |
| `--input-sample-mode` | Input sampling mode. 'first-k' (default) uses the first k inputs starting from `--input-id`. "'equally-spaced-k' selects k equally spaced inputs from the entire input range, where k is specified by --num-inputs. |
| `--input-loader` | Specify a json file to load inputs from the input json file. |


## Input Data Collection

We keep a set of input data in the [data/input_configs](https://github.com/meta-pytorch/tritonbench/tree/main/tritonbench/data/input_configs) directory.
The input data is organized by model names and is in json format. User can specify the input config by `--input-loader <path-to-input-json>`.
TritonBench will generate synthetic inputs based on the input config.
26 changes: 23 additions & 3 deletions docs/kineto_trace.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Kineto Trace Analysis with TritonBench



## Example 1: Kineto Trace Analysis

TritonBench supports generating a Kineto trace file for each `<input, impl>` pair.
For example, the following command will generate 6 Kineto traces, as it is running 2 inputs(`--num-inputs 2`) with 3 impls (`flash_v3,cudnn,triton_tutorial_flash_v2`).
We use the following command to generate 6 Kineto traces, as it is running 2 inputs(`--num-inputs 2`) with 3 impls (`flash_v3,cudnn,triton_tutorial_flash_v2`).

```
$ python run.py --op flash_attention --num-inputs 2 --metrics kineto_trace --only flash_v3,cudnn,triton_tutorial_flash_v2
Expand All @@ -14,8 +18,6 @@ $ python run.py --op flash_attention --num-inputs 2 --metrics kineto_trace --onl

The output table shows the directory where the Kineto trace file is stored.

## Example Kineto Trace Analysis

Opening the trace file with Chrome Trace Viewer, we need to first separate the profiling iteration with the warm-up iterations.
The profiling iteration runs after all warm-up iteraions and is labeled by `ProfilerStep#<number>`.

Expand All @@ -26,3 +28,21 @@ The second one corresponds to the actual computation kernel, which is from CUDNN

![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_2.png "Kineto Trace - Zoomed into Profile Iteration")

## Example 2: Kineto Trace with CUDA Graph enabled

If the operator supports CUDA Graph and CUPTI, we can generate Kineto trace with CUDA Graph enabled. To do that, simply combine `--cudagraph` with `--metrics kineto_trace`.
Here is an example command:

```
$ python run.py --op flash_attention --num-inputs 1 --metrics kineto_trace --only triton_tutorial_flash_v2 --cudagraph

(Batch, Heads, SeqLen, SeqLen_KV, Dhead) triton_tutorial_flash_v2-kineto_trace
------------------------------------------ ---------------------------------------------------------------------------
(4, 48, 128, 128, 64) /tmp/tritonbench_xzhao9/bf16_flash_attention_fwd/triton_tutorial_flash_v2_0
average

```



![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_cudagraph_fig_1.png "Kineto Trace - CUDA Graph launch")
37 changes: 32 additions & 5 deletions docs/metrics.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TritonBench Metrics
# TritonBench Metrics and Performance Measurement Options

TritonBench supports two types of metrics: built-in and user-defined.
All metrics are specified with `--metrics <METRIC_NAMES>` option, where `<METRIC_NAMES>` are built-in or user-defined metric names separated by comma.
Expand All @@ -9,23 +9,50 @@ TritonBench supports a rich set of built-in metrics.

| Metric Name | Definition |
|-----------------|---------------------------------------------------------------------------------------------------|
| `latency` | The latency given by `triton.testing.do_bench`. |
| `kineto_trace` | Chrome Trace generated by Kineto. |
| `latency` | The latency given by `triton.testing.do_bench`, in milliseconds. |
| `kineto_trace` | Chrome Trace generated by Kineto. More details in [Kineto Trace Analysis](kineto_trace.md) |
| `walltime` | CPU-side wall latency, including CPU kernel launch time. |
| `cuda_time` | Sum of all GPU-side kernels time of an operator backend, measured by Kineto and PyTorch Profiler. |
| `ncu_rep` | (NVIDIA-only) Generate the NVIDIA NSight Compute Replay file. |
| `nsys_rep` | (NVIDIA-only) Generate the NVIDIA NSight Systems Replay file. |
| `speedup` | (Requires baseline backend) Latency speedup comparing to the baseline backend. |
| `accuracy` | (Requires baseline backend) Numeric accuracy comapring to the baseline backend. |
| `compile_time` | (Triton-only) Triton compile time. |
| `compile_trace` | (Triton-only) Kineto profiling of Triton compile. |

## Additional Options for High-precision GPU Kernel Latency Measurement

For most of the built-in metrics (e.g., `latency`,`kineto_trace`,`cuda_time`), user can use the `--cudagraph` option to improve reduce the CPU-side launch overhead.
### Latency: `--latency-measure-mode [triton_do_bench | profiler]` and `--cudagraph`

Latency is the foundation of all performance related metrics such as memory throughput and TFLOPS.

By default, latency is measured by `triton.testing.do_bench`. This method is
fast, but it may not be accurate because it does not account for the time spent
in the CUDA kernel launch overhead, especially when the operator involves
multiple CUDA kernel launches. By using the `--cudagraph` option, it will use
CUDA Graph to reduce the CPU-side launch overhead.

Another option is to use `--latency-measure-mode profiler`, which is slower, but
more accurate because it will use the Kineto profiler to measure the latency,
which excludes the CUDA kernel launch overhead. Using `--latency-measure-mode
profiler --cudagraph` is by far the most accurate latency measurement approach.

The `--cudagraph` option also works with `--metrics kineto_trace`, which
collects the Kineto trace when launching the kernel with CUDA Graph. However,
note that not all kernels will work with CUDA Graph enabled.

## Latency: `--warmup <MS>`, `--rep <MS>` and `--sleep <SEC>`

There are three run time options that can also affect the latency measurement: `--warmup`, `--rep` and `--sleep`.

- `--warmup <MS>`: The number of milliseconds to warm up the kernel before measuring the latency.
- `--rep <MS>`: The number of milliseconds to repeat the kernel execution. For example, `--rep 1000` will repeat the kernel execution for 1 second.
- `--sleep <SEC>`: The number of seconds to sleep between each kernel execution. For example, `--sleep 1` will sleep for 1 second between each backend execution.
This is to restore the GPU power state to normal idle state before each backend execution.

## User-defined metrics

Additionally, users can define customized metrics or override in `operator.py` using the `@register_metric` decorator.
These user-defined metrics can utilize the basic metrics provided by TritonBench, such as `latency`, `walltime`, `kineto_trace`, etc.

Here are some examples:

Expand Down