@@ -320,6 +320,8 @@ def __init__(
320320 ** kwargs ,
321321 ) -> None :
322322 super ().__init__ (** kwargs )
323+ random .seed (self .random_seed )
324+ np .random .seed (self .random_seed )
323325
324326 def sample (
325327 self ,
@@ -376,10 +378,11 @@ def sample(
376378 # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
377379 # To avoid uncontrolled change of the prompt length,
378380 # the encoded sequence is truncated before being decode again.
381+ total_input_len = prefix_len + int (input_lens [i ])
379382 re_encoded_sequence = tokenizer .encode (
380- prompt , add_special_tokens = False )[:input_lens [ i ] ]
383+ prompt , add_special_tokens = False )[:total_input_len ]
381384 prompt = tokenizer .decode (re_encoded_sequence )
382- total_input_len = prefix_len + int ( input_lens [ i ] )
385+ total_input_len = len ( re_encoded_sequence )
383386 requests .append (
384387 SampleRequest (
385388 prompt = prompt ,
@@ -692,7 +695,8 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
692695 dataset_path = args .dataset_path ).
693696 sample (tokenizer = tokenizer , num_requests = args .num_prompts ),
694697 "random" :
695- lambda : RandomDataset (dataset_path = args .dataset_path ).sample (
698+ lambda : RandomDataset (random_seed = args .seed ,
699+ dataset_path = args .dataset_path ).sample (
696700 tokenizer = tokenizer ,
697701 num_requests = args .num_prompts ,
698702 prefix_len = args .random_prefix_len ,
0 commit comments