diff --git a/.gitignore b/.gitignore index 82f9275..b8d8527 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,10 @@ __pycache__/ # C extensions *.so +# Models and logs +work_dirs +outputs + # Distribution / packaging .Python build/ diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index df93e68..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "sam-hq"] - path = sam-hq - url = https://github.com/SysCV/sam-hq.git diff --git a/README.md b/README.md index ead83dd..03c9515 100644 --- a/README.md +++ b/README.md @@ -6,40 +6,26 @@ ## Requirements -- Python >= 3.8 -- CUDA >= 11.0 (optional but recommended for GPU acceleration) -- All dependencies are listed in the `requirements.txt` file. +Tested on: +- Python 3.11 +- torch 2.4.0 +- torchvision 0.19.0 -## Installation +To set up the environment and install `SAM-HQ`, follow the [instructions](https://github.com/SysCV/sam-hq?tab=readme-ov-file#example-conda-environment-setup). -### 1. Clone the Repository (with Submodules) +## Setup -To properly set up the repository and include the `SAM-HQ` submodule, run the following command: +1. Place the data folder in the "data" directory. +2. Set up the experiment in the "conf/experiments/" directory. -```bash -git clone --recurse-submodules https://github.com/YOUR_ORG/AgIR-FinetuneSAM.git -cd AgIR-FinetuneSAM -``` +## Execution -> **Note**: If you already cloned the repository without the `--recurse-submodules` flag, you can manually initialize and update the submodule: -```bash -git submodule update --init --recursive -``` +To execute the training script with 3 GPUs, use the following command: -### 2. Install Dependencies - -Install the required Python dependencies: ```bash -pip install -r requirements.txt +torchrun --nproc_per_node=3 train.py ``` -## Troubleshooting - -- **Submodule Not Cloned**: If the submodule did not clone correctly, ensure you used the `--recurse-submodules` flag when cloning, or manually initialize the submodule using: - ```bash - git submodule update --init --recursive - ``` - ## Acknowledgments -- Special thanks to the **SysCV team** for developing the [SAM-HQ](https://github.com/SysCV/sam-hq) repository. +Special thanks to the **SysCV team** for developing the [SAM-HQ](https://github.com/SysCV/sam-hq) repository. \ No newline at end of file diff --git a/conf/config.yaml b/conf/config.yaml new file mode 100644 index 0000000..464b7cc --- /dev/null +++ b/conf/config.yaml @@ -0,0 +1,12 @@ +defaults: + - experiments: default + - _self_ + +world_size: 1 # number of distributed processes +dist_url: env:// # url used to set up distributed training +rank: 0 # number of distributed processes +local_rank: 0 # local rank for dist +find_unused_params: false +gpu: None +distributed: false +dist_backend: nccl \ No newline at end of file diff --git a/conf/experiments/default.yaml b/conf/experiments/default.yaml new file mode 100644 index 0000000..d682644 --- /dev/null +++ b/conf/experiments/default.yaml @@ -0,0 +1,33 @@ +output: ./work_dirs/${model_type} # Path to the directory where masks and checkpoints will be output +model_type: vit_h # The type of model to load, in ['vit_h', 'vit_l', 'vit_b'] +checkpoint: ./pretrained_checkpoint/sam_vit_h_4b8939.pth # The path to the SAM checkpoint to use for mask generation. +device: cuda # The device to run generation on. + + +seed: 42 +learning_rate: 1e-3 +start_epoch: 0 +lr_drop_epoch: 10 +max_epoch_num: 12 +input_size: [1024, 1024] +batch_size_train: 4 +batch_size_valid: 1 +model_save_fre: 1 + +eval: false +visualize: false +restore_model: null + +datasets: + train: + - name: FIELD + im_dir: ./data/FIELD/train/images + gt_dir: ./data/FIELD/train/masks + im_ext: .JPG + gt_ext: .png + valid: + - name: FIELD + im_dir: ./data/FIELD/val/images + gt_dir: ./data/FIELD/val/masks + im_ext: .JPG + gt_ext: .png \ No newline at end of file diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/pretrained_checkpoint/.gitignore b/pretrained_checkpoint/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/pretrained_checkpoint/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/sam-hq b/sam-hq deleted file mode 160000 index ac19724..0000000 --- a/sam-hq +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ac19724c47b13689e5d9596277a6522b371001c8 diff --git a/scripts/download_checkpoints.sh b/scripts/download_checkpoints.sh new file mode 100644 index 0000000..269d9d7 --- /dev/null +++ b/scripts/download_checkpoints.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +echo "Downloading checkpoints..." +wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_b_01ec64.pth +wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_b_maskdecoder.pth +wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_h_4b8939.pth +wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_h_maskdecoder.pth +wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_l_0b3195.pth +wget -P ../pretrained_checkpoint https://huggingface.co/sam-hq-team/sam-hq-training/resolve/main/pretrained_checkpoint/sam_vit_l_maskdecoder.pth diff --git a/segment_anything_training/__init__.py b/segment_anything_training/__init__.py new file mode 100644 index 0000000..5514ce3 --- /dev/null +++ b/segment_anything_training/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) diff --git a/segment_anything_training/build_sam.py b/segment_anything_training/build_sam.py new file mode 100644 index 0000000..07abfca --- /dev/null +++ b/segment_anything_training/build_sam.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam, + "vit_h": build_sam, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/segment_anything_training/modeling/__init__.py b/segment_anything_training/modeling/__init__.py new file mode 100644 index 0000000..38e9062 --- /dev/null +++ b/segment_anything_training/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/segment_anything_training/modeling/common.py b/segment_anything_training/modeling/common.py new file mode 100644 index 0000000..2bf1523 --- /dev/null +++ b/segment_anything_training/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/segment_anything_training/modeling/image_encoder.py b/segment_anything_training/modeling/image_encoder.py new file mode 100644 index 0000000..d62d877 --- /dev/null +++ b/segment_anything_training/modeling/image_encoder.py @@ -0,0 +1,398 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + interm_embeddings=[] + for blk in self.blocks: + x = blk(x) + if blk.window_size == 0: + interm_embeddings.append(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x, interm_embeddings + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + return x + + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/segment_anything_training/modeling/mask_decoder.py b/segment_anything_training/modeling/mask_decoder.py new file mode 100644 index 0000000..19632e3 --- /dev/null +++ b/segment_anything_training/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + tranformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for outptu + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/segment_anything_training/modeling/prompt_encoder.py b/segment_anything_training/modeling/prompt_encoder.py new file mode 100644 index 0000000..c3143f4 --- /dev/null +++ b/segment_anything_training/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/segment_anything_training/modeling/sam.py b/segment_anything_training/modeling/sam.py new file mode 100644 index 0000000..d1727cb --- /dev/null +++ b/segment_anything_training/modeling/sam.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input promts, + C is determiend by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + + image_embeddings, interm_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output + ) + + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + "encoder_embedding": curr_embedding.unsqueeze(0), + "image_pe": self.prompt_encoder.get_dense_pe(), + "sparse_embeddings":sparse_embeddings, + "dense_embeddings":dense_embeddings, + } + ) + + return outputs, interm_embeddings + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/segment_anything_training/modeling/transformer.py b/segment_anything_training/modeling/transformer.py new file mode 100644 index 0000000..f1a2812 --- /dev/null +++ b/segment_anything_training/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attenion layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/segment_anything_training/utils/__init__.py b/segment_anything_training/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/segment_anything_training/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/segment_anything_training/utils/transforms.py b/segment_anything_training/utils/transforms.py new file mode 100644 index 0000000..3ad3466 --- /dev/null +++ b/segment_anything_training/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/train.py b/train.py new file mode 100644 index 0000000..eb7a3cc --- /dev/null +++ b/train.py @@ -0,0 +1,588 @@ +# Copyright by HQ-SAM team +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import numpy as np +import torch +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import matplotlib.pyplot as plt +import cv2 +import random +from typing import Dict, List, Tuple +import hydra +from omegaconf import DictConfig + +from segment_anything_training import sam_model_registry +from segment_anything_training.modeling import TwoWayTransformer, MaskDecoder + +from utils.dataloader import get_im_gt_name_dict, create_dataloaders, RandomHFlip, Resize, LargeScaleJitter +from utils.loss_mask import loss_masks +import utils.misc as misc + + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + +class MaskDecoderHQ(MaskDecoder): + def __init__(self, model_type): + super().__init__(transformer_dim=256, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=256, + mlp_dim=2048, + num_heads=8, + ), + num_multimask_outputs=3, + activation=nn.GELU, + iou_head_depth= 3, + iou_head_hidden_dim= 256,) + assert model_type in ["vit_b","vit_l","vit_h"] + + checkpoint_dict = {"vit_b":"pretrained_checkpoint/sam_vit_b_maskdecoder.pth", + "vit_l":"pretrained_checkpoint/sam_vit_l_maskdecoder.pth", + 'vit_h':"pretrained_checkpoint/sam_vit_h_maskdecoder.pth"} + checkpoint_path = checkpoint_dict[model_type] + self.load_state_dict(torch.load(checkpoint_path)) + print("HQ Decoder init from SAM MaskDecoder") + for n,p in self.named_parameters(): + p.requires_grad = False + + transformer_dim=256 + vit_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280} + vit_dim = vit_dim_dict[model_type] + + self.hf_token = nn.Embedding(1, transformer_dim) + self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + self.num_mask_tokens = self.num_mask_tokens + 1 + + self.compress_vit_feat = nn.Sequential( + nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2)) + + self.embedding_encoder = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + ) + + self.embedding_maskfeature = nn.Sequential( + nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1)) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + interm_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the ViT image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted hq masks + """ + + vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT + hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) + + batch_len = len(image_embeddings) + masks = [] + iou_preds = [] + for i_batch in range(batch_len): + mask, iou_pred = self.predict_masks( + image_embeddings=image_embeddings[i_batch].unsqueeze(0), + image_pe=image_pe[i_batch], + sparse_prompt_embeddings=sparse_prompt_embeddings[i_batch], + dense_prompt_embeddings=dense_prompt_embeddings[i_batch], + hq_feature = hq_features[i_batch].unsqueeze(0) + ) + masks.append(mask) + iou_preds.append(iou_pred) + masks = torch.cat(masks,0) + iou_preds = torch.cat(iou_preds,0) + + # Select the correct mask or masks for output + if multimask_output: + # mask with highest score + mask_slice = slice(1,self.num_mask_tokens-1) + iou_preds = iou_preds[:, mask_slice] + iou_preds, max_iou_idx = torch.max(iou_preds,dim=1) + iou_preds = iou_preds.unsqueeze(1) + masks_multi = masks[:, mask_slice, :, :] + masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) + else: + # singale mask output, default + mask_slice = slice(0, 1) + masks_sam = masks[:,mask_slice] + + masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens), :, :] + + if hq_token_only: + return masks_hq + else: + return masks_sam, masks_hq + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + hq_feature: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + upscaled_embedding_sam = self.output_upscaling(src) + upscaled_embedding_ours = self.embedding_maskfeature(upscaled_embedding_sam) + hq_feature + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + if i < 4: + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + else: + hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) + + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding_sam.shape + + masks_sam = (hyper_in[:,:4] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) + masks_ours = (hyper_in[:,4:] @ upscaled_embedding_ours.view(b, c, h * w)).view(b, -1, h, w) + masks = torch.cat([masks_sam,masks_ours],dim=1) + + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + + +def show_anns(masks, input_point, input_box, input_label, filename, image, ious, boundary_ious): + if len(masks) == 0: + return + + for i, (mask, iou, biou) in enumerate(zip(masks, ious, boundary_ious)): + plt.figure(figsize=(10,10)) + plt.imshow(image) + show_mask(mask, plt.gca()) + if input_box is not None: + show_box(input_box, plt.gca()) + if (input_point is not None) and (input_label is not None): + show_points(input_point, input_label, plt.gca()) + + plt.axis('off') + plt.savefig(filename+'_'+str(i)+'.png',bbox_inches='tight',pad_inches=-0.1) + plt.close() + +def show_mask(mask, ax, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels==1] + neg_points = coords[labels==0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig): + + net = MaskDecoderHQ(cfg.experiments.model_type) + + train_datasets = cfg.experiments.datasets.train + valid_datasets = cfg.experiments.datasets.valid + + misc.init_distributed_mode(cfg) + print('world size: {}'.format(cfg.world_size)) + print('rank: {}'.format(cfg.rank)) + print('local_rank: {}'.format(cfg.local_rank)) + print("cfg: " + str(cfg) + '\n') + + seed = cfg.experiments.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + ### --- Step 1: Train or Valid dataset --- + if not cfg.experiments.eval: + print("--- create training dataloader ---") + train_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train") + train_dataloaders, train_datasets = create_dataloaders(train_im_gt_list, + my_transforms = [ + RandomHFlip(), + LargeScaleJitter() + ], + batch_size = cfg.experiments.batch_size_train, + training = True) + print(len(train_dataloaders), " train dataloaders created") + + print("--- create valid dataloader ---") + valid_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid") + valid_dataloaders, valid_datasets = create_dataloaders(valid_im_gt_list, + my_transforms = [ + Resize(cfg.experiments.input_size) + ], + batch_size=cfg.experiments.batch_size_valid, + training=False) + print(len(valid_dataloaders), " valid dataloaders created") + + ### --- Step 2: DistributedDataParallel--- + if torch.cuda.is_available(): + net.cuda() + net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[cfg.gpu], find_unused_parameters=cfg.find_unused_params) + net_without_ddp = net.module + + + ### --- Step 3: Train or Evaluate --- + if not cfg.experiments.eval: + print("--- define optimizer ---") + optimizer = optim.Adam(net_without_ddp.parameters(), lr=cfg.experiments.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.experiments.lr_drop_epoch) + lr_scheduler.last_epoch = cfg.experiments.start_epoch + + train(cfg, net, optimizer, train_dataloaders, valid_dataloaders, lr_scheduler) + else: + sam = sam_model_registry[cfg.experiments.model_type](checkpoint=cfg.experiments.checkpoint) + _ = sam.to(device=cfg.experiments.device) + sam = torch.nn.parallel.DistributedDataParallel(sam, device_ids=[cfg.gpu], find_unused_parameters=cfg.find_unused_params) + + if cfg.experiments.restore_model: + print("restore model from:", cfg.experiments.restore_model) + if torch.cuda.is_available(): + net_without_ddp.load_state_dict(torch.load(cfg.experiments.restore_model)) + else: + net_without_ddp.load_state_dict(torch.load(cfg.experiments.restore_model,map_location="cpu")) + + evaluate(cfg, net, sam, valid_dataloaders, cfg.experiments.visualize) + + +def train(cfg, net, optimizer, train_dataloaders, valid_dataloaders, lr_scheduler): + if misc.is_main_process(): + os.makedirs(cfg.experiments.output, exist_ok=True) + + epoch_start = cfg.experiments.start_epoch + epoch_num = cfg.experiments.max_epoch_num + train_num = len(train_dataloaders) + + net.train() + _ = net.to(device=cfg.experiments.device) + + sam = sam_model_registry[cfg.experiments.model_type](checkpoint=cfg.experiments.checkpoint) + _ = sam.to(device=cfg.experiments.device) + sam = torch.nn.parallel.DistributedDataParallel(sam, device_ids=[cfg.gpu], find_unused_parameters=cfg.find_unused_params) + + for epoch in range(epoch_start,epoch_num): + print("epoch: ",epoch, " learning rate: ", optimizer.param_groups[0]["lr"]) + metric_logger = misc.MetricLogger(delimiter=" ") + train_dataloaders.batch_sampler.sampler.set_epoch(epoch) + + for data in metric_logger.log_every(train_dataloaders,1000): + inputs, labels = data['image'], data['label'] + if torch.cuda.is_available(): + inputs = inputs.cuda() + labels = labels.cuda() + + imgs = inputs.permute(0, 2, 3, 1).cpu().numpy() + + # input prompt + input_keys = ['box','point','noise_mask'] + labels_box = misc.masks_to_boxes(labels[:,0,:,:]) + try: + labels_points = misc.masks_sample_points(labels[:,0,:,:]) + except: + # less than 10 points + input_keys = ['box','noise_mask'] + labels_256 = F.interpolate(labels, size=(256, 256), mode='bilinear') + labels_noisemask = misc.masks_noise(labels_256) + + batched_input = [] + for b_i in range(len(imgs)): + dict_input = dict() + input_image = torch.as_tensor(imgs[b_i].astype(dtype=np.uint8), device=sam.device).permute(2, 0, 1).contiguous() + dict_input['image'] = input_image + input_type = random.choice(input_keys) + if input_type == 'box': + dict_input['boxes'] = labels_box[b_i:b_i+1] + elif input_type == 'point': + point_coords = labels_points[b_i:b_i+1] + dict_input['point_coords'] = point_coords + dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:] + elif input_type == 'noise_mask': + dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1] + else: + raise NotImplementedError + dict_input['original_size'] = imgs[b_i].shape[:2] + batched_input.append(dict_input) + + with torch.no_grad(): + batched_output, interm_embeddings = sam(batched_input, multimask_output=False) + + batch_len = len(batched_output) + encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) + image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] + sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] + dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] + + masks_hq = net( + image_embeddings=encoder_embedding, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, + hq_token_only=True, + interm_embeddings=interm_embeddings, + ) + + loss_mask, loss_dice = loss_masks(masks_hq, labels/255.0, len(masks_hq)) + loss = loss_mask + loss_dice + + loss_dict = {"loss_mask": loss_mask, "loss_dice":loss_dice} + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = misc.reduce_dict(loss_dict) + losses_reduced_scaled = sum(loss_dict_reduced.values()) + loss_value = losses_reduced_scaled.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + metric_logger.update(training_loss=loss_value, **loss_dict_reduced) + + + print("Finished epoch: ", epoch) + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} + + lr_scheduler.step() + test_stats = evaluate(cfg, net, sam, valid_dataloaders) + train_stats.update(test_stats) + + net.train() + + if epoch % cfg.experiments.model_save_fre == 0: + model_name = "/epoch_"+str(epoch)+".pth" + print('come here save at', cfg.experiments.output + model_name) + misc.save_on_master(net.module.state_dict(), cfg.experiments.output + model_name) + + # Finish training + print("Training Reaches The Maximum Epoch Number") + + # merge sam and hq_decoder + if misc.is_main_process(): + sam_ckpt = torch.load(cfg.experiments.checkpoint) + hq_decoder = torch.load(cfg.experiments.output + model_name) + for key in hq_decoder.keys(): + sam_key = 'mask_decoder.'+key + if sam_key not in sam_ckpt.keys(): + sam_ckpt[sam_key] = hq_decoder[key] + model_name = "/sam_hq_epoch_"+str(epoch)+".pth" + torch.save(sam_ckpt, cfg.experiments.output + model_name) + + + +def compute_iou(preds, target): + assert target.shape[1] == 1, 'only support one mask per image now' + if(preds.shape[2]!=target.shape[2] or preds.shape[3]!=target.shape[3]): + postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode='bilinear', align_corners=False) + else: + postprocess_preds = preds + iou = 0 + for i in range(0,len(preds)): + iou = iou + misc.mask_iou(postprocess_preds[i],target[i]) + return iou / len(preds) + +def compute_boundary_iou(preds, target): + assert target.shape[1] == 1, 'only support one mask per image now' + if(preds.shape[2]!=target.shape[2] or preds.shape[3]!=target.shape[3]): + postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode='bilinear', align_corners=False) + else: + postprocess_preds = preds + iou = 0 + for i in range(0,len(preds)): + iou = iou + misc.boundary_iou(target[i],postprocess_preds[i]) + return iou / len(preds) + +def evaluate(cfg, net, sam, valid_dataloaders, visualize=False): + net.eval() + print("Validating...") + test_stats = {} + + for k in range(len(valid_dataloaders)): + metric_logger = misc.MetricLogger(delimiter=" ") + valid_dataloader = valid_dataloaders[k] + print('valid_dataloader len:', len(valid_dataloader)) + + for data_val in metric_logger.log_every(valid_dataloader,1000): + imidx_val, inputs_val, labels_val, shapes_val, labels_ori = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'], data_val['ori_label'] + + if torch.cuda.is_available(): + inputs_val = inputs_val.cuda() + labels_val = labels_val.cuda() + labels_ori = labels_ori.cuda() + + imgs = inputs_val.permute(0, 2, 3, 1).cpu().numpy() + + labels_box = misc.masks_to_boxes(labels_val[:,0,:,:]) + input_keys = ['box'] + batched_input = [] + for b_i in range(len(imgs)): + dict_input = dict() + input_image = torch.as_tensor(imgs[b_i].astype(dtype=np.uint8), device=sam.device).permute(2, 0, 1).contiguous() + dict_input['image'] = input_image + input_type = random.choice(input_keys) + if input_type == 'box': + dict_input['boxes'] = labels_box[b_i:b_i+1] + elif input_type == 'point': + point_coords = labels_points[b_i:b_i+1] + dict_input['point_coords'] = point_coords + dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:] + elif input_type == 'noise_mask': + dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1] + else: + raise NotImplementedError + dict_input['original_size'] = imgs[b_i].shape[:2] + batched_input.append(dict_input) + + with torch.no_grad(): + batched_output, interm_embeddings = sam(batched_input, multimask_output=False) + + batch_len = len(batched_output) + encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) + image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] + sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] + dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] + + masks_sam, masks_hq = net( + image_embeddings=encoder_embedding, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, + hq_token_only=False, + interm_embeddings=interm_embeddings, + ) + + iou = compute_iou(masks_hq,labels_ori) + boundary_iou = compute_boundary_iou(masks_hq,labels_ori) + + if visualize: + print("visualize") + os.makedirs(cfg.experiments.output, exist_ok=True) + masks_hq_vis = (F.interpolate(masks_hq.detach(), (1024, 1024), mode="bilinear", align_corners=False) > 0).cpu() + for ii in range(len(imgs)): + base = data_val['imidx'][ii].item() + print('base:', base) + save_base = os.path.join(cfg.experiments.output, str(k)+'_'+ str(base)) + imgs_ii = imgs[ii].astype(dtype=np.uint8) + show_iou = torch.tensor([iou.item()]) + show_boundary_iou = torch.tensor([boundary_iou.item()]) + show_anns(masks_hq_vis[ii], None, labels_box[ii].cpu(), None, save_base , imgs_ii, show_iou, show_boundary_iou) + + + loss_dict = {"val_iou_"+str(k): iou, "val_boundary_iou_"+str(k): boundary_iou} + loss_dict_reduced = misc.reduce_dict(loss_dict) + metric_logger.update(**loss_dict_reduced) + + + print('============================') + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + resstat = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} + test_stats.update(resstat) + + + return test_stats + + +if __name__ == "__main__": + main() + diff --git a/utils/dataloader.py b/utils/dataloader.py new file mode 100644 index 0000000..239b42e --- /dev/null +++ b/utils/dataloader.py @@ -0,0 +1,272 @@ +# Copyright by HQ-SAM team +# All rights reserved. + +## data loader +from __future__ import print_function, division + +import numpy as np +import random +from copy import deepcopy +from skimage import io +import os +from glob import glob + +import torch +from torch.utils.data import Dataset, DataLoader, ConcatDataset +from torchvision import transforms, utils +from torchvision.transforms.functional import normalize +import torch.nn.functional as F +from torch.utils.data.distributed import DistributedSampler + +#### --------------------- dataloader online ---------------------#### + +def get_im_gt_name_dict(datasets, flag='valid'): + print("------------------------------", flag, "--------------------------------") + name_im_gt_list = [] + + for i in range(len(datasets)): + print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---") + tmp_im_list, tmp_gt_list = [], [] + tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"]) + print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list)) + + if(datasets[i]["gt_dir"]==""): + print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found') + tmp_gt_list = [] + else: + tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list] + print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list)) + + + name_im_gt_list.append({"dataset_name":datasets[i]["name"], + "im_path":tmp_im_list, + "gt_path":tmp_gt_list, + "im_ext":datasets[i]["im_ext"], + "gt_ext":datasets[i]["gt_ext"]}) + + return name_im_gt_list + +def create_dataloaders(name_im_gt_list, my_transforms=[], batch_size=1, training=False): + gos_dataloaders = [] + gos_datasets = [] + + if(len(name_im_gt_list)==0): + return gos_dataloaders, gos_datasets + + num_workers_ = 1 + if(batch_size>1): + num_workers_ = 2 + if(batch_size>4): + num_workers_ = 4 + if(batch_size>8): + num_workers_ = 8 + + + if training: + for i in range(len(name_im_gt_list)): + gos_dataset = OnlineDataset([name_im_gt_list[i]], transform = transforms.Compose(my_transforms)) + gos_datasets.append(gos_dataset) + + gos_dataset = ConcatDataset(gos_datasets) + sampler = DistributedSampler(gos_dataset) + batch_sampler_train = torch.utils.data.BatchSampler( + sampler, batch_size, drop_last=True) + dataloader = DataLoader(gos_dataset, batch_sampler=batch_sampler_train, num_workers=num_workers_) + + gos_dataloaders = dataloader + gos_datasets = gos_dataset + + else: + for i in range(len(name_im_gt_list)): + gos_dataset = OnlineDataset([name_im_gt_list[i]], transform = transforms.Compose(my_transforms), eval_ori_resolution = True) + sampler = DistributedSampler(gos_dataset, shuffle=False) + dataloader = DataLoader(gos_dataset, batch_size, sampler=sampler, drop_last=False, num_workers=num_workers_) + + gos_dataloaders.append(dataloader) + gos_datasets.append(gos_dataset) + + return gos_dataloaders, gos_datasets + +class RandomHFlip(object): + def __init__(self,prob=0.5): + self.prob = prob + def __call__(self,sample): + imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] + + # random horizontal flip + if random.random() >= self.prob: + image = torch.flip(image,dims=[2]) + label = torch.flip(label,dims=[2]) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} + +class Resize(object): + def __init__(self,size=[320,320]): + self.size = size + def __call__(self,sample): + imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] + + image = torch.squeeze(F.interpolate(torch.unsqueeze(image,0),tuple(self.size),mode='bilinear'),dim=0) + label = torch.squeeze(F.interpolate(torch.unsqueeze(label,0),tuple(self.size),mode='bilinear'),dim=0) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':torch.tensor(self.size)} + +class RandomCrop(object): + def __init__(self,size=[288,288]): + self.size = size + def __call__(self,sample): + imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] + + h, w = image.shape[1:] + new_h, new_w = self.size + + top = np.random.randint(0, h - new_h) + left = np.random.randint(0, w - new_w) + + image = image[:,top:top+new_h,left:left+new_w] + label = label[:,top:top+new_h,left:left+new_w] + + return {'imidx':imidx,'image':image, 'label':label, 'shape':torch.tensor(self.size)} + + +class Normalize(object): + def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]): + self.mean = mean + self.std = std + + def __call__(self,sample): + + imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] + image = normalize(image,self.mean,self.std) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} + + + +class LargeScaleJitter(object): + """ + implementation of large scale jitter from copy_paste + https://github.com/gaopengcuhk/Pretrained-Pix2Seq/blob/7d908d499212bfabd33aeaa838778a6bfb7b84cc/datasets/transforms.py + """ + + def __init__(self, output_size=1024, aug_scale_min=0.1, aug_scale_max=2.0): + self.desired_size = torch.tensor(output_size) + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + + def pad_target(self, padding, target): + target = target.copy() + if "masks" in target: + target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[1], 0, padding[0])) + return target + + def __call__(self, sample): + imidx, image, label, image_size = sample['imidx'], sample['image'], sample['label'], sample['shape'] + + #resize keep ratio + out_desired_size = (self.desired_size * image_size / max(image_size)).round().int() + + random_scale = torch.rand(1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min + scaled_size = (random_scale * self.desired_size).round() + + scale = torch.minimum(scaled_size / image_size[0], scaled_size / image_size[1]) + scaled_size = (image_size * scale).round().long() + + scaled_image = torch.squeeze(F.interpolate(torch.unsqueeze(image,0),scaled_size.tolist(),mode='bilinear'),dim=0) + scaled_label = torch.squeeze(F.interpolate(torch.unsqueeze(label,0),scaled_size.tolist(),mode='bilinear'),dim=0) + + # random crop + crop_size = (min(self.desired_size, scaled_size[0]), min(self.desired_size, scaled_size[1])) + + margin_h = max(scaled_size[0] - crop_size[0], 0).item() + margin_w = max(scaled_size[1] - crop_size[1], 0).item() + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + crop_size[0].item() + crop_x1, crop_x2 = offset_w, offset_w + crop_size[1].item() + + scaled_image = scaled_image[:,crop_y1:crop_y2, crop_x1:crop_x2] + scaled_label = scaled_label[:,crop_y1:crop_y2, crop_x1:crop_x2] + + # pad + padding_h = max(self.desired_size - scaled_image.size(1), 0).item() + padding_w = max(self.desired_size - scaled_image.size(2), 0).item() + image = F.pad(scaled_image, [0,padding_w, 0,padding_h],value=128) + label = F.pad(scaled_label, [0,padding_w, 0,padding_h],value=0) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':torch.tensor(image.shape[-2:])} + + + + + + +class OnlineDataset(Dataset): + def __init__(self, name_im_gt_list, transform=None, eval_ori_resolution=False): + + self.transform = transform + self.dataset = {} + ## combine different datasets into one + dataset_names = [] + dt_name_list = [] # dataset name per image + im_name_list = [] # image name + im_path_list = [] # im path + gt_path_list = [] # gt path + im_ext_list = [] # im ext + gt_ext_list = [] # gt ext + for i in range(0,len(name_im_gt_list)): + dataset_names.append(name_im_gt_list[i]["dataset_name"]) + # dataset name repeated based on the number of images in this dataset + dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]]) + im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]]) + im_path_list.extend(name_im_gt_list[i]["im_path"]) + gt_path_list.extend(name_im_gt_list[i]["gt_path"]) + im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]]) + gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]]) + + + self.dataset["data_name"] = dt_name_list + self.dataset["im_name"] = im_name_list + self.dataset["im_path"] = im_path_list + self.dataset["ori_im_path"] = deepcopy(im_path_list) + self.dataset["gt_path"] = gt_path_list + self.dataset["ori_gt_path"] = deepcopy(gt_path_list) + self.dataset["im_ext"] = im_ext_list + self.dataset["gt_ext"] = gt_ext_list + + self.eval_ori_resolution = eval_ori_resolution + + def __len__(self): + return len(self.dataset["im_path"]) + def __getitem__(self, idx): + im_path = self.dataset["im_path"][idx] + gt_path = self.dataset["gt_path"][idx] + im = io.imread(im_path) + gt = io.imread(gt_path) + + if len(gt.shape) > 2: + gt = gt[:, :, 0] + if len(im.shape) < 3: + im = im[:, :, np.newaxis] + if im.shape[2] == 1: + im = np.repeat(im, 3, axis=2) + im = torch.tensor(im.copy(), dtype=torch.float32) + im = torch.transpose(torch.transpose(im,1,2),0,1) + gt = torch.unsqueeze(torch.tensor(gt, dtype=torch.float32),0) + + sample = { + "imidx": torch.from_numpy(np.array(idx)), + "image": im, + "label": gt, + "shape": torch.tensor(im.shape[-2:]), + } + + if self.transform: + sample = self.transform(sample) + + if self.eval_ori_resolution: + sample["ori_label"] = gt.type(torch.uint8) # NOTE for evaluation only. And no flip here + sample['ori_im_path'] = self.dataset["im_path"][idx] + sample['ori_gt_path'] = self.dataset["gt_path"][idx] + + return sample \ No newline at end of file diff --git a/utils/loss_mask.py b/utils/loss_mask.py new file mode 100644 index 0000000..d348bed --- /dev/null +++ b/utils/loss_mask.py @@ -0,0 +1,196 @@ +import torch +from torch.nn import functional as F +from typing import List, Optional +import utils.misc as misc + +def point_sample(input, point_coords, **kwargs): + """ + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + Args: + input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains + [0, 1] x [0, 1] normalized point coordinates. + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + +def cat(tensors: List[torch.Tensor], dim: int = 0): + """ + Efficient version of torch.cat that avoids a copy if there is only a single element in a list + """ + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + return torch.cat(tensors, dim) + +def get_uncertain_point_coords_with_randomness( + coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio +): + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + Args: + coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + num_boxes = coarse_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) + point_logits = point_sample(coarse_logits, point_coords, align_corners=False) + # It is crucial to calculate uncertainty based on the sampled prediction value for the points. + # Calculating uncertainties of the coarse predictions first and sampling them for points leads + # to incorrect results. + # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between + # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. + # However, if we calculate uncertainties for the coarse predictions first, + # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + num_boxes, num_uncertain_points, 2 + ) + if num_random_points > 0: + point_coords = cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), + ], + dim=1, + ) + return point_coords + +def dice_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + num_masks: float, + ): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + +dice_loss_jit = torch.jit.script( + dice_loss +) # type: torch.jit.ScriptModule + + +def sigmoid_ce_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + num_masks: float, + ): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + + return loss.mean(1).sum() / num_masks + + +sigmoid_ce_loss_jit = torch.jit.script( + sigmoid_ce_loss +) # type: torch.jit.ScriptModule + + +def calculate_uncertainty(logits): + """ + We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the + foreground class in `classes`. + Args: + logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images and C is + the number of foreground classes. The values are logits. + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + assert logits.shape[1] == 1 + gt_class_logits = logits.clone() + return -(torch.abs(gt_class_logits)) + +def loss_masks(src_masks, target_masks, num_masks, oversample_ratio=3.0): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + + # No need to upsample predictions as we are using normalized coordinates :) + + with torch.no_grad(): + # sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + src_masks, + lambda logits: calculate_uncertainty(logits), + 112 * 112, + oversample_ratio, + 0.75, + ) + # get gt labels + point_labels = point_sample( + target_masks, + point_coords, + align_corners=False, + ).squeeze(1) + + point_logits = point_sample( + src_masks, + point_coords, + align_corners=False, + ).squeeze(1) + + loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks) + loss_dice = dice_loss_jit(point_logits, point_labels, num_masks) + + del src_masks + del target_masks + return loss_mask, loss_dice + + + diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..812d3c1 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,527 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import random +import subprocess +import time +from collections import OrderedDict, defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import json, time +import numpy as np +import torch +import torch.distributed as dist +from torch import Tensor + +import colorsys +import torch.nn.functional as F + +import cv2 + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + if d.shape[0] == 0: + return 0 + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + # print(name, str(meter)) + # import ipdb;ipdb.set_trace() + if meter.count > 0: + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, logger=None): + if logger is None: + print_func = print + else: + print_func = logger.info + + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print_func(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print_func(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print_func('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(cfg): + if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': + local_world_size = int(os.environ['WORLD_SIZE']) + cfg.world_size = cfg.world_size * local_world_size + cfg.gpu = cfg.local_rank = int(os.environ['LOCAL_RANK']) + cfg.rank = cfg.rank * local_world_size + cfg.local_rank + print('world size: {}, rank: {}, local rank: {}'.format(cfg.world_size, cfg.rank, cfg.local_rank)) + print(json.dumps(dict(os.environ), indent=2)) + elif 'SLURM_PROCID' in os.environ: + cfg.rank = int(os.environ['SLURM_PROCID']) + cfg.gpu = cfg.local_rank = int(os.environ['SLURM_LOCALID']) + cfg.world_size = int(os.environ['SLURM_NPROCS']) + + print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(cfg.world_size, cfg.rank, cfg.local_rank, torch.cuda.device_count())) + else: + print('Not using distributed mode') + cfg.distributed = False + cfg.world_size = 1 + cfg.rank = 0 + cfg.local_rank = 0 + return + + print("world_size:{} rank:{} local_rank:{}".format(cfg.world_size, cfg.rank, cfg.local_rank)) + cfg.distributed = True + torch.cuda.set_device(cfg.local_rank) + cfg.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format(cfg.rank, cfg.dist_url), flush=True) + torch.distributed.init_process_group(backend=cfg.dist_backend, init_method=cfg.dist_url, + world_size=cfg.world_size, rank=cfg.rank) + print("Before torch.distributed.barrier()") + torch.distributed.barrier() + print("End torch.distributed.barrier()") + setup_for_distributed(cfg.rank == 0) + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + y = y.to(masks) + x = x.to(masks) + + x_mask = ((masks>128) * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks>128), 1e8).flatten(1).min(-1)[0] + + y_mask = ((masks>128) * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks>128), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + +def box_noise(boxes, box_noise_scale=0): + + known_bbox_expand = box_xyxy_to_cxcywh(boxes) + + diff = torch.zeros_like(known_bbox_expand) + diff[:, :2] = known_bbox_expand[:, 2:] / 2 + diff[:, 2:] = known_bbox_expand[:, 2:] + known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0),diff).cuda() * box_noise_scale + boxes = box_cxcywh_to_xyxy(known_bbox_expand) + boxes = boxes.clamp(min=0.0, max=1024) + + return boxes + +def masks_sample_points(masks,k=10): + """Sample points on mask + """ + if masks.numel() == 0: + return torch.zeros((0, 2), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + y = y.to(masks) + x = x.to(masks) + + # k = 10 + samples = [] + for b_i in range(len(masks)): + select_mask = (masks[b_i]>128) + x_idx = torch.masked_select(x,select_mask) + y_idx = torch.masked_select(y,select_mask) + + perm = torch.randperm(x_idx.size(0)) + idx = perm[:k] + samples_x = x_idx[idx] + samples_y = y_idx[idx] + samples_xy = torch.cat((samples_x[:,None],samples_y[:,None]),dim=1) + samples.append(samples_xy) + + samples = torch.stack(samples) + return samples + + +# Add noise to mask input +# From Mask Transfiner https://github.com/SysCV/transfiner +def masks_noise(masks): + def get_incoherent_mask(input_masks, sfact): + mask = input_masks.float() + w = input_masks.shape[-1] + h = input_masks.shape[-2] + mask_small = F.interpolate(mask, (h//sfact, w//sfact), mode='bilinear') + mask_recover = F.interpolate(mask_small, (h, w), mode='bilinear') + mask_residue = (mask - mask_recover).abs() + mask_residue = (mask_residue >= 0.01).float() + return mask_residue + gt_masks_vector = masks / 255 + mask_noise = torch.randn(gt_masks_vector.shape, device= gt_masks_vector.device) * 1.0 + inc_masks = get_incoherent_mask(gt_masks_vector, 8) + gt_masks_vector = ((gt_masks_vector + mask_noise * inc_masks) > 0.5).float() + gt_masks_vector = gt_masks_vector * 255 + + return gt_masks_vector + + +def mask_iou(pred_label,label): + ''' + calculate mask iou for pred_label and gt_label + ''' + + pred_label = (pred_label>0)[0].int() + label = (label>128)[0].int() + + intersection = ((label * pred_label) > 0).sum() + union = ((label + pred_label) > 0).sum() + return intersection / union + + + +# General util function to get the boundary of a binary mask. +# https://gist.github.com/bowenc0221/71f7a02afee92646ca05efeeb14d687d +def mask_to_boundary(mask, dilation_ratio=0.02): + """ + Convert binary mask to boundary mask. + :param mask (numpy array, uint8): binary mask + :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal + :return: boundary mask (numpy array) + """ + h, w = mask.shape + img_diag = np.sqrt(h ** 2 + w ** 2) + dilation = int(round(dilation_ratio * img_diag)) + if dilation < 1: + dilation = 1 + # Pad image so mask truncated by the image border is also considered as boundary. + new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) + kernel = np.ones((3, 3), dtype=np.uint8) + new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation) + mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1] + # G_d intersects G in the paper. + return mask - mask_erode + + +def boundary_iou(gt, dt, dilation_ratio=0.02): + """ + Compute boundary iou between two binary masks. + :param gt (numpy array, uint8): binary mask + :param dt (numpy array, uint8): binary mask + :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal + :return: boundary iou (float) + """ + device = gt.device + dt = (dt>0)[0].cpu().byte().numpy() + gt = (gt>128)[0].cpu().byte().numpy() + + gt_boundary = mask_to_boundary(gt, dilation_ratio) + dt_boundary = mask_to_boundary(dt, dilation_ratio) + intersection = ((gt_boundary * dt_boundary) > 0).sum() + union = ((gt_boundary + dt_boundary) > 0).sum() + boundary_iou = intersection / union + return torch.tensor(boundary_iou).float().to(device) \ No newline at end of file