@@ -46,18 +46,18 @@ def __init__(
4646        self ,
4747        parent ,
4848        batch_size = 2 ,
49-         image_width = 80 , 
50-         image_height = 60 , 
51-         stage_num_blocks : list [int ] =  [1 , 1 ,  1 ],
52-         out_features : list [int ] =  [32 ,  32 ,  128 ], 
53-         stage_stride : list [int ] =  [2 , 1 ,  2 ],
49+         image_width = 6 ,   # need to be a multiple of `stage_stride[0] * stage_stride[1]` 
50+         image_height = 4 ,   # need to be a multiple of `stage_stride[0] * stage_stride[1]` 
51+         stage_num_blocks : list [int ] =  [1 , 1 ],
52+         out_features : list [int ] =  [16 ,  16 ],   # need to be >= 2 to make `config.fine_fusion_dims > 0` 
53+         stage_stride : list [int ] =  [2 , 1 ],
5454        q_aggregation_kernel_size : int  =  1 ,
5555        kv_aggregation_kernel_size : int  =  1 ,
5656        q_aggregation_stride : int  =  1 ,
5757        kv_aggregation_stride : int  =  1 ,
5858        num_attention_layers : int  =  2 ,
5959        num_attention_heads : int  =  8 ,
60-         hidden_size : int  =  128 ,
60+         hidden_size : int  =  16 ,
6161        coarse_matching_threshold : float  =  0.0 ,
6262        fine_kernel_size : int  =  2 ,
6363        coarse_matching_border_removal : int  =  0 ,
0 commit comments