Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions BackendBench/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _parse_trace_stream(
filter: Optional[List[str]] = None,
desc: str = "Parsing stream",
limit: Optional[int] = None,
model_mapping: Optional[Dict] = None,
) -> List[Dict]:
"""
Parse trace data from a text stream (e.g., from requests.Response.iter_lines()).
Expand All @@ -110,6 +111,7 @@ def _parse_trace_stream(
op_inputs = []
op = None
num_ops = 0
args_to_model = {}

iterator = tqdm(stream, desc=desc, total=len(stream))

Expand All @@ -124,11 +126,14 @@ def _parse_trace_stream(
if num_ops > limit:
break
op = m.group(1)
args_to_model = model_mapping[op]
if op == "aten.sum.SymInt":
op = "aten.sum.dim_IntList"
if m := re.match("cnt: \\d+, (.*)", line):
assert op is not None
args_str = m.group(1)
in_models = args_to_model.get(args_str, [])
in_models_count = len(in_models)
cnt = int(m.group(0).split(",")[0].split(":")[1])

if filter is None or any(f in op for f in filter):
Expand All @@ -141,6 +146,8 @@ def _parse_trace_stream(
"args": args_str,
"count": cnt,
"is_synthetic": is_synthetic,
"in_models": in_models,
"in_models_count": in_models_count,
}
)
return op_inputs
Expand Down Expand Up @@ -257,10 +264,14 @@ def op_list_to_benchmark_dict(ops_list: List[Dict]) -> Dict[str, List[str]]:


def _load_from_trace(
source: Union[str, Path], filter: Optional[List[str]], limit: Optional[int] = None
source: Union[str, Path],
filter: Optional[List[str]],
limit: Optional[int] = None,
model_mapping: Optional[Dict] = None,
) -> List[Dict]:
"""Load operations from trace file(s) and return list of dicts."""
op_inputs = []
assert model_mapping is not None

# Handle URLs - stream directly without saving to disk
if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")):
Expand All @@ -275,7 +286,9 @@ def _load_from_trace(
lines = content.splitlines()

# Now parse with accurate progress (tqdm will know total lines)
op_inputs = _parse_trace_stream(lines, filter, "Parsing", limit=limit)
op_inputs = _parse_trace_stream(
lines, filter, "Parsing", limit=limit, model_mapping=model_mapping
)

# Handle single files
else:
Expand Down
82 changes: 69 additions & 13 deletions BackendBench/scripts/parquet_trace_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# utility functions to convert parquet and trace files back and forth

import hashlib
import json
import logging
import os
from collections import defaultdict
Expand All @@ -16,6 +17,7 @@
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import requests
from BackendBench.data_loaders import _load_from_trace
from BackendBench.scripts.dataset_filters import (
apply_runtime_filter,
Expand All @@ -25,7 +27,7 @@

DEFAULT_TRACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/augmented_hf_op_traces.txt"
DEFAULT_PARQUET_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/backend_bench_problems.parquet"

DEFAULT_MODEL_MAPPING_URL = "https://huggingface.co/datasets/GPUMODE/backendbench_tests/resolve/main/operator_input_models_mapping.json"

"""
Columns for the parquet dataset:
Expand All @@ -44,6 +46,15 @@
logger = logging.getLogger(__name__)


def load_model_mapping() -> dict:
"""Load model mapping json file."""

response = requests.get(DEFAULT_MODEL_MAPPING_URL)
response.raise_for_status()
content = response.text
return json.loads(content)


def _upload_to_hf(file_path: str) -> None:
"""Upload file to GPUMODE/huggingface_op_trace."""
try:
Expand All @@ -53,6 +64,7 @@ def _upload_to_hf(file_path: str) -> None:
path_in_repo=Path(file_path).name,
repo_id="GPUMODE/huggingface_op_trace",
repo_type="dataset",
create_pr=1,
)
logger.info(f"Uploaded {Path(file_path).name} to Hugging Face")
except Exception as e:
Expand All @@ -76,24 +88,42 @@ def setup_logging(log_level):
)


def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None):
def convert_trace_to_parquet(
trace_file, parquet_file, json_name: str = None, limit: int = None
):
"""
Convert a trace file to a parquet file
"""

# Load operations using local trace parsing function
ops = _load_from_trace(trace_file, filter=None, limit=limit)

