@@ -46,18 +46,18 @@ def __init__(
4646 self ,
4747 parent ,
4848 batch_size = 2 ,
49- image_width = 80 ,
50- image_height = 60 ,
51- stage_num_blocks : list [int ] = [1 , 1 , 1 ],
52- out_features : list [int ] = [32 , 32 , 128 ],
53- stage_stride : list [int ] = [2 , 1 , 2 ],
49+ image_width = 6 , # need to be a multiple of `stage_stride[0] * stage_stride[1]`
50+ image_height = 4 , # need to be a multiple of `stage_stride[0] * stage_stride[1]`
51+ stage_num_blocks : list [int ] = [1 , 1 ],
52+ out_features : list [int ] = [16 , 16 ], # need to be >= 2 to make `config.fine_fusion_dims > 0`
53+ stage_stride : list [int ] = [2 , 1 ],
5454 q_aggregation_kernel_size : int = 1 ,
5555 kv_aggregation_kernel_size : int = 1 ,
5656 q_aggregation_stride : int = 1 ,
5757 kv_aggregation_stride : int = 1 ,
5858 num_attention_layers : int = 2 ,
5959 num_attention_heads : int = 8 ,
60- hidden_size : int = 128 ,
60+ hidden_size : int = 16 ,
6161 coarse_matching_threshold : float = 0.0 ,
6262 fine_kernel_size : int = 2 ,
6363 coarse_matching_border_removal : int = 0 ,
0 commit comments