diff --git a/demo/assets/demo1.jpg b/demo/assets/demo1.jpg new file mode 100644 index 0000000..b51fde5 Binary files /dev/null and b/demo/assets/demo1.jpg differ diff --git a/demo/assets/demo2.jpg b/demo/assets/demo2.jpg new file mode 100644 index 0000000..583f69e Binary files /dev/null and b/demo/assets/demo2.jpg differ diff --git a/demo/assets/demo3.jpg b/demo/assets/demo3.jpg new file mode 100644 index 0000000..83c0c9e Binary files /dev/null and b/demo/assets/demo3.jpg differ diff --git a/demo/assets/demo4.jpg b/demo/assets/demo4.jpg new file mode 100644 index 0000000..deeafdb Binary files /dev/null and b/demo/assets/demo4.jpg differ diff --git a/demo/assets/demo5.jpg b/demo/assets/demo5.jpg new file mode 100644 index 0000000..a204f5a Binary files /dev/null and b/demo/assets/demo5.jpg differ diff --git a/demo/assets/demo6.jpg b/demo/assets/demo6.jpg new file mode 100644 index 0000000..679431b Binary files /dev/null and b/demo/assets/demo6.jpg differ diff --git a/demo/assets/demo7.jpg b/demo/assets/demo7.jpg new file mode 100644 index 0000000..9374e1f Binary files /dev/null and b/demo/assets/demo7.jpg differ diff --git a/demo/assets/demo8.jpg b/demo/assets/demo8.jpg new file mode 100644 index 0000000..20a3789 Binary files /dev/null and b/demo/assets/demo8.jpg differ diff --git a/demo/assets/demo9.jpg b/demo/assets/demo9.jpg new file mode 100644 index 0000000..85faeb8 Binary files /dev/null and b/demo/assets/demo9.jpg differ diff --git a/demo/groundingDINO_batched_float16.ipynb b/demo/groundingDINO_batched_float16.ipynb new file mode 100644 index 0000000..d12d9b4 --- /dev/null +++ b/demo/groundingDINO_batched_float16.ipynb @@ -0,0 +1,292 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Grounding DINO - Batched Half Precision Inference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare Environments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from PIL import Image\n", + "import io\n", + "import os\n", + "import supervision as sv\n", + "import numpy as np\n", + "import requests\n", + "import cv2\n", + "\n", + "# Grounding DINO\n", + "from groundingdino.util.inference import BatchedModel\n", + "import torchvision.transforms.functional as F\n", + "from huggingface_hub import hf_hub_download\n", + "\n", + "# If you have multiple GPUs, you can set the GPU to use here.\n", + "# The default is to use the first GPU, which is usually GPU 0.\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load Grounding DINO model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load demo image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def download_image(url, image_file_path):\n", + " r = requests.get(url, timeout=4.0)\n", + " if r.status_code != requests.codes.ok:\n", + " assert False, 'Status code error: {}.'.format(r.status_code)\n", + "\n", + " with Image.open(io.BytesIO(r.content)) as im:\n", + " im.save(image_file_path)\n", + "\n", + " print('Image downloaded from url: {} and saved to: {}.'.format(url, image_file_path))\n", + "\n", + "def load_image(image_path):\n", + " image_source = Image.open(image_path).convert(\"RGB\")\n", + " image = np.asarray(image_source)\n", + " image_tensor = F.to_tensor(image)\n", + " return image, image_tensor\n", + "\n", + "local_image_path = \"assets/demo4.jpg\"\n", + "#download_image(image_url, local_image_path)\n", + "image_source, image_tensor = load_image(local_image_path)\n", + "Image.fromarray(image_source)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run Grounding DINO for detection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use this command for evaluate the Grounding DINO model\n", + "# Or you can download the model by yourself\n", + "ckpt_repo_id = \"ShilongLiu/GroundingDINO\"\n", + "ckpt_filename = \"groundingdino_swint_ogc.pth\"\n", + "ckpt_config_filename = \"GroundingDINO_SwinT_OGC.cfg.py\"\n", + "device = \"cuda\"\n", + "\n", + "cache_config_file = hf_hub_download(repo_id=ckpt_repo_id, filename=ckpt_config_filename)\n", + "cache_file = hf_hub_download(repo_id=ckpt_repo_id, filename=ckpt_filename)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Single Precision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = 2\n", + "box_threshold = 0.3\n", + "text_threshold = 0.25\n", + "iou_threshold = 0.5\n", + "\n", + "# Batch of prompts\n", + "text_prompt = [\n", + " [\"Black dog\", \"Beige dog\"],\n", + " [\"Dog\", \"Stick\"]\n", + "]\n", + "\n", + "dtype = \"float32\"\n", + "\n", + "# Repeat image BATCH number of times\n", + "image_tensor = image_tensor.to(device=device).to(dtype=getattr(torch, dtype))\n", + "image_tensor = image_tensor[None, ...].expand(batch, -1, -1, -1)\n", + "\n", + "# Building GroundingDINO inference model\n", + "grounding_dino_model = BatchedModel(\n", + " model_config_path=cache_config_file, \n", + " model_checkpoint_path=cache_file,\n", + " device=device,\n", + " dtype=dtype,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit -n 10\n", + "with torch.no_grad():\n", + " bbox_batch, conf_batch, class_id_batch = grounding_dino_model(\n", + " image_batch=image_tensor,\n", + " text_prompts=text_prompt,\n", + " box_threshold=box_threshold,\n", + " text_threshold=text_threshold,\n", + " nms_threshold=iou_threshold\n", + " )\n", + " bbox_batch = [bbox.cpu().numpy() for bbox in bbox_batch]\n", + " conf_batch = [conf.cpu().numpy() for conf in conf_batch]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Half Precision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dtype = \"float16\"\n", + "\n", + "image_tensor = image_tensor.to(device=device).to(dtype=getattr(torch, dtype))\n", + "\n", + "# Building GroundingDINO inference model\n", + "grounding_dino_model = BatchedModel(\n", + " model_config_path=cache_config_file, \n", + " model_checkpoint_path=cache_file,\n", + " device=device,\n", + " dtype=dtype\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit -n 10\n", + "with torch.no_grad():\n", + " bbox_batch, conf_batch, class_id_batch = grounding_dino_model(\n", + " image_batch=image_tensor,\n", + " text_prompts=text_prompt,\n", + " box_threshold=box_threshold,\n", + " text_threshold=text_threshold,\n", + " nms_threshold=iou_threshold\n", + " )\n", + " bbox_batch = [bbox.cpu().numpy() for bbox in bbox_batch]\n", + " conf_batch = [conf.cpu().numpy() for conf in conf_batch]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Display result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " bbox_batch, conf_batch, class_id_batch = grounding_dino_model(\n", + " image_batch=image_tensor,\n", + " text_prompts=text_prompt,\n", + " box_threshold=box_threshold,\n", + " text_threshold=text_threshold,\n", + " nms_threshold=iou_threshold\n", + " )\n", + " bbox_batch = [bbox.cpu().numpy() for bbox in bbox_batch]\n", + " conf_batch = [conf.cpu().numpy() for conf in conf_batch]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display\n", + "def annotate(image_source, boxes, logits, phrases) -> np.ndarray:\n", + " detections = sv.Detections(xyxy=boxes)\n", + " labels = [\n", + " f\"{phrase} {logit:.2f}\"\n", + " for phrase, logit\n", + " in zip(phrases, logits)\n", + " ]\n", + " box_annotator = sv.BoxAnnotator()\n", + " annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)\n", + " annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)\n", + " return annotated_frame[...,::-1]\n", + "\n", + "\n", + "for i, (bbox, conf, class_id, class_label) in enumerate(zip(bbox_batch, conf_batch, class_id_batch, text_prompt)):\n", + " annotated_frame = annotate(\n", + " image_source=image_source, \n", + " boxes=bbox,\n", + " logits=conf,\n", + " phrases=np.array(class_label)[class_id]\n", + " )\n", + "\n", + " display(Image.fromarray(annotated_frame))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/groundingdino/models/GroundingDINO/backbone/swin_transformer.py b/groundingdino/models/GroundingDINO/backbone/swin_transformer.py index 1c66194..8a1cf83 100644 --- a/groundingdino/models/GroundingDINO/backbone/swin_transformer.py +++ b/groundingdino/models/GroundingDINO/backbone/swin_transformer.py @@ -159,6 +159,7 @@ def forward(self, x, mask=None): attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: + mask = mask.to(dtype=x.dtype) nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) diff --git a/groundingdino/models/GroundingDINO/ms_deform_attn.py b/groundingdino/models/GroundingDINO/ms_deform_attn.py index 489d501..58129ad 100644 --- a/groundingdino/models/GroundingDINO/ms_deform_attn.py +++ b/groundingdino/models/GroundingDINO/ms_deform_attn.py @@ -100,7 +100,7 @@ def multi_scale_deformable_attn_pytorch( bs, _, num_heads, embed_dims = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) - sampling_grids = 2 * sampling_locations - 1 + sampling_grids = 2 * sampling_locations.to(dtype=value.dtype) - 1 sampling_value_list = [] for level, (H_, W_) in enumerate(value_spatial_shapes): # bs, H_*W_, num_heads, embed_dims -> diff --git a/groundingdino/models/GroundingDINO/transformer.py b/groundingdino/models/GroundingDINO/transformer.py index fcb8742..94844ad 100644 --- a/groundingdino/models/GroundingDINO/transformer.py +++ b/groundingdino/models/GroundingDINO/transformer.py @@ -659,6 +659,7 @@ def forward( output = tgt intermediate = [] + refpoints_unsigmoid = refpoints_unsigmoid.to(dtype=tgt.dtype) reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] @@ -667,14 +668,14 @@ def forward( if reference_points.shape[-1] == 4: reference_points_input = ( reference_points[:, :, None] - * torch.cat([valid_ratios, valid_ratios], -1)[None, :] + * torch.cat([valid_ratios, valid_ratios], -1)[None, :].to(dtype=tgt.dtype) ) # nq, bs, nlevel, 4 else: assert reference_points.shape[-1] == 2 - reference_points_input = reference_points[:, :, None] * valid_ratios[None, :] + reference_points_input = reference_points[:, :, None] * valid_ratios[None, :].to(dtype=tgt.dtype) query_sine_embed = gen_sineembed_for_position( reference_points_input[:, :, 0, :] - ) # nq, bs, 256*2 + ).to(dtype=tgt.dtype) # nq, bs, 256*2 # conditional query raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 diff --git a/groundingdino/models/GroundingDINO/transformer_vanilla.py b/groundingdino/models/GroundingDINO/transformer_vanilla.py index 10c0920..0026234 100644 --- a/groundingdino/models/GroundingDINO/transformer_vanilla.py +++ b/groundingdino/models/GroundingDINO/transformer_vanilla.py @@ -96,7 +96,7 @@ def __init__( self.nhead = nhead def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos + return tensor if pos is None else tensor + pos.to(dtype=tensor.dtype) def forward( self, diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index 58528ed..b376ff5 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -1,11 +1,13 @@ -from typing import Tuple, List +from typing import Tuple, List, Any import cv2 import numpy as np import supervision as sv import torch from PIL import Image +import torchvision from torchvision.ops import box_convert +import torchvision.transforms.functional as F import bisect import groundingdino.datasets.transforms as T @@ -269,3 +271,176 @@ def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray: else: class_ids.append(None) return np.array(class_ids) + + +#============================================================================== + + +class BatchedModel(object): + +#===================================================== + + def __init__( + self, + model_config_path: str, + model_checkpoint_path: str, + device: str = "cuda", + dtype: str = "float32", + compile: bool = False + ) -> NotImplementedError: + + self._device = device + self._dtype = getattr(torch, dtype) + self._model = load_model( + model_config_path=model_config_path, + model_checkpoint_path=model_checkpoint_path + ).to(device=self._device).to(dtype=self._dtype) + + # Compile model if necessary + if compile: + self._model = torch.compile(self._model) + +#===================================================== + + @staticmethod + def preprocess_image( + image_batch: torch.Tensor + ) -> torch.Tensor: + + # Preprocessing friendly with batches + + image_batch = F.resize(image_batch, [800], antialias=True) + image_batch = F.normalize(image_batch, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + + return image_batch + +#===================================================== + + @classmethod + def post_process_result( + cls, + boxes_cxcywh: torch.Tensor, + logits: torch.Tensor, + nms_threshold: float, + source_size: Tuple[int, int], + phrases: List[str], + text_prompts: List[str] + ): + + bbox_batch, conf_batch, class_id_batch = [], [], [] + source_h, source_w = source_size + for bbox_cxcywh, conf, phrase, text_prompt in zip(boxes_cxcywh, logits, phrases, text_prompts): + bbox_cxcywh *= torch.Tensor([source_w, source_h, source_w, source_h]) + bbox_xyxy = box_convert(boxes=bbox_cxcywh, in_fmt="cxcywh", out_fmt="xyxy") + + # Perform NMS + nms_idx = torchvision.ops.nms(bbox_xyxy.float(), conf.float(), nms_threshold).numpy().tolist() + class_id = cls.phrases2classes(phrases=phrase, classes=text_prompt) + + bbox_batch.append(bbox_xyxy[nms_idx]) + conf_batch.append(conf[nms_idx]) + class_id_batch.append(class_id[nms_idx]) + + return bbox_batch, conf_batch, class_id_batch + +#===================================================== + + def _batched_predict( + self, + image_batch, + text_prompts, + box_threshold, + text_threshold + ): + # Predict refactored to work with batches + captions = [preprocess_caption(caption) for caption in text_prompts] + + outputs = self._model(image_batch, captions=captions) + + prediction_logits = outputs["pred_logits"].cpu().sigmoid() # prediction_logits.shape = (bsz,nq, 256) + prediction_boxes = outputs["pred_boxes"].cpu() # prediction_boxes.shape = (bsz, nq, 4) + + logits_res = [] + boxs_res = [] + phrases_list = [] + tokenizer = self._model.tokenizer + for ub_logits, ub_boxes, ub_captions in zip(prediction_logits, prediction_boxes, captions): + mask = ub_logits.max(dim=1)[0] > box_threshold + logits = ub_logits[mask] # logits.shape = (n, 256) + boxes = ub_boxes[mask] # boxes.shape = (n, 4) + logits_res.append(logits.max(dim=1)[0]) + boxs_res.append(boxes) + + tokenized = tokenizer(ub_captions) + phrases = [ + get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') + for logit + in logits + ] + phrases_list.append(phrases) + + return boxs_res, logits_res, phrases_list + + def predict( + self, + image_batch: torch.Tensor, + text_prompts: List[str], + box_threshold: float = 0.3, + text_threshold: float = 0.3, + nms_threshold: float = 0.5 + ): + + # Move to device and type just in case + image_batch = image_batch.to(device=self._device).to(dtype=self._dtype) + source_h, source_w = image_batch.shape[-2:] + + if any(isinstance(i, list) for i in text_prompts): + captions = [". ".join(text_prompt) for text_prompt in text_prompts] + else: + captions = [". ".join(text_prompts)] + text_prompts = [text_prompts] + + # Extend caption to batch + if len(captions) == 1: + captions *= image_batch.shape[0] + if len(text_prompts) == 1: + text_prompts *= image_batch.shape[0] + + # Preprocess, inference and postprocess + processed_image = self.preprocess_image(image_batch) + bboxes, logits, phrases = self._batched_predict( + processed_image, + captions, + box_threshold, + text_threshold + ) + bbox_batch, conf_batch, class_id_batch = self.post_process_result( + bboxes, + logits, + nms_threshold, + (source_h, source_w), + phrases, + text_prompts + ) + + return bbox_batch, conf_batch, class_id_batch + + @staticmethod + def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray: + class_ids = [] + for phrase in phrases: + for class_ in classes: + if class_.lower() in phrase.lower(): + class_ids.append(classes.index(class_)) + break + else: + class_ids.append(None) + return np.array(class_ids) + + + def __call__( + self, + *args, + **kwargs + ) -> Any: + return self.predict(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7fe91db --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[build-system] +requires = [ + "setuptools", + "torch", + "wheel" +] +build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index 9282de3..e8cdb84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 torch torchvision transformers diff --git a/setup.py b/setup.py index 275b6fc..31b34ed 100644 --- a/setup.py +++ b/setup.py @@ -189,8 +189,25 @@ def gen_packages_items(): item = "".join(parts) yield item + def filter_index(packages): + + new_packages = [] + dependency_links = [] + for i, requirement in enumerate(packages): + if requirement.startswith("--extra-index-url"): + dependency_links.append(requirement.split()[-1]) + elif requirement.startswith("./dependencies") or requirement.startswith( + "dependencies" + ): + dependency_links.append(requirement) + else: + new_packages.append(requirement) + + return new_packages, dependency_links + packages = list(gen_packages_items()) - return packages + packages, dependency_links = filter_index(packages) + return packages, dependency_links if __name__ == "__main__": @@ -201,6 +218,8 @@ def gen_packages_items(): write_version_file() + install_requires, dependency_links = parse_requirements("requirements.txt") + setup( name="groundingdino", version="0.1.0", @@ -208,7 +227,8 @@ def gen_packages_items(): url="https://github.com/IDEA-Research/GroundingDINO", description="open-set object detector", license=license, - install_requires=parse_requirements("requirements.txt"), + install_requires=install_requires, + dependency_links=dependency_links, packages=find_packages( exclude=( "configs",