Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ triton_kernel_logs/
*.log
session_*/
worker_*/
.fuse/

# Generated kernels
kernel.py
Expand All @@ -139,6 +140,6 @@ CLAUDE.md
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
Thumbs.db
# Local batch runner
scripts/run_kernelbench_batch.py
50 changes: 27 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,46 @@ Every stage writes artifacts to a run directory under `.fuse/<run_id>/`, includi
## Quickstart

### Requirements
- Python 3.8 – 3.12
- Linux or macOS; CUDA‑capable GPU for Triton execution
- Python 3.8–3.12
- Triton (install separately: `pip install triton` or nightly from source)
- At least one LLM provider:
- OpenAI (`OPENAI_API_KEY`, models like `o4-mini`, `gpt-5`)
- Anthropic (`ANTHROPIC_API_KEY`; default fallback model is `claude-sonnet-4-20250514` when `OPENAI_MODEL` is unset)
- Any OpenAI‑compatible relay endpoint (`LLM_RELAY_URL`, optional `LLM_RELAY_API_KEY`; see `triton_kernel_agent/providers/relay_provider.py`)
- Gradio (UI dependencies; installed as part of the core package)
- Triton (installed separately: `pip install triton` or nightly from source)
- PyTorch (https://pytorch.org/get-started/locally/)
- LLM provider ([OpenAI](https://openai.com/api/), [Anthropic](https://www.anthropic.com/), or a self-hosted relay)

### Installation
### Install
```bash
git clone https://github.com/pytorch-labs/KernelAgent.git
cd KernelAgent
python -m venv .venv && source .venv/bin/activate # choose your own env manager
pip install -e .[dev] # project + tooling deps
pip install triton # not part of extras; install the version you need
pip install -e .
```

# (optional) Install KernelBench for problem examples
#### (Optional) Install KernelBench for problem examples
```bash
git clone https://github.com/ScalingIntelligence/KernelBench.git
```
Note: By default, KernelAgent UI searches for KernelBench at the same level as `KernelAgent`. (i.e. `../KernelBench`)

### Configure credentials
You can export keys directly or use an `.env` file that the CLIs load automatically:
### Configure
You can export keys directly or use an `.env` file that the CLIs load automatically.

```bash
OPENAI_API_KEY=sk-...
OPENAI_MODEL=gpt-5 # override default fallback (claude-sonnet-4-20250514)
OPENAI_MODEL=gpt-5 # default model for extraction
NUM_KERNEL_SEEDS=4 # parallel workers per kernel
MAX_REFINEMENT_ROUNDS=10 # retry budget per worker
LOG_LEVEL=INFO
LOG_LEVEL=INFO # logging level
```

#### LLM Providers
KernelAgent currently supports OpenAI and Anthropic out-of-the-box. You can also use a custom OpenAI endpoint.
These can be configured in `.env` or via environment variables.
```bash
# OpenAI (models like `o4-mini`, `gpt-5`)
OPENAI_API_KEY=sk-...

# Anthropic (default; `claude-sonnet-4-20250514` is used when `OPENAI_MODEL` is unset)
ANTHROPIC_API_KEY=sk-ant-...

# Optional relay configuration for self-hosted gateways
# LLM_RELAY_URL=http://127.0.0.1:11434
# LLM_RELAY_API_KEY=your-relay-token
# LLM_RELAY_TIMEOUT_S=120
# Relay configuration for self-hosted gateways
LLM_RELAY_URL=http://127.0.0.1:11434
LLM_RELAY_TIMEOUT_S=120
```

More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`.
Expand Down
5 changes: 3 additions & 2 deletions triton_kernel_agent/providers/relay_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import requests
import logging
import os

from .base import BaseProvider, LLMResponse

Expand All @@ -34,7 +35,7 @@ class RelayProvider(BaseProvider):
"""

def __init__(self):
self.server_url = "http://127.0.0.1:11434"
self.server_url = os.environ.get("LLM_RELAY_URL", "http://127.0.0.1:11434")
self.is_available_flag = False
super().__init__()

Expand Down Expand Up @@ -68,7 +69,7 @@ def get_response(
self.server_url,
json=request_data,
headers={"Content-Type": "application/json"},
timeout=120.0,
timeout=int(os.environ.get("LLM_RELAY_TIMEOUT_S", 120)),
)

if response.status_code != 200:
Expand Down