A comprehensive visualization and analysis tool for Triton IR files, designed to help developers analyze, debug, and understand Triton kernel compilation processes.
- Interactive Kernel Explorer: Display detailed kernel information and stack traces
- Multi-format IR Support: View and explore multiple Triton IR formats:
- TTGIR (Triton GPU IR)
- TTIR (Triton IR)
- LLIR (LLVM IR)
- PTX (NVIDIA)
- AMDGCN (AMD)
- Side-by-side Comparison: Compare the above IR code with synchronized highlighting
- Interactive Code Views: Click-to-highlight corresponding lines across different formats
- Compilation Tracing: Capture detailed Triton compilation events
- Stack Trace Integration: Full Python stack traces for compilation events
- Metadata Extraction: Comprehensive kernel metadata and compilation statistics
- NDJSON Output: Structured logging format for easy processing
- GitHub Pages: Automatic deployment with GitHub Actions
- Local Development: Full development environment setup
Frontend:
- React 19 with TypeScript
- Vite for build tooling
- Tailwind CSS for styling
- Monaco Editor for code display
- React Syntax Highlighter for syntax highlighting
- React Resizable Panels for layout
Backend/Processing:
- Python with Triton integration
- Structured logging and event tracing
- Source mapping extraction utilities
Prerequisites:
- Python >= 3.9
- Triton > 3.3.1
For now, you need to manually compile latest Triton from source.
Quick Start:
# Clone the repository
git clone https://github.com/pytorch-labs/tritonparse.git
cd tritonparse
# Install Python dependencies
pip install -e .
Additional Prerequisites:
- Node.js >= 18.0.0
- npm
Website Setup:
# Install website dependencies
cd website
npm install
Please refer to wiki usage for more details.
First, integrate TritonParse with your Triton/PyTorch code to generate trace files:
import torch
# === TritonParse init ===
import tritonparse.structured_logging
# Initialize structured logging to capture Triton compilation events
# This will generate NDJSON trace logs in ./logs/
log_path = "./logs/"
tritonparse.structured_logging.init(log_path)
# === TritonParse init end ===
# The below is your original Triton/PyTorch 2 code
...
# === TritonParse parse ===
import tritonparse.utils
tritonparse.utils.unified_parse(log_path)
# === TritonParse parse end ===
See a full example in tests/test_add.py
.
Exampled output:
% TORCHINDUCTOR_FX_GRAPH_CACHE=0 python test_add.py
Triton kernel executed successfully
Torch compiled function executed successfully
WARNING:SourceMapping:No frame_id or frame_compile_id found in the payload.
WARNING:SourceMapping:No frame_id or frame_compile_id found in the payload.
tritonparse log file list: /tmp/tmpl1tp9fto/log_file_list.json
In our test example, it has two triton kernels: one is a pure triton kernel and the other is a PT2 compiled triton kernel. TORCHINDUCTOR_FX_GRAPH_CACHE=0
is used to disable FX graph cache to let PT2 compiler compile the kernel every time. Otherwise, the final parsed log files will only contain the first triton kernel.
The final parsed gz files are stored in the /tmp/tmpl1tp9fto/
directory. The ./logs
directory contains the raw NDJSON logs without source code mapping.
Visit https://pytorch-labs.github.io/tritonparse/ to use the tool directly in your browser:
- Open your local trace file (NDJSON or .gz format) directly in the browser
- Explore the visualization using the Overview and Code Comparison tabs
Supported File Formats:
.ndjson
- Newline Delimited JSON trace files.gz
- Gzip compressed trace files
Once you load a trace file, you'll see the main interface with several key components:
Kernel Overview & Details:
The main interface showing the kernel list, compilation metadata, call stack, and navigation links to different IR representations.
Code Comparison View:
Side-by-side comparison of different IR stages (e.g., TTGIR and PTX) with synchronized line highlighting and interactive navigation.
For contributors working on the website:
cd website
npm install
npm run dev
Access the application at http://localhost:5173
Available Scripts:
npm run build
- Standard buildnpm run build:single
- Standalone HTML filenpm run preview
- Preview production build
tritonparse/
βββ tritonparse/ # Python package
β βββ structured_logging.py # Main logging infrastructure
β βββ extract_source_mappings.py # Source mapping utilities
β βββ source_type.py # Source type definitions
β βββ utils.py # Helper utilities
β βββ common.py # Common functions
β βββ tp_logger.py # Logger configuration
βββ website/ # React web application
β βββ src/ # React source code
β βββ public/ # Static assets and example files
β βββ scripts/ # Build utilities (inline-html.js)
β βββ node_modules/ # Dependencies
β βββ package.json # Node.js dependencies
β βββ vite.config.ts # Vite configuration
β βββ dist/ # Built application (after build)
βββ docs/ # Documentation and assets
β βββ README.md # Documentation guidelines
β βββ screenshots/ # Screenshots for README
βββ tests/ # Test files and example traces
β βββ test_add.py # Example Triton kernel test
β βββ unit_tests.py # Unit tests
β βββ *.ndjson # Example trace files
βββ run.py # Main runner script
βββ pyproject.toml # Python package configuration
βββ LICENSE # BSD-3 license
βββ CONTRIBUTING.md # Contribution guidelines
βββ CODE_OF_CONDUCT.md # Code of conduct
Install in development mode:
pip install -e .
Example test:
cd tests
python test_add.py
TRITONPARSE_DEBUG=1
- Enable debug loggingTRITONPARSE_NDJSON=1
- Output in NDJSON format (default)
Start development server:
cd website
npm run dev
Available Scripts:
npm run dev
- Start development servernpm run build
- Production buildnpm run build:single
- Standalone HTML buildnpm run lint
- Run ESLintnpm run preview
- Preview production build
The TritonParse visualization tool is automatically deployed and available at: https://pytorch-labs.github.io/tritonparse/
Build standalone version:
cd website
npm run build:single
The dist/standalone.html
file contains the entire application and can be deployed anywhere.
TritonParse helps visualize the Triton compilation pipeline:
- Python Source β Triton kernel functions
- TTIR β Triton's high-level IR
- TTGIR β GPU-specific Triton IR
- LLIR β LLVM IR representation
- PTX β NVIDIA PTX assembly
- AMDGCN β AMD GPU IR
Each stage can be inspected and compared to understand optimization transformations.
- Fork the repository
- Create a feature branch:
git checkout -b feature-name
- Make your changes
- Run tests:
npm test
(website) andpython -m pytest
(Python) - Submit a pull request
This project is licensed under the BSD-3 License - see the LICENSE file for details.
- OpenAI Triton - The Triton compiler and language
- PyTorch - Deep learning framework with Triton integration
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Wiki: TritonParse Wiki
Note: This tool is designed for developers working with Triton kernels and GPU computing. Basic familiarity with CUDA, GPU programming concepts, and the Triton language is recommended for effective use.