Skip to content

Commit ec4e471

Browse files
authored
Surface Fuser and Pipeline UI as script entrypoints (#36)
* Surface Fuser and Pipeline UI as script entrypoints * Cleanup * Remove TODOs * lint
1 parent 5a28c10 commit ec4e471

File tree

4 files changed

+26
-35
lines changed

4 files changed

+26
-35
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`.
9999

100100
- **UIs** — interactive runs with Gradio frontends:
101101
- Triton KernelAgent UI: `kernel-agent` or `python scripts/triton_ui.py`
102-
- Fuser orchestration UI: `python -m Fuser.fuser_ui`
103-
- Full pipeline UI: `python -m Fuser.pipeline_ui`
102+
- Fuser orchestration UI: `fuser-ui` or `python scripts/fuser_ui`
103+
- Full pipeline UI: `pipeline-ui` or `python scripts/pipeline_ui`
104104

105105
## Component Details
106106

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@ dev = [
4343
]
4444

4545
[project.scripts]
46+
fuser-ui = "scripts.fuser_ui:main"
4647
kernel-agent = "scripts.triton_ui:main"
4748
list-models = "scripts.list_models:main"
49+
pipeline-ui = "scripts.pipeline_ui:main"
4850

4951
[project.urls]
5052
"Homepage" = "https://github.com/pytorch-labs/KernelAgent"

Fuser/fuser_ui.py renamed to scripts/fuser_ui.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import argparse
2020
import ast
2121
import os
22-
import sys
2322
import tarfile
2423
import time
2524
import traceback
@@ -38,19 +37,9 @@
3837
except Exception: # pragma: no cover
3938
extract_subgraphs_to_json = None # type: ignore
4039

41-
# Support both package and script execution contexts
42-
if __package__ is None or __package__ == "":
43-
PACKAGE_ROOT = Path(__file__).resolve().parent
44-
REPO_ROOT = PACKAGE_ROOT.parent
45-
if str(REPO_ROOT) not in sys.path:
46-
sys.path.insert(0, str(REPO_ROOT))
47-
from Fuser.config import OrchestratorConfig, new_run_id
48-
from Fuser.orchestrator import Orchestrator
49-
from Fuser.paths import ensure_abs_regular_file, make_run_dirs, PathSafetyError
50-
else:
51-
from .config import OrchestratorConfig, new_run_id
52-
from .orchestrator import Orchestrator
53-
from .paths import ensure_abs_regular_file, make_run_dirs, PathSafetyError
40+
from Fuser.config import OrchestratorConfig, new_run_id
41+
from Fuser.orchestrator import Orchestrator
42+
from Fuser.paths import ensure_abs_regular_file, make_run_dirs, PathSafetyError
5443

5544

5645
@dataclass
@@ -327,6 +316,7 @@ def __init__(self) -> None:
327316
repo_root / "external" / "KernelBench" / "KernelBench",
328317
Path.cwd() / "external" / "KernelBench" / "KernelBench",
329318
Path.cwd() / "KernelBench" / "KernelBench",
319+
Path.cwd().parent / "KernelBench" / "KernelBench",
330320
]
331321
seen: set[str] = set()
332322
collected: list[tuple[str, str]] = []

Fuser/pipeline_ui.py renamed to scripts/pipeline_ui.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,9 @@
2929
import gradio as gr
3030
from dotenv import load_dotenv
3131

32-
try:
33-
# Support both package and script execution contexts (match Fuser/fuser_ui.py pattern)
34-
if __package__ is None or __package__ == "":
35-
PACKAGE_ROOT = Path(__file__).resolve().parent
36-
REPO_ROOT = PACKAGE_ROOT.parent
37-
if str(REPO_ROOT) not in sys.path:
38-
sys.path.insert(0, str(REPO_ROOT))
39-
from Fuser.pipeline import run_pipeline
40-
from Fuser.auto_agent import AutoKernelRouter
41-
from Fuser.code_extractor import extract_single_python_file
42-
else:
43-
from .pipeline import run_pipeline
44-
from .auto_agent import AutoKernelRouter
45-
from .code_extractor import extract_single_python_file
46-
except Exception:
47-
raise
32+
from Fuser.pipeline import run_pipeline
33+
from Fuser.auto_agent import AutoKernelRouter
34+
from Fuser.code_extractor import extract_single_python_file
4835
from triton_kernel_agent.providers.models import (
4936
get_model_provider,
5037
MODEL_NAME_TO_CONFIG,
@@ -416,9 +403,21 @@ def run_pipeline_ui(
416403
class PipelineUI:
417404
def __init__(self) -> None:
418405
load_dotenv()
419-
self.problem_choices = _list_kernelbench_problems(
420-
Path.cwd() / "external" / "KernelBench" / "KernelBench"
421-
)
406+
candidate_roots = [
407+
Path.cwd() / "external" / "KernelBench" / "KernelBench",
408+
Path.cwd() / "KernelBench" / "KernelBench",
409+
Path.cwd().parent / "KernelBench" / "KernelBench",
410+
]
411+
seen: set[str] = set()
412+
collected: list[tuple[str, str]] = []
413+
for base in candidate_roots:
414+
print(base, file=sys.stderr)
415+
for label, abspath in _list_kernelbench_problems(base):
416+
if abspath not in seen:
417+
collected.append((label, abspath))
418+
seen.add(abspath)
419+
self.problem_choices = collected
420+
422421
control_flow_path = (
423422
Path(__file__).resolve().parent.parent.parent
424423
/ "external"

0 commit comments

Comments
 (0)