Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

examples : add sample SAM inference #74

Merged
merged 8 commits into from
Aug 18, 2023
Merged

examples : add sample SAM inference #74

merged 8 commits into from
Aug 18, 2023

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Apr 9, 2023

Initial version: #418 (comment)


PTH tensors for ViT-B
image_encoder.neck.0.weight torch.Size([256, 768, 1, 1])
image_encoder.neck.1.weight torch.Size([256])
image_encoder.neck.1.bias torch.Size([256])
image_encoder.neck.2.weight torch.Size([256, 256, 3, 3])
image_encoder.neck.3.weight torch.Size([256])
image_encoder.neck.3.bias torch.Size([256])
image_encoder.patch_embed.proj.weight torch.Size([768, 3, 16, 16])
image_encoder.patch_embed.proj.bias torch.Size([768])
image_encoder.blocks.0.norm1.weight torch.Size([768])
image_encoder.blocks.0.norm1.bias torch.Size([768])
image_encoder.blocks.0.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.0.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.0.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.0.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.0.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.0.attn.proj.bias torch.Size([768])
image_encoder.blocks.0.norm2.weight torch.Size([768])
image_encoder.blocks.0.norm2.bias torch.Size([768])
image_encoder.blocks.0.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.0.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.0.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.0.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.1.norm1.weight torch.Size([768])
image_encoder.blocks.1.norm1.bias torch.Size([768])
image_encoder.blocks.1.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.1.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.1.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.1.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.1.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.1.attn.proj.bias torch.Size([768])
image_encoder.blocks.1.norm2.weight torch.Size([768])
image_encoder.blocks.1.norm2.bias torch.Size([768])
image_encoder.blocks.1.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.1.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.1.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.1.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.2.norm1.weight torch.Size([768])
image_encoder.blocks.2.norm1.bias torch.Size([768])
image_encoder.blocks.2.attn.rel_pos_h torch.Size([127, 64])
image_encoder.blocks.2.attn.rel_pos_w torch.Size([127, 64])
image_encoder.blocks.2.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.2.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.2.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.2.attn.proj.bias torch.Size([768])
image_encoder.blocks.2.norm2.weight torch.Size([768])
image_encoder.blocks.2.norm2.bias torch.Size([768])
image_encoder.blocks.2.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.2.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.2.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.2.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.3.norm1.weight torch.Size([768])
image_encoder.blocks.3.norm1.bias torch.Size([768])
image_encoder.blocks.3.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.3.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.3.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.3.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.3.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.3.attn.proj.bias torch.Size([768])
image_encoder.blocks.3.norm2.weight torch.Size([768])
image_encoder.blocks.3.norm2.bias torch.Size([768])
image_encoder.blocks.3.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.3.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.3.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.3.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.4.norm1.weight torch.Size([768])
image_encoder.blocks.4.norm1.bias torch.Size([768])
image_encoder.blocks.4.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.4.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.4.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.4.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.4.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.4.attn.proj.bias torch.Size([768])
image_encoder.blocks.4.norm2.weight torch.Size([768])
image_encoder.blocks.4.norm2.bias torch.Size([768])
image_encoder.blocks.4.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.4.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.4.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.4.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.5.norm1.weight torch.Size([768])
image_encoder.blocks.5.norm1.bias torch.Size([768])
image_encoder.blocks.5.attn.rel_pos_h torch.Size([127, 64])
image_encoder.blocks.5.attn.rel_pos_w torch.Size([127, 64])
image_encoder.blocks.5.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.5.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.5.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.5.attn.proj.bias torch.Size([768])
image_encoder.blocks.5.norm2.weight torch.Size([768])
image_encoder.blocks.5.norm2.bias torch.Size([768])
image_encoder.blocks.5.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.5.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.5.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.5.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.6.norm1.weight torch.Size([768])
image_encoder.blocks.6.norm1.bias torch.Size([768])
image_encoder.blocks.6.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.6.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.6.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.6.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.6.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.6.attn.proj.bias torch.Size([768])
image_encoder.blocks.6.norm2.weight torch.Size([768])
image_encoder.blocks.6.norm2.bias torch.Size([768])
image_encoder.blocks.6.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.6.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.6.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.6.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.7.norm1.weight torch.Size([768])
image_encoder.blocks.7.norm1.bias torch.Size([768])
image_encoder.blocks.7.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.7.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.7.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.7.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.7.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.7.attn.proj.bias torch.Size([768])
image_encoder.blocks.7.norm2.weight torch.Size([768])
image_encoder.blocks.7.norm2.bias torch.Size([768])
image_encoder.blocks.7.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.7.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.7.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.7.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.8.norm1.weight torch.Size([768])
image_encoder.blocks.8.norm1.bias torch.Size([768])
image_encoder.blocks.8.attn.rel_pos_h torch.Size([127, 64])
image_encoder.blocks.8.attn.rel_pos_w torch.Size([127, 64])
image_encoder.blocks.8.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.8.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.8.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.8.attn.proj.bias torch.Size([768])
image_encoder.blocks.8.norm2.weight torch.Size([768])
image_encoder.blocks.8.norm2.bias torch.Size([768])
image_encoder.blocks.8.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.8.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.8.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.8.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.9.norm1.weight torch.Size([768])
image_encoder.blocks.9.norm1.bias torch.Size([768])
image_encoder.blocks.9.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.9.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.9.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.9.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.9.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.9.attn.proj.bias torch.Size([768])
image_encoder.blocks.9.norm2.weight torch.Size([768])
image_encoder.blocks.9.norm2.bias torch.Size([768])
image_encoder.blocks.9.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.9.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.9.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.9.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.10.norm1.weight torch.Size([768])
image_encoder.blocks.10.norm1.bias torch.Size([768])
image_encoder.blocks.10.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.10.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.10.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.10.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.10.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.10.attn.proj.bias torch.Size([768])
image_encoder.blocks.10.norm2.weight torch.Size([768])
image_encoder.blocks.10.norm2.bias torch.Size([768])
image_encoder.blocks.10.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.10.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.10.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.10.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.11.norm1.weight torch.Size([768])
image_encoder.blocks.11.norm1.bias torch.Size([768])
image_encoder.blocks.11.attn.rel_pos_h torch.Size([127, 64])
image_encoder.blocks.11.attn.rel_pos_w torch.Size([127, 64])
image_encoder.blocks.11.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.11.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.11.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.11.attn.proj.bias torch.Size([768])
image_encoder.blocks.11.norm2.weight torch.Size([768])
image_encoder.blocks.11.norm2.bias torch.Size([768])
image_encoder.blocks.11.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.11.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.11.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.11.mlp.lin2.bias torch.Size([768])
prompt_encoder.pe_layer.positional_encoding_gaussian_matrix torch.Size([2, 128])
mask_decoder.transformer.layers.0.self_attn.q_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.0.self_attn.q_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.self_attn.k_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.0.self_attn.k_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.self_attn.v_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.0.self_attn.v_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.self_attn.out_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.0.self_attn.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.norm1.weight torch.Size([256])
mask_decoder.transformer.layers.0.norm1.bias torch.Size([256])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.norm2.weight torch.Size([256])
mask_decoder.transformer.layers.0.norm2.bias torch.Size([256])
mask_decoder.transformer.layers.0.mlp.lin1.weight torch.Size([2048, 256])
mask_decoder.transformer.layers.0.mlp.lin1.bias torch.Size([2048])
mask_decoder.transformer.layers.0.mlp.lin2.weight torch.Size([256, 2048])
mask_decoder.transformer.layers.0.mlp.lin2.bias torch.Size([256])
mask_decoder.transformer.layers.0.norm3.weight torch.Size([256])
mask_decoder.transformer.layers.0.norm3.bias torch.Size([256])
mask_decoder.transformer.layers.0.norm4.weight torch.Size([256])
mask_decoder.transformer.layers.0.norm4.bias torch.Size([256])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.self_attn.q_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.1.self_attn.q_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.self_attn.k_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.1.self_attn.k_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.self_attn.v_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.1.self_attn.v_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.self_attn.out_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.1.self_attn.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.norm1.weight torch.Size([256])
mask_decoder.transformer.layers.1.norm1.bias torch.Size([256])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.norm2.weight torch.Size([256])
mask_decoder.transformer.layers.1.norm2.bias torch.Size([256])
mask_decoder.transformer.layers.1.mlp.lin1.weight torch.Size([2048, 256])
mask_decoder.transformer.layers.1.mlp.lin1.bias torch.Size([2048])
mask_decoder.transformer.layers.1.mlp.lin2.weight torch.Size([256, 2048])
mask_decoder.transformer.layers.1.mlp.lin2.bias torch.Size([256])
mask_decoder.transformer.layers.1.norm3.weight torch.Size([256])
mask_decoder.transformer.layers.1.norm3.bias torch.Size([256])
mask_decoder.transformer.layers.1.norm4.weight torch.Size([256])
mask_decoder.transformer.layers.1.norm4.bias torch.Size([256])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.bias torch.Size([256])
mask_decoder.transformer.final_attn_token_to_image.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.final_attn_token_to_image.q_proj.bias torch.Size([128])
mask_decoder.transformer.final_attn_token_to_image.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.final_attn_token_to_image.k_proj.bias torch.Size([128])
mask_decoder.transformer.final_attn_token_to_image.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.final_attn_token_to_image.v_proj.bias torch.Size([128])
mask_decoder.transformer.final_attn_token_to_image.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.final_attn_token_to_image.out_proj.bias torch.Size([256])
mask_decoder.transformer.norm_final_attn.weight torch.Size([256])
mask_decoder.transformer.norm_final_attn.bias torch.Size([256])
prompt_encoder.point_embeddings.0.weight torch.Size([1, 256])
prompt_encoder.point_embeddings.1.weight torch.Size([1, 256])
prompt_encoder.point_embeddings.2.weight torch.Size([1, 256])
prompt_encoder.point_embeddings.3.weight torch.Size([1, 256])
prompt_encoder.not_a_point_embed.weight torch.Size([1, 256])
mask_decoder.output_upscaling.0.weight torch.Size([256, 64, 2, 2])
mask_decoder.output_upscaling.0.bias torch.Size([64])
mask_decoder.output_upscaling.1.weight torch.Size([64])
mask_decoder.output_upscaling.1.bias torch.Size([64])
mask_decoder.output_upscaling.3.weight torch.Size([64, 32, 2, 2])
mask_decoder.output_upscaling.3.bias torch.Size([32])
mask_decoder.output_hypernetworks_mlps.0.layers.0.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.0.layers.0.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.0.layers.1.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.0.layers.1.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.0.layers.2.weight torch.Size([32, 256])
mask_decoder.output_hypernetworks_mlps.0.layers.2.bias torch.Size([32])
mask_decoder.output_hypernetworks_mlps.1.layers.0.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.1.layers.0.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.1.layers.1.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.1.layers.1.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.1.layers.2.weight torch.Size([32, 256])
mask_decoder.output_hypernetworks_mlps.1.layers.2.bias torch.Size([32])
mask_decoder.output_hypernetworks_mlps.2.layers.0.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.2.layers.0.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.2.layers.1.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.2.layers.1.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.2.layers.2.weight torch.Size([32, 256])
mask_decoder.output_hypernetworks_mlps.2.layers.2.bias torch.Size([32])
mask_decoder.output_hypernetworks_mlps.3.layers.0.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.3.layers.0.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.3.layers.1.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.3.layers.1.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.3.layers.2.weight torch.Size([32, 256])
mask_decoder.output_hypernetworks_mlps.3.layers.2.bias torch.Size([32])
prompt_encoder.mask_downscaling.0.weight torch.Size([4, 1, 2, 2])
prompt_encoder.mask_downscaling.0.bias torch.Size([4])
prompt_encoder.mask_downscaling.1.weight torch.Size([4])
prompt_encoder.mask_downscaling.1.bias torch.Size([4])
prompt_encoder.mask_downscaling.3.weight torch.Size([16, 4, 2, 2])
prompt_encoder.mask_downscaling.3.bias torch.Size([16])
prompt_encoder.mask_downscaling.4.weight torch.Size([16])
prompt_encoder.mask_downscaling.4.bias torch.Size([16])
prompt_encoder.mask_downscaling.6.weight torch.Size([256, 16, 1, 1])
prompt_encoder.mask_downscaling.6.bias torch.Size([256])
prompt_encoder.no_mask_embed.weight torch.Size([1, 256])
mask_decoder.iou_prediction_head.layers.0.weight torch.Size([256, 256])
mask_decoder.iou_prediction_head.layers.0.bias torch.Size([256])
mask_decoder.iou_prediction_head.layers.1.weight torch.Size([256, 256])
mask_decoder.iou_prediction_head.layers.1.bias torch.Size([256])
mask_decoder.iou_prediction_head.layers.2.weight torch.Size([4, 256])
mask_decoder.iou_prediction_head.layers.2.bias torch.Size([4])
mask_decoder.iou_token.weight torch.Size([1, 256])
mask_decoder.mask_tokens.weight torch.Size([4, 256])
image_encoder.pos_embed torch.Size([1, 64, 64, 768])

