-
Notifications
You must be signed in to change notification settings - Fork 7
02. Usage Guide
This guide walks you through the complete workflow of using TritonParse to analyze Triton kernel compilation processes.
TritonParse workflow consists of three main steps:
- Generate Traces - Capture Triton compilation events
- Parse Traces - Process raw logs into structured format
- Analyze Results - Visualize and explore using the web interface
First, integrate TritonParse into your Triton/PyTorch code:
# === TritonParse initialization ===
import tritonparse.structured_logging
# Initialize structured logging to capture Triton compilation events
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_launch_trace=True)
# === End TritonParse initialization ===
# Your original Triton/PyTorch code below...
tritonparse.utils.unified_parse(
source=log_path,
out="./parsed_output",
overwrite=True
)
Here's a complete example showing how to instrument a Triton kernel:
import torch
import triton
import triton.language as tl
import tritonparse.structured_logging
import tritonparse.utils
# Initialize logging
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_launch_trace=True)
@triton.jit
def add_kernel(
a_ptr,
b_ptr,
c_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
c = a + b
tl.store(c_ptr + offsets, c, mask=mask)
def tensor_add(a, b):
n_elements = a.numel()
c = torch.empty_like(a)
BLOCK_SIZE = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
add_kernel[grid](a, b, c, n_elements, BLOCK_SIZE)
return c
# Example usage
if __name__ == "__main__":
# Create test tensors
device = "cuda" if torch.cuda.is_available() else "cpu"
a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
b = torch.randn(1024, 1024, device=device, dtype=torch.float32)
# Execute kernel (this will be traced)
c = tensor_add(a, b)
# Parse the generated logs
tritonparse.utils.unified_parse(
source=log_path,
out="./parsed_output",
overwrite=True
)
For PyTorch 2.0+ compiled functions:
import torch
import tritonparse.structured_logging
import tritonparse.utils
# Initialize logging
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_launch_trace=True)
def simple_add(a, b):
return a + b
# Test with torch.compile
compiled_add = torch.compile(simple_add)
# Create test data
device = "cuda"
a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
b = torch.randn(1024, 1024, device=device, dtype=torch.float32)
# Execute compiled function (this will be traced)
result = compiled_add(a, b)
# Parse the generated logs
tritonparse.utils.unified_parse(
source=log_path,
out="./parsed_output",
overwrite=True
)
Set these before running your code:
# Disable FX graph cache to ensure PT2 compilation happens every time (optional)
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
# Enable debug logging (optional)
export TRITONPARSE_DEBUG=1
# Enable NDJSON output (default)
export TRITONPARSE_NDJSON=1
# Enable gzip compression for trace files (optional)
export TRITON_TRACE_GZIP=1
# Run your instrumented code
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python your_script.py
Expected Output:
Triton kernel executed successfully
Torch compiled function executed successfully
tritonparse log file list: /tmp/tmp1gan7zky/log_file_list.json
INFO:tritonparse:Copying parsed logs from /tmp/tmp1gan7zky to /scratch/findhao/tritonparse/tests/parsed_output
================================================================================
📁 TRITONPARSE PARSING RESULTS
================================================================================
📂 Parsed files directory: /scratch/findhao/tritonparse/tests/parsed_output
📊 Total files generated: 2
📄 Generated files:
--------------------------------------------------
1. 📝 dedicated_log_triton_trace_findhao__mapped.ndjson.gz (7.2KB)
2. 📝 log_file_list.json (181B)
================================================================================
✅ Parsing completed successfully!
================================================================================
The unified_parse
function processes raw logs into structured format:
import tritonparse.utils
# Parse logs from directory
tritonparse.utils.unified_parse(
source="./logs/", # Input directory with raw logs
out="./parsed_output", # Output directory for processed files
overwrite=True # Overwrite existing output directory
)
# Parse with additional options
tritonparse.utils.unified_parse(
source="./logs/",
out="./parsed_output",
overwrite=True,
rank=0, # Analyze specific rank (for multi-GPU)
all_ranks=False, # Analyze all ranks
verbose=True # Enable verbose logging
)
After parsing, you'll have:
parsed_output/
├── f0_fc0_a0_cai-.ndjson.gz # Compressed trace for PT2 compiled functions
├── dedicated_log_triton_trace_findhao__mapped.ndjson.gz # Compressed trace for Triton compiled functions
├── ...
└── log_file_list.json # Index of all generated files (optional)
Each ndjson.gz
file contains:
- Kernel metadata (grid size, block size, etc.)
- All IRs (TTGIR, TTIR, LLIR, PTX, AMDGCN)
- Source mappings between IRs
- Compilation stack traces
- Launch diffs (if launch tracing is enabled)
You can also use the command line interface:
# Basic usage
python run.py ./logs/ -o ./parsed_output
# With options
python run.py ./logs/ -o ./parsed_output --overwrite --verbose
# Parse specific rank
python run.py ./logs/ -o ./parsed_output --rank 0
# Parse all ranks
python run.py ./logs/ -o ./parsed_output --all-ranks
-
Visit the live tool: https://pytorch-labs.github.io/tritonparse/
-
Load your trace files:
- Click "Browse Files" or drag-and-drop
- Select
.gz
files from yourparsed_output
directory - Or select
.ndjson
files from yourlogs
directory
-
Explore the visualization:
- Kernel Overview Tab: Kernel metadata, call stack, IR links
- IR Comparison Tab: Side-by-side IR comparison with line mapping
For contributors or custom deployments:
cd website
npm install
npm run dev
Access at http://localhost:5173
Format | Description | Source Mapping | Recommended |
---|---|---|---|
.gz |
Compressed parsed traces | ✅ Yes | ✅ Yes |
.ndjson |
Raw trace logs | ❌ No |
Note: .ndjson
files don't contain source code mappings between IR stages and launch diffs. Always use .gz
files for full functionality.
The overview page shows:
- Kernel Information: Name, hash, grid/block sizes
- Compilation Metadata: Device, compile time, memory usage
- Call Stack: Python source code that triggered compilation
- IR Navigation: Links to different IR representations
- Launch Diff: Launch parameters that changed across different launches of the same kernel
The comparison view offers:
- Side-by-side IR viewing: Compare different compilation stages
- Synchronized highlighting: Click a line to see corresponding lines in other IRs
- Source mapping: Trace transformations across compilation pipeline
Stage | Description | When Generated |
---|---|---|
TTGIR | Triton GPU IR - High-level GPU operations | After Triton frontend |
TTIR | Triton IR - Language-level operations | After parsing |
LLIR | LLVM IR - Low-level operations | After LLVM conversion |
PTX | NVIDIA PTX Assembly | For NVIDIA GPUs |
AMDGCN | AMD GPU Assembly | For AMD GPUs |
TritonParse can analyze kernel launch parameters to identify variations and commonalities across different launches of the same kernel. This is useful for understanding how dynamic shapes or other factors affect kernel execution.
-
Enable Launch Tracing: You must enable launch tracing during the trace generation step. This is done by passing
enable_launch_trace=True
totritonparse.structured_logging.init()
. -
Parsing: During the parsing step (
tritonparse.utils.unified_parse
), TritonParse will automatically group all launches for each kernel. -
Launch Diff Event: A new event of type
launch_diff
is generated for each kernel. This event contains:-
total_launches
: The total number of times the kernel was launched. -
diffs
: A dictionary showing which launch parameters (e.g.,grid_x
,grid_y
) changed across launches and what their different values were. -
sames
: A dictionary showing which launch parameters remained constant across all launches. -
launch_index_map
: A mapping from the launch index to the original line number in the trace file.
-
{
"event_type": "launch_diff",
"hash": "...",
"name": "triton_kernel_name",
"total_launches": 10,
"launch_index_map": { "0": 15, "1": 25, ... },
"diffs": {
"grid_x": [1024, 2048]
},
"sames": {
"grid_y": 1,
"grid_z": 1,
"stream": 7
}
}
This example shows that grid_x
varied between 1024
and 2048
across 10 launches, while other parameters remained the same.
# Trace a simple kernel to understand compilation stages
@triton.jit
def simple_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = x * 2.0 # Simple operation
tl.store(y_ptr + offsets, y, mask=mask)
# Trace and analyze each compilation stage
Set kernel allowlist to trace only specific kernels:
# Only trace kernels matching these patterns
export TRITONPARSE_KERNEL_ALLOWLIST="my_kernel*,important_*"
For multi-GPU setups:
# Parse all ranks
tritonparse.utils.unified_parse(
source="./logs/",
out="./parsed_output",
all_ranks=True # Analyze all GPU ranks
)
# Or parse specific rank
tritonparse.utils.unified_parse(
source="./logs/",
out="./parsed_output",
rank=1 # Analyze GPU rank 1
)
Error: No kernels found in the processed data