Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions BackendBench/eval_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""Model-level evaluation utilities for testing full model correctness."""

import logging
import random
import traceback
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple

import torch

import BackendBench
from BackendBench.eval import allclose
from BackendBench.utils import deserialize_args

logger = logging.getLogger(__name__)


@dataclass
class ModelCorrectnessTestResult:
"""Result from testing a model configuration."""

model_name: str
test_name: str
is_correct: bool = False
error_msg: str = ""
error_type: str = ""
traceback: str = ""
output_match: bool = False
gradients_match: bool = False
num_gradients: int = 0


def eval_model_correctness_test(
model_name: str,
model_class: type,
model_config: Dict[str, Any],
test_name: str,
test_args: str,
kernel_dir: str = None,
atol: float = 1e-2,
rtol: float = 1e-2,
) -> ModelCorrectnessTestResult:
"""Evaluate model correctness by comparing eager vs backend execution.

Similar to eval_correctness_test in eval.py, but for full models instead of individual ops.

Args:
model_name: Name of the model being tested
model_class: Model class to instantiate
model_config: Model configuration dict with init_args
test_name: Name of this test configuration
test_args: Serialized arguments string for forward pass
kernel_dir: Optional directory containing kernels for backend
atol: Absolute tolerance for allclose
rtol: Relative tolerance for allclose

Returns:
ModelCorrectnessTestResult with detailed comparison results
"""
try:
# Generate a single seed to use for both eager and backend runs
# This ensures both runs use the same model initialization
seed = random.randint(0, 2**32 - 1)

# Run in eager mode (reference)
eager_out, eager_grads = _run_model(
model_class,
model_config,
test_args,
backend_enabled=False,
kernel_dir=None,
seed=seed,
)

# Run with backend (implementation)
backend_out, backend_grads = _run_model(
model_class,
model_config,
test_args,
backend_enabled=True,
kernel_dir=kernel_dir,
seed=seed,
)

# Compare outputs
output_match = allclose(eager_out, backend_out, atol=atol, rtol=rtol)

# Compare gradients
gradients_match = True
if len(eager_grads) != len(backend_grads):
gradients_match = False
else:
for eager_grad, backend_grad in zip(eager_grads, backend_grads):
if not allclose(eager_grad, backend_grad, atol=atol, rtol=rtol):
gradients_match = False
break

is_correct = output_match and gradients_match

return ModelCorrectnessTestResult(
model_name=model_name,
test_name=test_name,
is_correct=is_correct,
output_match=output_match,
gradients_match=gradients_match,
num_gradients=len(eager_grads),
)

except Exception as e:
error_msg = f"Model {model_name}::{test_name} failed: {e}"
logger.error(error_msg)
return ModelCorrectnessTestResult(
model_name=model_name,
test_name=test_name,
is_correct=False,
error_msg=error_msg,
error_type=str(type(e)),
traceback=traceback.format_exc(),
)


def _move_model_to_input_device(
model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any]
) -> torch.nn.Module:
"""Move model to the same device as input tensor.

Args:
model: Model to move
args: Positional arguments list
kwargs: Keyword arguments dict

Returns:
Model on input device (or original model if no input tensor found)
"""

# this is specific to our configs atm, we should generalize this
input_tensor = kwargs["x"]
if input_tensor is not None:
device = input_tensor.device
model = model.to(device)
return model


def _collect_gradients(
model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any]
) -> List[torch.Tensor]:
"""Collect gradients from input and model parameters.

Args:
model: Model with computed gradients
args: Positional arguments list
kwargs: Keyword arguments dict

Returns:
List of gradient tensors [input_grad, param1_grad, ...]
"""
grads = []

# Input gradient - check both args and kwargs
input_grad = None
if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None:
input_grad = args[0].grad
elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None:
input_grad = kwargs["x"].grad

if input_grad is not None:
grads.append(input_grad.clone())

# Parameter gradients
for param in model.parameters():
if param.grad is not None:
grads.append(param.grad.clone())

return grads


def _run_model(
model_class: type,
model_config: Dict[str, Any],
test_args: str,
backend_enabled: bool,
kernel_dir: str = "generated_kernels",
seed: int = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Run model with or without backend enabled.

Args:
model_class: Model class to instantiate
model_config: Model configuration dict with init_args
test_args: Serialized arguments string for forward pass
backend_enabled: If True, use BackendBench context manager
kernel_dir: Optional directory containing kernels
seed: Random seed for reproducibility. If None, generates a random seed.

Returns:
Tuple of (output, gradients) where:
- output: Model output tensor (detached)
- gradients: List of gradient tensors [input_grad, param1_grad, ...]
"""