@ggerganov ggerganov force-pushed the sam branch 11 times, most recently from a3443b2 to 4556c35 Compare May 31, 2023 10:44
@ggerganov ggerganov added the model Model specific label Jun 25, 2023
@ggerganov ggerganov changed the title Add sample SAM inference examples : add sample SAM inference Jun 25, 2023
@ggerganov ggerganov self-assigned this Jun 25, 2023
* Add loading of decoder layers in Model

* Multiply by hypernet_layer_cnt for ctx_size on model load

* Add decoder layers to py conversion script

* Fix wrong and reversed tensor sizes for decoder

* Add decoder transformer implementation

* Add decoder hypernet and iou prediction mlps

* Add transpose convolution operation and unit test

* Finish mask decoder and write the decoder output in the model state

* Output masks to png after removing padding and upsampling to original size

- Also filter based on the iou treshold
- Additionally filtering based on the stability score and crop boxes
should be done

* Add stb image write in order to output masks from SAM

* Add transpose convolution 2d name and symbol to ggml ops static arrays

* Comment out debug print in transpose convolution test to fix compilation

ggml-ci
src/ggml.c Outdated Show resolved Hide resolved
@ggerganov
Copy link
Owner Author

Add README.md with instructions for obtaining and converting the model and we can merge it.