model_mapping = load_model_mapping()
ops = _load_from_trace(
trace_file, filter=None, limit=limit, model_mapping=model_mapping
)
# Add additional metadata fields required for the parquet format
for op in ops:
op["uuid"] = hashlib.sha256(op["args"].encode() + op["op_name"].encode()).hexdigest()
# check if in model mapping
op["uuid"] = hashlib.sha256(
op["args"].encode() + op["op_name"].encode()
).hexdigest()
op["included_in_benchmark"] = True
op["why_excluded"] = []
op["runtime_ms"] = np.nan
op["relative_runtime_to_kernel_launch"] = np.nan
op["runnable"] = True
op["is_overhead_dominated_op"] = False

# count how many ops are not in any model and not synthetic
nonsynthetic_ops = [op for op in ops if not op["is_synthetic"]]
nonsynthetic_ops_not_in_models = [
op for op in nonsynthetic_ops if len(op["in_models"]) == 0
]
logger.info(
f"Found {len(nonsynthetic_ops_not_in_models)} / {len(nonsynthetic_ops)} nonsynthetic ops that are not in any model"
)
logger.info(
f"The following {len(nonsynthetic_ops_not_in_models)} nonsynthetic ops are not in any model: {nonsynthetic_ops_not_in_models}"
)
# apply filters
ops = apply_skip_ops_filter(ops)
ops = apply_runtime_filter(ops)
Expand Down Expand Up @@ -139,6 +169,9 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None):

# Write parquet file
pq.write_table(table, parquet_file)
# write to json file
with open(json_name, "w") as f:
json.dump(ops, f)

logger.info(f"Wrote {len(ops)} ops and inputs to {parquet_file}")

Expand Down Expand Up @@ -199,7 +232,9 @@ def _validate_trace_file(trace_file: str, is_input: bool = True) -> str:

# For local files, check extension
if not (trace_file.endswith(".txt") or Path(trace_file).is_dir()):
raise click.BadParameter("Local trace file must end with .txt or be a directory")
raise click.BadParameter(
"Local trace file must end with .txt or be a directory"
)

if Path(trace_file).is_dir() and not is_input:
raise click.BadParameter("Output trace file cannot be a directory")
Expand All @@ -211,7 +246,9 @@ def _validate_trace_file(trace_file: str, is_input: bool = True) -> str:
@click.option(
"--log-level",
default=os.getenv("LOG_LEVEL", "INFO"),
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False),
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
help="Set the logging level",
)
@click.option(
Expand Down Expand Up @@ -244,7 +281,13 @@ def _validate_trace_file(trace_file: str, is_input: bool = True) -> str:
type=int,
help="Limit the number of operators to convert. (Useful for testing)",
)
def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit):
@click.option(
"--json-name",
default="backend_bench_problems.json",
type=str,
help="JSON filename: URL allowed as input in trace-to-parquet mode, local files in datasets/.",
)
def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit, json_name):
"""Convert trace files to parquet format or vice versa."""
setup_logging(log_level)

Expand All @@ -253,27 +296,40 @@ def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit):

if mode == "trace-to-parquet":
# Validate inputs/outputs
trace_file = _validate_trace_file(trace_file, is_input=True) # Input: URLs allowed
trace_file = _validate_trace_file(
trace_file, is_input=True
) # Input: URLs allowed
parquet_name = _validate_parquet_name(parquet_name) # Output: URLs not allowed

logger.info(f"Converting trace file {trace_file} to parquet file {parquet_name}")
logger.info(
f"Converting trace file {trace_file} to parquet file {parquet_name}"
)

convert_trace_to_parquet(trace_file, parquet_name, limit=limit)
convert_trace_to_parquet(
trace_file, parquet_name, json_name=json_name, limit=limit
)
logger.info("Conversion completed successfully")

if upload_to_hf:
# Upload to Hugging Face
_upload_to_hf(os.path.abspath(parquet_name))
_upload_to_hf(os.path.abspath(json_name))

elif mode == "parquet-to-trace":
# Validate parquet input (URLs allowed for input in this mode)
parquet_input = _validate_parquet_name(parquet_name)
# Validate trace output (URLs not allowed for output)
trace_output = _validate_trace_file(trace_file, is_input=False) # Output: URLs not allowed
trace_output = _validate_trace_file(
trace_file, is_input=False
) # Output: URLs not allowed

logger.info(f"Converting parquet file {parquet_input} to trace file {trace_output}")
logger.info(
f"Converting parquet file {parquet_input} to trace file {trace_output}"
)
convert_parquet_to_trace(parquet_input, trace_output, limit=limit)
logger.info("Conversion completed successfully")
# _upload_to_hf(os.path.abspath(parquet_name))
# _upload_to_hf(os.path.abspath(json_name))


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions backend_bench_problems.json

Large diffs are not rendered by default.

Loading