@@ -572,6 +572,7 @@ def get_sampling_params(
572572 # Ensure the lower bound for output length is at least 1 to
573573 # prevent sampling 0 tokens.
574574 output_low = max (output_low , 1 )
575+ output_high = max (output_high , 1 )
575576
576577 if input_low > input_high :
577578 raise ValueError (
@@ -638,6 +639,112 @@ def generate_token_sequence(
638639 return prompt , total_input_len , token_mismatch
639640
640641
642+ # -----------------------------------------------------------------------------
643+ # Random Dataset Implementation (Synthetic Data)
644+ # -----------------------------------------------------------------------------
645+
646+
647+ class RandomDatasetForReranking (RandomDataset ):
648+ """
649+ Random dataset specialized for the needs of scoring:
650+ - Batches of inputs
651+ - Inputs composed of pairs
652+ """
653+
654+ def __init__ (self , ** kwargs ) -> None :
655+ super ().__init__ (** kwargs )
656+
657+ def sample (
658+ self ,
659+ tokenizer : PreTrainedTokenizerBase ,
660+ num_requests : int ,
661+ request_id_prefix : str = "" ,
662+ range_ratio : float = RandomDataset .DEFAULT_RANGE_RATIO ,
663+ input_len : int = RandomDataset .DEFAULT_INPUT_LEN ,
664+ batchsize : int = 1 ,
665+ is_reranker : bool = True ,
666+ ** kwargs ,
667+ ) -> list [SampleRequest ]:
668+ n_sep_tokens = int (is_reranker )
669+
670+ query_len_param = (input_len // 2 ) - n_sep_tokens if is_reranker else input_len
671+
672+ query_lens , _ , query_offsets = self .get_sampling_params (
673+ 1 , range_ratio , query_len_param , 0 , tokenizer
674+ )
675+
676+ query_len = int (query_lens [0 ])
677+
678+ if not is_reranker :
679+ assert num_requests > 1 and batchsize > 1
680+ num_requests -= 1
681+ batchsize -= 1
682+ doc_len_param = input_len
683+ else :
684+ doc_len_param = input_len - query_len - n_sep_tokens
685+
686+ doc_lens , _ , doc_offsets = self .get_sampling_params (
687+ num_requests , range_ratio , doc_len_param , 0 , tokenizer
688+ )
689+ vocab_size = tokenizer .vocab_size
690+
691+ query_prompt , query_input_len , token_mismatch_total = (
692+ self .generate_token_sequence (
693+ tokenizer = tokenizer ,
694+ prefix_token_ids = [],
695+ prefix_len = 0 ,
696+ vocab_size = vocab_size ,
697+ input_len = query_len ,
698+ offset = int (query_offsets [0 ]),
699+ index = 0 ,
700+ )
701+ )
702+
703+ requests = []
704+ for i in range (num_requests ):
705+ prompt , total_input_len , token_mismatch = self .generate_token_sequence ( # noqa: E501
706+ tokenizer = tokenizer ,
707+ prefix_token_ids = [],
708+ prefix_len = 0 ,
709+ vocab_size = vocab_size ,
710+ input_len = int (doc_lens [i ]),
711+ offset = int (doc_offsets [i ]),
712+ index = i + 1 ,
713+ )
714+ token_mismatch_total += token_mismatch
715+ requests .append ((prompt , total_input_len ))
716+
717+ batch_requests = []
718+ # Create batched requests
719+ for i in range (0 , num_requests , batchsize ):
720+ batch = requests [i : i + batchsize ]
721+ query_contrib = (
722+ (query_input_len + n_sep_tokens ) * len (batch )
723+ if is_reranker
724+ else query_input_len
725+ )
726+ batch_requests .append (
727+ SampleRequest (
728+ prompt = [query_prompt ] + [req [0 ] for req in batch ],
729+ prompt_len = query_contrib + sum (req [1 ] for req in batch ),
730+ expected_output_len = 0 ,
731+ request_id = request_id_prefix + str (i // batchsize ),
732+ )
733+ )
734+
735+ if token_mismatch_total != 0 :
736+ logger .warning (
737+ "Across all generated prompts, there were %d %s tokens "
738+ "than expected after decoding and re-encoding. This is "
739+ "expected due to the imperfect nature of the sampling "
740+ "procedure." ,
741+ abs (token_mismatch_total ),
742+ "more" if token_mismatch_total > 0 else "fewer" ,
743+ )
744+
745+ return batch_requests
746+
747+
641748# -----------------------------------------------------------------------------
642749# MultiModalDataset Implementation
643750# -----------------------------------------------------------------------------
@@ -1149,6 +1256,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11491256 "sonnet" ,
11501257 "random" ,
11511258 "random-mm" ,
1259+ "random-rerank" ,
11521260 "hf" ,
11531261 "custom" ,
11541262 "prefix_repetition" ,
@@ -1292,6 +1400,14 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
12921400 default = 1 ,
12931401 help = ("Batch size for random sampling. Only used for embeddings benchmark." ),
12941402 )
1403+ random_group .add_argument (
1404+ "--no-reranker" ,
1405+ action = "store_true" ,
1406+ help = (
1407+ "Whether the model supports reranking natively."
1408+ " Only used for reranker benchmark."
1409+ ),
1410+ )
12951411
12961412 # random multimodal dataset options
12971413 random_mm_group = parser .add_argument_group (
@@ -1678,6 +1794,19 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16781794 request_id_prefix = args .request_id_prefix ,
16791795 no_oversample = args .no_oversample ,
16801796 ),
1797+ "random-rerank" : lambda : RandomDatasetForReranking (
1798+ random_seed = args .seed ,
1799+ dataset_path = args .dataset_path ,
1800+ disable_shuffle = args .disable_shuffle ,
1801+ ).sample (
1802+ tokenizer = tokenizer ,
1803+ num_requests = args .num_prompts ,
1804+ input_len = args .random_input_len ,
1805+ range_ratio = args .random_range_ratio ,
1806+ request_id_prefix = args .request_id_prefix ,
1807+ batchsize = args .random_batch_size ,
1808+ is_reranker = not args .no_reranker ,
1809+ ),
16811810 "prefix_repetition" : lambda : PrefixRepetitionRandomDataset (
16821811 random_seed = args .seed ,
16831812 dataset_path = args .dataset_path ,
0 commit comments