diff --git a/benchmarks/flash_attn/mha.py b/benchmarks/flash_attn/mha.py index 7b7a343..5e34636 100644 --- a/benchmarks/flash_attn/mha.py +++ b/benchmarks/flash_attn/mha.py @@ -3,6 +3,7 @@ import torch from torch.nn import functional as F from torch.nn.attention import sdpa_kernel, SDPBackend +from top.utils.utils import load_input_from_path class mha_fwd_benchmark(Benchmark): @@ -27,13 +28,28 @@ def total_flops(self): def total_memory(self): return 4 * self.batch * self.heads * self.seq_len * self.dim * self.dtype.itemsize - def gen_inputs(self): - Q = torch.randn( - self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) - K = torch.randn( - self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) - V = torch.randn( - self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + def gen_inputs(self, input_path=None): + if input_path is None: + # gen random inputs + print("Gen random inputs!") + Q = torch.randn( + self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + K = torch.randn( + self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + V = torch.randn( + self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + else: + # Load input data from file paths + print("Gen inputs from file!") + paths = input_path.split(';') + if len(paths) != 3: + raise ValueError(f"Expected 3 input paths for Q, K, V, but got {len(paths)}") + + # Load Q, K, V + expected_shape = (self.batch, self.seq_len, self.heads, self.dim) + Q = load_input_from_path(paths[0], expected_shape, self.dtype) + K = load_input_from_path(paths[1], expected_shape, self.dtype) + V = load_input_from_path(paths[2], expected_shape, self.dtype) return Q, K, V def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor): diff --git a/tests/ops/test_mha.py b/tests/ops/test_mha.py index 726fe43..4f81f75 100644 --- a/tests/ops/test_mha.py +++ b/tests/ops/test_mha.py @@ -4,20 +4,20 @@ from benchmarks import mha_fwd_benchmark, mha_bwd_benchmark -def test_mha_fwd(B, S, H, D, causal, dtype, tune=False): +def test_mha_fwd(B, S, H, D, causal, dtype, tune=False, input_path=None): op = mha_fwd(B, H, S, D, causal, dtype, tune=tune) benchmark = mha_fwd_benchmark(B, H, S, D, causal, dtype) - inputs = benchmark.gen_inputs() + inputs = benchmark.gen_inputs(input_path) benchmark.check(op, *inputs) benchmark.profile(op, *inputs) -def test_mha_bwd(B, S, H, D, causal, dtype, tune=False): +def test_mha_bwd(B, S, H, D, causal, dtype, tune=False, input_path=None): op = mha_bwd(B, H, S, D, causal, dtype, tune=tune) benchmark = mha_bwd_benchmark(B, H, S, D, causal, dtype) - inputs = benchmark.gen_inputs() + inputs = benchmark.gen_inputs(input_path) benchmark.check(op, *inputs) benchmark.profile(op, *inputs) @@ -34,10 +34,16 @@ def test_mha_bwd(B, S, H, D, causal, dtype, tune=False): parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') parser.add_argument( '--disable_bwd', action='store_false', default=True, help='when test fwd profile') + parser.add_argument( + '--input_path', + type=str, + default=None, + help='Path to real input data. Use ";" to separate multiple paths. If None, random inputs will be generated.' + ) args = parser.parse_args() test_mha_fwd(args.batch, args.seq_len, args.heads, args.dim, args.causal, str2dtype[args.dtype], - args.tune) + args.tune, args.input_path) if args.disable_bwd: test_mha_bwd(args.batch, args.seq_len, args.heads, args.dim, args.causal, - str2dtype[args.dtype], args.tune) + str2dtype[args.dtype], args.tune, args.input_path) diff --git a/top/utils/utils.py b/top/utils/utils.py index a872e6a..e2918f7 100644 --- a/top/utils/utils.py +++ b/top/utils/utils.py @@ -1,3 +1,4 @@ +import os import torch # A mapping from string dtype names to torch dtypes @@ -62,3 +63,29 @@ def is_hopper(): def get_sm_version(): major, minor = torch.cuda.get_device_capability() return major * 10 + minor + + +def load_input_from_path(path, expected_shape, dtype, device='cuda'): + """ + 从文件路径加载输入数据的公共函数 + + Args: + path: 文件路径 + expected_shape: 期望的张量形状 + dtype: 数据类型 + device: 设备类型 + + Returns: + 加载的张量 + """ + if not os.path.exists(path): + raise FileNotFoundError(f"Input file not found: {path}") + + tensor = torch.load(path) + + if tensor.shape != expected_shape: + raise ValueError( + f"Shape mismatch: expected {expected_shape}, got {tensor.shape} from {path}") + + tensor = tensor.to(dtype=dtype, device=device) + return tensor