Skip to content

Commit dc2cb88

Browse files
authored
Merge pull request huggingface#2 from SangbumChoi/sam2_config
check
2 parents 66d6fb8 + 289a0c0 commit dc2cb88

File tree

2 files changed

+124
-52
lines changed

2 files changed

+124
-52
lines changed

src/transformers/models/sam2/configuration_sam2.py

Lines changed: 118 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,42 @@ def __init__(
6969
self.layer_norm_eps = layer_norm_eps
7070

7171

72+
class Sam2PositionEmbeddingConfig(PretrainedConfig):
73+
r"""
74+
This is the configuration class to store the configuration of a [`Sam2PositionEmbedding`]. The [`Sam2PositionEmbedding`]
75+
module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield
76+
a similar configuration to that of the SAM2-hiera-tiny
77+
[facebook/sam2-hiera-tiny](https://huggingface.co/facebook/sam2-hiera-tiny) architecture.
78+
79+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
80+
documentation from [`PretrainedConfig`] for more information.
81+
82+
Args:
83+
num_pos_feats (`int`):
84+
The number of feature size for positioinal features.
85+
temperature (`int`, *optional*, defaults to 10000):
86+
The temperature value to consider.
87+
normalize (`bool`, *optional*, defaults to True):
88+
Whether to normalize the embedding vector.
89+
scale (`float`, *optional*, defaults to None):
90+
The scale value for embedding vector.
91+
"""
92+
93+
def __init__(
94+
self,
95+
num_pos_feats,
96+
temperature=10000,
97+
normalize=True,
98+
scale=None,
99+
**kwargs,
100+
):
101+
super().__init__(**kwargs)
102+
self.num_pos_feats = num_pos_feats
103+
self.temperature = temperature
104+
self.normalize = normalize
105+
self.scale = scale
106+
107+
72108
class Sam2MaskDecoderConfig(PretrainedConfig):
73109
r"""
74110
This is the configuration class to store the configuration of a [`Sam2MaskDecoder`]. It is used to instantiate a SAM2
@@ -135,8 +171,8 @@ def __init__(
135171
dynamic_multimask_via_stability=False,
136172
dynamic_multimask_stability_delta=0.05,
137173
dynamic_multimask_stability_thresh=0.98,
138-
pred_obj_scores= False,
139-
pred_obj_scores_mlp= False,
174+
pred_obj_scores=False,
175+
pred_obj_scores_mlp=False,
140176
use_multimask_token_for_obj_ptr=False,
141177
layer_norm_eps=1e-6,
142178
**kwargs,
@@ -151,14 +187,14 @@ def __init__(
151187
self.num_multimask_outputs = num_multimask_outputs
152188
self.iou_head_depth = iou_head_depth
153189
self.iou_head_hidden_dim = iou_head_hidden_dim
154-
self.use_high_res_features=use_high_res_features,
155-
self.iou_prediction_use_sigmoid=iou_prediction_use_sigmoid,
156-
self.dynamic_multimask_via_stability=dynamic_multimask_via_stability,
157-
self.dynamic_multimask_stability_delta=dynamic_multimask_stability_delta,
158-
self.dynamic_multimask_stability_thresh=dynamic_multimask_stability_thresh,
159-
self.pred_obj_scores= pred_obj_scores,
160-
self.pred_obj_scores_mlp= pred_obj_scores_mlp,
161-
self.use_multimask_token_for_obj_ptr=use_multimask_token_for_obj_ptr,
190+
self.use_high_res_features = (use_high_res_features,)
191+
self.iou_prediction_use_sigmoid = (iou_prediction_use_sigmoid,)
192+
self.dynamic_multimask_via_stability = (dynamic_multimask_via_stability,)
193+
self.dynamic_multimask_stability_delta = (dynamic_multimask_stability_delta,)
194+
self.dynamic_multimask_stability_thresh = (dynamic_multimask_stability_thresh,)
195+
self.pred_obj_scores = (pred_obj_scores,)
196+
self.pred_obj_scores_mlp = (pred_obj_scores_mlp,)
197+
self.use_multimask_token_for_obj_ptr = (use_multimask_token_for_obj_ptr,)
162198
self.layer_norm_eps = layer_norm_eps
163199

164200

@@ -175,6 +211,7 @@ class Sam2MemoryAttentionConfig(PretrainedConfig):
175211
Args:
176212
177213
"""
214+
178215
def __init__(
179216
self,
180217
# TO DO
@@ -198,14 +235,37 @@ class Sam2MemoryEncoderConfig(PretrainedConfig):
198235
Args:
199236
200237
"""
238+
201239
def __init__(
202240
self,
203-
# TO DO
241+
out_dim=64,
242+
positional_encoding_config=None,
243+
mask_downsmapler_config=None,
244+
fuser_config=None,
245+
in_dim=256,
204246
**kwargs,
205247
):
206248
super().__init__(**kwargs)
207-
208-
# TO DO
249+
if positional_encoding_config is None:
250+
positional_encoding_config = {"num_pos_feats": 64, "normalize": True, "scale": None, "temperature": 1000}
251+
if mask_downsmapler_config is None:
252+
mask_downsmapler_config = {"kernel_size": 3, "stride": 2, "padding": 1}
253+
if fuser_config is None:
254+
fuser_config = {
255+
"layer": {
256+
"dim": 256,
257+
"kernel_size": 7,
258+
"padding": 3,
259+
"layer_scale_init_value": 1e-6,
260+
"use_dwconv": True,
261+
},
262+
"num_layers": 2,
263+
}
264+
265+
self.out_dim = out_dim
266+
self.positional_encoding_config = positional_encoding_config
267+
self.mask_downsmapler_config = mask_downsmapler_config
268+
self.fuser_config = fuser_config
209269

210270

211271
# TO DO
@@ -263,48 +323,56 @@ class Sam2VisionConfig(PretrainedConfig):
263323

264324
def __init__(
265325
self,
266-
hidden_size=768,
267-
output_channels=256,
268-
num_hidden_layers=12,
269-
num_attention_heads=12,
270-
num_channels=3,
271-
image_size=1024,
272-
patch_size=16,
273-
hidden_act="gelu",
274-
layer_norm_eps=1e-06,
275-
attention_dropout=0.0,
276-
initializer_range=1e-10,
277-
qkv_bias=True,
278-
mlp_ratio=4.0,
279-
use_abs_pos=True,
280-
use_rel_pos=True,
281-
window_size=14,
282-
global_attn_indexes=[2, 5, 8, 11],
283-
num_pos_feats=128,
284-
mlp_dim=None,
326+
scalp=1,
327+
hidden_size=96,
328+
num_heads=1,
329+
drop_path_rate=0,
330+
q_pool=3,
331+
q_stride=[2, 2],
332+
stages=[1, 2, 7, 2],
333+
dim_mul=2.0,
334+
head_mul=2.0,
335+
window_pos_embed_bkg_spatial_size=[7, 7],
336+
window_spec=[8, 4, 14, 7],
337+
global_att_blocks=[5, 7, 9],
338+
return_interm_layers=False,
339+
neck_position_encoding_config=None,
340+
neck_hidden_size=256,
341+
neck_backbone_channel_list=[768, 384, 192, 96],
342+
neck_kernel_size=1,
343+
neck_stride=1,
344+
neck_padding=0,
345+
neck_fpn_interp_model="nearest",
346+
neck_fuse_type="sum",
347+
neck_fpn_top_down_level=[2, 3],
285348
**kwargs,
286349
):
287350
super().__init__(**kwargs)
351+
if neck_position_encoding_config is None:
352+
neck_position_encoding_config = Sam2PositionEmbeddingConfig(num_pos_feats=256)
288353

354+
self.scalp = scalp
289355
self.hidden_size = hidden_size
290-
self.output_channels = output_channels
291-
self.num_hidden_layers = num_hidden_layers
292-
self.num_attention_heads = num_attention_heads
293-
self.num_channels = num_channels
294-
self.image_size = image_size
295-
self.patch_size = patch_size
296-
self.hidden_act = hidden_act
297-
self.layer_norm_eps = layer_norm_eps
298-
self.attention_dropout = attention_dropout
299-
self.initializer_range = initializer_range
300-
self.qkv_bias = qkv_bias
301-
self.mlp_ratio = mlp_ratio
302-
self.use_abs_pos = use_abs_pos
303-
self.use_rel_pos = use_rel_pos
304-
self.window_size = window_size
305-
self.global_attn_indexes = global_attn_indexes
306-
self.num_pos_feats = num_pos_feats
307-
self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
356+
self.num_heads = num_heads
357+
self.drop_path_rate = drop_path_rate
358+
self.q_pool = q_pool
359+
self.q_stride = q_stride
360+
self.stages = stages
361+
self.dim_mul = dim_mul
362+
self.head_mul = head_mul
363+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
364+
self.window_spec = window_spec
365+
self.global_att_blocks = global_att_blocks
366+
self.return_interm_layers = return_interm_layers
367+
self.neck_position_encoding_config = neck_position_encoding_config
368+
self.neck_hidden_size = neck_hidden_size
369+
self.neck_backbone_channel_list = neck_backbone_channel_list
370+
self.neck_kernel_size = neck_kernel_size
371+
self.neck_stride = neck_stride
372+
self.neck_padding = neck_padding
373+
self.neck_fpn_interp_model = neck_fpn_interp_model
374+
self.neck_fuse_type = neck_fuse_type
375+
self.neck_fpn_top_down_level = neck_fpn_top_down_level
308376

309377

310378
class Sam2Config(PretrainedConfig):

src/transformers/models/sam2/convert_sam2_to_hf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ def get_config(model_name):
4141
vision_config = Sam2VisionConfig()
4242
elif "sam2_hiera_small" in model_name:
4343
# TO DO
44+
pass
4445
elif "sam2_hiera_base_plus" in model_name:
4546
# TO DO
47+
pass
4648
elif "sam2_hiera_large" in model_name:
4749
# TO DO
50+
pass
4851

4952
config = Sam2Config(
5053
vision_config=vision_config,
@@ -153,13 +156,14 @@ def convert_sam2_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, pu
153156

154157
elif model_name == "sam2_hiera_small":
155158
# TO DO
156-
159+
pass
157160
elif model_name == "sam2_hiera_base_plus":
158161
# TO DO
162+
pass
159163

160164
elif model_name == "sam2_hiera_large":
161165
# TO DO
162-
166+
pass
163167

164168
if pytorch_dump_folder is not None:
165169
processor.save_pretrained(pytorch_dump_folder)

0 commit comments

Comments
 (0)