diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index f349a912..6bcea200 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -98,6 +98,7 @@ class InputRequest: prompt_len: int = 0 output: str = "" output_len: int = 0 + sample_idx: int = -1 @dataclass @@ -119,6 +120,7 @@ def to_dict(self): "success": self.success, "latency": self.latency, "prompt_len": self.prompt_len, + "sample_idx": self.input_request.sample_idx, } @@ -180,8 +182,14 @@ def tokenize_dataset( n = len(dataset) - prompts = [prompt for prompt, _ in dataset] - outputs = [output for _, output in dataset] + prompts = [] + outputs = [] + indices = [] + + for prompt, output, idx in dataset: + prompts.append(prompt) + outputs.append(output) + indices.append(idx) prompt_token_ids = tokenizer.tokenize( prompts @@ -194,9 +202,15 @@ def tokenize_dataset( for i in range(n): prompt_len = len(prompt_token_ids[i]) output_len = len(outputs_token_ids[i]) - tokenized_dataset.append( - (prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len) + tokenized_data = ( + prompts[i], + prompt_token_ids[i], + outputs[i], + prompt_len, + output_len, + indices[i], ) + tokenized_dataset.append(tokenized_data) return tokenized_dataset @@ -219,6 +233,7 @@ def filter_dataset( output, prompt_len, output_len, + sample_idx, ) in tokenized_dataset: if prompt_len < 4 or output_len < 4: # Prune too short sequences. @@ -229,7 +244,7 @@ def filter_dataset( # Prune too long sequences. continue request = InputRequest( - prompt, prompt_len, output, max_output_length or output_len + prompt, prompt_len, output, max_output_length or output_len, sample_idx ) filtered_dataset.append(request) @@ -249,10 +264,11 @@ def sample_requests( # Original dataset size n = len(dataset) + dataset_indices = range(n) # Create necessary number of requests even if bigger than dataset size sampled_indices = random.sample( - range(n), min(int(num_requests * oversample_multiplier), n) + dataset_indices, min(int(num_requests * oversample_multiplier), n) ) if num_requests > len(sampled_indices): @@ -267,9 +283,13 @@ def sample_requests( print(f"{len(sampled_indices)=}") # some of these will be filtered out, so sample more than we need - dataset = [dataset[i] for i in sampled_indices] - tokenized_dataset = tokenize_dataset(dataset, tokenizer) + sampled_dataset = [] + for i in sampled_indices: + sampled_data = dataset[i] + (dataset_indices[i],) + sampled_dataset.append(sampled_data) + + tokenized_dataset = tokenize_dataset(sampled_dataset, tokenizer) input_requests = filter_dataset(tokenized_dataset, max_output_length)