Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#2 from LokeZhou/grouded-sam
Browse files Browse the repository at this point in the history
add grounded-sam
  • Loading branch information
LokeZhou authored Jul 10, 2023
2 parents ff661de + 0525cf4 commit 2a97b1b
Show file tree
Hide file tree
Showing 24 changed files with 3,018 additions and 10 deletions.
193 changes: 193 additions & 0 deletions applications/CVinW/grounded_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from dataclasses import dataclass, field
import os
import sys
import numpy as np
import requests
from typing import List

import paddle
import paddle.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
from paddlenlp.trainer import PdArgumentParser
from paddlevlp.utils.log import logger

from paddlevlp.processors.groundingdino_processing import GroudingDinoProcessor
from paddlevlp.models.groundingdino.modeling import GroundingDinoModel
from paddlevlp.models.sam.modeling import SamModel
from paddlevlp.processors.sam_processing import SamProcessor
import matplotlib.pyplot as plt


def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)


def show_box(box, ax, label):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
ax.text(x0, y0, label)




@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `PdArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""

input_image: str = field(
metadata={"help": "The name of input image."}
)
prompt: str = field(
default=None, metadata={"help": "The prompt of the image to be generated."}
)


@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

dino_model_name_or_path: str = field(
default="GroundingDino/groundingdino-swint-ogc",
metadata={"help": "Path to pretrained model or model identifier"},
)
sam_model_name_or_path: str = field(
default="Sam/SamVitH",
metadata={"help": "Path to pretrained model or model identifier"},
)
box_threshold: float = field(
default=0.3,
metadata={
"help": "box threshold."
},
)
text_threshold: float = field(
default=0.25,
metadata={
"help": "text threshold."
},
)
output_dir: str = field(
default="grounded_sam_output",
metadata={
"help": "output directory."
},
)
visual: bool = field(
default=True,
metadata={
"help": "save visual image."
},
)

def main():
parser = PdArgumentParser((ModelArguments, DataArguments))
model_args, data_args = parser.parse_args_into_dataclasses()
url = (data_args.input_image)
#bulid dino processor
dino_processor = GroudingDinoProcessor.from_pretrained(
model_args.dino_model_name_or_path
)

#bulid dino model
logger.info("dino_model: {}".format(model_args.dino_model_name_or_path))
dino_model = GroundingDinoModel.from_pretrained(model_args.dino_model_name_or_path)
dino_model.eval()
#buidl sam processor
sam_processor = SamProcessor.from_pretrained(
model_args.sam_model_name_or_path
)
#bulid model
logger.info("SamModel: {}".format(model_args.sam_model_name_or_path))
sam_model = SamModel.from_pretrained(model_args.sam_model_name_or_path,input_type="boxs")

#read image
if os.path.isfile(url):
#read image
image_pil = Image.open(data_args.input_image).convert("RGB")
else:
image_pil = Image.open(requests.get(url, stream=True).raw).convert("RGB")
#preprocess image text_prompt
image_tensor,mask,tokenized_out = dino_processor(images=image_pil,text=data_args.prompt)

with paddle.no_grad():
outputs = dino_model(image_tensor,mask, input_ids=tokenized_out['input_ids'],
attention_mask=tokenized_out['attention_mask'],text_self_attention_masks=tokenized_out['text_self_attention_masks'],
position_ids=tokenized_out['position_ids'])

logits = F.sigmoid(outputs["pred_logits"])[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)

# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(axis=1) > model_args.box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4

# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = dino_processor.decode(logit > model_args.text_threshold)
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")


size = image_pil.size
pred_dict = {
"boxes": boxes_filt,
"size": [size[1], size[0]], # H,W
"labels": pred_phrases,
}
logger.info("dino output{}".format(pred_dict))

H,W = size[1], size[0]
boxes = []
for box in zip(boxes_filt):
box = box[0] * paddle.to_tensor([W, H, W, H])
box[:2] -= box[2:] / 2
box[2:] += box[:2]
x0, y0, x1, y1 = box.numpy()
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
boxes.append([x0, y0, x1, y1])
boxes = np.array(boxes)
image_seg,prompt = sam_processor(image_pil,input_type="boxs",box=boxes,point_coords=None)
seg_masks = sam_model(img=image_seg,prompt=prompt)
seg_masks = sam_processor.postprocess_masks(seg_masks)

logger.info("Sam finish!")

if model_args.visual:
# make dir
os.makedirs(model_args.output_dir, exist_ok=True)
# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image_pil)
for mask in seg_masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes, pred_phrases):
show_box(box, plt.gca(), label)

plt.axis('off')
plt.savefig(
os.path.join(model_args.output_dir, 'mask_pred.jpg'),
bbox_inches="tight", dpi=300, pad_inches=0.0
)

logger.info("finish!")


if __name__ == "__main__":
main()
33 changes: 33 additions & 0 deletions deploy/groundingdino/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Grounding DINO

## 1. 模型简介

Paddle implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector.


## 2. Demo

## 2.1 prepare
```bash
#Multi-scale deformable attention custom OP compilation
cd /paddlevlp/models/groundingdino/csrc/
python setup_ms_deformable_attn_op.py install

```
## 2.2 Export model for static inference
```bash
#export grounding dino model
python export.py


#inference
python predict.py \
--text_encoder_type GroundingDino/groundingdino-swint-ogc
--model_path output_groundingdino \
--input_image image_you_want_to_detect.jpg \
-output_dir "dir you want to save the output" \
-prompt "Detect Cat"

```


71 changes: 71 additions & 0 deletions deploy/groundingdino/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import argparse
import os
import paddle
from paddle.static import InputSpec

from paddlevlp.models.groundingdino.modeling import GroundingDinoModel



def _prune_input_spec(input_spec, program, targets):
# try to prune static program to figure out pruned input spec
# so we perform following operations in static mode

device = paddle.get_device()
paddle.enable_static()
paddle.set_device(device)
pruned_input_spec = [{}]
program = program.clone()
program = program._prune(targets=targets)
global_block = program.global_block()

for spec in input_spec:
try:
name = spec.name
v = global_block.var(name)
pruned_input_spec[0][name] = spec
except Exception:
pass
paddle.disable_static(place=device)
return pruned_input_spec

def apply_to_static(model):

input_spec = [
InputSpec(
shape=[None,3,None,None], name='x',dtype='float32'),
InputSpec(
shape=[None,None, None], name='m',dtype="int64"),
InputSpec(
shape=[None, None], name='input_ids',dtype="int64"),
InputSpec(
shape=[None, None], name='attention_mask',dtype="int64"),
InputSpec(
shape=[None, None,None], name='text_self_attention_masks',dtype="int64"),
InputSpec(
shape=[None, None], name='position_ids',dtype="int64")
]
model = paddle.jit.to_static(model, input_spec=input_spec)
return model,input_spec


if __name__ == "__main__":

parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
parser.add_argument("--dino_type", "-dt", type=str, default="GroundingDino/groundingdino-swint-ogc", help="dino type")
parser.add_argument(
"--output_dir", "-o", type=str, default="output_groundingdino", help="output directory"
)
args = parser.parse_args()

output_dir = args.output_dir
# load model
model = GroundingDinoModel.from_pretrained(args.dino_type)
model.eval()

static_model,input_spec = apply_to_static(model)

paddle.jit.save(
static_model,
os.path.join(output_dir, 'groundingdino_model'),
input_spec=input_spec)
Loading

0 comments on commit 2a97b1b

Please sign in to comment.