# Generate seed dynamically and set for deterministic behavior
# IMPORTANT: Must set seed BEFORE deserializing args, because deserialization
# may create random tensors!
if seed is None:
seed = random.randint(0, 2**32 - 1)
torch.manual_seed(seed)

# Deserialize test arguments (now uses the seed we just set)
args, kwargs = deserialize_args(test_args)

# Extract model initialization args
init_args = model_config.get("init_args", {}).copy()

# Create fresh model instance
model = model_class(**init_args)
model.train()

# Move model to same device as input
model = _move_model_to_input_device(model, args, kwargs)
ctx = (
BackendBench.BackendBench.enable(kernel_dir=kernel_dir)
if backend_enabled
else nullcontext()
)
# Run forward + backward with or without backend
with ctx:
output = model(*args, **kwargs)
loss = output.sum()
loss.backward()

# Collect gradients
grads = _collect_gradients(model, args, kwargs)

return output.detach(), grads
44 changes: 42 additions & 2 deletions BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from BackendBench.output import save_results
from BackendBench.suite import (
FactoTestSuite,
ModelSuite,
OpInfoTestSuite,
SmokeTestSuite,
TorchBenchTestSuite,
Expand All @@ -40,6 +41,21 @@ def setup_logging(log_level):
)


# Helper function as model suite gets fleshed out
def _test_full_models(suite, backend):
assert suite.name == "model"
all_results = []
for model in suite.models:
results = suite.eval_model(model, backend)
all_results.append(results)
logger.info("=" * 60)
logger.info("MODEL EVALUATION RESULTS")
logger.info("=" * 60)
for result in all_results:
suite.print_results(result)
logger.info("=" * 60)


@click.command()
@click.option(
"--log-level",
Expand All @@ -50,7 +66,7 @@ def setup_logging(log_level):
@click.option(
"--suite",
default="smoke",
type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]),
type=click.Choice(["smoke", "opinfo", "torchbench", "facto", "model"]),
help="Which suite to run",
)
@click.option(
Expand All @@ -63,7 +79,13 @@ def setup_logging(log_level):
"--ops",
default=None,
type=str,
help="Comma-separated list of ops to run",
help="Comma-separated list of ops to run (not supported for model suite)",
)
@click.option(
"--model-filter",
default=None,
type=str,
help="Comma-separated list of models to run (only for model suite)",
)
@click.option(
"--topn-inputs",
Expand Down Expand Up @@ -147,6 +169,7 @@ def cli(
suite,
backend,
ops,
model_filter,
topn_inputs,
llm_attempts,
llm_model,
Expand All @@ -166,9 +189,20 @@ def cli(
if check_overhead_dominated_ops:
raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite")

if suite == "model":
if ops is not None:
raise ValueError(
"--ops filter is not supported for model suite. Use --model-filter instead"
)

if suite != "model" and model_filter is not None:
raise ValueError("--model-filter is only supported for model suite")

setup_logging(log_level)
if ops:
ops = ops.split(",")
if model_filter:
model_filter = model_filter.split(",")

suite = {
"smoke": lambda: SmokeTestSuite,
Expand All @@ -191,6 +225,7 @@ def cli(
torch.bfloat16,
filter=ops,
),
"model": lambda: ModelSuite(filter=model_filter),
}[suite]()

backend_name = backend
Expand Down Expand Up @@ -224,6 +259,11 @@ def cli(
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = f"backendbench_output_{timestamp}"

if suite.name == "model":
_test_full_models(suite, backend)
# currently model suite does not support op testing so now we're done
return

overall_correctness = []
overall_performance = []
all_correctness_results = []
Expand Down
2 changes: 2 additions & 0 deletions BackendBench/suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .base import OpTest, Test, TestSuite
from .facto import FactoTestSuite
from .model import ModelSuite
from .opinfo import OpInfoTestSuite
from .smoke import randn, SmokeTestSuite
from .torchbench import TorchBenchOpTest, TorchBenchTestSuite
Expand All @@ -24,6 +25,7 @@
"OpTest",
"TestSuite",
"FactoTestSuite",
"ModelSuite",
"OpInfoTestSuite",
"SmokeTestSuite",
"randn",
Expand Down
Loading