Skip to content

Commit

Permalink
Adds filtering for sharegpt based on conversation starter. (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
patemotter authored Mar 25, 2024
1 parent 1c153b1 commit 0637309
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,17 @@ def sample_requests(
num_requests: int,
tokenizer: Any,
max_output_length: int,
conversation_starter: str,
) -> List[InputRequest]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]

# Filter based on conversation starter
if conversation_starter != "both":
dataset = [data for data in dataset if data["conversations"][0]["from"] == conversation_starter]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
Expand Down Expand Up @@ -169,8 +174,8 @@ def sample_requests(
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
reqeust = InputRequest(prompt, prompt_len, output, max_output_length)
filtered_dataset.append(reqeust)
request = InputRequest(prompt, prompt_len, output, max_output_length)
filtered_dataset.append(request)

# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
Expand Down Expand Up @@ -409,7 +414,13 @@ def main(args: argparse.Namespace):
if tokenizer == "test" or args.dataset == "test":
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
else:
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_output_length)
input_requests = sample_requests(
args.dataset,
args.num_prompts,
tokenizer,
args.max_output_length,
args.conversation_starter,
)

if args.warmup_first:
print('Warm up start:' )
Expand Down Expand Up @@ -597,6 +608,15 @@ def main(args: argparse.Namespace):
"Whether to send warmup req first"
),
)
parser.add_argument(
"--conversation-starter",
type=str,
default="human",
choices=["human", "gpt", "both"],
help=(
"What entity should be the one starting the conversations."
),
)

args = parser.parse_args()
main(args)

0 comments on commit 0637309

Please sign in to comment.