Skip to content

Commit

Permalink
add sample_idx for debugging (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu authored and jwyang-google committed May 6, 2024
1 parent 1ca2199 commit ee387ad
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class InputRequest:
prompt_len: int = 0
output: str = ""
output_len: int = 0
sample_idx: int = -1


@dataclass
Expand All @@ -121,6 +122,7 @@ def to_dict(self):
"success": self.success,
"latency": self.latency,
"prompt_len": self.prompt_len,
"sample_idx": self.input_request.sample_idx,
}


Expand Down Expand Up @@ -196,8 +198,14 @@ def tokenize_dataset(

n = len(dataset)

prompts = [prompt for prompt, _ in dataset]
outputs = [output for _, output in dataset]
prompts = []
outputs = []
indices = []

for prompt, output, idx in dataset:
prompts.append(prompt)
outputs.append(output)
indices.append(idx)

prompt_token_ids = tokenizer.tokenize(
prompts
Expand All @@ -210,9 +218,15 @@ def tokenize_dataset(
for i in range(n):
prompt_len = len(prompt_token_ids[i])
output_len = len(outputs_token_ids[i])
tokenized_dataset.append(
(prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len)
tokenized_data = (
prompts[i],
prompt_token_ids[i],
outputs[i],
prompt_len,
output_len,
indices[i],
)
tokenized_dataset.append(tokenized_data)
return tokenized_dataset


Expand All @@ -235,6 +249,7 @@ def filter_dataset(
output,
prompt_len,
output_len,
sample_idx,
) in tokenized_dataset:
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
Expand All @@ -245,7 +260,7 @@ def filter_dataset(
# Prune too long sequences.
continue
request = InputRequest(
prompt, prompt_len, output, max_output_length or output_len
prompt, prompt_len, output, max_output_length or output_len, sample_idx
)
filtered_dataset.append(request)

Expand All @@ -265,10 +280,11 @@ def sample_requests(

# Original dataset size
n = len(dataset)
dataset_indices = range(n)

# Create necessary number of requests even if bigger than dataset size
sampled_indices = random.sample(
range(n), min(int(num_requests * oversample_multiplier), n)
dataset_indices, min(int(num_requests * oversample_multiplier), n)
)

if num_requests > len(sampled_indices):
Expand All @@ -283,9 +299,13 @@ def sample_requests(

print(f"{len(sampled_indices)=}")
# some of these will be filtered out, so sample more than we need
dataset = [dataset[i] for i in sampled_indices]

tokenized_dataset = tokenize_dataset(dataset, tokenizer)
sampled_dataset = []
for i in sampled_indices:
sampled_data = dataset[i] + (dataset_indices[i],)
sampled_dataset.append(sampled_data)

tokenized_dataset = tokenize_dataset(sampled_dataset, tokenizer)

input_requests = filter_dataset(tokenized_dataset, max_output_length)

Expand Down

0 comments on commit ee387ad

Please sign in to comment.