Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions benchmarks/flash_attn/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn import functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
from top.utils.utils import load_input_from_path


class mha_fwd_benchmark(Benchmark):
Expand All @@ -27,13 +28,28 @@ def total_flops(self):
def total_memory(self):
return 4 * self.batch * self.heads * self.seq_len * self.dim * self.dtype.itemsize

def gen_inputs(self):
Q = torch.randn(
self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype)
K = torch.randn(
self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype)
V = torch.randn(
self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype)
def gen_inputs(self, input_path=None):
if input_path is None:
# gen random inputs
print("Gen random inputs!")
Q = torch.randn(
self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype)
K = torch.randn(
self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype)
V = torch.randn(
self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype)
else:
# Load input data from file paths
print("Gen inputs from file!")
Comment on lines +32 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better logging control, consider using the logging module instead of print. This allows for configuring verbosity levels and directing output to different handlers, which is more flexible for a benchmarking suite.

paths = input_path.split(';')
if len(paths) != 3:
raise ValueError(f"Expected 3 input paths for Q, K, V, but got {len(paths)}")

# Load Q, K, V
expected_shape = (self.batch, self.seq_len, self.heads, self.dim)
Q = load_input_from_path(paths[0], expected_shape, self.dtype)
K = load_input_from_path(paths[1], expected_shape, self.dtype)
V = load_input_from_path(paths[2], expected_shape, self.dtype)
return Q, K, V

def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
Expand Down
18 changes: 12 additions & 6 deletions tests/ops/test_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
from benchmarks import mha_fwd_benchmark, mha_bwd_benchmark


def test_mha_fwd(B, S, H, D, causal, dtype, tune=False):
def test_mha_fwd(B, S, H, D, causal, dtype, tune=False, input_path=None):
op = mha_fwd(B, H, S, D, causal, dtype, tune=tune)
benchmark = mha_fwd_benchmark(B, H, S, D, causal, dtype)

inputs = benchmark.gen_inputs()
inputs = benchmark.gen_inputs(input_path)
benchmark.check(op, *inputs)
benchmark.profile(op, *inputs)


def test_mha_bwd(B, S, H, D, causal, dtype, tune=False):
def test_mha_bwd(B, S, H, D, causal, dtype, tune=False, input_path=None):
op = mha_bwd(B, H, S, D, causal, dtype, tune=tune)
benchmark = mha_bwd_benchmark(B, H, S, D, causal, dtype)

inputs = benchmark.gen_inputs()
inputs = benchmark.gen_inputs(input_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The gen_inputs method of mha_bwd_benchmark (defined in benchmarks/flash_attn/mha.py) does not accept an input_path argument. This call will raise a TypeError. You need to update mha_bwd_benchmark.gen_inputs to accept input_path and load data from it, similar to mha_fwd_benchmark.gen_inputs. Also, remember that for the backward pass, the loaded tensors for Q, K, and V will need to have requires_grad set to True.

benchmark.check(op, *inputs)
benchmark.profile(op, *inputs)

Expand All @@ -34,10 +34,16 @@ def test_mha_bwd(B, S, H, D, causal, dtype, tune=False):
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
parser.add_argument(
'--disable_bwd', action='store_false', default=True, help='when test fwd profile')
parser.add_argument(
'--input_path',
type=str,
default=None,
help='Path to real input data. Use ";" to separate multiple paths. If None, random inputs will be generated.'
)
args = parser.parse_args()

test_mha_fwd(args.batch, args.seq_len, args.heads, args.dim, args.causal, str2dtype[args.dtype],
args.tune)
args.tune, args.input_path)
if args.disable_bwd:
test_mha_bwd(args.batch, args.seq_len, args.heads, args.dim, args.causal,
str2dtype[args.dtype], args.tune)
str2dtype[args.dtype], args.tune, args.input_path)
27 changes: 27 additions & 0 deletions top/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch

# A mapping from string dtype names to torch dtypes
Expand Down Expand Up @@ -62,3 +63,29 @@ def is_hopper():
def get_sm_version():
major, minor = torch.cuda.get_device_capability()
return major * 10 + minor


def load_input_from_path(path, expected_shape, dtype, device='cuda'):
"""
从文件路径加载输入数据的公共函数

Args:
path: 文件路径
expected_shape: 期望的张量形状
dtype: 数据类型
device: 设备类型

Returns:
加载的张量
"""
Comment on lines +69 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring is written in Chinese, which is inconsistent with the rest of the codebase being in English. To maintain consistency and improve readability for all contributors, please translate it to English.

Suggested change
"""
从文件路径加载输入数据的公共函数
Args:
path: 文件路径
expected_shape: 期望的张量形状
dtype: 数据类型
device: 设备类型
Returns:
加载的张量
"""
"""
Loads input data from a file path.
Args:
path: The file path.
expected_shape: The expected shape of the tensor.
dtype: The data type.
device: The device.
Returns:
The loaded tensor.
"""

if not os.path.exists(path):
raise FileNotFoundError(f"Input file not found: {path}")

tensor = torch.load(path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using torch.load is a security risk because it uses pickle internally, which can lead to arbitrary code execution if a malicious file is loaded. Since the file path is provided via a command-line argument, it's possible for a user to provide a malicious file. Consider using a safer format for saving and loading tensors, such as safetensors, especially if the input files could come from an untrusted source.


if tensor.shape != expected_shape:
raise ValueError(
f"Shape mismatch: expected {expected_shape}, got {tensor.shape} from {path}")

tensor = tensor.to(dtype=dtype, device=device)
return tensor
Loading