We can continue optimizations from master:

  • reduce memory usage by utilizing the new ggml-alloc
  • remove redundant graph nodes
  • support F16 for heavy F32 ops
  • test quantization
  • support bigger models
  • etc.

@ggerganov ggerganov marked this pull request as ready for review August 17, 2023 13:10
@YavorGIvanov
Copy link
Collaborator

YavorGIvanov commented Aug 17, 2023

Add README.md with instructions for obtaining and converting the model and we can merge it.

We can continue optimizations from master:

  • reduce memory usage by utilizing the new ggml-alloc
  • remove redundant graph nodes
  • support F16 for heavy F32 ops
  • test quantization
  • support bigger models
  • etc.

Additionally I think we should:

  • Trace where the difference in output masks comes from. This will be done by going through the inference tensors step by step and comparing to the PyTorch version
  • Filter masks based on stability score and based on boxes, which touch crop boundaries
  • Add support for user input (avoid having harcoded point)

Aded all those next steps to the README and I am going to start working on them

@YavorGIvanov YavorGIvanov self-assigned this Aug 18, 2023
@ggerganov ggerganov merged commit 8da5be2 into master Aug 18, 2023
2 checks passed
@ggerganov ggerganov deleted the sam branch August 18, 2023 11:50
CCLDArjun pushed a commit to CCLDArjun/ggml that referenced this pull request Dec 18, 2023
@cmp-nct
Copy link

