diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 89309c79f099..54e3a0f78743 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -500,6 +500,11 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]: assert len(o_shape) == 3 assert o_shape == (num_slices, num_tokens, lora_rank) + no_lora = torch.all(self.lora_kernel_meta.token_lora_mapping == -1) + no_lora_flag_cpu = torch.tensor([no_lora], + dtype=torch.bool, + device='cpu') + return { "inputs": self.input, "lora_a_weights": self.lora_weights_lst, @@ -511,6 +516,7 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]: "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora, "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, "lora_ids": self.lora_kernel_meta.active_lora_ids, + "no_lora_flag_cpu": no_lora_flag_cpu, "scaling": 1.0, } @@ -539,6 +545,11 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: assert len(o_shape) == 2 assert o_shape == (num_tokens, hidden_size * num_slices) + no_lora = torch.all(self.lora_kernel_meta.token_lora_mapping == -1) + no_lora_flag_cpu = torch.tensor([no_lora], + dtype=torch.bool, + device='cpu') + return { "inputs": self.input, "lora_b_weights": self.lora_weights_lst, @@ -552,6 +563,7 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: "lora_ids": self.lora_kernel_meta.active_lora_ids, "offset_start": 0, "add_inputs": add_inputs, + "no_lora_flag_cpu": no_lora_flag_cpu, } def bench_fn_kwargs(