@@ -69,6 +69,42 @@ def __init__(
6969 self .layer_norm_eps = layer_norm_eps
7070
7171
72+ class Sam2PositionEmbeddingConfig (PretrainedConfig ):
73+ r"""
74+ This is the configuration class to store the configuration of a [`Sam2PositionEmbedding`]. The [`Sam2PositionEmbedding`]
75+ module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield
76+ a similar configuration to that of the SAM2-hiera-tiny
77+ [facebook/sam2-hiera-tiny](https://huggingface.co/facebook/sam2-hiera-tiny) architecture.
78+
79+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
80+ documentation from [`PretrainedConfig`] for more information.
81+
82+ Args:
83+ num_pos_feats (`int`):
84+ The number of feature size for positioinal features.
85+ temperature (`int`, *optional*, defaults to 10000):
86+ The temperature value to consider.
87+ normalize (`bool`, *optional*, defaults to True):
88+ Whether to normalize the embedding vector.
89+ scale (`float`, *optional*, defaults to None):
90+ The scale value for embedding vector.
91+ """
92+
93+ def __init__ (
94+ self ,
95+ num_pos_feats ,
96+ temperature = 10000 ,
97+ normalize = True ,
98+ scale = None ,
99+ ** kwargs ,
100+ ):
101+ super ().__init__ (** kwargs )
102+ self .num_pos_feats = num_pos_feats
103+ self .temperature = temperature
104+ self .normalize = normalize
105+ self .scale = scale
106+
107+
72108class Sam2MaskDecoderConfig (PretrainedConfig ):
73109 r"""
74110 This is the configuration class to store the configuration of a [`Sam2MaskDecoder`]. It is used to instantiate a SAM2
@@ -135,8 +171,8 @@ def __init__(
135171 dynamic_multimask_via_stability = False ,
136172 dynamic_multimask_stability_delta = 0.05 ,
137173 dynamic_multimask_stability_thresh = 0.98 ,
138- pred_obj_scores = False ,
139- pred_obj_scores_mlp = False ,
174+ pred_obj_scores = False ,
175+ pred_obj_scores_mlp = False ,
140176 use_multimask_token_for_obj_ptr = False ,
141177 layer_norm_eps = 1e-6 ,
142178 ** kwargs ,
@@ -151,14 +187,14 @@ def __init__(
151187 self .num_multimask_outputs = num_multimask_outputs
152188 self .iou_head_depth = iou_head_depth
153189 self .iou_head_hidden_dim = iou_head_hidden_dim
154- self .use_high_res_features = use_high_res_features ,
155- self .iou_prediction_use_sigmoid = iou_prediction_use_sigmoid ,
156- self .dynamic_multimask_via_stability = dynamic_multimask_via_stability ,
157- self .dynamic_multimask_stability_delta = dynamic_multimask_stability_delta ,
158- self .dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh ,
159- self .pred_obj_scores = pred_obj_scores ,
160- self .pred_obj_scores_mlp = pred_obj_scores_mlp ,
161- self .use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr ,
190+ self .use_high_res_features = ( use_high_res_features ,)
191+ self .iou_prediction_use_sigmoid = ( iou_prediction_use_sigmoid ,)
192+ self .dynamic_multimask_via_stability = ( dynamic_multimask_via_stability ,)
193+ self .dynamic_multimask_stability_delta = ( dynamic_multimask_stability_delta ,)
194+ self .dynamic_multimask_stability_thresh = ( dynamic_multimask_stability_thresh ,)
195+ self .pred_obj_scores = ( pred_obj_scores ,)
196+ self .pred_obj_scores_mlp = ( pred_obj_scores_mlp ,)
197+ self .use_multimask_token_for_obj_ptr = ( use_multimask_token_for_obj_ptr ,)
162198 self .layer_norm_eps = layer_norm_eps
163199
164200
@@ -175,6 +211,7 @@ class Sam2MemoryAttentionConfig(PretrainedConfig):
175211 Args:
176212
177213 """
214+
178215 def __init__ (
179216 self ,
180217 # TO DO
@@ -198,14 +235,37 @@ class Sam2MemoryEncoderConfig(PretrainedConfig):
198235 Args:
199236
200237 """
238+
201239 def __init__ (
202240 self ,
203- # TO DO
241+ out_dim = 64 ,
242+ positional_encoding_config = None ,
243+ mask_downsmapler_config = None ,
244+ fuser_config = None ,
245+ in_dim = 256 ,
204246 ** kwargs ,
205247 ):
206248 super ().__init__ (** kwargs )
207-
208- # TO DO
249+ if positional_encoding_config is None :
250+ positional_encoding_config = {"num_pos_feats" : 64 , "normalize" : True , "scale" : None , "temperature" : 1000 }
251+ if mask_downsmapler_config is None :
252+ mask_downsmapler_config = {"kernel_size" : 3 , "stride" : 2 , "padding" : 1 }
253+ if fuser_config is None :
254+ fuser_config = {
255+ "layer" : {
256+ "dim" : 256 ,
257+ "kernel_size" : 7 ,
258+ "padding" : 3 ,
259+ "layer_scale_init_value" : 1e-6 ,
260+ "use_dwconv" : True ,
261+ },
262+ "num_layers" : 2 ,
263+ }
264+
265+ self .out_dim = out_dim
266+ self .positional_encoding_config = positional_encoding_config
267+ self .mask_downsmapler_config = mask_downsmapler_config
268+ self .fuser_config = fuser_config
209269
210270
211271# TO DO
@@ -263,48 +323,56 @@ class Sam2VisionConfig(PretrainedConfig):
263323
264324 def __init__ (
265325 self ,
266- hidden_size = 768 ,
267- output_channels = 256 ,
268- num_hidden_layers = 12 ,
269- num_attention_heads = 12 ,
270- num_channels = 3 ,
271- image_size = 1024 ,
272- patch_size = 16 ,
273- hidden_act = "gelu" ,
274- layer_norm_eps = 1e-06 ,
275- attention_dropout = 0.0 ,
276- initializer_range = 1e-10 ,
277- qkv_bias = True ,
278- mlp_ratio = 4.0 ,
279- use_abs_pos = True ,
280- use_rel_pos = True ,
281- window_size = 14 ,
282- global_attn_indexes = [2 , 5 , 8 , 11 ],
283- num_pos_feats = 128 ,
284- mlp_dim = None ,
326+ scalp = 1 ,
327+ hidden_size = 96 ,
328+ num_heads = 1 ,
329+ drop_path_rate = 0 ,
330+ q_pool = 3 ,
331+ q_stride = [2 , 2 ],
332+ stages = [1 , 2 , 7 , 2 ],
333+ dim_mul = 2.0 ,
334+ head_mul = 2.0 ,
335+ window_pos_embed_bkg_spatial_size = [7 , 7 ],
336+ window_spec = [8 , 4 , 14 , 7 ],
337+ global_att_blocks = [5 , 7 , 9 ],
338+ return_interm_layers = False ,
339+ neck_position_encoding_config = None ,
340+ neck_hidden_size = 256 ,
341+ neck_backbone_channel_list = [768 , 384 , 192 , 96 ],
342+ neck_kernel_size = 1 ,
343+ neck_stride = 1 ,
344+ neck_padding = 0 ,
345+ neck_fpn_interp_model = "nearest" ,
346+ neck_fuse_type = "sum" ,
347+ neck_fpn_top_down_level = [2 , 3 ],
285348 ** kwargs ,
286349 ):
287350 super ().__init__ (** kwargs )
351+ if neck_position_encoding_config is None :
352+ neck_position_encoding_config = Sam2PositionEmbeddingConfig (num_pos_feats = 256 )
288353
354+ self .scalp = scalp
289355 self .hidden_size = hidden_size
290- self .output_channels = output_channels
291- self .num_hidden_layers = num_hidden_layers
292- self .num_attention_heads = num_attention_heads
293- self .num_channels = num_channels
294- self .image_size = image_size
295- self .patch_size = patch_size
296- self .hidden_act = hidden_act
297- self .layer_norm_eps = layer_norm_eps
298- self .attention_dropout = attention_dropout
299- self .initializer_range = initializer_range
300- self .qkv_bias = qkv_bias
301- self .mlp_ratio = mlp_ratio
302- self .use_abs_pos = use_abs_pos
303- self .use_rel_pos = use_rel_pos
304- self .window_size = window_size
305- self .global_attn_indexes = global_attn_indexes
306- self .num_pos_feats = num_pos_feats
307- self .mlp_dim = int (hidden_size * mlp_ratio ) if mlp_dim is None else mlp_dim
356+ self .num_heads = num_heads
357+ self .drop_path_rate = drop_path_rate
358+ self .q_pool = q_pool
359+ self .q_stride = q_stride
360+ self .stages = stages
361+ self .dim_mul = dim_mul
362+ self .head_mul = head_mul
363+ self .window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
364+ self .window_spec = window_spec
365+ self .global_att_blocks = global_att_blocks
366+ self .return_interm_layers = return_interm_layers
367+ self .neck_position_encoding_config = neck_position_encoding_config
368+ self .neck_hidden_size = neck_hidden_size
369+ self .neck_backbone_channel_list = neck_backbone_channel_list
370+ self .neck_kernel_size = neck_kernel_size
371+ self .neck_stride = neck_stride
372+ self .neck_padding = neck_padding
373+ self .neck_fpn_interp_model = neck_fpn_interp_model
374+ self .neck_fuse_type = neck_fuse_type
375+ self .neck_fpn_top_down_level = neck_fpn_top_down_level
308376
309377
310378class Sam2Config (PretrainedConfig ):
0 commit comments