cmp-nct commented Feb 11, 2024

@ggerganov I just stumbled upon your SAM code and this comment:
// TODO: for some reason, this is not numerically identical to pytorch's interpolation
I've solved that here: ggerganov/llama.cpp#5267
It's a precision error when pytorch processes the data in 16 bit, with the below function I replicated the normalization values.
Of course that's only useful when trying to see if an architecture involving CLIP is similar on ggml, so by default it should be off.

// for replication purposes `.to(model.device, dtype=torch.float16)`
// converts a float to half precision and back to float
float simulateFloat16Precision(float value) {
    // Convert float32 to float16
    uint32_t f32 = *reinterpret_cast<uint32_t*>(&value);
    uint32_t sign = (f32 >> 16) & 0x8000; // Top bit (sign bit)
    uint32_t exponent = ((f32 >> 23) & 0xFF) - 112; // Adjust bias (112 is bias of float16, 127 is bias of float32)
    uint32_t mantissa = (f32 >> 13) & 0x3FF; // Keep top 10 bits (10 bits of precision in float16, 23 in float32)

    // Handle overflow/underflow
    if ((f32 & 0x7FFFFFFF) > 0x477FE000) { // Not representable
        exponent = 0x1F;
        mantissa = 0;
    } else if ((f32 & 0x7FFFFFFF) < 0x38800000) { // Too small for normal half precision
        exponent = 0;
        mantissa = 0;
    }

    uint16_t f16 = sign | (exponent << 10) | mantissa;

    // Convert back to float32
    uint32_t sign32 = (f16 & 0x8000) << 16;
    uint32_t exponent32 = ((f16 >> 10) & 0x1F);
    uint32_t mantissa32 = (f16 & 0x3FF) << 13;

    // Adjust bias back
    exponent32 = exponent32 == 0 ? 0 : exponent32 + 112;

    uint32_t f32Result = sign32 | (exponent32 << 23) | mantissa32;
    float result = *reinterpret_cast<float*>(&f32Result);

    return result;
}
// Normalize image to float32 - supports float16 replication as in pytorch .to(model.device, dtype=torch.float16)
void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3], bool replicate_float16) {
    dst->nx = src->nx;
    dst->ny = src->ny;
    dst->buf.resize(src->buf.size());

    for (size_t i = 0; i < src->buf.size(); ++i) {
        int c = i % 3; // rgb
        dst->buf[i] = (static_cast<float>(src->buf[i]) / 255.0f - mean[c]) / std[c];

        if (replicate_float16) {
            dst->buf[i] = simulateFloat16Precision(dst->buf[i]);
        }
    }
}

@ggerganov
Copy link
Owner Author

@cmp-nct Ah good to know - thanks for looking into this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model Model specific
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

3 participants