@@ -223,10 +223,10 @@ def main(
223223
224224if __name__ == "__main__" :
225225 parser = argparse .ArgumentParser ()
226- parser .add_argument ('--batch' , type = int , default = 16 , help = 'batch size' )
227- parser .add_argument ('--heads' , type = int , default = 64 , help = 'heads' )
228- parser .add_argument ('--seq_q' , type = int , default = 298 , help = 'query sequence length' )
229- parser .add_argument ('--seq_kv' , type = int , default = 298 , help = 'key/value sequence length' )
226+ parser .add_argument ('--batch' , type = int , default = 1 , help = 'batch size' )
227+ parser .add_argument ('--heads' , type = int , default = 1 , help = 'heads' )
228+ parser .add_argument ('--seq_q' , type = int , default = 256 , help = 'query sequence length' )
229+ parser .add_argument ('--seq_kv' , type = int , default = 256 , help = 'key/value sequence length' )
230230 parser .add_argument ('--dim' , type = int , default = 64 , help = 'dim' )
231231 parser .add_argument ('--is_causal' , action = 'store_true' , help = 'causal' , default = False )
232232 parser .add_argument ('--tune' , action = 'store_true' , help = 'tune configs' )
0 commit comments