Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup hydra #3

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ __pycache__/
# C extensions
*.so

# Models and logs
work_dirs
outputs

# Distribution / packaging
.Python
build/
Expand Down
3 changes: 0 additions & 3 deletions .gitmodules

This file was deleted.

38 changes: 12 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,26 @@

## Requirements

- Python >= 3.8
- CUDA >= 11.0 (optional but recommended for GPU acceleration)
- All dependencies are listed in the `requirements.txt` file.
Tested on:
- Python 3.11
- torch 2.4.0
- torchvision 0.19.0

## Installation
To set up the environment and install `SAM-HQ`, follow the [instructions](https://github.com/SysCV/sam-hq?tab=readme-ov-file#example-conda-environment-setup).

### 1. Clone the Repository (with Submodules)
## Setup

To properly set up the repository and include the `SAM-HQ` submodule, run the following command:
1. Place the data folder in the "data" directory.
2. Set up the experiment in the "conf/experiments/" directory.

```bash
git clone --recurse-submodules https://github.com/YOUR_ORG/AgIR-FinetuneSAM.git
cd AgIR-FinetuneSAM
```
## Execution

> **Note**: If you already cloned the repository without the `--recurse-submodules` flag, you can manually initialize and update the submodule:
```bash
git submodule update --init --recursive
```
To execute the training script with 3 GPUs, use the following command:

### 2. Install Dependencies

Install the required Python dependencies:
```bash
pip install -r requirements.txt
torchrun --nproc_per_node=3 train.py
```

## Troubleshooting

- **Submodule Not Cloned**: If the submodule did not clone correctly, ensure you used the `--recurse-submodules` flag when cloning, or manually initialize the submodule using:
```bash
git submodule update --init --recursive
```

## Acknowledgments

- Special thanks to the **SysCV team** for developing the [SAM-HQ](https://github.com/SysCV/sam-hq) repository.
Special thanks to the **SysCV team** for developing the [SAM-HQ](https://github.com/SysCV/sam-hq) repository.
12 changes: 12 additions & 0 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defaults:
- experiments: default
- _self_

world_size: 1 # number of distributed processes
dist_url: env:// # url used to set up distributed training
rank: 0 # number of distributed processes
local_rank: 0 # local rank for dist
find_unused_params: false
gpu: None
distributed: false
dist_backend: nccl
33 changes: 33 additions & 0 deletions conf/experiments/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
output: ./work_dirs/${model_type} # Path to the directory where masks and checkpoints will be output
model_type: vit_h # The type of model to load, in ['vit_h', 'vit_l', 'vit_b']
checkpoint: ./pretrained_checkpoint/sam_vit_h_4b8939.pth # The path to the SAM checkpoint to use for mask generation.
device: cuda # The device to run generation on.


seed: 42
learning_rate: 1e-3
start_epoch: 0
lr_drop_epoch: 10
max_epoch_num: 12
input_size: [1024, 1024]
batch_size_train: 4
batch_size_valid: 1
model_save_fre: 1

eval: false
visualize: false
restore_model: null

datasets:
train:
- name: FIELD
im_dir: ./data/FIELD/train/images
gt_dir: ./data/FIELD/train/masks
im_ext: .JPG
gt_ext: .png
valid:
- name: FIELD
im_dir: ./data/FIELD/val/images
gt_dir: ./data/FIELD/val/masks
im_ext: .JPG
gt_ext: .png
2 changes: 2 additions & 0 deletions data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
2 changes: 2 additions & 0 deletions pretrained_checkpoint/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
1 change: 0 additions & 1 deletion sam-hq
Submodule sam-hq deleted from ac1972
9 changes: 9 additions & 0 deletions scripts/download_checkpoints.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

echo "Downloading checkpoints..."
wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_b_01ec64.pth
wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_b_maskdecoder.pth
wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_h_4b8939.pth
wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_h_maskdecoder.pth
wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_l_0b3195.pth
wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_l_maskdecoder.pth
13 changes: 13 additions & 0 deletions segment_anything_training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .build_sam import (
build_sam,
build_sam_vit_h,
build_sam_vit_l,
build_sam_vit_b,
sam_model_registry,
)
107 changes: 107 additions & 0 deletions segment_anything_training/build_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch

from functools import partial

from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer


def build_sam_vit_h(checkpoint=None):
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)


build_sam = build_sam_vit_h


def build_sam_vit_l(checkpoint=None):
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
)


def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)


sam_model_registry = {
"default": build_sam,
"vit_h": build_sam,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
}


def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam
11 changes: 11 additions & 0 deletions segment_anything_training/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .sam import Sam
from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .transformer import TwoWayTransformer
43 changes: 43 additions & 0 deletions segment_anything_training/modeling/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from typing import Type


class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))


# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
Loading